edu.cmu.cs.lti.ark.fn.identification.training.TrainBatch.java Source code

Java tutorial

Introduction

Here is the source code for edu.cmu.cs.lti.ark.fn.identification.training.TrainBatch.java

Source

/*******************************************************************************
 * Copyright (c) 2011 Dipanjan Das 
 * Language Technologies Institute, 
 * Carnegie Mellon University, 
 * All Rights Reserved.
 *
 * TrainBatchModelDerThreaded.java is part of SEMAFOR 2.0.
 *
 * SEMAFOR 2.0 is free software: you can redistribute it and/or modify  it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or 
 * (at your option) any later version.
 *
 * SEMAFOR 2.0 is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details. 
 *
 * You should have received a copy of the GNU General Public License along
 * with SEMAFOR 2.0.  If not, see <http://www.gnu.org/licenses/>.
 ******************************************************************************/
package edu.cmu.cs.lti.ark.fn.identification.training;

import com.google.common.base.Charsets;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import com.google.common.io.Files;
import com.google.common.primitives.Ints;
import edu.cmu.cs.lti.ark.fn.optimization.Lbfgs;
import edu.cmu.cs.lti.ark.fn.utils.FNModelOptions;
import edu.cmu.cs.lti.ark.fn.utils.ThreadPool;
import edu.cmu.cs.lti.ark.util.SerializedObjects;
import edu.cmu.cs.lti.ark.util.ds.Pair;
import gnu.trove.TIntDoubleHashMap;

import java.io.File;
import java.io.FilenameFilter;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.logging.FileHandler;
import java.util.logging.LogManager;
import java.util.logging.Logger;
import java.util.logging.SimpleFormatter;

import static edu.cmu.cs.lti.ark.util.IntRanges.xrange;
import static edu.cmu.cs.lti.ark.util.Math2.sum;

public class TrainBatch {
    private static final Logger logger = Logger.getLogger(TrainBatch.class.getCanonicalName());

    public static final String FEATURE_FILENAME_PREFIX = "feats_";
    public static final String FEATURE_FILENAME_SUFFIX = ".jobj.gz";
    private static final FilenameFilter featureFilenameFilter = new FilenameFilter() {
        public boolean accept(File dir, String name) {
            return name.startsWith(FEATURE_FILENAME_PREFIX) && name.endsWith(FEATURE_FILENAME_SUFFIX);
        }
    };
    private static final Comparator<String> featureFilenameComparator = new Comparator<String>() {
        private final int PREFIX_LEN = FEATURE_FILENAME_PREFIX.length();
        private final int SUFFIX_LEN = FEATURE_FILENAME_SUFFIX.length();

        public int compare(String a, String b) {
            final int aNum = Integer.parseInt(a.substring(PREFIX_LEN, a.length() - SUFFIX_LEN));
            final int bNum = Integer.parseInt(b.substring(PREFIX_LEN, b.length() - SUFFIX_LEN));
            return Ints.compare(aNum, bNum);
        }
    };
    private static float DEFAULT_COST_MULTIPLE = 5f;

    private final double[] params;
    private final List<String> eventFiles;
    private final String modelFile;
    private final boolean useL1Regularization;
    private final boolean useL2Regularization;
    private final int numThreads;
    private final double lambda;
    final double[] gradients;
    final double[][] tGradients;
    final double[] tValues;
    private final boolean usePartialCreditCosts;
    private float costMultiple;

    public static void main(String[] args) throws Exception {
        final FNModelOptions options = new FNModelOptions(args);
        LogManager.getLogManager().reset();
        final FileHandler fh = new FileHandler(options.logOutputFile.get(), true);
        fh.setFormatter(new SimpleFormatter());
        logger.addHandler(fh);

        final String restartFile = options.restartFile.get();
        final int numThreads = options.numThreads.present() ? options.numThreads.get()
                : Runtime.getRuntime().availableProcessors();
        final TrainBatch tbm = new TrainBatch(options.alphabetFile.get(), options.eventsFile.get(),
                options.modelFile.get(), options.reg.get(), options.lambda.get(),
                restartFile.equals("null") ? Optional.<String>absent() : Optional.of(restartFile), numThreads,
                options.usePartialCredit.get(),
                options.costMultiple.present() ? (float) options.costMultiple.get() : DEFAULT_COST_MULTIPLE);
        tbm.trainModel();
    }

    public TrainBatch(String alphabetFile, String eventsDir, String modelFile, String reg, double lambda,
            Optional<String> restartFile, int numThreads, boolean usePartialCreditCosts, float costMultiple)
            throws IOException {
        final int modelSize = AlphabetCreationThreaded.getAlphabetSize(alphabetFile);
        logger.info(String.format("Number of features: %d", modelSize));
        this.modelFile = modelFile;
        this.eventFiles = getEventFiles(new File(eventsDir));
        this.useL1Regularization = reg.toLowerCase().equals("l1");
        this.useL2Regularization = reg.toLowerCase().equals("l2");
        this.lambda = lambda;
        this.numThreads = numThreads;
        this.usePartialCreditCosts = usePartialCreditCosts;
        this.costMultiple = costMultiple;
        this.params = restartFile.isPresent() ? loadModel(restartFile.get()) : new double[modelSize];

        gradients = new double[modelSize];
        tGradients = new double[numThreads][modelSize];
        tValues = new double[numThreads];
    }

    public double[] trainModel() throws Exception {
        final Function<double[], Pair<Double, double[]>> valueAndGradientFunction = new Function<double[], Pair<Double, double[]>>() {
            @Override
            public Pair<Double, double[]> apply(double[] currentParams) {
                return getValuesAndGradientsThreaded(currentParams);
            }
        };
        return Lbfgs.trainAndSaveModel(params, valueAndGradientFunction, modelFile);
    }

    private Pair<Double, double[]> getValuesAndGradientsThreaded(final double[] currentParams) {
        final long startTime = System.currentTimeMillis();
        final int batchSize = (int) Math.ceil(eventFiles.size() / (double) numThreads);
        final ThreadPool threadPool = new ThreadPool(numThreads);
        int task = 0;
        for (int start = 0; start < eventFiles.size(); start += batchSize) {
            final int taskId = task;
            final Iterable<Integer> targetIdxs = xrange(start, Math.min(start + batchSize, eventFiles.size()));
            threadPool.runTask(new Runnable() {
                public void run() {
                    logger.info(String.format("Task %d : start", taskId));
                    try {
                        tValues[taskId] = processBatch(targetIdxs, currentParams, tGradients[taskId]);
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                    logger.info(String.format("Task %d : end", taskId));
                }
            });
            task++;
        }
        threadPool.join();
        // merge the calculated tValues and gradients together
        double logLikelihood = sum(tValues);
        Arrays.fill(gradients, 0.0);
        for (double[] tGrad : tGradients) {
            plusEquals(gradients, tGrad);
        }
        // LBGFS always minimizes, so objective is -log(likelihood)
        double negLogLikelihood = -logLikelihood; // * oneOverN;
        timesEquals(gradients, -1.0); //oneOverN);
        // add the regularization penalty
        if (useL1Regularization) {
            double penalty = 0.0;
            for (double param : currentParams)
                penalty += Math.abs(param);
            negLogLikelihood += lambda * penalty;
            for (int i = 0; i < currentParams.length; ++i) {
                gradients[i] += lambda * currentParams[i];
            }
        }
        if (useL2Regularization) {
            negLogLikelihood += lambda * dotProduct(currentParams, currentParams);
            for (int i = 0; i < currentParams.length; ++i) {
                gradients[i] += 2 * lambda * currentParams[i];
            }
        }
        final long endTime = System.currentTimeMillis();
        logger.info(String.format("Finished logLikelihood and gradient computation. Took %s seconds.",
                (endTime - startTime) / 1000.0));
        return Pair.of(negLogLikelihood, gradients);
    }

    /** Writes gradients to tGradients as a side-effect */
    private double processBatch(Iterable<Integer> targetIdxs, double[] currentParams, double[] tGradients)
            throws Exception {
        Arrays.fill(tGradients, 0.0);
        double logLikelihood = 0.0;
        for (int targetIdx : targetIdxs) {
            if (Thread.currentThread().isInterrupted())
                throw new InterruptedException();
            logLikelihood += addLogLossAndGradientForExample(getFeaturesForTarget(targetIdx), currentParams,
                    tGradients);
            if (targetIdx % 100 == 0)
                logger.info(String.format("target idx: %d", targetIdx));
        }
        return logLikelihood;
    }

    /** Writes to gradients as a side-effect */
    private double addLogLossAndGradientForExample(FeaturesAndCost[] featuresByFrame, double[] currentParams,
            double[] gradient) {
        int numFrames = featuresByFrame.length;
        double frameScore[] = new double[numFrames];
        double expdFrameScore[] = new double[numFrames];
        for (int frameIdx = 0; frameIdx < numFrames; frameIdx++) {
            final TIntDoubleHashMap frameFeatures = featuresByFrame[frameIdx].features;
            frameScore[frameIdx] = dotProduct(currentParams, frameFeatures);
            if (usePartialCreditCosts) {
                // softmax-margin
                frameScore[frameIdx] += costMultiple * featuresByFrame[frameIdx].cost;
            }
            expdFrameScore[frameIdx] = Math.exp(frameScore[frameIdx]);
        }
        final double logPartitionFn = Math.log(sum(expdFrameScore));

        // the correct frame is always first
        final int correctFrameIdx = 0;
        plusEquals(gradient, featuresByFrame[correctFrameIdx].features);
        for (int frameIdx = 0; frameIdx < numFrames; frameIdx++) {
            // estimate of P(y | x) * cost(y, y*) under the current parameters
            double prob = Math.exp(frameScore[frameIdx] - logPartitionFn);
            for (int featIdx : featuresByFrame[frameIdx].features.keys()) {
                gradient[featIdx] -= prob * featuresByFrame[frameIdx].features.get(featIdx);
            }
        }
        return frameScore[correctFrameIdx] - logPartitionFn;
    }

    /** Performs a dot product of the dense vector a and the sparse vector b (encoded as a map from index to value). */
    private static double dotProduct(double[] a, TIntDoubleHashMap b) {
        double result = 0.0;
        for (int featureIdx : b.keys()) {
            result += a[featureIdx] * b.get(featureIdx);
        }
        return result;
    }

    /** Performs a dot product of the two dense vectors a and b. */
    private static double dotProduct(double[] a, double[] b) {
        double result = 0.0;
        for (int featureIdx : xrange(a.length)) {
            result += a[featureIdx] * b[featureIdx];
        }
        return result;
    }

    /** Adds the sparse vector rhs to the dense vector lhs **/
    private static void plusEquals(double[] lhs, TIntDoubleHashMap rhs) {
        for (int i : rhs.keys()) {
            lhs[i] += rhs.get(i);
        }
    }

    /** Adds the dense vector rhs to the dense vector lhs **/
    private static void plusEquals(double[] lhs, double[] rhs) {
        for (int i = 0; i < lhs.length; i++) {
            lhs[i] += rhs[i];
        }
    }

    /** Multiplies the dense vector <code>vector</code> by the scalar <code>scalar</code> **/
    private static void timesEquals(double[] vector, double scalar) {
        for (int i = 0; i < vector.length; i++) {
            vector[i] *= scalar;
        }
    }

    /**
     * Get features for every frame for given target
     *
     * @param targetIdx the target to get features for
     * @return map from feature index to feature value (for each frame)
     */
    private FeaturesAndCost[] getFeaturesForTarget(int targetIdx) throws Exception {
        return SerializedObjects.readObject(eventFiles.get(targetIdx));
    }

    /**
     * Reads parameters from modelFile, one param per line.
     *
     * @param modelFile the filename to read from
     */
    public static double[] loadModel(String modelFile) throws IOException {
        final List<String> lines = Files.readLines(new File(modelFile), Charsets.UTF_8);
        final double[] params = new double[lines.size()];
        for (int i : xrange(lines.size())) {
            params[i] = Double.parseDouble(lines.get(i).trim());
        }
        return params;
    }

    /** Gets the list of all feature files */
    private static List<String> getEventFiles(File eventsDir) {
        final String[] files = eventsDir.list(featureFilenameFilter);
        Arrays.sort(files, featureFilenameComparator);
        final List<String> eventFiles = Lists.newArrayListWithExpectedSize(files.length);
        for (String file : files)
            eventFiles.add(eventsDir.getAbsolutePath() + "/" + file);
        logger.info(String.format("Total number of datapoints: %d", eventFiles.size()));
        return eventFiles;
    }
}