com.joliciel.talismane.machineLearning.perceptron.PerceptronClassifactionModelTrainerImpl.java Source code

Java tutorial

Introduction

Here is the source code for com.joliciel.talismane.machineLearning.perceptron.PerceptronClassifactionModelTrainerImpl.java

Source

///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2013 Assaf Urieli
//
//This file is part of Talismane.
//
//Talismane is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Talismane is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Talismane.  If not, see <http://www.gnu.org/licenses/>.
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.talismane.machineLearning.perceptron;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.joliciel.talismane.machineLearning.ClassificationModel;
import com.joliciel.talismane.machineLearning.ClassificationEvent;
import com.joliciel.talismane.machineLearning.ClassificationEventStream;
import com.joliciel.talismane.machineLearning.DecisionFactory;
import com.joliciel.talismane.machineLearning.MachineLearningModel;
import com.joliciel.talismane.machineLearning.Outcome;
import com.joliciel.talismane.utils.JolicielException;
import com.joliciel.talismane.utils.LogUtils;

class PerceptronClassifactionModelTrainerImpl<T extends Outcome>
        implements PerceptronClassificationModelTrainer<T> {
    private static final Log LOG = LogFactory.getLog(PerceptronClassifactionModelTrainerImpl.class);
    private int iterations = 100;
    private int cutoff = 0;
    private double tolerance = 1e-5;

    private double[][] totalFeatureWeights;
    private PerceptronModelParameters params;
    private File eventFile;
    private PerceptronDecisionMaker<T> decisionMaker;
    private DecisionFactory<T> decisionFactory;
    private Map<String, List<String>> descriptors;
    private ClassificationEventStream corpusEventStream;
    private PerceptronModelTrainerObserver<T> observer;
    private List<Integer> observationPoints;
    private boolean averageAtIntervals = false;

    public PerceptronClassifactionModelTrainerImpl() {
    }

    void prepareData(ClassificationEventStream eventStream) {
        try {
            eventFile = File.createTempFile("events", "txt");
            eventFile.deleteOnExit();
            Writer eventWriter = new BufferedWriter(
                    new OutputStreamWriter(new FileOutputStream(eventFile), "UTF-8"));
            while (eventStream.hasNext()) {
                ClassificationEvent corpusEvent = eventStream.next();
                PerceptronEvent event = new PerceptronEvent(corpusEvent, params);
                event.write(eventWriter);
            }
            eventWriter.flush();
            eventWriter.close();

            if (cutoff > 1) {
                params.initialiseCounts();
                File originalEventFile = eventFile;
                Scanner scanner = new Scanner(
                        new BufferedReader(new InputStreamReader(new FileInputStream(eventFile), "UTF-8")));

                while (scanner.hasNextLine()) {
                    String line = scanner.nextLine();
                    PerceptronEvent event = new PerceptronEvent(line);
                    for (int featureIndex : event.getFeatureIndexes()) {
                        params.getFeatureCounts()[featureIndex]++;
                    }
                }
                scanner.close();

                if (LOG.isDebugEnabled()) {
                    int[] cutoffCounts = new int[21];
                    for (int count : params.getFeatureCounts()) {
                        for (int i = 1; i < 21; i++) {
                            if (count >= i) {
                                cutoffCounts[i]++;
                            }
                        }
                    }
                    LOG.debug("Feature counts:");
                    for (int i = 1; i < 21; i++) {
                        LOG.debug("Cutoff " + i + ": " + cutoffCounts[i]);
                    }
                }
                PerceptronModelParameters cutoffParams = new PerceptronModelParameters();
                int[] newIndexes = cutoffParams.initialise(params, cutoff);
                decisionMaker = new PerceptronDecisionMaker<T>(cutoffParams, decisionFactory);
                scanner = new Scanner(
                        new BufferedReader(new InputStreamReader(new FileInputStream(eventFile), "UTF-8")));

                eventFile = File.createTempFile("eventsCutoff", "txt");
                eventFile.deleteOnExit();
                Writer eventCutoffWriter = new BufferedWriter(
                        new OutputStreamWriter(new FileOutputStream(eventFile), "UTF-8"));
                while (scanner.hasNextLine()) {
                    String line = scanner.nextLine();
                    PerceptronEvent oldEvent = new PerceptronEvent(line);
                    PerceptronEvent newEvent = new PerceptronEvent(oldEvent, newIndexes);
                    newEvent.write(eventCutoffWriter);
                }
                eventCutoffWriter.flush();
                eventCutoffWriter.close();
                params = cutoffParams;
                originalEventFile.delete();
            }

            params.initialiseWeights();
            totalFeatureWeights = new double[params.getFeatureCount()][params.getOutcomeCount()];
        } catch (IOException e) {
            LogUtils.logError(LOG, e);
            throw new RuntimeException(e);
        }
    }

    void train() {
        try {
            double prevAccuracy1 = 0.0;
            double prevAccuracy2 = 0.0;
            double prevAccuracy3 = 0.0;
            int i = 0;
            int averagingCount = 0;
            for (i = 1; i <= iterations; i++) {
                LOG.debug("Iteration " + i);
                int totalErrors = 0;
                int totalEvents = 0;

                Scanner scanner = new Scanner(
                        new BufferedReader(new InputStreamReader(new FileInputStream(eventFile), "UTF-8")));

                while (scanner.hasNextLine()) {
                    String line = scanner.nextLine();
                    PerceptronEvent event = new PerceptronEvent(line);
                    totalEvents++;

                    // don't normalise unless we calculate the log-likelihood, to avoid mathematical cost of normalising
                    double[] results = decisionMaker.predict(event.getFeatureIndexes(), event.getFeatureValues());
                    double maxValue = results[0];
                    int predicted = 0;
                    for (int j = 1; j < results.length; j++) {
                        if (results[j] > maxValue) {
                            maxValue = results[j];
                            predicted = j;
                        }
                    }

                    int actual = event.getOutcomeIndex();

                    if (actual != predicted) {
                        for (int j = 0; j < event.getFeatureIndexes().size(); j++) {
                            double[] classWeights = params.getFeatureWeights()[event.getFeatureIndexes().get(j)];
                            classWeights[actual] += event.getFeatureValues().get(j);
                            classWeights[predicted] -= event.getFeatureValues().get(j);
                        }
                        totalErrors++;
                    } // correct outcome?
                } // next event

                // Add feature weights for this iteration
                boolean addAverage = true;
                if (this.isAverageAtIntervals()) {
                    if (i <= 20 || i == 25 || i == 36 || i == 49 || i == 64 || i == 81 || i == 100 || i == 121
                            || i == 144 || i == 169 || i == 196) {
                        addAverage = true;
                        LOG.debug("Averaging at iteration: " + i);
                    } else
                        addAverage = false;
                }

                if (addAverage) {
                    for (int j = 0; j < params.getFeatureWeights().length; j++) {
                        double[] totalClassWeights = totalFeatureWeights[j];
                        double[] classWeights = params.getFeatureWeights()[j];
                        for (int k = 0; k < params.getOutcomeCount(); k++) {
                            totalClassWeights[k] += classWeights[k];
                        }
                    }
                    averagingCount++;
                }

                if (observer != null && observationPoints.contains(i)) {
                    PerceptronModelParameters cloneParams = params.clone();
                    // average the weights for this model
                    for (int j = 0; j < cloneParams.getFeatureWeights().length; j++) {
                        double[] totalClassWeights = totalFeatureWeights[j];
                        double[] classWeights = cloneParams.getFeatureWeights()[j];
                        for (int k = 0; k < cloneParams.getOutcomeCount(); k++) {
                            classWeights[k] = totalClassWeights[k] / averagingCount;
                        }
                    }
                    ClassificationModel<T> model = this.getModel(cloneParams, i);
                    observer.onNextModel(model, i);
                    cloneParams = null;
                }

                double accuracy = (double) (totalEvents - totalErrors) / (double) totalEvents;
                LOG.debug("Accuracy: " + accuracy);

                // exit if accuracy hasn't significantly changed in 3 iterations
                if (Math.abs(accuracy - prevAccuracy1) < tolerance && Math.abs(accuracy - prevAccuracy2) < tolerance
                        && Math.abs(accuracy - prevAccuracy3) < tolerance) {
                    LOG.info("Accuracy change < " + tolerance + " for 3 iterations: exiting after " + i
                            + " iterations");
                    break;
                }

                prevAccuracy3 = prevAccuracy2;
                prevAccuracy2 = prevAccuracy1;
                prevAccuracy1 = accuracy;
            } // next iteration

            // average the final weights
            for (int j = 0; j < params.getFeatureWeights().length; j++) {
                double[] totalClassWeights = totalFeatureWeights[j];
                double[] classWeights = params.getFeatureWeights()[j];
                for (int k = 0; k < params.getOutcomeCount(); k++) {
                    classWeights[k] = totalClassWeights[k] / averagingCount;
                }
            }

        } catch (IOException e) {
            LogUtils.logError(LOG, e);
            throw new RuntimeException(e);
        }
    }

    private static final class PerceptronEvent {
        List<Integer> featureIndexes;
        List<Double> featureValues;
        int outcomeIndex;

        public PerceptronEvent(ClassificationEvent corpusEvent, PerceptronModelParameters params) {
            featureIndexes = new ArrayList<Integer>();
            featureValues = new ArrayList<Double>();
            params.prepareData(corpusEvent.getFeatureResults(), featureIndexes, featureValues, true);
            outcomeIndex = params.getOrCreateOutcomeIndex(corpusEvent.getClassification());
        }

        public PerceptronEvent(String line) {
            String[] parts = line.split(" ");
            this.outcomeIndex = Integer.parseInt(parts[0]);
            int featureCount = (parts.length - 1) / 2;
            featureIndexes = new ArrayList<Integer>(featureCount);
            featureValues = new ArrayList<Double>(featureCount);
            int j = 1;
            for (int i = 0; i < featureCount; i++) {
                featureIndexes.add(Integer.parseInt(parts[j++]));
                featureValues.add(Double.parseDouble(parts[j++]));
            }
        }

        public PerceptronEvent(PerceptronEvent oldEvent, int[] newIndexes) {
            featureIndexes = new ArrayList<Integer>();
            featureValues = new ArrayList<Double>();
            int i = 0;
            for (int oldIndex : oldEvent.featureIndexes) {
                if (newIndexes[oldIndex] >= 0) {
                    featureIndexes.add(newIndexes[oldIndex]);
                    featureValues.add(oldEvent.featureValues.get(i));
                }
                i++;
            }
            outcomeIndex = oldEvent.outcomeIndex;
        }

        public List<Integer> getFeatureIndexes() {
            return featureIndexes;
        }

        public List<Double> getFeatureValues() {
            return featureValues;
        }

        public int getOutcomeIndex() {
            return outcomeIndex;
        }

        public void write(Writer writer) throws IOException {
            writer.write("" + outcomeIndex);
            for (int i = 0; i < featureIndexes.size(); i++) {
                writer.write(" ");
                writer.write("" + featureIndexes.get(i));
                writer.write(" ");
                writer.write("" + featureValues.get(i));
            }
            writer.write("\n");
            writer.flush();
        }

    }

    public int getIterations() {
        return iterations;
    }

    public void setIterations(int iterations) {
        this.iterations = iterations;
    }

    public int getCutoff() {
        return cutoff;
    }

    public void setCutoff(int cutoff) {
        this.cutoff = cutoff;
    }

    public double getTolerance() {
        return tolerance;
    }

    public void setTolerance(double tolerance) {
        this.tolerance = tolerance;
    }

    /**
     * If true, will only average for iterations <= 20 and then for all
     * perfect squares (25, 36, 49, 64, 81, 100, etc.).
     * @return
     */
    public boolean isAverageAtIntervals() {
        return averageAtIntervals;
    }

    public void setAverageAtIntervals(boolean averageAtIntervals) {
        this.averageAtIntervals = averageAtIntervals;
    }

    @Override
    public void trainModelsWithObserver(ClassificationEventStream corpusEventStream,
            DecisionFactory<T> decisionFactory, List<String> featureDescriptors,
            PerceptronModelTrainerObserver<T> observer, List<Integer> observationPoints) {
        Map<String, List<String>> descriptors = new HashMap<String, List<String>>();
        descriptors.put(MachineLearningModel.FEATURE_DESCRIPTOR_KEY, featureDescriptors);
        this.trainModelsWithObserver(corpusEventStream, decisionFactory, descriptors, observer, observationPoints);
    }

    @Override
    public void trainModelsWithObserver(ClassificationEventStream corpusEventStream,
            DecisionFactory<T> decisionFactory, Map<String, List<String>> descriptors,
            PerceptronModelTrainerObserver<T> observer, List<Integer> observationPoints) {
        params = new PerceptronModelParameters();
        decisionMaker = new PerceptronDecisionMaker<T>(params, decisionFactory);
        this.decisionFactory = decisionFactory;
        this.descriptors = descriptors;
        this.observer = observer;
        this.observationPoints = observationPoints;
        this.corpusEventStream = corpusEventStream;
        this.prepareData(corpusEventStream);
        this.train();

        if (this.eventFile != null) {
            this.eventFile.delete();
        }

    }

    @Override
    public ClassificationModel<T> trainModel(ClassificationEventStream corpusEventStream,
            DecisionFactory<T> decisionFactory, List<String> featureDescriptors) {
        Map<String, List<String>> descriptors = new HashMap<String, List<String>>();
        descriptors.put(MachineLearningModel.FEATURE_DESCRIPTOR_KEY, featureDescriptors);
        return this.trainModel(corpusEventStream, decisionFactory, descriptors);
    }

    @Override
    public ClassificationModel<T> trainModel(ClassificationEventStream corpusEventStream,
            DecisionFactory<T> decisionFactory, Map<String, List<String>> descriptors) {
        params = new PerceptronModelParameters();
        decisionMaker = new PerceptronDecisionMaker<T>(params, decisionFactory);
        this.decisionFactory = decisionFactory;
        this.descriptors = descriptors;
        this.corpusEventStream = corpusEventStream;
        this.prepareData(corpusEventStream);
        this.train();
        ClassificationModel<T> model = this.getModel(params, this.getIterations());

        if (this.eventFile != null)
            this.eventFile.delete();

        return model;
    }

    ClassificationModel<T> getModel(PerceptronModelParameters params, int iterations) {
        PerceptronClassificationModel<T> model = new PerceptronClassificationModel<T>(params, descriptors,
                decisionFactory);
        model.addModelAttribute("cutoff", "" + this.getCutoff());
        model.addModelAttribute("iterations", "" + iterations);
        model.addModelAttribute("tolerance", "" + this.getTolerance());
        model.addModelAttribute("averageAtIntervals", "" + this.isAverageAtIntervals());

        model.getModelAttributes().putAll(corpusEventStream.getAttributes());

        return model;
    }

    @Override
    public void setParameters(Map<String, Object> parameters) {
        if (parameters != null) {
            for (String parameter : parameters.keySet()) {
                PerceptronModelParameter modelParameter = PerceptronModelParameter.valueOf(parameter);
                Object value = parameters.get(parameter);
                if (!modelParameter.getParameterType().isAssignableFrom(value.getClass())) {
                    throw new JolicielException("Parameter of wrong type: " + parameter + ". Expected: "
                            + modelParameter.getParameterType().getSimpleName());
                }
                switch (modelParameter) {
                case Iterations:
                    this.setIterations((Integer) value);
                    break;
                case Cutoff:
                    this.setCutoff((Integer) value);
                    break;
                case Tolerance:
                    this.setTolerance((Double) value);
                    break;
                case AverageAtIntervals:
                    this.setAverageAtIntervals((Boolean) value);
                    break;
                default:
                    throw new JolicielException("Unknown parameter type: " + modelParameter);
                }
            }
        }
    }
}