Java tutorial
/******************************************************************************* * Copyright 2013 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universitt Darmstadt * * All rights reserved. This program and the accompanying materials * are made available under the terms of the GNU Public License v3.0 * which accompanies this distribution, and is available at * http://www.gnu.org/licenses/gpl-3.0.txt ******************************************************************************/ package dkpro.similarity.experiments.rte.util; import static dkpro.similarity.experiments.rte.Pipeline.GOLD_DIR; import static dkpro.similarity.experiments.rte.Pipeline.MODELS_DIR; import static dkpro.similarity.experiments.rte.Pipeline.OUTPUT_DIR; import static dkpro.similarity.experiments.rte.Pipeline.EvaluationMetric.Accuracy; import static dkpro.similarity.experiments.rte.Pipeline.EvaluationMetric.AveragePrecision; import static dkpro.similarity.experiments.rte.Pipeline.EvaluationMetric.CWS; import java.io.File; import java.io.FileFilter; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Date; import java.util.List; import java.util.Random; import org.apache.commons.io.FileUtils; import org.apache.commons.io.filefilter.FileFilterUtils; import org.springframework.util.CollectionUtils; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.evaluation.output.prediction.AbstractOutput; import weka.classifiers.evaluation.output.prediction.PlainText; import weka.classifiers.meta.FilteredClassifier; import weka.core.Instance; import weka.core.Instances; import weka.core.converters.ConverterUtils.DataSink; import weka.core.converters.ConverterUtils.DataSource; import weka.filters.Filter; import weka.filters.supervised.attribute.AddClassification; import weka.filters.unsupervised.attribute.AddID; import weka.filters.unsupervised.attribute.Remove; import dkpro.similarity.algorithms.ml.ClassifierSimilarityMeasure; import dkpro.similarity.algorithms.ml.ClassifierSimilarityMeasure.WekaClassifier; import dkpro.similarity.experiments.rte.Pipeline.Dataset; import dkpro.similarity.experiments.rte.Pipeline.EvaluationMetric; //import de.tudarmstadt.ukp.similarity.experiments.rte.Pipeline.EvaluationMetric; //import de.tudarmstadt.ukp.similarity.experiments.rte.Pipeline.Mode; //import de.tudarmstadt.ukp.similarity.experiments.rte.filter.LogFilter; public class Evaluator { public static final String LF = System.getProperty("line.separator"); // public static void runClassifier(Dataset train, Dataset test) // throws UIMAException, IOException // { // CollectionReader reader = createCollectionReader( // RTECorpusReader.class, // RTECorpusReader.PARAM_INPUT_FILE, RteUtil.getInputFilePathForDataset(DATASET_DIR, test), // RTECorpusReader.PARAM_COMBINATION_STRATEGY, CombinationStrategy.SAME_ROW_ONLY.toString()); // // AnalysisEngineDescription seg = createPrimitiveDescription( // BreakIteratorSegmenter.class); // // AggregateBuilder builder = new AggregateBuilder(); // builder.add(seg, CombinationReader.INITIAL_VIEW, CombinationReader.VIEW_1); // builder.add(seg, CombinationReader.INITIAL_VIEW, CombinationReader.VIEW_2); // AnalysisEngine aggr_seg = builder.createAggregate(); // // AnalysisEngine scorer = createPrimitive( // SimilarityScorer.class, // SimilarityScorer.PARAM_NAME_VIEW_1, CombinationReader.VIEW_1, // SimilarityScorer.PARAM_NAME_VIEW_2, CombinationReader.VIEW_2, // SimilarityScorer.PARAM_SEGMENT_FEATURE_PATH, Document.class.getName(), // SimilarityScorer.PARAM_TEXT_SIMILARITY_RESOURCE, createExternalResourceDescription( // ClassifierResource.class, // ClassifierResource.PARAM_CLASSIFIER, wekaClassifier.toString(), // ClassifierResource.PARAM_TRAIN_ARFF, MODELS_DIR + "/" + train.toString() + ".arff", // ClassifierResource.PARAM_TEST_ARFF, MODELS_DIR + "/" + test.toString() + ".arff") // ); // // AnalysisEngine writer = createPrimitive( // SimilarityScoreWriter.class, // SimilarityScoreWriter.PARAM_OUTPUT_FILE, OUTPUT_DIR + "/" + test.toString() + ".csv", // SimilarityScoreWriter.PARAM_OUTPUT_SCORES_ONLY, true, // SimilarityScoreWriter.PARAM_OUTPUT_GOLD_SCORES, false); // // SimplePipeline.runPipeline(reader, aggr_seg, scorer, writer); // } public static void runClassifier(WekaClassifier wekaClassifier, Dataset trainDataset, Dataset testDataset) throws Exception { Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the train instances and get the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + trainDataset.toString() + ".arff", "-o", MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff" }); Instances train = DataSource.read(MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff"); train.setClassIndex(train.numAttributes() - 1); // Add IDs to the test instances and get the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + testDataset.toString() + ".arff", "-o", MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff" }); Instances test = DataSource.read(MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff"); test.setClassIndex(test.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data test.randomize(random); // Apply log filter // Filter logFilter = new LogFilter(); // logFilter.setInputFormat(train); // train = Filter.useFilter(train, logFilter); // logFilter.setInputFormat(test); // test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Prepare the output buffer AbstractOutput output = new PlainText(); output.setBuffer(new StringBuffer()); output.setHeader(test); output.setAttributes("first"); Evaluation eval = new Evaluation(train); eval.evaluateModel(filteredClassifier, test, output); // Convert predictions to CSV // Format: inst#, actual, predicted, error, probability, (ID) String[] scores = new String[new Double(eval.numInstances()).intValue()]; double[] probabilities = new double[new Double(eval.numInstances()).intValue()]; for (String line : output.getBuffer().toString().split("\n")) { String[] linesplit = line.split("\\s+"); // If there's been an error, the length of linesplit is 6, otherwise 5, // due to the error flag "+" int id; String expectedValue, classification; double probability; if (line.contains("+")) { id = Integer.parseInt(linesplit[6].substring(1, linesplit[6].length() - 1)); expectedValue = linesplit[2].substring(2); classification = linesplit[3].substring(2); probability = Double.parseDouble(linesplit[5]); } else { id = Integer.parseInt(linesplit[5].substring(1, linesplit[5].length() - 1)); expectedValue = linesplit[2].substring(2); classification = linesplit[3].substring(2); probability = Double.parseDouble(linesplit[4]); } scores[id - 1] = classification; probabilities[id - 1] = probability; } System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); // Output classifications StringBuilder sb = new StringBuilder(); for (String score : scores) sb.append(score.toString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".csv"), sb.toString()); // Output probabilities sb = new StringBuilder(); for (Double probability : probabilities) sb.append(probability.toString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".probabilities.csv"), sb.toString()); // Output predictions FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".predictions.txt"), output.getBuffer().toString()); // Output meta information sb = new StringBuilder(); sb.append(classifier.toString() + LF); sb.append(eval.toSummaryString() + LF); sb.append(eval.toMatrixString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".meta.txt"), sb.toString()); } public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception { // Set parameters int folds = 10; Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o", MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" }); Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff"); data.setClassIndex(data.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data data.randomize(random); // Perform cross-validation Instances predictedData = null; Evaluation eval = new Evaluation(data); for (int n = 0; n < folds; n++) { Instances train = data.trainCV(folds, n, random); Instances test = data.testCV(folds, n); // Apply log filter // Filter logFilter = new LogFilter(); // logFilter.setInputFormat(train); // train = Filter.useFilter(train, logFilter); // logFilter.setInputFormat(test); // test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Evaluate eval.evaluateModel(filteredClassifier, test); // Add predictions AddClassification filter = new AddClassification(); filter.setClassifier(classifier); filter.setOutputClassification(true); filter.setOutputDistribution(false); filter.setOutputErrorFlag(true); filter.setInputFormat(train); Filter.useFilter(train, filter); // trains the classifier Instances pred = Filter.useFilter(test, filter); // performs predictions on test set if (predictedData == null) predictedData = new Instances(pred, 0); for (int j = 0; j < pred.numInstances(); j++) predictedData.add(pred.instance(j)); } System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); // Prepare output scores String[] scores = new String[predictedData.numInstances()]; for (Instance predInst : predictedData) { int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1; int valueIdx = predictedData.numAttributes() - 2; String value = predInst.stringValue(predInst.attribute(valueIdx)); scores[id] = value; } // Output classifications StringBuilder sb = new StringBuilder(); for (String score : scores) sb.append(score.toString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".csv"), sb.toString()); // Output prediction arff DataSink.write(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".predicted.arff", predictedData); // Output meta information sb = new StringBuilder(); sb.append(baseClassifier.toString() + LF); sb.append(eval.toSummaryString() + LF); sb.append(eval.toMatrixString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".meta.txt"), sb.toString()); } @SuppressWarnings("unchecked") public static void runEvaluationMetric(EvaluationMetric metric, Dataset dataset) throws IOException { // Get all subdirectories (i.e. all classifiers) File outputDir = new File(OUTPUT_DIR + "/" + dataset.toString() + "/"); File[] dirsArray = outputDir.listFiles((FileFilter) FileFilterUtils.directoryFileFilter()); List<File> dirs = CollectionUtils.arrayToList(dirsArray); // Don't list hidden dirs (such as .svn) for (int i = dirs.size() - 1; i >= 0; i--) if (dirs.get(i).getName().startsWith(".")) dirs.remove(i); // Iteratively evaluate all classifiers' results for (File dir : dirs) runEvaluationMetric(WekaClassifier.valueOf(dir.getName()), metric, dataset); } public static void runEvaluationMetric(WekaClassifier wekaClassifier, EvaluationMetric metric, Dataset dataset) throws IOException { StringBuilder sb = new StringBuilder(); if (metric == Accuracy) { // Read gold scores List<String> goldScores = FileUtils.readLines(new File(GOLD_DIR + "/" + dataset.toString() + ".txt")); // Read the experimental scores List<String> expScores = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".csv")); // Compute the accuracy double acc = 0.0; for (int i = 0; i < goldScores.size(); i++) { // The predictions have a max length of 8 characters... if (goldScores.get(i).substring(0, Math.min(goldScores.get(i).length(), 8)) .equals(expScores.get(i).substring(0, Math.min(expScores.get(i).length(), 8)))) acc++; } acc = acc / goldScores.size(); sb.append(acc); } if (metric == CWS) { // Read gold scores List<String> goldScores = FileUtils.readLines(new File(GOLD_DIR + "/" + dataset.toString() + ".txt")); // Read the experimental scores List<String> expScores = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".csv")); // Read the confidence scores List<String> probabilities = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".probabilities.csv")); // Combine the data List<CwsData> data = new ArrayList<CwsData>(); for (int i = 0; i < goldScores.size(); i++) { CwsData cws = (new Evaluator()).new CwsData(Double.parseDouble(probabilities.get(i)), goldScores.get(i), expScores.get(i)); data.add(cws); } // Sort in descending order Collections.sort(data, Collections.reverseOrder()); // Compute the CWS score double cwsScore = 0.0; for (int i = 0; i < data.size(); i++) { double cws_sub = 0.0; for (int j = 0; j <= i; j++) { if (data.get(j).isCorrect()) cws_sub++; } cws_sub /= (i + 1); cwsScore += cws_sub; } cwsScore /= data.size(); sb.append(cwsScore); } if (metric == AveragePrecision) { // Read gold scores List<String> goldScores = FileUtils.readLines(new File(GOLD_DIR + "/" + dataset.toString() + ".txt")); // Trim to 8 characters for (int i = 0; i < goldScores.size(); i++) if (goldScores.get(i).length() > 8) goldScores.set(i, goldScores.get(i).substring(0, 8)); // Read the experimental scores List<String> expScores = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".csv")); // Trim to 8 characters for (int i = 0; i < expScores.size(); i++) if (expScores.get(i).length() > 8) expScores.set(i, expScores.get(i).substring(0, 8)); // Read the confidence scores List<String> probabilities = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".probabilities.csv")); // Conflate UNKONWN + CONTRADICTION classes for 3-way classifications if (RteUtil.hasThreeWayClassification(dataset)) { // Gold for (int i = 0; i < goldScores.size(); i++) if (goldScores.get(i).equals("CONTRADI") || goldScores.get(i).equals("NO") || goldScores.get(i).equals("FALSE")) goldScores.set(i, "FALSE"); // Experimental for (int i = 0; i < expScores.size(); i++) if (expScores.get(i).equals("CONTRADI") || expScores.get(i).equals("NO") || expScores.get(i).equals("FALSE")) expScores.set(i, "FALSE"); } // Combine the data List<CwsData> data = new ArrayList<CwsData>(); for (int i = 0; i < goldScores.size(); i++) { CwsData cws = (new Evaluator()).new CwsData(Double.parseDouble(probabilities.get(i)), goldScores.get(i), expScores.get(i)); data.add(cws); } // Sort in descending order Collections.sort(data, Collections.reverseOrder()); // Compute the average precision double avgPrec = 0.0; int numPositive = 0; for (int i = 0; i < data.size(); i++) { double ap_sub = 0.0; if (data.get(i).isPositivePair()) { numPositive++; for (int j = 0; j <= i; j++) { if (data.get(j).isCorrect()) ap_sub++; } ap_sub /= (i + 1); } avgPrec += ap_sub; } avgPrec /= numPositive; sb.append(avgPrec); } FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + "_" + metric.toString() + ".txt"), sb.toString()); System.out.println("[" + wekaClassifier.toString() + "] " + metric.toString() + ": " + sb.toString()); } private class CwsData implements Comparable { private double confidence; private String goldScore; private String expScore; public CwsData(double confidence, String goldScore, String expScore) { this.confidence = confidence; this.goldScore = goldScore; this.expScore = expScore; } public boolean isCorrect() { return goldScore.equals(expScore); } public int compareTo(Object other) { CwsData otherObj = (CwsData) other; if (this.getConfidence() == otherObj.getConfidence()) { return 0; } else if (this.getConfidence() > otherObj.getConfidence()) { return 1; } else { return -1; } } public boolean isPositivePair() { return this.goldScore.equals("TRUE") || this.goldScore.equals("YES") || this.goldScore.equals("ENTAILMENT") || this.goldScore.equals("ENTAILME"); } public double getConfidence() { return confidence; } public String getGoldScore() { return goldScore; } public String getExpScore() { return expScore; } } // // @SuppressWarnings("unchecked") // private static void computePearsonCorrelation(Mode mode, Dataset dataset) // throws IOException // { // File expScoresFile = new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv"); // // String gsScoresFilePath = GOLDSTANDARD_DIR + "/" + mode.toString().toLowerCase() + "/" + // "STS.gs." + dataset.toString() + ".txt"; // // PathMatchingResourcePatternResolver r = new PathMatchingResourcePatternResolver(); // Resource res = r.getResource(gsScoresFilePath); // File gsScoresFile = res.getFile(); // // List<Double> expScores = new ArrayList<Double>(); // List<Double> gsScores = new ArrayList<Double>(); // // List<String> expLines = FileUtils.readLines(expScoresFile); // List<String> gsLines = FileUtils.readLines(gsScoresFile); // // for (int i = 0; i < expLines.size(); i++) // { // expScores.add(Double.parseDouble(expLines.get(i))); // gsScores.add(Double.parseDouble(gsLines.get(i))); // } // // double[] expArray = ArrayUtils.toPrimitive(expScores.toArray(new Double[expScores.size()])); // double[] gsArray = ArrayUtils.toPrimitive(gsScores.toArray(new Double[gsScores.size()])); // // PearsonsCorrelation pearson = new PearsonsCorrelation(); // Double correl = pearson.correlation(expArray, gsArray); // // FileUtils.writeStringToFile( // new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".txt"), // correl.toString()); // } }