dkpro.similarity.experiments.rte.util.Evaluator.java Source code

Java tutorial

Introduction

Here is the source code for dkpro.similarity.experiments.rte.util.Evaluator.java

Source

/*******************************************************************************
 * Copyright 2013
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the GNU Public License v3.0
 * which accompanies this distribution, and is available at
 * http://www.gnu.org/licenses/gpl-3.0.txt
 ******************************************************************************/
package dkpro.similarity.experiments.rte.util;

import static dkpro.similarity.experiments.rte.Pipeline.GOLD_DIR;
import static dkpro.similarity.experiments.rte.Pipeline.MODELS_DIR;
import static dkpro.similarity.experiments.rte.Pipeline.OUTPUT_DIR;
import static dkpro.similarity.experiments.rte.Pipeline.EvaluationMetric.Accuracy;
import static dkpro.similarity.experiments.rte.Pipeline.EvaluationMetric.AveragePrecision;
import static dkpro.similarity.experiments.rte.Pipeline.EvaluationMetric.CWS;

import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Random;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.FileFilterUtils;
import org.springframework.util.CollectionUtils;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.output.prediction.AbstractOutput;
import weka.classifiers.evaluation.output.prediction.PlainText;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSink;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.supervised.attribute.AddClassification;
import weka.filters.unsupervised.attribute.AddID;
import weka.filters.unsupervised.attribute.Remove;
import dkpro.similarity.algorithms.ml.ClassifierSimilarityMeasure;
import dkpro.similarity.algorithms.ml.ClassifierSimilarityMeasure.WekaClassifier;
import dkpro.similarity.experiments.rte.Pipeline.Dataset;
import dkpro.similarity.experiments.rte.Pipeline.EvaluationMetric;
//import de.tudarmstadt.ukp.similarity.experiments.rte.Pipeline.EvaluationMetric;
//import de.tudarmstadt.ukp.similarity.experiments.rte.Pipeline.Mode;
//import de.tudarmstadt.ukp.similarity.experiments.rte.filter.LogFilter;

public class Evaluator {
    public static final String LF = System.getProperty("line.separator");

    //   public static void runClassifier(Dataset train, Dataset test)
    //      throws UIMAException, IOException
    //   {
    //      CollectionReader reader = createCollectionReader(
    //            RTECorpusReader.class,
    //            RTECorpusReader.PARAM_INPUT_FILE, RteUtil.getInputFilePathForDataset(DATASET_DIR, test),
    //            RTECorpusReader.PARAM_COMBINATION_STRATEGY, CombinationStrategy.SAME_ROW_ONLY.toString());
    //      
    //      AnalysisEngineDescription seg = createPrimitiveDescription(
    //            BreakIteratorSegmenter.class);
    //      
    //      AggregateBuilder builder = new AggregateBuilder();
    //      builder.add(seg, CombinationReader.INITIAL_VIEW, CombinationReader.VIEW_1);
    //      builder.add(seg, CombinationReader.INITIAL_VIEW, CombinationReader.VIEW_2);
    //      AnalysisEngine aggr_seg = builder.createAggregate();
    //
    //      AnalysisEngine scorer = createPrimitive(
    //            SimilarityScorer.class,
    //             SimilarityScorer.PARAM_NAME_VIEW_1, CombinationReader.VIEW_1,
    //             SimilarityScorer.PARAM_NAME_VIEW_2, CombinationReader.VIEW_2,
    //             SimilarityScorer.PARAM_SEGMENT_FEATURE_PATH, Document.class.getName(),
    //             SimilarityScorer.PARAM_TEXT_SIMILARITY_RESOURCE, createExternalResourceDescription(
    //                ClassifierResource.class,
    //                ClassifierResource.PARAM_CLASSIFIER, wekaClassifier.toString(),
    //                ClassifierResource.PARAM_TRAIN_ARFF, MODELS_DIR + "/" + train.toString() + ".arff",
    //                ClassifierResource.PARAM_TEST_ARFF, MODELS_DIR + "/" + test.toString() + ".arff")
    //             );
    //      
    //      AnalysisEngine writer = createPrimitive(
    //            SimilarityScoreWriter.class,
    //            SimilarityScoreWriter.PARAM_OUTPUT_FILE, OUTPUT_DIR + "/" + test.toString() + ".csv",
    //            SimilarityScoreWriter.PARAM_OUTPUT_SCORES_ONLY, true,
    //            SimilarityScoreWriter.PARAM_OUTPUT_GOLD_SCORES, false);
    //
    //      SimplePipeline.runPipeline(reader, aggr_seg, scorer, writer);
    //   }

    public static void runClassifier(WekaClassifier wekaClassifier, Dataset trainDataset, Dataset testDataset)
            throws Exception {
        Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier);

        // Set up the random number generator
        long seed = new Date().getTime();
        Random random = new Random(seed);

        // Add IDs to the train instances and get the instances
        AddID.main(new String[] { "-i", MODELS_DIR + "/" + trainDataset.toString() + ".arff", "-o",
                MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff" });
        Instances train = DataSource.read(MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff");
        train.setClassIndex(train.numAttributes() - 1);

        // Add IDs to the test instances and get the instances
        AddID.main(new String[] { "-i", MODELS_DIR + "/" + testDataset.toString() + ".arff", "-o",
                MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff" });
        Instances test = DataSource.read(MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff");
        test.setClassIndex(test.numAttributes() - 1);

        // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
        removeIDFilter.setAttributeIndices("first");

        // Randomize the data
        test.randomize(random);

        // Apply log filter
        //       Filter logFilter = new LogFilter();
        //       logFilter.setInputFormat(train);
        //       train = Filter.useFilter(train, logFilter);        
        //       logFilter.setInputFormat(test);
        //       test = Filter.useFilter(test, logFilter);

        // Copy the classifier
        Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

        // Instantiate the FilteredClassifier
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(removeIDFilter);
        filteredClassifier.setClassifier(classifier);

        // Build the classifier
        filteredClassifier.buildClassifier(train);

        // Prepare the output buffer 
        AbstractOutput output = new PlainText();
        output.setBuffer(new StringBuffer());
        output.setHeader(test);
        output.setAttributes("first");

        Evaluation eval = new Evaluation(train);
        eval.evaluateModel(filteredClassifier, test, output);

        // Convert predictions to CSV
        // Format: inst#, actual, predicted, error, probability, (ID)
        String[] scores = new String[new Double(eval.numInstances()).intValue()];
        double[] probabilities = new double[new Double(eval.numInstances()).intValue()];
        for (String line : output.getBuffer().toString().split("\n")) {
            String[] linesplit = line.split("\\s+");

            // If there's been an error, the length of linesplit is 6, otherwise 5,
            // due to the error flag "+"

            int id;
            String expectedValue, classification;
            double probability;

            if (line.contains("+")) {
                id = Integer.parseInt(linesplit[6].substring(1, linesplit[6].length() - 1));
                expectedValue = linesplit[2].substring(2);
                classification = linesplit[3].substring(2);
                probability = Double.parseDouble(linesplit[5]);
            } else {
                id = Integer.parseInt(linesplit[5].substring(1, linesplit[5].length() - 1));
                expectedValue = linesplit[2].substring(2);
                classification = linesplit[3].substring(2);
                probability = Double.parseDouble(linesplit[4]);
            }

            scores[id - 1] = classification;
            probabilities[id - 1] = probability;
        }

        System.out.println(eval.toSummaryString());
        System.out.println(eval.toMatrixString());

        // Output classifications
        StringBuilder sb = new StringBuilder();
        for (String score : scores)
            sb.append(score.toString() + LF);

        FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/"
                + wekaClassifier.toString() + "/" + testDataset.toString() + ".csv"), sb.toString());

        // Output probabilities
        sb = new StringBuilder();
        for (Double probability : probabilities)
            sb.append(probability.toString() + LF);

        FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/"
                + wekaClassifier.toString() + "/" + testDataset.toString() + ".probabilities.csv"), sb.toString());

        // Output predictions
        FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/"
                + wekaClassifier.toString() + "/" + testDataset.toString() + ".predictions.txt"),
                output.getBuffer().toString());

        // Output meta information
        sb = new StringBuilder();
        sb.append(classifier.toString() + LF);
        sb.append(eval.toSummaryString() + LF);
        sb.append(eval.toMatrixString() + LF);

        FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/"
                + wekaClassifier.toString() + "/" + testDataset.toString() + ".meta.txt"), sb.toString());
    }

    public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception {
        // Set parameters
        int folds = 10;
        Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier);

        // Set up the random number generator
        long seed = new Date().getTime();
        Random random = new Random(seed);

        // Add IDs to the instances
        AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o",
                MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" });
        Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff");
        data.setClassIndex(data.numAttributes() - 1);

        // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
        removeIDFilter.setAttributeIndices("first");

        // Randomize the data
        data.randomize(random);

        // Perform cross-validation
        Instances predictedData = null;
        Evaluation eval = new Evaluation(data);

        for (int n = 0; n < folds; n++) {
            Instances train = data.trainCV(folds, n, random);
            Instances test = data.testCV(folds, n);

            // Apply log filter
            //          Filter logFilter = new LogFilter();
            //           logFilter.setInputFormat(train);
            //           train = Filter.useFilter(train, logFilter);        
            //           logFilter.setInputFormat(test);
            //           test = Filter.useFilter(test, logFilter);

            // Copy the classifier
            Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

            // Instantiate the FilteredClassifier
            FilteredClassifier filteredClassifier = new FilteredClassifier();
            filteredClassifier.setFilter(removeIDFilter);
            filteredClassifier.setClassifier(classifier);

            // Build the classifier
            filteredClassifier.buildClassifier(train);

            // Evaluate
            eval.evaluateModel(filteredClassifier, test);

            // Add predictions
            AddClassification filter = new AddClassification();
            filter.setClassifier(classifier);
            filter.setOutputClassification(true);
            filter.setOutputDistribution(false);
            filter.setOutputErrorFlag(true);
            filter.setInputFormat(train);
            Filter.useFilter(train, filter); // trains the classifier

            Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
            if (predictedData == null)
                predictedData = new Instances(pred, 0);
            for (int j = 0; j < pred.numInstances(); j++)
                predictedData.add(pred.instance(j));
        }

        System.out.println(eval.toSummaryString());
        System.out.println(eval.toMatrixString());

        // Prepare output scores
        String[] scores = new String[predictedData.numInstances()];

        for (Instance predInst : predictedData) {
            int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

            int valueIdx = predictedData.numAttributes() - 2;

            String value = predInst.stringValue(predInst.attribute(valueIdx));

            scores[id] = value;
        }

        // Output classifications
        StringBuilder sb = new StringBuilder();
        for (String score : scores)
            sb.append(score.toString() + LF);

        FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString()
                + "/" + dataset.toString() + ".csv"), sb.toString());

        // Output prediction arff
        DataSink.write(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/"
                + dataset.toString() + ".predicted.arff", predictedData);

        // Output meta information
        sb = new StringBuilder();
        sb.append(baseClassifier.toString() + LF);
        sb.append(eval.toSummaryString() + LF);
        sb.append(eval.toMatrixString() + LF);

        FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString()
                + "/" + dataset.toString() + ".meta.txt"), sb.toString());
    }

    @SuppressWarnings("unchecked")
    public static void runEvaluationMetric(EvaluationMetric metric, Dataset dataset) throws IOException {
        // Get all subdirectories (i.e. all classifiers)
        File outputDir = new File(OUTPUT_DIR + "/" + dataset.toString() + "/");
        File[] dirsArray = outputDir.listFiles((FileFilter) FileFilterUtils.directoryFileFilter());

        List<File> dirs = CollectionUtils.arrayToList(dirsArray);

        // Don't list hidden dirs (such as .svn)
        for (int i = dirs.size() - 1; i >= 0; i--)
            if (dirs.get(i).getName().startsWith("."))
                dirs.remove(i);

        // Iteratively evaluate all classifiers' results
        for (File dir : dirs)
            runEvaluationMetric(WekaClassifier.valueOf(dir.getName()), metric, dataset);
    }

    public static void runEvaluationMetric(WekaClassifier wekaClassifier, EvaluationMetric metric, Dataset dataset)
            throws IOException {
        StringBuilder sb = new StringBuilder();

        if (metric == Accuracy) {
            // Read gold scores
            List<String> goldScores = FileUtils.readLines(new File(GOLD_DIR + "/" + dataset.toString() + ".txt"));

            // Read the experimental scores
            List<String> expScores = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/"
                    + wekaClassifier.toString() + "/" + dataset.toString() + ".csv"));

            // Compute the accuracy
            double acc = 0.0;
            for (int i = 0; i < goldScores.size(); i++) {
                // The predictions have a max length of 8 characters...
                if (goldScores.get(i).substring(0, Math.min(goldScores.get(i).length(), 8))
                        .equals(expScores.get(i).substring(0, Math.min(expScores.get(i).length(), 8))))
                    acc++;
            }
            acc = acc / goldScores.size();

            sb.append(acc);
        }
        if (metric == CWS) {
            // Read gold scores
            List<String> goldScores = FileUtils.readLines(new File(GOLD_DIR + "/" + dataset.toString() + ".txt"));

            // Read the experimental scores
            List<String> expScores = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/"
                    + wekaClassifier.toString() + "/" + dataset.toString() + ".csv"));

            // Read the confidence scores
            List<String> probabilities = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/"
                    + wekaClassifier.toString() + "/" + dataset.toString() + ".probabilities.csv"));

            // Combine the data
            List<CwsData> data = new ArrayList<CwsData>();

            for (int i = 0; i < goldScores.size(); i++) {
                CwsData cws = (new Evaluator()).new CwsData(Double.parseDouble(probabilities.get(i)),
                        goldScores.get(i), expScores.get(i));
                data.add(cws);
            }

            // Sort in descending order
            Collections.sort(data, Collections.reverseOrder());

            // Compute the CWS score
            double cwsScore = 0.0;
            for (int i = 0; i < data.size(); i++) {
                double cws_sub = 0.0;
                for (int j = 0; j <= i; j++) {
                    if (data.get(j).isCorrect())
                        cws_sub++;
                }
                cws_sub /= (i + 1);

                cwsScore += cws_sub;
            }
            cwsScore /= data.size();

            sb.append(cwsScore);
        }
        if (metric == AveragePrecision) {
            // Read gold scores
            List<String> goldScores = FileUtils.readLines(new File(GOLD_DIR + "/" + dataset.toString() + ".txt"));

            // Trim to 8 characters
            for (int i = 0; i < goldScores.size(); i++)
                if (goldScores.get(i).length() > 8)
                    goldScores.set(i, goldScores.get(i).substring(0, 8));

            // Read the experimental scores
            List<String> expScores = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/"
                    + wekaClassifier.toString() + "/" + dataset.toString() + ".csv"));

            // Trim to 8 characters
            for (int i = 0; i < expScores.size(); i++)
                if (expScores.get(i).length() > 8)
                    expScores.set(i, expScores.get(i).substring(0, 8));

            // Read the confidence scores
            List<String> probabilities = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/"
                    + wekaClassifier.toString() + "/" + dataset.toString() + ".probabilities.csv"));

            // Conflate UNKONWN + CONTRADICTION classes for 3-way classifications
            if (RteUtil.hasThreeWayClassification(dataset)) {
                // Gold
                for (int i = 0; i < goldScores.size(); i++)
                    if (goldScores.get(i).equals("CONTRADI") || goldScores.get(i).equals("NO")
                            || goldScores.get(i).equals("FALSE"))
                        goldScores.set(i, "FALSE");

                // Experimental
                for (int i = 0; i < expScores.size(); i++)
                    if (expScores.get(i).equals("CONTRADI") || expScores.get(i).equals("NO")
                            || expScores.get(i).equals("FALSE"))
                        expScores.set(i, "FALSE");
            }

            // Combine the data
            List<CwsData> data = new ArrayList<CwsData>();

            for (int i = 0; i < goldScores.size(); i++) {
                CwsData cws = (new Evaluator()).new CwsData(Double.parseDouble(probabilities.get(i)),
                        goldScores.get(i), expScores.get(i));
                data.add(cws);
            }

            // Sort in descending order
            Collections.sort(data, Collections.reverseOrder());

            // Compute the average precision
            double avgPrec = 0.0;
            int numPositive = 0;
            for (int i = 0; i < data.size(); i++) {
                double ap_sub = 0.0;
                if (data.get(i).isPositivePair()) {
                    numPositive++;

                    for (int j = 0; j <= i; j++) {
                        if (data.get(j).isCorrect())
                            ap_sub++;
                    }
                    ap_sub /= (i + 1);
                }

                avgPrec += ap_sub;
            }
            avgPrec /= numPositive;

            sb.append(avgPrec);
        }

        FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString()
                + "/" + dataset.toString() + "_" + metric.toString() + ".txt"), sb.toString());

        System.out.println("[" + wekaClassifier.toString() + "] " + metric.toString() + ": " + sb.toString());
    }

    private class CwsData implements Comparable {
        private double confidence;
        private String goldScore;
        private String expScore;

        public CwsData(double confidence, String goldScore, String expScore) {
            this.confidence = confidence;
            this.goldScore = goldScore;
            this.expScore = expScore;
        }

        public boolean isCorrect() {
            return goldScore.equals(expScore);
        }

        public int compareTo(Object other) {
            CwsData otherObj = (CwsData) other;

            if (this.getConfidence() == otherObj.getConfidence()) {
                return 0;
            } else if (this.getConfidence() > otherObj.getConfidence()) {
                return 1;
            } else {
                return -1;
            }
        }

        public boolean isPositivePair() {
            return this.goldScore.equals("TRUE") || this.goldScore.equals("YES")
                    || this.goldScore.equals("ENTAILMENT") || this.goldScore.equals("ENTAILME");
        }

        public double getConfidence() {
            return confidence;
        }

        public String getGoldScore() {
            return goldScore;
        }

        public String getExpScore() {
            return expScore;
        }
    }

    //   
    //   @SuppressWarnings("unchecked")
    //   private static void computePearsonCorrelation(Mode mode, Dataset dataset)
    //      throws IOException
    //   {
    //      File expScoresFile = new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv");
    //      
    //      String gsScoresFilePath = GOLDSTANDARD_DIR + "/" + mode.toString().toLowerCase() + "/" + 
    //            "STS.gs." + dataset.toString() + ".txt";
    //      
    //      PathMatchingResourcePatternResolver r = new PathMatchingResourcePatternResolver();
    //        Resource res = r.getResource(gsScoresFilePath);            
    //      File gsScoresFile = res.getFile();
    //      
    //      List<Double> expScores = new ArrayList<Double>();
    //      List<Double> gsScores = new ArrayList<Double>();
    //      
    //      List<String> expLines = FileUtils.readLines(expScoresFile);
    //      List<String> gsLines = FileUtils.readLines(gsScoresFile);
    //      
    //      for (int i = 0; i < expLines.size(); i++)
    //      {
    //         expScores.add(Double.parseDouble(expLines.get(i)));
    //         gsScores.add(Double.parseDouble(gsLines.get(i)));
    //      }
    //      
    //      double[] expArray = ArrayUtils.toPrimitive(expScores.toArray(new Double[expScores.size()])); 
    //      double[] gsArray = ArrayUtils.toPrimitive(gsScores.toArray(new Double[gsScores.size()]));
    //
    //      PearsonsCorrelation pearson = new PearsonsCorrelation();
    //      Double correl = pearson.correlation(expArray, gsArray);
    //      
    //      FileUtils.writeStringToFile(
    //            new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".txt"),
    //            correl.toString());
    //   }
}