Java tutorial
/* * Copyright 2015 Steve Ash * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.steveash.jg2p.seq; import com.google.common.base.Function; import com.google.common.collect.FluentIterable; import com.google.common.collect.ImmutableList; import com.github.steveash.jg2p.align.Alignment; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Collection; import java.util.List; import cc.mallet.grmm.learning.ACRF; import cc.mallet.grmm.learning.ACRFTrainer; import cc.mallet.grmm.learning.DefaultAcrfTrainer; import cc.mallet.pipe.Pipe; import cc.mallet.pipe.SerialPipes; import cc.mallet.pipe.Target2LabelSequence; import cc.mallet.pipe.TokenSequence2FeatureVectorSequence; import cc.mallet.pipe.TokenSequenceLowercase; import cc.mallet.types.Alphabet; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.types.LabelAlphabet; import static org.apache.commons.lang3.StringUtils.isBlank; /** * Just the BP version of the linear chain CRF, performs similarly to the CRF class version (albeit 10x slower) * @author Steve Ash */ public class PhonemeACrfTrainer { private static final Logger log = LoggerFactory.getLogger(PhonemeACrfTrainer.class); public void train(Collection<Alignment> examples) { Pipe pipe = makePipe(); InstanceList instances = makeExamplesFromAligns(examples, pipe); ACRF.Template[] tmpls = new ACRF.Template[] { new ACRF.BigramTemplate(0) // new ACRF.BigramTemplate (1), // new ACRF.PairwiseFactorTemplate (0,1), // new CrossTemplate1(0,1) }; ACRF acrf = new ACRF(pipe, tmpls); ACRFTrainer trainer = new DefaultAcrfTrainer(); acrf.setSupportedOnly(true); acrf.setGaussianPriorVariance(2.0); DefaultAcrfTrainer.LogEvaluator eval = new DefaultAcrfTrainer.LogEvaluator(); eval.setNumIterToSkip(2); trainer.train(acrf, instances, null, null, eval, 9999); } private static InstanceList makeExamplesFromAligns(Iterable<Alignment> alignsToTrain, Pipe pipe) { int count = 0; InstanceList instances = new InstanceList(pipe); for (Alignment align : alignsToTrain) { List<String> phones = align.getAllYTokensAsList(); updateEpsilons(phones); Instance ii = new Instance(align.getAllXTokensAsList(), phones, null, null); instances.addThruPipe(ii); count += 1; // if (count > 1000) { // break; // } } log.info("Read {} instances of training data", count); return instances; } private Iterable<Alignment> getAlignsFromGroup(List<SeqInputReader.AlignGroup> inputs) { return FluentIterable.from(inputs) .transformAndConcat(new Function<SeqInputReader.AlignGroup, Iterable<Alignment>>() { @Override public Iterable<Alignment> apply(SeqInputReader.AlignGroup input) { return input.alignments; } }); } private static void updateEpsilons(List<String> phones) { String last = "<EPS>"; int blankCount = 0; for (int i = 0; i < phones.size(); i++) { String p = phones.get(i); if (isBlank(p)) { // phones.set(i, last + "_" + blankCount); phones.set(i, "<EPS>"); blankCount += 1; } else { last = p; blankCount = 0; } } } private static Pipe makePipe() { Alphabet alpha = new Alphabet(); Target2LabelSequence labelPipe = new Target2LabelSequence(); LabelAlphabet labelAlpha = (LabelAlphabet) labelPipe.getTargetAlphabet(); return new SerialPipes(ImmutableList.of(new StringListToTokenSequence(alpha, labelAlpha), // convert to token sequence new TokenSequenceLowercase(), // make all lowercase new NeighborTokenFeature(true, makeNeighbors()), // grab neighboring graphemes new NeighborShapeFeature(true, makeShapeNeighs()), new TokenSequenceToFeature(), // convert the strings in the text to features new TokenSequence2FeatureVectorSequence(alpha, true, true), labelPipe, new LabelSequenceToLabelsAssignment(alpha, labelAlpha))); } private static List<TokenWindow> makeShapeNeighs() { return ImmutableList.of(new TokenWindow(-5, 5), new TokenWindow(-4, 4), new TokenWindow(-3, 3), new TokenWindow(-2, 2), new TokenWindow(-1, 1), new TokenWindow(1, 1), new TokenWindow(1, 2), new TokenWindow(1, 3), new TokenWindow(1, 4), new TokenWindow(1, 5)); } private static List<TokenWindow> makeNeighbors() { return ImmutableList.of(new TokenWindow(1, 1), new TokenWindow(2, 1), new TokenWindow(3, 1), // new TokenWindow(1, 2), // new TokenWindow(1, 3), new TokenWindow(-1, 1), new TokenWindow(-2, 1), new TokenWindow(-3, 1), new TokenWindow(-2, 2) // new TokenWindow(-3, 3) // new TokenWindow(-4, 4), ); } }