asap.NLPSystem.java Source code

Java tutorial

Introduction

Here is the source code for asap.NLPSystem.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package asap;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.RemoveType;

/**
 * Full representation of a system, grouping together the classifier, its
 * learning and/or evaluation sets, FeatureCalculators used for generating the
 * sets and correlation data for comparison with other systems.
 *
 *
 * @author David Jorge Vieira Simes (a21210644@alunos.isec.pt) AKA examinus
 */
public class NLPSystem implements Serializable, Comparable<NLPSystem> {

    private static final int SEED = 0;
    private static final int NO_FOLDS = 10;

    private Classifier classifier;

    //private final FeatureCalculators[] featureCalculators;
    //private final Map<FeatureCalculator, List<Integer>> featuresMap;
    private Instances trainingSet;
    private Instances evaluationSet;

    private Instances trainingOriginalSet;
    private Instances evaluationOriginalSet;

    private double[] trainingPredictions;
    private double[] evaluationPredictions;

    private double trainingPearsonsCorrelation;
    private double crossValidationPearsonsCorrelation;
    private double evaluationPearsonsCorrelation;

    private boolean classifierBuilt;
    private boolean classifierBuiltWithCrossValidation;
    private boolean evaluated;
    private String filename;

    //--------------------------------------------------------------------------
    //-         public members                                                 -
    //--------------------------------------------------------------------------
    //    public NLPSystem(Classifier classifier, Instances trainingSet,
    //        Instances evaluationSet, List<FeatureCalculators> featureCalculators,
    //        Map<FeatureCalculator, List<Integer>> featuresMap) {
    public NLPSystem(Classifier classifier, Instances trainingSet, Instances evaluationSet) {

        this.classifier = classifier;
        this.trainingOriginalSet = trainingSet;
        this.evaluationOriginalSet = evaluationSet;

        this.trainingSet = getFilteredSet(trainingSet);
        this.evaluationSet = getFilteredSet(evaluationSet);

        //        this.featureCalculators = featureCalculators.toArray(
        //              new FeatureCalculators[featureCalculators.size()]);
        //        this.featuresMap = featuresMap;
        classifierBuilt = false;
        classifierBuiltWithCrossValidation = false;
        evaluated = false;
    }

    public Classifier getClassifier() {
        return classifier;
    }

    public synchronized double getCrossValidationPearsonsCorrelation() {
        if (!classifierBuiltWithCrossValidation) {
            return Double.NaN;
        }
        return crossValidationPearsonsCorrelation;
    }

    public synchronized double getEvaluationPearsonsCorrelation() {
        if (!evaluated) {
            return Double.NaN;
        }
        return evaluationPearsonsCorrelation;
    }

    public synchronized double[] getEvaluationPredictions() {
        if (!evaluated) {
            return null;
        }
        return evaluationPredictions;
    }

    public synchronized Instances getEvaluationSet() {
        return evaluationSet;
    }

    public synchronized double[] getTrainingPredictions() {
        if (!classifierBuilt) {
            return null;
        }
        return trainingPredictions;
    }

    public synchronized Instances getTrainingSet() {
        return trainingSet;
    }

    public Instances getTrainingOriginalSet() {
        return trainingOriginalSet;
    }

    public Instances getEvaluationOriginalSet() {
        return evaluationOriginalSet;
    }

    public synchronized boolean isClassifierBuilt() {
        return classifierBuilt;
    }

    public synchronized boolean isEvaluated() {
        return evaluated;
    }

    public synchronized void evaluate() {
        evaluateModel(false);
    }

    public String buildClassifier() {
        return buildClassifier(true);
    }

    public synchronized String buildClassifier(boolean runCrossValidation) {
        if (classifierBuilt && classifierBuiltWithCrossValidation == runCrossValidation) {
            return null;
        }
        //        checkInstancesFeatures(trainingSet);
        final StringBuilder sb = new StringBuilder();
        sb.delete(0, sb.length());

        //build model with or without cross-validation
        if (Config.getNumThreads() > 1) {
            Thread buildThread = new Thread(new Runnable() {

                @Override
                public void run() {
                    sb.append(_buildClassifier());
                }

            });
            buildThread.start();
            if (runCrossValidation) {
                sb.append(crossValidate(SEED, NO_FOLDS, null));
            }
            while (buildThread.isAlive()) {
                try {
                    buildThread.join();
                } catch (InterruptedException ex) {
                    Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
                }
            }
        }

        if (Config.getNumThreads() == 1) {
            if (runCrossValidation) {
                sb.append(crossValidate(SEED, NO_FOLDS, null));
            }

            sb.append(_buildClassifier());
        }

        return sb.toString();
    }

    public synchronized void setEvaluationSet(Instances evaluationSet) {
        this.evaluationOriginalSet = evaluationSet;
        this.evaluationSet = getFilteredSet(evaluationSet);
        this.evaluationPredictions = null;
        evaluated = false;
    }

    private Instances getFilteredSet(Instances set) {
        //TODO: filter all unwanted features
        if (set == null) {
            return null;
        }

        RemoveType removeTypeFilter = new RemoveType();
        String[] removeTypeFilterOptions = { "-T", "string" };
        Instances filteredSet = null;
        try {
            removeTypeFilter.setInputFormat(set);
            removeTypeFilter.setOptions(removeTypeFilterOptions);
            filteredSet = Filter.useFilter(set, removeTypeFilter);
        } catch (Exception ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        }
        return filteredSet;
    }

    private double getComparableCorrelation() {
        if (evaluated) {
            return evaluationPearsonsCorrelation;
        }

        if (classifierBuiltWithCrossValidation) {
            return crossValidationPearsonsCorrelation;
        }

        if (classifierBuilt) {
            return trainingPearsonsCorrelation;
        }

        return -1;
    }

    @Override
    public int compareTo(NLPSystem o) {
        double diff = o.getComparableCorrelation() - getComparableCorrelation();

        if (diff > 0) {
            return 1;
        }
        if (diff < 0) {
            return -1;
        }
        return 0;
    }

    public void saveSystem(File dir, String systemFilename) {
        filename = systemFilename;
        File systemFile = new File(dir, systemFilename);
        try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(systemFile))) {
            oos.writeObject(this);
            oos.flush();
            oos.close();
        } catch (IOException ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    public String shortName() {
        return getClassifier().getClass().getSimpleName() + hashCode();
    }

    @Override
    public String toString() {
        AbstractClassifier ac = (AbstractClassifier) classifier;
        return String.format("Classifier: %s %s\n", ac.getClass().getName(), Utils.joinOptions(ac.getOptions()));
    }

    private String _buildClassifier() {
        Evaluation eval;
        try {
            eval = new Evaluation(trainingSet);
        } catch (Exception ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            return "Error creating evaluation instance for given data!";
        }

        try {
            classifier.buildClassifier(trainingSet);
        } catch (Exception ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        }

        try {
            trainingPredictions = eval.evaluateModel(classifier, trainingSet);
            trainingPearsonsCorrelation = eval.correlationCoefficient();
        } catch (Exception ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        }

        classifierBuilt = true;
        return "Classifier built (" + trainingPearsonsCorrelation + ").";
    }

    private String crossValidate(int seed, int folds, String modelOutputFile) {

        PerformanceCounters.startTimer("cross-validation");
        PerformanceCounters.startTimer("cross-validation init");

        AbstractClassifier abstractClassifier = (AbstractClassifier) classifier;
        // randomize data
        Random rand = new Random(seed);
        Instances randData = new Instances(trainingSet);
        randData.randomize(rand);
        if (randData.classAttribute().isNominal()) {
            randData.stratify(folds);
        }

        // perform cross-validation and add predictions
        Evaluation eval;
        try {
            eval = new Evaluation(randData);
        } catch (Exception ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            return "Error creating evaluation instance for given data!";
        }
        List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

        List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

        for (int n = 0; n < folds; n++) {
            try {
                foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                        AbstractClassifier.makeCopy(abstractClassifier)));
            } catch (Exception ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }

            if (n < Config.getNumThreads() - 1) {
                Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
                foldThreads.add(foldThread);
            }
        }

        PerformanceCounters.stopTimer("cross-validation init");
        PerformanceCounters.startTimer("cross-validation folds+train");

        if (Config.getNumThreads() > 1) {
            for (Thread foldThread : foldThreads) {
                foldThread.start();
            }
        } else {
            new CrossValidationFoldThread(0, foldSets, eval).run();
        }

        for (Thread foldThread : foldThreads) {
            while (foldThread.isAlive()) {
                try {
                    foldThread.join();
                } catch (InterruptedException ex) {
                    Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
                }
            }
        }

        PerformanceCounters.stopTimer("cross-validation folds+train");
        PerformanceCounters.startTimer("cross-validation post");
        // evaluation for output:
        String out = String.format(
                "\n=== Setup ===\nClassifier: %s %s\n" + "Dataset: %s\nFolds: %s\nSeed: %s\n\n%s\n",
                abstractClassifier.getClass().getName(), Utils.joinOptions(abstractClassifier.getOptions()),
                trainingSet.relationName(), folds, seed,
                eval.toSummaryString(String.format("=== %s-fold Cross-validation ===", folds), false));

        try {
            crossValidationPearsonsCorrelation = eval.correlationCoefficient();
        } catch (Exception ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        }
        if (modelOutputFile != null) {
            if (!modelOutputFile.isEmpty()) {
                try {
                    SerializationHelper.write(modelOutputFile, abstractClassifier);
                } catch (Exception ex) {
                    Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
                }
            }
        }

        classifierBuiltWithCrossValidation = true;
        PerformanceCounters.stopTimer("cross-validation post");
        PerformanceCounters.stopTimer("cross-validation");
        return out;
    }

    private void evaluateModel(boolean printEvaluation) {
        //        checkInstancesFeatures(evaluationSet);
        PerformanceCounters.startTimer("evaluateModel");
        System.out.println("Evaluating model...");
        AbstractClassifier abstractClassifier = (AbstractClassifier) classifier;
        try {
            // evaluate classifier and print some statistics
            Evaluation eval = new Evaluation(evaluationSet);

            evaluationPredictions = eval.evaluateModel(abstractClassifier, evaluationSet);

            if (printEvaluation) {
                System.out.println("\tstats for model:" + abstractClassifier.getClass().getName() + " "
                        + Utils.joinOptions(abstractClassifier.getOptions()));
                System.out.println(eval.toSummaryString());
            }

            evaluationPearsonsCorrelation = eval.correlationCoefficient();
            evaluated = true;
        } catch (Exception ex) {
            Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex);
        }

        System.out.println("\tevaluation done.");
        PerformanceCounters.stopTimer("evaluateModel");
    }

    //--------------------------------------------------------------------------
    //-         Inner classes                                                  -
    //--------------------------------------------------------------------------
    private static class CrossValidationFoldThread implements Runnable {

        final Evaluation eval;
        final List<FoldSet> foldSets;
        int threadNumber;

        public CrossValidationFoldThread(int threadNumber, List<FoldSet> foldSets, Evaluation eval) {
            this.threadNumber = threadNumber;
            this.foldSets = foldSets;
            this.eval = eval;
        }

        @Override
        public void run() {
            Thread.currentThread().setPriority(Thread.MIN_PRIORITY);

            FoldSet foldSet;
            Instances trainSet, testSet;
            Classifier cls;

            while (!foldSets.isEmpty()) {
                foldSet = foldSets.remove(0);
                trainSet = foldSet.getTrainSet();
                testSet = foldSet.getTestSet();
                cls = foldSet.getClassifier();

                try {
                    cls.buildClassifier(trainSet);
                } catch (Exception ex) {
                    Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
                }
                try {
                    synchronized (eval) {
                        eval.evaluateModel(cls, testSet);
                    }
                } catch (Exception ex) {
                    Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
                }
            }
        }

    }

    private static class FoldSet {

        private final Instances trainSet, testSet;
        private final Classifier cls;

        public FoldSet(Instances trainSet, Instances testSet, Classifier cls) {
            this.cls = cls;
            this.trainSet = trainSet;
            this.testSet = testSet;
        }

        public Instances getTrainSet() {
            return trainSet;
        }

        public Instances getTestSet() {
            return testSet;
        }

        public Classifier getClassifier() {
            return cls;
        }

    }

}