org.openscience.cdk.applications.taverna.weka.regression.EvaluateRegressionResultsAsPDFActivity.java Source code

Java tutorial

Introduction

Here is the source code for org.openscience.cdk.applications.taverna.weka.regression.EvaluateRegressionResultsAsPDFActivity.java

Source

/*
 * Copyright (C) 2010 - 2011 by Andreas Truszkowski <ATruszkowski@gmx.de>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 2.1
 * of the License, or (at your option) any later version.
 * All we ask is that proper credit is given for our work, which includes
 * - but is not limited to - adding the above copyright notice to the beginning
 * of your source code files, and to any copyright notice that you may distribute
 * with programs based on this work.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
 */
package org.openscience.cdk.applications.taverna.weka.regression;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.UUID;

import net.sf.taverna.t2.reference.ExternalReferenceSPI;
import net.sf.taverna.t2.reference.impl.external.file.FileReference;
import net.sf.taverna.t2.reference.impl.external.object.InlineStringReference;

import org.jfree.chart.JFreeChart;
import org.jfree.data.category.DefaultCategoryDataset;
import org.jfree.data.xy.DefaultXYDataset;
import org.jfree.data.xy.XYSeries;
import org.openscience.cdk.applications.taverna.AbstractCDKActivity;
import org.openscience.cdk.applications.taverna.CDKTavernaConstants;
import org.openscience.cdk.applications.taverna.CDKTavernaException;
import org.openscience.cdk.applications.taverna.basicutilities.ChartTool;
import org.openscience.cdk.applications.taverna.basicutilities.CollectionUtilities;
import org.openscience.cdk.applications.taverna.basicutilities.ErrorLogger;
import org.openscience.cdk.applications.taverna.basicutilities.FileNameGenerator;
import org.openscience.cdk.applications.taverna.weka.utilities.WekaTools;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.filters.Filter;

/**
 * Class which implements the evaluate regression results as pdf activity.
 * 
 * @author Andreas Truzskowski
 * 
 */
public class EvaluateRegressionResultsAsPDFActivity extends AbstractCDKActivity {

    public static final String EVALUATE_REGRESSION_RESULTS_AS_PDF_ACTIVITY = "Evaluate Regression Results as PDF";
    public static final int TEST_TRAININGSET_PORT = 0;
    public static final int SINGLE_DATASET_PORT = 1;

    /**
     * Creates a new instance.
     */
    public EvaluateRegressionResultsAsPDFActivity() {
        this.INPUT_PORTS = new String[] { "Regression Model Files", "Regression Train Datasets",
                "Regression Test Datasets" };
        this.OUTPUT_PORTS = new String[] { "Files" };
    }

    @Override
    protected void addInputPorts() {
        String[] options = ((String) this.getConfiguration()
                .getAdditionalProperty(CDKTavernaConstants.PROPERTY_SCATTER_PLOT_OPTIONS)).split(";");
        List<Class<? extends ExternalReferenceSPI>> expectedReferences = new ArrayList<Class<? extends ExternalReferenceSPI>>();
        expectedReferences.add(FileReference.class);
        expectedReferences.add(InlineStringReference.class);
        addInput(this.INPUT_PORTS[0], 1, false, expectedReferences, null);
        if (options[0].equals("" + TEST_TRAININGSET_PORT)) {
            this.INPUT_PORTS[1] = "Regression Train Datasets";
            addInput(this.INPUT_PORTS[2], 1, true, null, byte[].class);
        } else {
            this.INPUT_PORTS[1] = "Weka Regression Datasets";
        }
        addInput(this.INPUT_PORTS[1], 1, false, null, byte[].class);
    }

    @Override
    protected void addOutputPorts() {
        addOutput(this.OUTPUT_PORTS[0], 1);
    }

    @Override
    public void work() throws Exception {
        // Get input
        String[] options = ((String) this.getConfiguration()
                .getAdditionalProperty(CDKTavernaConstants.PROPERTY_SCATTER_PLOT_OPTIONS)).split(";");
        List<File> modelFiles = this.getInputAsFileList(this.INPUT_PORTS[0]);
        List<Instances> trainDatasets = this.getInputAsList(this.INPUT_PORTS[1], Instances.class);
        List<Instances> testDatasets = null;
        if (options[0].equals("" + TEST_TRAININGSET_PORT)) {
            testDatasets = this.getInputAsList(this.INPUT_PORTS[2], Instances.class);
        } else {
            testDatasets = null;
        }
        String directory = modelFiles.get(0).getParent();

        // Do work
        ArrayList<String> resultFiles = new ArrayList<String>();
        HashMap<UUID, Double> orgClassMap = new HashMap<UUID, Double>();
        HashMap<UUID, Double> calcClassMap = new HashMap<UUID, Double>();
        WekaTools tools = new WekaTools();
        ChartTool chartTool = new ChartTool();
        List<Object> rmseCharts = new ArrayList<Object>();
        List<Double> trainMeanRMSE = new ArrayList<Double>();
        List<Double> testMeanRMSE = new ArrayList<Double>();
        List<Double> cvMeanRMSE = new ArrayList<Double>();
        DefaultCategoryDataset[] ratioRMSESet = new DefaultCategoryDataset[trainDatasets.size()];
        for (int i = 0; i < trainDatasets.size(); i++) {
            ratioRMSESet[i] = new DefaultCategoryDataset();
        }
        List<Double> trainingSetRatios = null;
        int fileIDX = 1;
        while (!modelFiles.isEmpty()) {
            trainingSetRatios = new ArrayList<Double>();
            List<Double> trainRMSE = new ArrayList<Double>();
            HashSet<Integer> trainSkippedRMSE = new HashSet<Integer>();
            List<Double> testRMSE = new ArrayList<Double>();
            HashSet<Integer> testSkippedRMSE = new HashSet<Integer>();
            List<Double> cvRMSE = new ArrayList<Double>();
            HashSet<Integer> cvSkippedRMSE = new HashSet<Integer>();
            List<Object> chartsObjects = new LinkedList<Object>();
            File modelFile = null;
            Classifier classifier = null;
            String name = "";
            for (int j = 0; j < trainDatasets.size(); j++) {
                LinkedList<Double> predictedValues = new LinkedList<Double>();
                LinkedList<Double> orgValues = new LinkedList<Double>();
                LinkedList<Double[]> yResidueValues = new LinkedList<Double[]>();
                LinkedList<String> yResidueNames = new LinkedList<String>();
                if (modelFiles.isEmpty()) {
                    break;
                }
                calcClassMap.clear();
                modelFile = modelFiles.remove(0);
                classifier = (Classifier) SerializationHelper.read(modelFile.getPath());
                Instances testset = null;
                if (testDatasets != null) {
                    testset = testDatasets.get(j);
                }
                name = classifier.getClass().getSimpleName();
                String sum = "Method: " + name + " " + tools.getOptionsFromFile(modelFile, name) + "\n\n";
                // Produce training set data
                Instances trainset = trainDatasets.get(j);
                Instances trainUUIDSet = Filter.useFilter(trainset, tools.getIDGetter(trainset));
                trainset = Filter.useFilter(trainset, tools.getIDRemover(trainset));
                double trainingSetRatio = 1.0;
                if (testset != null) {
                    trainingSetRatio = trainset.numInstances()
                            / (double) (trainset.numInstances() + testset.numInstances());
                }
                trainingSetRatios.add(trainingSetRatio * 100);
                // Predict
                for (int k = 0; k < trainset.numInstances(); k++) {
                    UUID uuid = UUID.fromString(trainUUIDSet.instance(k).stringValue(0));
                    orgClassMap.put(uuid, trainset.instance(k).classValue());
                    calcClassMap.put(uuid, classifier.classifyInstance(trainset.instance(k)));
                }
                // Evaluate
                Evaluation trainEval = new Evaluation(trainset);
                trainEval.evaluateModel(classifier, trainset);
                // Chart data
                DefaultXYDataset xyDataSet = new DefaultXYDataset();
                String trainSeries = "Training Set (RMSE: "
                        + String.format("%.2f", trainEval.rootMeanSquaredError()) + ")";
                XYSeries series = new XYSeries(trainSeries);
                Double[] yTrainResidues = new Double[trainUUIDSet.numInstances()];
                Double[] orgTrain = new Double[trainUUIDSet.numInstances()];
                Double[] calc = new Double[trainUUIDSet.numInstances()];
                for (int k = 0; k < trainUUIDSet.numInstances(); k++) {
                    UUID uuid = UUID.fromString(trainUUIDSet.instance(k).stringValue(0));
                    orgTrain[k] = orgClassMap.get(uuid);
                    calc[k] = calcClassMap.get(uuid);
                    if (calc[k] != null && orgTrain[k] != null) {
                        series.add(orgTrain[k].doubleValue(), calc[k]);
                        yTrainResidues[k] = calc[k].doubleValue() - orgTrain[k].doubleValue();
                    } else {
                        ErrorLogger.getInstance().writeError("Can't find value for UUID: " + uuid.toString(),
                                this.getActivityName());
                        throw new CDKTavernaException(this.getActivityName(),
                                "Can't find value for UUID: " + uuid.toString());
                    }
                }
                orgValues.addAll(Arrays.asList(orgTrain));
                predictedValues.addAll(Arrays.asList(calc));
                CollectionUtilities.sortTwoArrays(orgTrain, yTrainResidues);
                yResidueValues.add(yTrainResidues);
                yResidueNames.add(trainSeries);
                xyDataSet.addSeries(trainSeries, series.toArray());

                // Summary
                sum += "Training Set:\n";
                if (trainEval.rootRelativeSquaredError() > 300) {
                    trainSkippedRMSE.add(j);
                }
                trainRMSE.add(trainEval.rootMeanSquaredError());
                sum += trainEval.toSummaryString(true);
                // Produce test set data
                if (testset != null) {
                    Instances testUUIDSet = Filter.useFilter(testset, tools.getIDGetter(testset));
                    testset = Filter.useFilter(testset, tools.getIDRemover(testset));
                    // Predict
                    for (int k = 0; k < testset.numInstances(); k++) {
                        UUID uuid = UUID.fromString(testUUIDSet.instance(k).stringValue(0));
                        orgClassMap.put(uuid, testset.instance(k).classValue());
                        calcClassMap.put(uuid, classifier.classifyInstance(testset.instance(k)));
                    }
                    // Evaluate
                    Evaluation testEval = new Evaluation(testset);
                    testEval.evaluateModel(classifier, testset);
                    // Chart data
                    String testSeries = "Test Set (RMSE: " + String.format("%.2f", testEval.rootMeanSquaredError())
                            + ")";
                    series = new XYSeries(testSeries);
                    Double[] yTestResidues = new Double[testUUIDSet.numInstances()];
                    Double[] orgTest = new Double[testUUIDSet.numInstances()];
                    calc = new Double[testUUIDSet.numInstances()];
                    for (int k = 0; k < testUUIDSet.numInstances(); k++) {
                        UUID uuid = UUID.fromString(testUUIDSet.instance(k).stringValue(0));
                        orgTest[k] = orgClassMap.get(uuid);
                        calc[k] = calcClassMap.get(uuid);
                        if (calc[k] != null && orgTest[k] != null) {
                            series.add(orgTest[k].doubleValue(), calc[k].doubleValue());
                            yTestResidues[k] = calc[k].doubleValue() - orgTest[k].doubleValue();
                        } else {
                            ErrorLogger.getInstance().writeError("Can't find value for UUID: " + uuid.toString(),
                                    this.getActivityName());
                            throw new CDKTavernaException(this.getActivityName(),
                                    "Can't find value for UUID: " + uuid.toString());
                        }
                    }
                    orgValues.addAll(Arrays.asList(orgTest));
                    predictedValues.addAll(Arrays.asList(calc));
                    CollectionUtilities.sortTwoArrays(orgTest, yTestResidues);
                    yResidueValues.add(yTestResidues);
                    yResidueNames.add(testSeries);
                    xyDataSet.addSeries(testSeries, series.toArray());
                    // Create summary
                    sum += "\nTest Set:\n";
                    if (testEval.rootRelativeSquaredError() > 300) {
                        testSkippedRMSE.add(j);
                    }
                    testRMSE.add(testEval.rootMeanSquaredError());
                    sum += testEval.toSummaryString(true);
                }
                // Produce cross validation data
                if (Boolean.parseBoolean(options[1])) {
                    Evaluation cvEval = new Evaluation(trainset);
                    if (testset != null) {
                        Instances fullSet = tools.getFullSet(trainset, testset);
                        cvEval.crossValidateModel(classifier, fullSet, 10, new Random(1));
                    } else {
                        cvEval.crossValidateModel(classifier, trainset, 10, new Random(1));
                    }
                    sum += "\n10-fold cross-validation:\n";
                    if (cvEval.rootRelativeSquaredError() > 300) {
                        cvSkippedRMSE.add(j);
                    }
                    cvRMSE.add(cvEval.rootMeanSquaredError());
                    sum += cvEval.toSummaryString(true);
                }

                // Create scatter plot
                String header = classifier.getClass().getSimpleName() + "\n Training set ratio: "
                        + String.format("%.2f", trainingSetRatios.get(j)) + "%" + "\n Model name: "
                        + modelFile.getName();
                chartsObjects
                        .add(chartTool.createScatterPlot(xyDataSet, header, "Original values", "Predicted values"));
                // Create residue plot
                chartsObjects.add(chartTool.createResiduePlot(yResidueValues, header, "Index",
                        "(Predicted - Original)", yResidueNames));
                // Create curve
                Double[] tmpOrg = new Double[orgValues.size()];
                tmpOrg = orgValues.toArray(tmpOrg);
                Double[] tmpPred = new Double[predictedValues.size()];
                tmpPred = predictedValues.toArray(tmpPred);
                CollectionUtilities.sortTwoArrays(tmpOrg, tmpPred);
                DefaultXYDataset dataSet = new DefaultXYDataset();
                String orgName = "Original";
                XYSeries orgSeries = new XYSeries(orgName);
                String predName = "Predicted";
                XYSeries predSeries = new XYSeries(predName);
                for (int k = 0; k < tmpOrg.length; k++) {
                    orgSeries.add((k + 1), tmpOrg[k]);
                    predSeries.add((k + 1), tmpPred[k]);
                }
                dataSet.addSeries(orgName, orgSeries.toArray());
                dataSet.addSeries(predName, predSeries.toArray());
                chartsObjects.add(chartTool.createXYLineChart(header, "Index", "Value", dataSet, true, false));
                // Add summary
                chartsObjects.add(sum);
            }
            // Create RMSE Plot
            DefaultCategoryDataset dataSet = new DefaultCategoryDataset();
            double meanRMSE = 0;
            for (int i = 0; i < trainRMSE.size(); i++) {
                if (!trainSkippedRMSE.contains(i)) {
                    dataSet.addValue(trainRMSE.get(i), "Training Set",
                            "(" + String.format("%.2f", trainingSetRatios.get(i)) + "%/" + (i + 1) + ")");
                    ratioRMSESet[i].addValue(trainRMSE.get(i), "Training Set",
                            "(" + String.format("%.2f", trainingSetRatios.get(i)) + "%/" + (i + 1) + "/" + fileIDX
                                    + ")");
                }
                meanRMSE += trainRMSE.get(i);
            }
            trainMeanRMSE.add(meanRMSE / trainRMSE.size());
            meanRMSE = 0;
            if (!testRMSE.isEmpty()) {
                for (int i = 0; i < testRMSE.size(); i++) {
                    if (!testSkippedRMSE.contains(i)) {
                        dataSet.addValue(testRMSE.get(i), "Test Set",
                                "(" + String.format("%.2f", trainingSetRatios.get(i)) + "%/" + (i + 1) + ")");
                        ratioRMSESet[i].addValue(testRMSE.get(i), "Test Set",
                                "(" + String.format("%.2f", trainingSetRatios.get(i)) + "%/" + (i + 1) + "/"
                                        + fileIDX + ")");
                    }
                    meanRMSE += testRMSE.get(i);
                }
                testMeanRMSE.add(meanRMSE / testRMSE.size());
            }
            meanRMSE = 0;
            if (!cvRMSE.isEmpty()) {
                for (int i = 0; i < cvRMSE.size(); i++) {
                    if (!cvSkippedRMSE.contains(i)) {
                        dataSet.addValue(cvRMSE.get(i), "10-fold Cross-validation",
                                "(" + String.format("%.2f", trainingSetRatios.get(i)) + "%/" + (i + 1) + ")");
                        ratioRMSESet[i].addValue(cvRMSE.get(i), "10-fold Cross-validation",
                                "(" + String.format("%.2f", trainingSetRatios.get(i)) + "%/" + (i + 1) + "/"
                                        + fileIDX + ")");
                    }
                    meanRMSE += cvRMSE.get(i);
                }
                cvMeanRMSE.add(meanRMSE / cvRMSE.size());
            }
            JFreeChart rmseChart = chartTool.createLineChart(
                    "RMSE Plot\n Classifier:" + name + " " + tools.getOptionsFromFile(modelFile, name),
                    "(Training set ratio/Set Index/File index)", "RMSE", dataSet, false, true);
            chartsObjects.add(rmseChart);
            rmseCharts.add(rmseChart);
            // Write PDF
            File file = FileNameGenerator.getNewFile(directory, ".pdf", "ScatterPlot");
            chartTool.writeChartAsPDF(file, chartsObjects);
            resultFiles.add(file.getPath());
            fileIDX++;
        }
        // Create set ratio RMSE plots
        for (int i = 0; i < ratioRMSESet.length; i++) {
            JFreeChart rmseChart = chartTool
                    .createLineChart(
                            "Set RMSE plot\n" + "(" + String.format("%.2f", trainingSetRatios.get(i)) + "%/"
                                    + (i + 1) + ")",
                            "(Training set ratio/Index)", "RMSE", ratioRMSESet[i], false, true);
            rmseCharts.add(rmseChart);
        }
        // Create mean RMSE plot
        DefaultCategoryDataset dataSet = new DefaultCategoryDataset();
        for (int i = 0; i < trainMeanRMSE.size(); i++) {
            dataSet.addValue(trainMeanRMSE.get(i), "Training Set", "" + (i + 1));
        }
        for (int i = 0; i < testMeanRMSE.size(); i++) {
            dataSet.addValue(testMeanRMSE.get(i), "Test Set", "" + (i + 1));
        }
        for (int i = 0; i < cvMeanRMSE.size(); i++) {
            dataSet.addValue(cvMeanRMSE.get(i), "10-fold Cross-validation", "" + (i + 1));
        }
        JFreeChart rmseChart = chartTool.createLineChart("RMSE Mean Plot", "Dataset number", "Mean RMSE", dataSet);
        rmseCharts.add(rmseChart);
        File file = FileNameGenerator.getNewFile(directory, ".pdf", "RMSE-Sum");
        chartTool.writeChartAsPDF(file, rmseCharts);
        resultFiles.add(file.getPath());
        // Set output
        this.setOutputAsStringList(resultFiles, this.OUTPUT_PORTS[0]);
    }

    @Override
    public String getActivityName() {
        return EvaluateRegressionResultsAsPDFActivity.EVALUATE_REGRESSION_RESULTS_AS_PDF_ACTIVITY;
    }

    @Override
    public HashMap<String, Object> getAdditionalProperties() {
        HashMap<String, Object> properties = new HashMap<String, Object>();
        properties.put(CDKTavernaConstants.PROPERTY_SCATTER_PLOT_OPTIONS, "0;false");
        return properties;
    }

    @Override
    public String getDescription() {
        return "Description: " + EvaluateRegressionResultsAsPDFActivity.EVALUATE_REGRESSION_RESULTS_AS_PDF_ACTIVITY;
    }

    @Override
    public String getFolderName() {
        return CDKTavernaConstants.WEKA_REGRESSION_FOLDER_NAME;
    }
}