org.dkpro.tc.core.util.ReportUtils.java Source code

Java tutorial

Introduction

Here is the source code for org.dkpro.tc.core.util.ReportUtils.java

Source

/*******************************************************************************
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ******************************************************************************/
package org.dkpro.tc.core.util;

import static org.dkpro.tc.core.util.ReportConstants.CORRELATION;
import static org.dkpro.tc.core.util.ReportConstants.PCT_CORRECT;
import static org.dkpro.tc.core.util.ReportConstants.PCT_INCORRECT;

import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeSet;
import java.util.regex.Pattern;

//import meka.core.Result;
//import mulan.evaluation.measure.MicroPrecision;
//import mulan.evaluation.measure.MicroRecall;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.text.StrTokenizer;
import org.dkpro.lab.engine.TaskContext;
import org.dkpro.lab.reporting.ChartUtil;
import org.dkpro.lab.reporting.FlexTable;
import org.dkpro.lab.storage.StorageService;
import org.dkpro.lab.storage.StreamWriter;
import org.dkpro.lab.storage.impl.PropertiesAdapter;
import org.dkpro.tc.api.exception.TextClassificationException;
import org.dkpro.tc.core.Constants;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.renderer.xy.XYSplineRenderer;
import org.jfree.data.xy.DefaultXYDataset;

/**
 * Utility methods needed in reports
 */
public class ReportUtils {
    /**
     * Creates a confusion matrix by collecting the results from the overall CV run stored in
     * {@code tempM}
     * 
     * @param tempM temporary map
     * @param actualLabelsList
     *            the label powerset transformed list of actual/true labels
     * @param predictedLabelsList
     *            the label powerset transformed list of predicted labels
     * @return
     */
    public static double[][] createConfusionMatrix(HashMap<String, Map<String, Integer>> tempM,
            List<String> actualLabelsList, List<String> predictedLabelsList) {
        double[][] matrix = new double[actualLabelsList.size()][predictedLabelsList.size()];

        Iterator<String> actualsIter = tempM.keySet().iterator();
        while (actualsIter.hasNext()) {
            String actual = actualsIter.next();
            Iterator<String> predsIter = tempM.get(actual).keySet().iterator();
            while (predsIter.hasNext()) {
                String pred = predsIter.next();
                int a = actualLabelsList.indexOf(actual);
                int p = predictedLabelsList.indexOf(pred);
                matrix[a][p] = tempM.get(actual).get(pred);
            }
        }
        return matrix;
    }

    /**
     * Converts a bipartition array into a list of class names. Parameter arrays must have the same
     * length
     */
    public static String doubleArrayToClassNames(int[] labels, String[] classNames, Character separatorChar) {
        StringBuffer buffer = new StringBuffer();

        for (int y = 0; y < labels.length; y++) {
            if (labels[y] == 1) {
                buffer.append(classNames[y] + separatorChar);
            }
        }
        String classString;
        try {
            classString = buffer.substring(0, buffer.length() - 1).toString();
        } catch (StringIndexOutOfBoundsException e) {
            classString = "";
        }
        return classString;
    }

    /**
     * Adds results from one fold to the overall CV results
     */
    public static void addToResults(Map<String, Double> results, Map<String, List<Double>> cvResults) {
        for (Entry<String, Double> entry : results.entrySet()) {
            if (cvResults.get(entry.getKey()) != null) {
                cvResults.get(entry.getKey()).add(entry.getValue());
            } else {
                List<Double> d = new ArrayList<Double>();
                d.add(entry.getValue());
                cvResults.put(entry.getKey(), d);
            }
        }
    }

    public static DefaultXYDataset createXYDataset(List<double[][]> prcData) {
        DefaultXYDataset dataset = new DefaultXYDataset();
        double[][] data = new double[2][11];

        double[] avPrec = new double[11];
        double[] avRec = new double[11];

        for (int i = 0; i < prcData.size(); i++) {
            double[] r = prcData.get(i)[0];
            for (int j = 0; j < r.length; j++) {
                avRec[j] += r[j];
            }
            double[] p = prcData.get(i)[1];
            for (int j = 0; j < p.length; j++) {
                avPrec[j] += p[j];
            }
        }
        for (int i = 0; i < avPrec.length; i++) {
            avPrec[i] = avPrec[i] / prcData.size();
            avRec[i] = avRec[i] / prcData.size();
        }
        data[0] = avRec;
        data[1] = avPrec;
        dataset.addSeries("PR-Curve", data);
        return dataset;
    }

    /**
     * From TrecTool README:
     * 
     * Interpolated Recall - Precision Averages: at 0.00 at 0.10 ... at 1.00 See any standard IR
     * text (especially by Salton) for more details of recall-precision evaluation. Measures
     * precision (percent of retrieved docs that are relevant) at various recall levels (after a
     * certain percentage of all the relevant docs for that query have been retrieved).
     * 'Interpolated' means that, for example, precision at recall 0.10 (ie, after 10% of rel docs
     * for a query have been retrieved) is taken to be MAXIMUM of precision at all recall points 
     * &gt;= 0.10. Values are averaged over all queries (for each of the 11 recall levels). These 
     * values are used for Recall-Precision graphs.
     */
    public static class PrecisionRecallDiagramRenderer implements StreamWriter {
        private DefaultXYDataset dataset;

        public PrecisionRecallDiagramRenderer(DefaultXYDataset aDataset) {
            dataset = aDataset;
        }

        @Override
        public void write(OutputStream aStream) throws IOException {
            JFreeChart chart = ChartFactory.createXYLineChart(null, "Recall", "Precision", dataset,
                    PlotOrientation.VERTICAL, false, false, false);
            chart.getXYPlot().setRenderer(new XYSplineRenderer());
            chart.getXYPlot().getRangeAxis().setRange(0.0, 1.0);
            chart.getXYPlot().getDomainAxis().setRange(0.0, 1.0);
            ChartUtil.writeChartAsSVG(aStream, chart, 400, 400);
        }
    }

    public static boolean containsExcludePattern(String string, List<String> patterns) {

        Pattern matchPattern;
        for (String pattern : patterns) {
            matchPattern = Pattern.compile(pattern);
            if (matchPattern.matcher(string).find()) {
                return true;
            }
        }
        return false;
    }

    /**
     * Looks into the {@link FlexTable} and outputs general performance numbers if available
     */
    public static String getPerformanceOverview(FlexTable<String> table) {
        // output some general performance figures
        // TODO this is a bit of a hack. Is there a better way?
        Set<String> columnIds = new HashSet<String>(Arrays.asList(table.getColumnIds()));
        StringBuffer buffer = new StringBuffer("\n");
        if (columnIds.contains(PCT_CORRECT) && columnIds.contains(PCT_INCORRECT)) {
            int i = 0;
            buffer.append("ID\t% CORRECT\t% INCORRECT\n");
            for (String id : table.getRowIds()) {
                String correct = table.getValueAsString(id, PCT_CORRECT);
                String incorrect = table.getValueAsString(id, PCT_INCORRECT);
                buffer.append(i + "\t" + correct + "\t" + incorrect + "\n");
                i++;
            }
            buffer.append("\n");
        } else if (columnIds.contains(CORRELATION)) {
            int i = 0;
            buffer.append("ID\tCORRELATION\n");
            for (String id : table.getRowIds()) {
                String correlation = table.getValueAsString(id, CORRELATION);
                buffer.append(i + "\t" + correlation + "\n");
                i++;
            }
            buffer.append("\n");
        }
        return buffer.toString();
    }

    /**
     * Adds results from a serialized matrix to a map
     * 
     * @param aggregateMap
     * @param matrix
     *            a csv matrix with the class names in the first row and first column
     * @return updated map
     * @throws IOException
     */
    public static Map<List<String>, Double> updateAggregateMatrix(Map<List<String>, Double> aggregateMap,
            File matrix) throws IOException {
        List<String> confMatrixLines = FileUtils.readLines(matrix);
        StrTokenizer l = StrTokenizer.getCSVInstance(confMatrixLines.get(0));
        l.setDelimiterChar(',');
        String[] headline = l.getTokenArray();

        for (int i = 1; i < confMatrixLines.size(); i++) {
            for (int j = 1; j < headline.length; j++) {
                StrTokenizer line = StrTokenizer.getCSVInstance(confMatrixLines.get(i));
                String pred = headline[j];
                line.setDelimiterChar(',');
                String act = line.getTokenArray()[0];
                double value = Double.valueOf(line.getTokenArray()[j]);

                List<String> key = new ArrayList<String>(Arrays.asList(new String[] { pred, act }));

                if (aggregateMap.get(key) != null) {
                    aggregateMap.put(key, aggregateMap.get(key) + value);
                } else {
                    aggregateMap.put(key, value);
                }
            }
        }
        return aggregateMap;
    }

    /**
     * Converts a map containing a matrix into a matrix
     * 
     * @param aggregateMap
     *            a map created with {@link ReportUtils#updateAggregateMatrix(Map, File)}
     * @see ReportUtils#updateAggregateMatrix(Map, File)
     * @return a table with the matrix
     */
    public static FlexTable<String> createOverallConfusionMatrix(Map<List<String>, Double> aggregateMap) {
        FlexTable<String> cMTable = FlexTable.forClass(String.class);
        cMTable.setSortRows(false);

        Set<String> labelsPred = new TreeSet<String>();
        Set<String> labelsAct = new TreeSet<String>();

        // sorting rows/columns
        for (List<String> key : aggregateMap.keySet()) {
            labelsPred.add(key.get(0).substring(0, key.get(0).indexOf(Constants.CM_PREDICTED)));
            labelsAct.add(key.get(1).substring(0, key.get(1).indexOf(Constants.CM_ACTUAL)));
        }

        List<String> labelsPredL = new ArrayList<String>(labelsPred);
        List<String> labelsActL = new ArrayList<String>(labelsAct);

        // create temporary matrix
        double[][] tempM = new double[labelsAct.size()][labelsPred.size()];
        for (List<String> key : aggregateMap.keySet()) {
            int c = labelsPredL.indexOf(key.get(0).substring(0, key.get(0).indexOf(Constants.CM_PREDICTED)));
            int r = labelsActL.indexOf(key.get(1).substring(0, key.get(1).indexOf(Constants.CM_ACTUAL)));
            tempM[r][c] = aggregateMap.get(key);
        }

        // convert to FlexTable
        for (int i = 0; i < tempM.length; i++) {
            LinkedHashMap<String, String> row = new LinkedHashMap<String, String>();
            for (int r = 0; r < tempM[0].length; r++) {
                row.put(labelsPredL.get(r) + " " + Constants.CM_PREDICTED, String.valueOf(tempM[i][r]));
            }
            cMTable.addRow(labelsActL.get(i) + " " + Constants.CM_ACTUAL, row);
        }

        return cMTable;
    }

    /**
     * Find a specific discriminator value given a discriminator key.
     * @param discriminatorsMap The map to search
     * @param discriminatorName The name of the discriminator
     * @return The discriminator value for the given key
     * @throws TextClassificationException
     */
    public static String getDiscriminatorValue(Map<String, String> discriminatorsMap, String discriminatorName)
            throws TextClassificationException {
        for (String key : discriminatorsMap.keySet()) {
            if (key.split("\\|")[1].equals(discriminatorName)) {
                return discriminatorsMap.get(key);
            }
        }
        throw new TextClassificationException(discriminatorName + " not found in discriminators set.");
    }

    public static Map<String, String> getDiscriminatorsForContext(StorageService store, String contextId,
            String discriminatorsKey) {
        return store.retrieveBinary(contextId, discriminatorsKey, new PropertiesAdapter()).getMap();
    }

    public static void writeExcelAndCSV(TaskContext context, String contextLabel, FlexTable<String> table,
            String evalFileName, String suffixExcel, String suffixCsv) {
        StorageService store = context.getStorageService();
        context.getLoggingService().message(contextLabel, ReportUtils.getPerformanceOverview(table));
        // Excel cannot cope with more than 255 columns
        if (table.getColumnIds().length <= 255) {
            context.storeBinary(evalFileName + "_compact" + suffixExcel, table.getExcelWriter());
        }
        context.storeBinary(evalFileName + "_compact" + suffixCsv, table.getCsvWriter());
        table.setCompact(false);
        // Excel cannot cope with more than 255 columns
        if (table.getColumnIds().length <= 255) {
            context.storeBinary(evalFileName + suffixExcel, table.getExcelWriter());
        }
        context.storeBinary(evalFileName + suffixCsv, table.getCsvWriter());

        // output the location of the batch evaluation folder
        // otherwise it might be hard for novice users to locate this
        File dummyFolder = store.locateKey(context.getId(), "dummy");
        // TODO can we also do this without creating and deleting the dummy folder?
        context.getLoggingService().message(contextLabel,
                "Storing detailed results in:\n" + dummyFolder.getParent() + "\n");
        dummyFolder.delete();
    }

    public static Map<String, String> clearDiscriminatorsByExcludePattern(Map<String, String> discriminatorsMap,
            List<String> discriminatorsToExclude) {
        Map<String, String> cleanedDiscriminatorsMap = new HashMap<String, String>();

        for (String disc : discriminatorsMap.keySet()) {
            if (!ReportUtils.containsExcludePattern(disc, discriminatorsToExclude)) {
                cleanedDiscriminatorsMap.put(disc, discriminatorsMap.get(disc));
            }
        }
        return cleanedDiscriminatorsMap;
    }

}