Java tutorial
/* * Copyright 2015 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universitt Darmstadt * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.dkpro.argumentation.sequence.report; import de.tudarmstadt.ukp.dkpro.tc.svmhmm.util.ConfusionMatrix; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVParser; import org.apache.commons.csv.CSVRecord; import org.apache.commons.io.FileUtils; import org.apache.commons.lang.StringUtils; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.*; /** * @author Ivan Habernal */ public class ConfusionMatrixTools { private static final String GLUE = "\t"; public static ConfusionMatrix tokenLevelPredictionsToConfusionMatrix(File predictionsFile) throws IOException { ConfusionMatrix cm = new ConfusionMatrix(); CSVParser csvParser = new CSVParser(new FileReader(predictionsFile), CSVFormat.DEFAULT.withCommentMarker('#')); for (CSVRecord csvRecord : csvParser) { // update confusion matrix cm.increaseValue(csvRecord.get(0), csvRecord.get(1)); } return cm; } public static String prettyPrintConfusionMatrixResults(ConfusionMatrix cm) { cm.printNiceResults(); String f = "%.3f"; List<String> header = new ArrayList<>(); List<String> row = new ArrayList<>(); header.add("Macro F1"); header.add("Accuracy"); header.add("Acc CI@95"); row.add(String.format(Locale.ENGLISH, f, cm.getMacroFMeasure())); row.add(String.format(Locale.ENGLISH, f, cm.getAccuracy())); row.add(String.format(Locale.ENGLISH, f, cm.getConfidence95Accuracy())); Map<String, Double> precisionForLabels = cm.getPrecisionForLabels(); Map<String, Double> recallForLabels = cm.getRecallForLabels(); Map<String, Double> fMForLabels = cm.getFMeasureForLabels(); SortedSet<String> labels = new TreeSet<>(precisionForLabels.keySet()); for (String label : labels) { header.add(label + " P"); row.add(String.format(Locale.ENGLISH, f, precisionForLabels.get(label))); header.add(label + " R"); row.add(String.format(Locale.ENGLISH, f, recallForLabels.get(label))); header.add(label + " F1"); row.add(String.format(Locale.ENGLISH, f, fMForLabels.get(label))); } return StringUtils.join(header, GLUE) + "\n" + StringUtils.join(row, GLUE); } public static void generateNiceTable(File predictionsFile) throws IOException { ConfusionMatrix cm = tokenLevelPredictionsToConfusionMatrix(predictionsFile); File outFile = new File(predictionsFile.getParent(), "niceResults.csv"); FileUtils.writeStringToFile(outFile, prettyPrintConfusionMatrixResults(cm)); System.out.println("Writing " + outFile); } public static void main(String[] args) throws IOException { String path = args[0]; for (File file : FileUtils.listFiles(new File(path), new String[] { "csv" }, true)) { if (file.getName().startsWith("tokenLevelPredictions")) { generateNiceTable(file); } } } }