com.github.steveash.jg2p.seq.PhonemeACrfTrainer.java Source code

Java tutorial

Introduction

Here is the source code for com.github.steveash.jg2p.seq.PhonemeACrfTrainer.java

Source

/*
 * Copyright 2015 Steve Ash
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.steveash.jg2p.seq;

import com.google.common.base.Function;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;

import com.github.steveash.jg2p.align.Alignment;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;
import java.util.List;

import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.ACRFTrainer;
import cc.mallet.grmm.learning.DefaultAcrfTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2LabelSequence;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;

import static org.apache.commons.lang3.StringUtils.isBlank;

/**
 * Just the BP version of the linear chain CRF, performs similarly to the CRF class version (albeit 10x slower)
 * @author Steve Ash
 */
public class PhonemeACrfTrainer {

    private static final Logger log = LoggerFactory.getLogger(PhonemeACrfTrainer.class);

    public void train(Collection<Alignment> examples) {
        Pipe pipe = makePipe();
        InstanceList instances = makeExamplesFromAligns(examples, pipe);

        ACRF.Template[] tmpls = new ACRF.Template[] { new ACRF.BigramTemplate(0)
                //                new ACRF.BigramTemplate (1),
                //                new ACRF.PairwiseFactorTemplate (0,1),
                //                new CrossTemplate1(0,1)
        };

        ACRF acrf = new ACRF(pipe, tmpls);

        ACRFTrainer trainer = new DefaultAcrfTrainer();
        acrf.setSupportedOnly(true);
        acrf.setGaussianPriorVariance(2.0);
        DefaultAcrfTrainer.LogEvaluator eval = new DefaultAcrfTrainer.LogEvaluator();
        eval.setNumIterToSkip(2);
        trainer.train(acrf, instances, null, null, eval, 9999);

    }

    private static InstanceList makeExamplesFromAligns(Iterable<Alignment> alignsToTrain, Pipe pipe) {
        int count = 0;
        InstanceList instances = new InstanceList(pipe);
        for (Alignment align : alignsToTrain) {
            List<String> phones = align.getAllYTokensAsList();
            updateEpsilons(phones);
            Instance ii = new Instance(align.getAllXTokensAsList(), phones, null, null);
            instances.addThruPipe(ii);
            count += 1;

            //      if (count > 1000) {
            //        break;
            //      }
        }
        log.info("Read {} instances of training data", count);
        return instances;
    }

    private Iterable<Alignment> getAlignsFromGroup(List<SeqInputReader.AlignGroup> inputs) {
        return FluentIterable.from(inputs)
                .transformAndConcat(new Function<SeqInputReader.AlignGroup, Iterable<Alignment>>() {
                    @Override
                    public Iterable<Alignment> apply(SeqInputReader.AlignGroup input) {
                        return input.alignments;
                    }
                });
    }

    private static void updateEpsilons(List<String> phones) {
        String last = "<EPS>";
        int blankCount = 0;
        for (int i = 0; i < phones.size(); i++) {
            String p = phones.get(i);
            if (isBlank(p)) {
                //        phones.set(i, last + "_" + blankCount);
                phones.set(i, "<EPS>");
                blankCount += 1;
            } else {
                last = p;
                blankCount = 0;
            }
        }
    }

    private static Pipe makePipe() {
        Alphabet alpha = new Alphabet();
        Target2LabelSequence labelPipe = new Target2LabelSequence();
        LabelAlphabet labelAlpha = (LabelAlphabet) labelPipe.getTargetAlphabet();

        return new SerialPipes(ImmutableList.of(new StringListToTokenSequence(alpha, labelAlpha), // convert to token sequence
                new TokenSequenceLowercase(), // make all lowercase
                new NeighborTokenFeature(true, makeNeighbors()), // grab neighboring graphemes
                new NeighborShapeFeature(true, makeShapeNeighs()), new TokenSequenceToFeature(), // convert the strings in the text to features
                new TokenSequence2FeatureVectorSequence(alpha, true, true), labelPipe,
                new LabelSequenceToLabelsAssignment(alpha, labelAlpha)));
    }

    private static List<TokenWindow> makeShapeNeighs() {
        return ImmutableList.of(new TokenWindow(-5, 5), new TokenWindow(-4, 4), new TokenWindow(-3, 3),
                new TokenWindow(-2, 2), new TokenWindow(-1, 1), new TokenWindow(1, 1), new TokenWindow(1, 2),
                new TokenWindow(1, 3), new TokenWindow(1, 4), new TokenWindow(1, 5));
    }

    private static List<TokenWindow> makeNeighbors() {
        return ImmutableList.of(new TokenWindow(1, 1), new TokenWindow(2, 1), new TokenWindow(3, 1),
                //        new TokenWindow(1, 2),
                //              new TokenWindow(1, 3),
                new TokenWindow(-1, 1), new TokenWindow(-2, 1), new TokenWindow(-3, 1), new TokenWindow(-2, 2)
        //        new TokenWindow(-3, 3)
        //        new TokenWindow(-4, 4),
        );
    }
}