Java tutorial
/******************************************************************************* * Copyright 2013 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universitt Darmstadt * * All rights reserved. This program and the accompanying materials * are made available under the terms of the GNU Public License v3.0 * which accompanies this distribution, and is available at * http://www.gnu.org/licenses/gpl-3.0.txt ******************************************************************************/ package de.tudarmstadt.ukp.similarity.experiments.coling2012.util; import static de.tudarmstadt.ukp.similarity.experiments.coling2012.Pipeline.MODELS_DIR; import static de.tudarmstadt.ukp.similarity.experiments.coling2012.Pipeline.OUTPUT_DIR; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Date; import java.util.HashSet; import java.util.List; import java.util.Random; import java.util.Set; import org.apache.commons.io.FileUtils; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.bayes.NaiveBayes; import weka.classifiers.meta.FilteredClassifier; import weka.classifiers.trees.J48; import weka.core.Instance; import weka.core.Instances; import weka.core.converters.ConverterUtils.DataSource; import weka.filters.Filter; import weka.filters.supervised.attribute.AddClassification; import weka.filters.unsupervised.attribute.AddID; import weka.filters.unsupervised.attribute.Remove; import de.tudarmstadt.ukp.similarity.experiments.coling2012.Pipeline.Dataset; import de.tudarmstadt.ukp.similarity.experiments.coling2012.Pipeline.EvaluationMetric; public class Evaluator { public static final String LF = System.getProperty("line.separator"); public enum WekaClassifier { NAIVE_BAYES, J48 } public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception { // Set parameters int folds = 10; Classifier baseClassifier = getClassifier(wekaClassifier); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o", MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" }); Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff"); data.setClassIndex(data.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data data.randomize(random); // Perform cross-validation Instances predictedData = null; Evaluation eval = new Evaluation(data); for (int n = 0; n < folds; n++) { Instances train = data.trainCV(folds, n, random); Instances test = data.testCV(folds, n); // Apply log filter // Filter logFilter = new LogFilter(); // logFilter.setInputFormat(train); // train = Filter.useFilter(train, logFilter); // logFilter.setInputFormat(test); // test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Evaluate eval.evaluateModel(filteredClassifier, test); // Add predictions AddClassification filter = new AddClassification(); filter.setClassifier(filteredClassifier); filter.setOutputClassification(true); filter.setOutputDistribution(false); filter.setOutputErrorFlag(true); filter.setInputFormat(train); Filter.useFilter(train, filter); // trains the classifier Instances pred = Filter.useFilter(test, filter); // performs predictions on test set if (predictedData == null) predictedData = new Instances(pred, 0); for (int j = 0; j < pred.numInstances(); j++) predictedData.add(pred.instance(j)); } // Prepare output classification String[] scores = new String[predictedData.numInstances()]; for (Instance predInst : predictedData) { int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1; int valueIdx = predictedData.numAttributes() - 2; String value = predInst.stringValue(predInst.attribute(valueIdx)); scores[id] = value; } // Output StringBuilder sb = new StringBuilder(); for (String score : scores) sb.append(score.toString() + LF); FileUtils.writeStringToFile( new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/output.csv"), sb.toString()); } @SuppressWarnings("unchecked") public static void runEvaluationMetric(WekaClassifier wekaClassifier, EvaluationMetric metric, Dataset dataset) throws IOException { StringBuilder sb = new StringBuilder(); List<String> gold = ColingUtils.readGoldstandard(dataset); List<String> exp = FileUtils.readLines( new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/output.csv")); if (metric.equals(EvaluationMetric.Accuracy)) { double acc = 0.0; for (int i = 0; i < gold.size(); i++) { if (gold.get(i).equals(exp.get(i))) acc++; } acc /= gold.size(); sb.append(acc); } else if (metric.equals(EvaluationMetric.AverageF1)) { // Get all classes Set<String> classesSet = new HashSet<String>(); for (String cl : gold) classesSet.add(cl); // Order the classes List<String> classes = new ArrayList<String>(classesSet); // Initialize confusion matrix // exp\class A B // A x1 x2 // B x3 x4 int[][] matrix = new int[classes.size()][classes.size()]; // Initialize matrix for (int i = 0; i < classes.size(); i++) for (int j = 0; j < classes.size(); j++) matrix[i][j] = 0; // Construct confusion matrix for (int i = 0; i < gold.size(); i++) { int goldIndex = classes.indexOf(gold.get(i)); int expIndex = classes.indexOf(exp.get(i)); matrix[goldIndex][expIndex] += 1; } // Compute precision and recall per class double[] prec = new double[classes.size()]; double[] rec = new double[classes.size()]; for (int i = 0; i < classes.size(); i++) { double tp = matrix[i][i]; double fp = 0.0; double fn = 0.0; // FP for (int j = 0; j < classes.size(); j++) { if (i == j) continue; fp += matrix[j][i]; } // FN for (int j = 0; j < classes.size(); j++) { if (i == j) continue; fn += matrix[i][j]; } // Save prec[i] = tp / (tp + fp); rec[i] = tp / (tp + fn); } // Compute average F1 score across all classes double f1 = 0.0; for (int i = 0; i < classes.size(); i++) { double f1PerClass = (2 * prec[i] * rec[i]) / (prec[i] + rec[i]); f1 += f1PerClass; } f1 = f1 / classes.size(); // Output sb.append(f1); } FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + metric.toString() + ".txt"), sb.toString()); } public static Classifier getClassifier(WekaClassifier classifier) throws IllegalArgumentException { try { switch (classifier) { case NAIVE_BAYES: return new NaiveBayes(); case J48: J48 j48 = new J48(); j48.setOptions(new String[] { "-C", "0.25", "-M", "2" }); return j48; // case SMO: // SMO smo = new SMO(); // smo.setOptions(Utils.splitOptions("-C 1.0 -L 0.001 -P 1.0E-12 -N 0 -V -1 -W 1 -K \"weka.classifiers.functions.supportVector.PolyKernel -C 250007 -E 1.0\"")); // return smo; // case LOGISTIC: // Logistic logistic = new Logistic(); // logistic.setOptions(Utils.splitOptions("-R 1.0E-8 -M -1")); // return logistic; default: throw new IllegalArgumentException("Classifier " + classifier + " not found!"); } } catch (Exception e) { throw new IllegalArgumentException(e); } } }