com.joliciel.talismane.machineLearning.maxent.custom.GISTrainer.java Source code

Java tutorial

Introduction

Here is the source code for com.joliciel.talismane.machineLearning.maxent.custom.GISTrainer.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.joliciel.talismane.machineLearning.maxent.custom;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import opennlp.maxent.GISModel;
import opennlp.model.DataIndexer;
import opennlp.model.EvalParameters;
import opennlp.model.EventStream;
import opennlp.model.MutableContext;
import opennlp.model.OnePassDataIndexer;
import opennlp.model.Prior;
import opennlp.model.UniformPrior;

/**
 * An implementation of Generalized Iterative Scaling.  The reference paper
 * for this implementation was Adwait Ratnaparkhi's tech report at the
 * University of Pennsylvania's Institute for Research in Cognitive Science,
 * and is available at <a href ="ftp://ftp.cis.upenn.edu/pub/ircs/tr/97-08.ps.Z"><code>ftp://ftp.cis.upenn.edu/pub/ircs/tr/97-08.ps.Z</code></a>. 
 *
 * The slack parameter used in the above implementation has been removed by default
 * from the computation and a method for updating with Gaussian smoothing has been
 * added per Investigating GIS and Smoothing for Maximum Entropy Taggers, Clark and Curran (2002).  
 * <a href="http://acl.ldc.upenn.edu/E/E03/E03-1071.pdf"><code>http://acl.ldc.upenn.edu/E/E03/E03-1071.pdf</code></a>
 * The slack parameter can be used by setting <code>useSlackParameter</code> to true.
 * Gaussian smoothing can be used by setting <code>useGaussianSmoothing</code> to true. 
 * 
 * A prior can be used to train models which converge to the distribution which minimizes the
 * relative entropy between the distribution specified by the empirical constraints of the training
 * data and the specified prior.  By default, the uniform distribution is used as the prior.
 */
public class GISTrainer {
    private static final Log LOG = LogFactory.getLog(GISTrainer.class);
    private String currentMessage = "";

    /**
     * Specifies whether unseen context/outcome pairs should be estimated as occur very infrequently.
     */
    private boolean useSimpleSmoothing = false;

    /** 
     * Specified whether parameter updates should prefer a distribution of parameters which
     * is gaussian.
     */
    private boolean useGaussianSmoothing = false;

    private double sigma = 2.0;

    // If we are using smoothing, this is used as the "number" of
    // times we want the trainer to imagine that it saw a feature that it
    // actually didn't see.  Defaulted to 0.1.
    private double _smoothingObservation = 0.1;

    private final boolean printMessages;

    /** 
     * Number of unique events which occured in the event set. 
     */
    private int numUniqueEvents;

    /** 
     * Number of predicates. 
     */
    private int numPreds;

    /** 
     * Number of outcomes. 
     */
    private int numOutcomes;

    /** 
     * Records the array of predicates seen in each event.
     */
    private int[][] contexts;

    /** 
     * The value associated with each context. If null then context values are assumes to be 1.
     */
    private float[][] values;

    /** 
     * List of outcomes for each event i, in context[i].
     */
    private int[] outcomeList;

    /** 
     * Records the num of times an event has been seen for each event i, in context[i].
     */
    private int[] numTimesEventsSeen;

    /** 
     * The number of times a predicate occured in the training data.
     */
    private int[] predicateCounts;

    private int cutoff;

    /**
     * Stores the String names of the outcomes. The GIS only tracks outcomes as
     * ints, and so this array is needed to save the model to disk and thereby
     * allow users to know what the outcome was in human understandable terms.
     */
    private String[] outcomeLabels;

    /**
     * Stores the String names of the predicates. The GIS only tracks predicates
     * as ints, and so this array is needed to save the model to disk and thereby
     * allow users to know what the outcome was in human understandable terms.
     */
    private String[] predLabels;

    /**
     * Stores the observed expected values of the features based on training data.
     */
    private MutableContext[] observedExpects;

    /**
     * Stores the estimated parameter value of each predicate during iteration
     */
    private MutableContext[] params;

    /** 
     * Stores the expected values of the features based on the current models 
     */
    private MutableContext[][] modelExpects;

    /**
     * This is the prior distribution that the model uses for training.
     */
    private Prior prior;

    private static final double LLThreshold = 0.0001;

    /**
     * Initial probability for all outcomes.
     */
    private EvalParameters evalParams;

    /**
     * Creates a new <code>GISTrainer</code> instance which does not print
     * progress messages about training to STDOUT.
     * 
     */
    public GISTrainer() {
        printMessages = false;
    }

    /**
     * Creates a new <code>GISTrainer</code> instance.
     *
     * @param printMessages sends progress messages about training to
     *                      STDOUT when true; trains silently otherwise.
     */
    public GISTrainer(boolean printMessages) {
        this.printMessages = printMessages;
    }

    /**
     * Sets whether this trainer will use smoothing while training the model.
     * This can improve model accuracy, though training will potentially take
     * longer and use more memory.  Model size will also be larger.
     *
     * @param smooth true if smoothing is desired, false if not
     */
    public void setSmoothing(boolean smooth) {
        useSimpleSmoothing = smooth;
    }

    /**
     * Sets whether this trainer will use smoothing while training the model.
     * This can improve model accuracy, though training will potentially take
     * longer and use more memory.  Model size will also be larger.
     *
     * @param timesSeen the "number" of times we want the trainer to imagine
     *                  it saw a feature that it actually didn't see
     */
    public void setSmoothingObservation(double timesSeen) {
        _smoothingObservation = timesSeen;
    }

    /**
     * Sets whether this trainer will use smoothing while training the model.
     * This can improve model accuracy, though training will potentially take
     * longer and use more memory.  Model size will also be larger.
     *
     * @param smooth true if smoothing is desired, false if not
     */
    public void setGaussianSigma(double sigmaValue) {
        useGaussianSmoothing = true;
        sigma = sigmaValue;
    }

    /**
     * Trains a GIS model on the event in the specified event stream, using the specified number
     * of iterations and the specified count cutoff.
     * @param eventStream A stream of all events.
     * @param iterations The number of iterations to use for GIS.
     * @param cutoff The number of times a feature must occur to be included.
     * @return A GIS model trained with specified 
     */
    public GISModel trainModel(EventStream eventStream, int iterations, int cutoff) throws IOException {
        return trainModel(iterations, new OnePassDataIndexer(eventStream, cutoff), cutoff);
    }

    /**
     * Train a model using the GIS algorithm.
     *
     * @param iterations  The number of GIS iterations to perform.
     * @param di The data indexer used to compress events in memory.
     * @return The newly trained model, which can be used immediately or saved
     *         to disk using an opennlp.maxent.io.GISModelWriter object.
     */
    public GISModel trainModel(int iterations, DataIndexer di, int cutoff) {
        return trainModel(iterations, di, new UniformPrior(), cutoff, 1);
    }

    /**
     * Train a model using the GIS algorithm.
     *
     * @param iterations  The number of GIS iterations to perform.
     * @param di The data indexer used to compress events in memory.
     * @param modelPrior The prior distribution used to train this model.
     * @return The newly trained model, which can be used immediately or saved
     *         to disk using an opennlp.maxent.io.GISModelWriter object.
     */
    public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int cutoff, int threads) {

        if (threads <= 0)
            throw new IllegalArgumentException("threads must be at leat one or greater!");

        modelExpects = new MutableContext[threads][];

        /************** Incorporate all of the needed info ******************/
        display("Incorporating indexed data for training...  \n");
        contexts = di.getContexts();
        values = di.getValues();
        this.cutoff = cutoff;
        predicateCounts = di.getPredCounts();
        numTimesEventsSeen = di.getNumTimesEventsSeen();
        numUniqueEvents = contexts.length;
        this.prior = modelPrior;
        //printTable(contexts);

        // determine the correction constant and its inverse
        double correctionConstant = 0;
        for (int ci = 0; ci < contexts.length; ci++) {
            if (values == null || values[ci] == null) {
                if (contexts[ci].length > correctionConstant) {
                    correctionConstant = contexts[ci].length;
                }
            } else {
                float cl = values[ci][0];
                for (int vi = 1; vi < values[ci].length; vi++) {
                    cl += values[ci][vi];
                }

                if (cl > correctionConstant) {
                    correctionConstant = cl;
                }
            }
        }
        display("done.\n");

        outcomeLabels = di.getOutcomeLabels();
        outcomeList = di.getOutcomeList();
        numOutcomes = outcomeLabels.length;

        predLabels = di.getPredLabels();
        prior.setLabels(outcomeLabels, predLabels);
        numPreds = predLabels.length;

        display("\tNumber of Event Tokens: " + numUniqueEvents + "\n");
        display("\t    Number of Outcomes: " + numOutcomes + "\n");
        display("\t  Number of Predicates: " + numPreds + "\n");

        // set up feature arrays
        float[][] predCount = new float[numPreds][numOutcomes];
        for (int ti = 0; ti < numUniqueEvents; ti++) {
            for (int j = 0; j < contexts[ti].length; j++) {
                if (values != null && values[ti] != null) {
                    predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti] * values[ti][j];
                } else {
                    predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti];
                }
            }
        }

        //printTable(predCount);
        di = null; // don't need it anymore

        // A fake "observation" to cover features which are not detected in
        // the data.  The default is to assume that we observed "1/10th" of a
        // feature during training.
        final double smoothingObservation = _smoothingObservation;

        // Get the observed expectations of the features. Strictly speaking,
        // we should divide the counts by the number of Tokens, but because of
        // the way the model's expectations are approximated in the
        // implementation, this is cancelled out when we compute the next
        // iteration of a parameter, making the extra divisions wasteful.
        params = new MutableContext[numPreds];
        for (int i = 0; i < modelExpects.length; i++)
            modelExpects[i] = new MutableContext[numPreds];
        observedExpects = new MutableContext[numPreds];

        // The model does need the correction constant and the correction feature. The correction constant
        // is only needed during training, and the correction feature is not necessary.
        // For compatibility reasons the model contains form now on a correction constant of 1, 
        // and a correction param 0.
        evalParams = new EvalParameters(params, 0, 1, numOutcomes);
        int[] activeOutcomes = new int[numOutcomes];
        int[] outcomePattern;
        int[] allOutcomesPattern = new int[numOutcomes];
        for (int oi = 0; oi < numOutcomes; oi++) {
            allOutcomesPattern[oi] = oi;
        }
        int numActiveOutcomes = 0;
        for (int pi = 0; pi < numPreds; pi++) {
            numActiveOutcomes = 0;
            if (useSimpleSmoothing) {
                numActiveOutcomes = numOutcomes;
                outcomePattern = allOutcomesPattern;
            } else { //determine active outcomes
                for (int oi = 0; oi < numOutcomes; oi++) {
                    if (predCount[pi][oi] > 0 && predicateCounts[pi] >= cutoff) {
                        activeOutcomes[numActiveOutcomes] = oi;
                        numActiveOutcomes++;
                    }
                }
                if (numActiveOutcomes == numOutcomes) {
                    outcomePattern = allOutcomesPattern;
                } else {
                    outcomePattern = new int[numActiveOutcomes];
                    for (int aoi = 0; aoi < numActiveOutcomes; aoi++) {
                        outcomePattern[aoi] = activeOutcomes[aoi];
                    }
                }
            }
            params[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
            for (int i = 0; i < modelExpects.length; i++)
                modelExpects[i][pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
            observedExpects[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
            for (int aoi = 0; aoi < numActiveOutcomes; aoi++) {
                int oi = outcomePattern[aoi];
                params[pi].setParameter(aoi, 0.0);
                for (int i = 0; i < modelExpects.length; i++)
                    modelExpects[i][pi].setParameter(aoi, 0.0);
                if (predCount[pi][oi] > 0) {
                    observedExpects[pi].setParameter(aoi, predCount[pi][oi]);
                } else if (useSimpleSmoothing) {
                    observedExpects[pi].setParameter(aoi, smoothingObservation);
                }
            }
        }

        predCount = null; // don't need it anymore

        display("...done.\n");

        /***************** Find the parameters ************************/
        if (threads == 1)
            display("Computing model parameters ...\n");
        else
            display("Computing model parameters in " + threads + " threads...\n");

        findParameters(iterations, correctionConstant);

        /*************** Create and return the model ******************/
        // To be compatible with old models the correction constant is always 1
        return new GISModel(params, predLabels, outcomeLabels, 1, evalParams.getCorrectionParam());

    }

    /* Estimate and return the model parameters. */
    private void findParameters(int iterations, double correctionConstant) {
        double prevLL = 0.0;
        double currLL = 0.0;
        display("Performing " + iterations + " iterations.\n");
        for (int i = 1; i <= iterations; i++) {
            if (i < 10)
                display("  " + i + ":  ");
            else if (i < 100)
                display(" " + i + ":  ");
            else
                display(i + ":  ");
            currLL = nextIteration(correctionConstant);
            if (i > 1) {
                if (prevLL > currLL) {
                    LOG.error("Model Diverging: loglikelihood decreased");
                    break;
                }
                if (currLL - prevLL < LLThreshold) {
                    break;
                }
            }
            prevLL = currLL;
        }

        // kill a bunch of these big objects now that we don't need them
        observedExpects = null;
        modelExpects = null;
        numTimesEventsSeen = null;
        contexts = null;
    }

    //modeled on implementation in  Zhang Le's maxent kit
    private double gaussianUpdate(int predicate, int oid, int n, double correctionConstant) {
        double param = params[predicate].getParameters()[oid];
        double x0 = 0.0;
        double modelValue = modelExpects[0][predicate].getParameters()[oid];
        double observedValue = observedExpects[predicate].getParameters()[oid];
        for (int i = 0; i < 50; i++) {
            double tmp = modelValue * Math.exp(correctionConstant * x0);
            double f = tmp + (param + x0) / sigma - observedValue;
            double fp = tmp * correctionConstant + 1 / sigma;
            if (fp == 0) {
                break;
            }
            double x = x0 - f / fp;
            if (Math.abs(x - x0) < 0.000001) {
                x0 = x;
                break;
            }
            x0 = x;
        }
        return x0;
    }

    private class ModelExpactationComputeTask implements Callable<ModelExpactationComputeTask> {

        private final int startIndex;
        private final int length;

        private double loglikelihood = 0;

        private int numEvents = 0;
        private int numCorrect = 0;

        final private int threadIndex;

        // startIndex to compute, number of events to compute
        ModelExpactationComputeTask(int threadIndex, int startIndex, int length) {
            this.startIndex = startIndex;
            this.length = length;
            this.threadIndex = threadIndex;
        }

        @Override
        public ModelExpactationComputeTask call() {

            final double[] modelDistribution = new double[numOutcomes];

            for (int ei = startIndex; ei < startIndex + length; ei++) {

                // TODO: check interruption status here, if interrupted set a poisoned flag and return

                if (values != null) {
                    prior.logPrior(modelDistribution, contexts[ei], values[ei]);
                    GISModel.eval(contexts[ei], values[ei], modelDistribution, evalParams);
                } else {
                    prior.logPrior(modelDistribution, contexts[ei]);
                    GISModel.eval(contexts[ei], modelDistribution, evalParams);
                }
                for (int j = 0; j < contexts[ei].length; j++) {
                    int pi = contexts[ei][j];
                    if (predicateCounts[pi] >= cutoff) {
                        int[] activeOutcomes = modelExpects[threadIndex][pi].getOutcomes();
                        for (int aoi = 0; aoi < activeOutcomes.length; aoi++) {
                            int oi = activeOutcomes[aoi];

                            // numTimesEventsSeen must also be thread safe
                            if (values != null && values[ei] != null) {
                                modelExpects[threadIndex][pi].updateParameter(aoi,
                                        modelDistribution[oi] * values[ei][j] * numTimesEventsSeen[ei]);
                            } else {
                                modelExpects[threadIndex][pi].updateParameter(aoi,
                                        modelDistribution[oi] * numTimesEventsSeen[ei]);
                            }
                        }
                    }
                }

                loglikelihood += Math.log(modelDistribution[outcomeList[ei]]) * numTimesEventsSeen[ei];

                numEvents += numTimesEventsSeen[ei];
                if (printMessages) {
                    int max = 0;
                    for (int oi = 1; oi < numOutcomes; oi++) {
                        if (modelDistribution[oi] > modelDistribution[max]) {
                            max = oi;
                        }
                    }
                    if (max == outcomeList[ei]) {
                        numCorrect += numTimesEventsSeen[ei];
                    }
                }

            }

            return this;
        }

        synchronized int getNumEvents() {
            return numEvents;
        }

        synchronized int getNumCorrect() {
            return numCorrect;
        }

        synchronized double getLoglikelihood() {
            return loglikelihood;
        }
    }

    /* Compute one iteration of GIS and retutn log-likelihood.*/
    private double nextIteration(double correctionConstant) {
        // compute contribution of p(a|b_i) for each feature and the new
        // correction parameter
        double loglikelihood = 0.0;
        int numEvents = 0;
        int numCorrect = 0;

        int numberOfThreads = modelExpects.length;

        ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads);

        int taskSize = numUniqueEvents / numberOfThreads;

        int leftOver = numUniqueEvents % numberOfThreads;

        List<Future<?>> futures = new ArrayList<Future<?>>();

        for (int i = 0; i < numberOfThreads; i++) {
            if (i != numberOfThreads - 1)
                futures.add(executor.submit(new ModelExpactationComputeTask(i, i * taskSize, taskSize)));
            else
                futures.add(executor.submit(new ModelExpactationComputeTask(i, i * taskSize, taskSize + leftOver)));
        }

        for (Future<?> future : futures) {
            ModelExpactationComputeTask finishedTask = null;
            try {
                finishedTask = (ModelExpactationComputeTask) future.get();
            } catch (InterruptedException e) {
                // TODO: We got interrupted, but that is currently not really supported!
                // For now we just print the exception and fail hard. We hopefully soon
                // handle this case properly!
                e.printStackTrace();
                throw new IllegalStateException("Interruption is not supported!", e);
            } catch (ExecutionException e) {
                // Only runtime exception can be thrown during training, if one was thrown
                // it should be re-thrown. That could for example be a NullPointerException
                // which is caused through a bug in our implementation.
                throw new RuntimeException(e.getCause());
            }

            // When they are done, retrieve the results ...
            numEvents += finishedTask.getNumEvents();
            numCorrect += finishedTask.getNumCorrect();
            loglikelihood += finishedTask.getLoglikelihood();
        }

        executor.shutdown();

        display(".");

        // merge the results of the two computations
        for (int pi = 0; pi < numPreds; pi++) {
            int[] activeOutcomes = params[pi].getOutcomes();

            for (int aoi = 0; aoi < activeOutcomes.length; aoi++) {
                for (int i = 1; i < modelExpects.length; i++) {
                    modelExpects[0][pi].updateParameter(aoi, modelExpects[i][pi].getParameters()[aoi]);
                }
            }
        }

        display(".");

        // compute the new parameter values
        for (int pi = 0; pi < numPreds; pi++) {
            double[] observed = observedExpects[pi].getParameters();
            double[] model = modelExpects[0][pi].getParameters();
            int[] activeOutcomes = params[pi].getOutcomes();
            for (int aoi = 0; aoi < activeOutcomes.length; aoi++) {
                if (useGaussianSmoothing) {
                    params[pi].updateParameter(aoi, gaussianUpdate(pi, aoi, numEvents, correctionConstant));
                } else {
                    if (model[aoi] == 0) {
                        LOG.error("Model expects == 0 for " + predLabels[pi] + " " + outcomeLabels[aoi]);
                    }
                    //params[pi].updateParameter(aoi,(Math.log(observed[aoi]) - Math.log(model[aoi])));
                    params[pi].updateParameter(aoi,
                            ((Math.log(observed[aoi]) - Math.log(model[aoi])) / correctionConstant));
                }

                for (int i = 0; i < modelExpects.length; i++)
                    modelExpects[i][pi].setParameter(aoi, 0.0); // re-initialize to 0.0's

            }
        }

        display(". loglikelihood=" + loglikelihood + "\t" + ((double) numCorrect / numEvents) + "\n");

        return loglikelihood;
    }

    private void display(String s) {
        if (printMessages) {
            currentMessage += s;
            if (s.endsWith("\n")) {
                LOG.debug(currentMessage.substring(0, currentMessage.length() - 1));
                currentMessage = "";
            }
        }
    }
}