Java tutorial
/* * Copyright (C) 2006-2010 Institute for Computational Biomedicine, * Weill Medical College of Cornell University * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package edu.cornell.med.icb.learning; import cern.jet.random.engine.RandomEngine; import edu.cornell.med.icb.R.RConnectionPool; import edu.cornell.med.icb.learning.tools.svmlight.EvaluationMeasure; import edu.cornell.med.icb.stat.AccuracyCalculator; import edu.cornell.med.icb.stat.AreaUnderTheRocCurveCalculator; import edu.cornell.med.icb.stat.MatthewsCorrelationCalculator; import edu.cornell.med.icb.stat.PredictionStatisticCalculator; import edu.cornell.med.icb.stat.RootMeanSquaredErrorCalculator; import edu.cornell.med.icb.stat.SensitivityCalculator; import edu.cornell.med.icb.stat.SpecificityCalculator; import edu.cornell.med.icb.util.RandomAdapter; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; import it.unimi.dsi.fastutil.doubles.DoubleArraySet; import it.unimi.dsi.fastutil.doubles.DoubleList; import it.unimi.dsi.fastutil.doubles.DoubleSet; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntArraySet; import it.unimi.dsi.fastutil.ints.IntList; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.objects.ObjectArraySet; import it.unimi.dsi.fastutil.objects.ObjectIterator; import it.unimi.dsi.fastutil.objects.ObjectList; import it.unimi.dsi.fastutil.objects.ObjectSet; import org.apache.commons.io.FilenameUtils; import org.apache.commons.lang.ArrayUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.rosuda.REngine.REXP; import org.rosuda.REngine.Rserve.RConnection; import java.io.File; import java.util.Collections; /** * Performs cross-validation for a configurable classifier. * * @author Fabien Campagne Date: Feb 28, 2006 Time: 3:33:37 PM */ public class CrossValidation { private static final Log LOG = LogFactory.getLog(CrossValidation.class); private ClassificationModel model; private final Classifier classifier; private final ClassificationProblem problem; private final ObjectSet<CharSequence> evaluationMeasureNames = new ObjectArraySet<CharSequence>(); private int repeatNumber = 1; private RandomAdapter randomAdapter; private boolean useRServer; private Class<? extends FeatureScaler> featureScalerClass; /** * Request evaluation of the given performance measure. * * @param measureName Name of a performance measure supported by ROCR. * Valid names include: * acc, err, fpr, fall, tpr, rec, sens, fnr, miss, tnr, spec, ppv, prec, npv, pcfall, pcmiss, * rpp, rnp, phi, mat, mi, chisq, odds, lift, f, rch, auc, prbe, cal, mxe, rmse, sar, ecost, cost * See the ROCR documentation for definition of these measures. */ public void evaluateMeasure(final CharSequence measureName) { evaluationMeasureNames.add(measureName); } /** * Set the number of cross-validation repeats. When more than 1, repeats are done with * different folds and results reported averaged over all the fold repeats. * * @param repeatNumber The number of repeats */ public void setRepeatNumber(final int repeatNumber) { assert repeatNumber >= 1 : "Number of repeats must be at least one."; this.repeatNumber = repeatNumber; } public CrossValidation(final Classifier classifier, final ClassificationProblem problem, final RandomEngine randomEngine) { super(); this.classifier = classifier; this.problem = problem; this.randomAdapter = new RandomAdapter(randomEngine); this.useRServer(true); } public ClassificationModel trainModel() { return classifier.train(problem); } /** * Initialize the ClassificationModel with a previoulsy trained model. * * @param model The ClassificationModel to use from now on. */ public void setModel(final ClassificationModel model) { this.model = model; } /** * Train SVM on entire training set and report evaluation measures on training set. * * @return */ public EvaluationMeasure trainEvaluate() { final ClassificationModel trainingModel = classifier.train(problem); final ContingencyTable ctable = new ContingencyTable(); for (int i = 0; i < problem.getSize(); i++) { final double decision = classifier.predict(trainingModel, problem, i); final double trueLabel = problem.getLabel(i); ctable.observeDecision(trueLabel, decision); } ctable.average(); return convertToEvalMeasure(ctable); } /** * Transform continuous score into binary labels in int. * input <0 output -1, input >0, output 1 * * @param decisions Negative values predict the first class, while positive values predict * the second class. * @return */ public static IntList convertBinaryLabels(final DoubleList decisions) { final IntList binaryDecisions = new IntArrayList(); for (int i = 0; i < decisions.size(); i++) { // for each training example, leave it out: final double decision = decisions.getDouble(i); final int binaryDecision = decision < 0 ? -1 : 1; //group -1 and group 1 binaryDecisions.add(binaryDecision); } return binaryDecisions; } /** * Report evaluation measures for predictions on a test set. * * @param decisions Negative values predict the first class, while positive values (zero included) predict * the second class. * @param trueLabels label=0 encodes the first class, label=1 the second class. * @return */ public static EvaluationMeasure testSetEvaluation(final double[] decisions, final double[] trueLabels, final ObjectSet<CharSequence> evaluationMeasureNames, final boolean useRServer) { final ContingencyTable ctable = new ContingencyTable(); assert decisions.length == trueLabels.length : "decision and label arrays must have the same length."; for (int i = 0; i < trueLabels.length; i++) { // convert labels to the conventions used by contingency table. if (trueLabels[i] == 0) { trueLabels[i] = -1; } } final double[] binaryDecisions = new double[decisions.length]; for (int i = 0; i < decisions.length; i++) { // for each training example, leave it out: final double decision = decisions[i]; final double trueLabel = trueLabels[i]; final int binaryDecision = decision < 0 ? -1 : 1; binaryDecisions[i] = binaryDecision; ctable.observeDecision(trueLabel, binaryDecision); } ctable.average(); final EvaluationMeasure measure = convertToEvalMeasure(ctable); try { evaluate(decisions, trueLabels, evaluationMeasureNames, measure, "", useRServer); evaluate(binaryDecisions, trueLabels, evaluationMeasureNames, measure, "binary-", useRServer); } catch (Exception e) { LOG.warn("cannot evaluate with ROCR"); } return measure; } /** * Report evaluation measures for predictions on a test set. * * @param decisionList Negative values predict the first class, while positive values * predict the second class. * @param trueLabelList label=0 encodes the first class, label=1 the second class. * @return */ public static EvaluationMeasure testSetEvaluation(final ObjectList<double[]> decisionList, final ObjectList<double[]> trueLabelList, final ObjectSet<CharSequence> evaluationMeasureNames, final boolean useRServer) { final ContingencyTable ctable = new ContingencyTable(); for (int j = 0; j < decisionList.size(); j++) { final double[] decisions = decisionList.get(j); final double[] trueLabels = trueLabelList.get(j); assert decisions.length == trueLabels.length : "decision and label arrays must have the same length."; for (int i = 0; i < trueLabels.length; i++) { // convert labels to the conventions used by contingency table. if (trueLabels[i] == 0) { trueLabels[i] = -1; } } for (int i = 0; i < decisions.length; i++) { // for each training example, leave it out: final double decision = decisions[i]; final double trueLabel = trueLabels[i]; final int binaryDecision = decision < 0 ? -1 : 1; ctable.observeDecision(trueLabel, binaryDecision); } } ctable.average(); final EvaluationMeasure measure = convertToEvalMeasure(ctable); try { evaluate(decisionList, trueLabelList, evaluationMeasureNames, measure, "", useRServer); } catch (Exception e) { LOG.warn("cannot evaluate with ROCR"); } return measure; } /** * Report leave-one out evaluation measures for training set. * * @return */ public EvaluationMeasure leaveOneOutEvaluation() { final ContingencyTable ctable = new ContingencyTable(); final double[] decisionValues = new double[problem.getSize()]; final double[] labels = new double[problem.getSize()]; final FeatureScaler scaler = resetScaler(); final double[] probs = { 0.0d, 0.0d }; for (int testInstanceIndex = 0; testInstanceIndex < problem.getSize(); testInstanceIndex++) { // for each training example, leave it out: final ClassificationProblem currentTrainingSet = problem.exclude(testInstanceIndex); final ClassificationProblem scaledTrainingSet = currentTrainingSet.scaleTraining(scaler); final ClassificationModel looModel = classifier.train(scaledTrainingSet); final ClassificationProblem oneScaledTestInstanceProblem = problem.scaleTestSet(scaler, testInstanceIndex); final double decision = classifier.predict(looModel, oneScaledTestInstanceProblem, 0, probs); final double trueLabel = problem.getLabel(testInstanceIndex); decisionValues[testInstanceIndex] = decision; labels[testInstanceIndex] = trueLabel; final int binaryDecision = decision < 0 ? -1 : 1; ctable.observeDecision(trueLabel, binaryDecision); } ctable.average(); final EvaluationMeasure measure = convertToEvalMeasure(ctable); evaluate(decisionValues, labels, evaluationMeasureNames, measure, "", useRServer); return measure; } /** * Indicate whether or not the RServe process should be used. Setting this flag to false * removes the dependency on the R server. * * @param useRServer If True, use an RServer to evaluate area under the roc curve. * If False, skip the calculation. */ public void useRServer(final boolean useRServer) { this.useRServer = useRServer; if (!this.useRServer) { evaluationMeasureNames.remove("auc"); } } /** * Report the area under the Receiver Operating Characteristic (ROC) curve. * See <a href="http://pages.cs.wisc.edu/~richm/programs/AUC/">http://pages.cs.wisc.edu/~richm/programs/AUC/</a> * * @param decisionValues Larger values indicate better confidence that the instance belongs to class 1. * @param labels Values of -1 or 0 indicate that the instance belongs to class 0, values of 1 indicate that the * instance belongs to class 1. * @return ROC AUC */ public static double areaUnderRocCurveLOO(final double[] decisionValues, final double[] labels) { if (ArrayUtils.isEmpty(decisionValues) || ArrayUtils.isEmpty(labels)) { throw new IllegalArgumentException("There must be at least 1 label and predition." + " Predictions are empty: " + ArrayUtils.isEmpty(decisionValues) + " Labels are empty: " + ArrayUtils.isEmpty(labels)); } if (decisionValues.length != labels.length) { throw new IllegalArgumentException("number of predictions (" + decisionValues.length + ") must match number of labels (" + labels.length + ")."); } for (int i = 0; i < labels.length; i++) { // for each training example, leave it out: if (labels[i] < 0) { labels[i] = 0; } } if (LOG.isDebugEnabled()) { LOG.debug("decisions: " + ArrayUtils.toString(decisionValues)); } if (LOG.isDebugEnabled()) { LOG.debug("labels: " + ArrayUtils.toString(labels)); } final Double shortCircuitValue = areaUnderRocCurvShortCircuit(decisionValues, labels); if (shortCircuitValue != null) { return shortCircuitValue; } final RConnectionPool connectionPool = RConnectionPool.getInstance(); RConnection connection = null; try { // CALL R ROC connection = connectionPool.borrowConnection(); connection.assign("predictions", decisionValues); connection.assign("labels", labels); // library(ROCR) // predictions <- c(1,1,0,1,1,1,1) // labels <- c(1,1,1,1,1,0,1) // flabels <- factor(labels,c(0,1)) // pred.svm <- prediction(predictions, flabels) // perf.svm <- performance(pred.svm, 'auc') // attr(perf.svm,"y.values")[[1]] final StringBuilder rCommand = new StringBuilder(); rCommand.append("library(ROCR)\n"); rCommand.append("flabels <- factor(labels,c(0,1))\n"); rCommand.append("pred.svm <- prediction(predictions, labels)\n"); rCommand.append("perf.svm <- performance(pred.svm, 'auc')\n"); rCommand.append("attr(perf.svm,\"y.values\")[[1]]"); // attr(perf.rocOutAUC,"y.values")[[1]]\ final REXP expression = connection.eval(rCommand.toString()); final double valueROC_AUC = expression.asDouble(); if (LOG.isDebugEnabled()) { LOG.debug("result from R: " + valueROC_AUC); } return valueROC_AUC; } catch (Exception e) { // connection error or otherwise me LOG.warn("Cannot calculate area under the ROC curve. Make sure Rserve (R server) " + "is configured and running.", e); return Double.NaN; } finally { if (connection != null) { connectionPool.returnConnection(connection); } } } /** * Evaluate a variety of performance measures with <a href="http://rocr.bioinf.mpi-sb.mpg.de/ROCR.pdf">ROCR</a>. * * @param decisionValues Larger values indicate better confidence that the instance belongs to class 1. * @param labels Values of -1 or 0 indicate that the instance belongs to class 0, values of 1 indicate that the * instance belongs to class 1. * @param measureNames Name of performance measures to evaluate. * @param measure Where performance values will be stored. * @see #evaluateMeasure */ public static void evaluate(final double[] decisionValues, final double[] labels, ObjectSet<CharSequence> measureNames, final EvaluationMeasure measure, final CharSequence measureNameSuffix, final boolean useRServer) { measureNames = evaluatePerformanceMeasure(decisionValues, labels, measureNames, measure, measureNameSuffix, new MatthewsCorrelationCalculator()); measureNames = evaluatePerformanceMeasure(decisionValues, labels, measureNames, measure, measureNameSuffix, new AreaUnderTheRocCurveCalculator()); measureNames = evaluatePerformanceMeasure(decisionValues, labels, measureNames, measure, measureNameSuffix, new RootMeanSquaredErrorCalculator()); measureNames = evaluatePerformanceMeasure(decisionValues, labels, measureNames, measure, measureNameSuffix, new AccuracyCalculator()); measureNames = evaluatePerformanceMeasure(decisionValues, labels, measureNames, measure, measureNameSuffix, new SensitivityCalculator()); measureNames = evaluatePerformanceMeasure(decisionValues, labels, measureNames, measure, measureNameSuffix, new SpecificityCalculator()); if (measureNames.size() > 0) { // more measures to evaluate, send to ROCR if (useRServer) { evaluateWithROCR(decisionValues, labels, measureNames, measure, measureNameSuffix); } } } public static void evaluate(final ObjectList<double[]> decisionList, final ObjectList<double[]> trueLabelList, ObjectSet<CharSequence> evaluationMeasureNames, final EvaluationMeasure measure, final CharSequence measureNamePrefix, final boolean useRServer) { evaluationMeasureNames = evaluatePerformanceMeasure(decisionList, trueLabelList, evaluationMeasureNames, measure, new MatthewsCorrelationCalculator()); evaluationMeasureNames = evaluatePerformanceMeasure(decisionList, trueLabelList, evaluationMeasureNames, measure, new AreaUnderTheRocCurveCalculator()); evaluationMeasureNames = evaluatePerformanceMeasure(decisionList, trueLabelList, evaluationMeasureNames, measure, new AccuracyCalculator()); evaluationMeasureNames = evaluatePerformanceMeasure(decisionList, trueLabelList, evaluationMeasureNames, measure, new SensitivityCalculator()); evaluationMeasureNames = evaluatePerformanceMeasure(decisionList, trueLabelList, evaluationMeasureNames, measure, new SpecificityCalculator()); if (evaluationMeasureNames.size() > 0) { // more measures to evaluate, send to ROCR if (useRServer) { for (int i = 0; i < decisionList.size(); i++) { evaluateWithROCR(decisionList.get(i), trueLabelList.get(i), evaluationMeasureNames, measure, measureNamePrefix); } } } } private static ObjectSet<CharSequence> evaluateMCC(final ObjectList<double[]> decisionValueList, final ObjectList<double[]> trueLabelList, final ObjectSet<CharSequence> evaluationMeasureNames, final EvaluationMeasure measure) { if (evaluationMeasureNames.contains("MCC")) { final MatthewsCorrelationCalculator c = new MatthewsCorrelationCalculator(); // find optimal threshold across all splits: c.thresholdIndependentStatistic(decisionValueList, trueLabelList); final double optimalThreshold = c.optimalThreshold; for (int i = 0; i < decisionValueList.size(); i++) { final double mcc = c.evaluateMCC(optimalThreshold, decisionValueList.get(i), trueLabelList.get(i)); measure.addValue("MCC", mcc); } final ObjectSet<CharSequence> measureNamesFiltered = new ObjectArraySet<CharSequence>(); measureNamesFiltered.addAll(evaluationMeasureNames); measureNamesFiltered.remove("MCC"); return measureNamesFiltered; } else { return evaluationMeasureNames; } } private static ObjectSet<CharSequence> evaluatePerformanceMeasure(final ObjectList<double[]> decisionValueList, final ObjectList<double[]> trueLabelList, final ObjectSet<CharSequence> evaluationMeasureNames, final EvaluationMeasure measure, final PredictionStatisticCalculator calculator) { final String measureName = calculator.getMeasureName(); if (evaluationMeasureNames.contains(measureName)) { // find optimal threshold across all splits: calculator.thresholdIndependentStatistic(decisionValueList, trueLabelList); final double optimalThreshold = calculator.optimalThreshold; for (int i = 0; i < decisionValueList.size(); i++) { final double statistic = calculator.evaluateStatisticAtThreshold(optimalThreshold, decisionValueList.get(i), trueLabelList.get(i)); measure.addValue(measureName, statistic); measure.addValue(measureName + "-zero", calculator.evaluateStatisticAtThreshold(0, decisionValueList.get(i), trueLabelList.get(i))); } final ObjectSet<CharSequence> measureNamesFiltered = new ObjectArraySet<CharSequence>(); measureNamesFiltered.addAll(evaluationMeasureNames); measureNamesFiltered.remove(measureName); return measureNamesFiltered; } else { return evaluationMeasureNames; } } private static ObjectSet<CharSequence> evaluatePerformanceMeasure(final double[] decisionValueList, final double[] trueLabelList, final ObjectSet<CharSequence> evaluationMeasureNames, final EvaluationMeasure measure, final CharSequence measureNameSuffix, final PredictionStatisticCalculator calculator) { final String measureName = calculator.getMeasureName(); if (evaluationMeasureNames.contains(measureName)) { // find optimal threshold across all splits: final double statistic = calculator.thresholdIndependentStatistic(decisionValueList, trueLabelList); measure.addValue(measureName + measureNameSuffix, statistic); measure.addValue(measureName + "-zero", calculator.evaluateStatisticAtThreshold(0, decisionValueList, trueLabelList)); final ObjectSet<CharSequence> measureNamesFiltered = new ObjectArraySet<CharSequence>(); measureNamesFiltered.addAll(evaluationMeasureNames); measureNamesFiltered.remove(measureName); return measureNamesFiltered; } else { return evaluationMeasureNames; } } /** * Evaluate a variety of performance measures with <a href="http://rocr.bioinf.mpi-sb.mpg.de/ROCR.pdf">ROCR</a>. * * @param decisionValues Larger values indicate better confidence that the instance belongs to class 1. * @param labels Values of -1 or 0 indicate that the instance belongs to class 0, values of 1 indicate that the * instance belongs to class 1. * @param measureNames Name of performance measures to evaluate. * @param measure Where performance values will be stored. * @see #evaluateMeasure */ public static void evaluateWithROCR(final double[] decisionValues, final double[] labels, final ObjectSet<CharSequence> measureNames, final EvaluationMeasure measure, final CharSequence measureNamePrefix) { assert decisionValues.length == labels.length : "number of predictions must match number of labels."; for (int i = 0; i < labels.length; i++) { // for each training example, leave it out: if (labels[i] < 0) { labels[i] = 0; } } if (LOG.isDebugEnabled()) { LOG.debug("decisions: " + ArrayUtils.toString(decisionValues)); } if (LOG.isDebugEnabled()) { LOG.debug("labels: " + ArrayUtils.toString(labels)); } final RConnectionPool connectionPool = RConnectionPool.getInstance(); RConnection connection = null; CharSequence performanceValueName = null; try { // CALL R ROC connection = connectionPool.borrowConnection(); connection.assign("predictions", decisionValues); connection.assign("labels", labels); // library(ROCR) // predictions <- c(1,1,0,1,1,1,1) // labels <- c(1,1,1,1,1,0,1) // flabels <- factor(labels,c(0,1)) // pred.svm <- prediction(predictions, flabels) // perf.svm <- performance(pred.svm, 'auc') // attr(perf.svm,"y.values")[[1]] final StringBuilder rCommand = new StringBuilder(); rCommand.append("library(ROCR)\n"); rCommand.append("flabels <- labels\n"); rCommand.append("pred.svm <- prediction(predictions, labels)\n"); connection.eval(rCommand.toString()); for (ObjectIterator<CharSequence> charSequenceObjectIterator = measureNames .iterator(); charSequenceObjectIterator.hasNext();) { final StringBuilder rCommandMeasure = new StringBuilder(); performanceValueName = charSequenceObjectIterator.next(); if (performanceValueName == null) { continue; } final CharSequence storedPerformanceMeasureName = measureNamePrefix.toString() + performanceValueName.toString(); rCommandMeasure.append("perf.svm <- performance(pred.svm, '"); rCommandMeasure.append(performanceValueName); rCommandMeasure.append("')\n"); rCommandMeasure.append("attr(perf.svm,\"y.values\")[[1]]"); final REXP expressionValue = connection.eval(rCommandMeasure.toString()); final double[] values = expressionValue.asDoubles(); if (values.length == 1) { // this performance measure is threshold independent.. LOG.debug("result from R (" + performanceValueName + ") : " + values[0]); measure.addValue(storedPerformanceMeasureName, values[0]); } else { // we have one performance measure value per decision threshold. final StringBuilder rCommandThresholds = new StringBuilder(); rCommandThresholds.append("attr(perf.svm,\"x.values\")[[1]]"); final REXP expressionThresholds = connection.eval(rCommandThresholds.toString()); final double[] thresholds = expressionThresholds.asDoubles(); // find the index of x.value which indicates a threshold more or equal to zero (for the decision value) int thresholdGEZero = -1; for (int index = thresholds.length - 1; index >= 0; index--) { if (thresholds[index] >= 0) { thresholdGEZero = index; break; } } if (LOG.isDebugEnabled()) { LOG.debug("result from R (" + performanceValueName + ") : " + values[thresholdGEZero]); } if (thresholdGEZero != -1) { measure.addValue(storedPerformanceMeasureName, values[thresholdGEZero]); } } } } catch (Exception e) { // connection error or otherwise LOG.warn("Cannot evaluate performance measure " + performanceValueName + ". Make sure Rserve (R server) is configured and running.", e); } finally { if (connection != null) { connectionPool.returnConnection(connection); } } } /** * Checks decisionValues and labels and determines if we * can short-circuit the value based on pre-defined rules. * Returns null if the decision cannot be short-circuited * or the value * * @param decisionValues the decision values * @param labels the label values * @return null or a Double value */ public static Double areaUnderRocCurvShortCircuit(final double[] decisionValues, final double[] labels) { Double shortCircuitValue = null; String debugStr = null; if (ArrayUtils.isEmpty(decisionValues) || ArrayUtils.isEmpty(labels)) { debugStr = "++SHORTCIRCUIT: No labels or decision values. This will fail ROC."; } else { final VectorDetails decisionValueDetails = new VectorDetails(decisionValues); final VectorDetails labelDetails = new VectorDetails(labels); if (labelDetails.isAllZeros()) { if (decisionValueDetails.isAllPositive()) { shortCircuitValue = 0.0; debugStr = "++SHORTCIRCUIT: Label all zeros, decision all positive. Returning 0"; } else if (decisionValueDetails.isAllNegative()) { shortCircuitValue = 1.0; debugStr = "++SHORTCIRCUIT: Label all zeros, decision all negative. Returning 1"; } else { debugStr = "++SHORTCIRCUIT: Label all zeros, decisions vary. This will fail ROC."; } } else if (labelDetails.isAllOnes()) { if (decisionValueDetails.isAllPositive()) { shortCircuitValue = 1.0; debugStr = "++SHORTCIRCUIT: Label all ones, decision all positive. Returning 1"; } else if (decisionValueDetails.isAllNegative()) { shortCircuitValue = 0.0; debugStr = "++SHORTCIRCUIT: Label all ones, decision all negative. Returning 0"; } else { debugStr = "++SHORTCIRCUIT: Label all ones, decisions vary. This will fail ROC."; } } } if (LOG.isDebugEnabled() && debugStr != null) { LOG.debug(debugStr); } return shortCircuitValue; } /** * Report the area under the Receiver Operating Characteristic (ROC) curve. Estimates are * done with a leave one out evaluation. * * @param decisionValues Decision values output by classifier. Larger values indicate more * confidence in prediction of a positive label. * @param labels Correct label for item, can be 0 (negative class) or +1 (positive class). * @param rocCurvefilename Name of the file to plot the pdf image to */ public static void plotRocCurveLOO(final double[] decisionValues, final double[] labels, final String rocCurvefilename) { assert decisionValues.length == labels.length : "number of predictions must match number of labels."; for (int i = 0; i < labels.length; i++) { // for each training example, leave it out: if (decisionValues[i] < 0) { decisionValues[i] = 0; } if (labels[i] < 0) { labels[i] = 0; } } // R server only understands unix style path. Convert windows to unix if needed: final String plotFilename = FilenameUtils.separatorsToUnix(rocCurvefilename); final File plotFile = new File(plotFilename); final RConnectionPool connectionPool = RConnectionPool.getInstance(); RConnection connection = null; // CALL R ROC try { if (plotFile.exists()) { plotFile.delete(); } connection = connectionPool.borrowConnection(); connection.assign("predictions", decisionValues); connection.assign("labels", labels); final String cmd = " library(ROCR) \n" + "pred.svm <- prediction(predictions, labels)\n" + "pdf(\"" + plotFilename + "\", height=5, width=5)\n" + "perf <- performance(pred.svm, measure = \"tpr\", x.measure = \"fpr\")\n" + "plot(perf)\n" + "dev.off()"; final REXP expression = connection.eval(cmd); // attr(perf.rocOutAUC,"y.values")[[1]] final double valueROC_AUC = expression.asDouble(); // System.out.println("result from R: " + valueROC_AUC); } catch (Exception e) { // connection error or otherwise LOG.warn("Cannot plot ROC curve to " + plotFilename + ". Make sure Rserve (R server) " + "is configured and running and the owner of the Rserve process has permission " + "to write to the directory \"" + FilenameUtils.getFullPath(plotFile.getAbsolutePath()) + "\"", e); } finally { if (connection != null) { connectionPool.returnConnection(connection); } } } /* ContingencyTable ctable = new ContingencyTable(); for (int i = 0; i < numberOfTrainingExamples; i++) { // for each training example, leave it out: final svm_problem looProblem = splitProblem(problem, i); final svm_model looModel = svm.svm_train(looProblem, parameters); final double decision = svm.svm_predict(looModel, problem.x[i]); final double trueLabel = problem.y[i]; decisionValues[i] = decision; labels[i] = trueLabel; ctable.observeDecision(trueLabel, decision); } ctable.average(); EvaluationMeasure measure = convertToEvalMeasure(ctable); measure.setRocAuc(areaUnderRocCurveLOO(decisionValues, labels)); return measure; */ /** * Run cross-validation with k folds. * * @param k Number of folds for cross validation. Typical values are 5 or 10. * @param randomEngine Random engine to use when splitting the training set into folds. * @return Evaluation measures. */ public EvaluationMeasure crossValidation(final int k, final RandomEngine randomEngine) { this.randomAdapter = new RandomAdapter(randomEngine); return this.crossValidation(k); } /** * Run cross-validation with k folds. * * @param k Number of folds for cross validation. Typical values are 5 or 10. * @return Evaluation measures. */ public EvaluationMeasure crossValidation(final int k) { final ContingencyTable ctable = new ContingencyTable(); final DoubleList aucValues = new DoubleArrayList(); final DoubleList f1Values = new DoubleArrayList(); final EvaluationMeasure measure = new EvaluationMeasure(); for (int r = 0; r < repeatNumber; r++) { assert k <= problem.getSize() : "Number of folds must be less or equal to number of training examples."; final int[] foldIndices = assignFolds(k); for (int f = 0; f < k; ++f) { // use each fold as test set while the others are the training set: final IntSet trainingSet = new IntArraySet(); final IntSet testSet = new IntArraySet(); for (int i = 0; i < problem.getSize(); i++) { // assign each training example to a fold: if (f == foldIndices[i]) { testSet.add(i); } else { trainingSet.add(i); } } assert testSet.size() + trainingSet.size() == problem .getSize() : "test set and training set size must add to whole problem size."; final IntSet intersection = new IntOpenHashSet(); intersection.addAll(trainingSet); intersection.retainAll(testSet); assert intersection.size() == 0 : "test set and training set must never overlap"; final ClassificationProblem currentTrainingSet = problem.filter(trainingSet); assert currentTrainingSet.getSize() == trainingSet .size() : "Problem size must match size of training set"; final FeatureScaler scaler = resetScaler(); // reset the scaler for each test set.. final ClassificationProblem scaledTrainingSet = currentTrainingSet.scaleTraining(scaler); final ClassificationModel looModel = classifier.train(scaledTrainingSet); final ContingencyTable ctableMicro = new ContingencyTable(); final double[] decisionValues = new double[testSet.size()]; final double[] labels = new double[testSet.size()]; int index = 0; final double[] probs = { 0.0d, 0.0d }; for (final int testInstanceIndex : testSet) { // for each test example: // ClassificationProblem oneScaledTestInstanceProblem = problem.filter(testInstanceIndex); final ClassificationProblem oneScaledTestInstanceProblem = problem.scaleTestSet(scaler, testInstanceIndex); assert oneScaledTestInstanceProblem .getSize() == 1 : "filtered test problem must have one instance left (size was " + oneScaledTestInstanceProblem.getSize() + ")."; final double decision = classifier.predict(looModel, oneScaledTestInstanceProblem, 0, probs); final double trueLabel = problem.getLabel(testInstanceIndex); final double maxProb; maxProb = Math.max(probs[0], probs[1]); decisionValues[index] = decision * maxProb; labels[index] = trueLabel; index++; final int binaryDecision = decision < 0 ? -1 : 1; ctable.observeDecision(trueLabel, binaryDecision); ctableMicro.observeDecision(trueLabel, binaryDecision); } ctableMicro.average(); f1Values.add(ctableMicro.getF1Measure()); final double aucForOneFold = Double.NaN; evaluate(decisionValues, labels, evaluationMeasureNames, measure, "", useRServer); aucValues.add(aucForOneFold); } } ctable.average(); measure.setContingencyTable(ctable); // The below line was previously commented out? KCD 2008-09-29 measure.setRocAucValues(aucValues); measure.setF1Values(f1Values); return measure; } private FeatureScaler resetScaler() { FeatureScaler scaler = null; try { scaler = featureScalerClass.newInstance(); } catch (InstantiationException e) { LOG.error("Cannot instantiate feature scaler", e); } catch (IllegalAccessException e) { LOG.error("Cannot create feature scaler", e); } return scaler; } /** * Calculates semi-random fold assignments. Ideally fold assignments would be as random as * possible. Because prediction results on test folds are evaluated with ROCR (to calculate * ROC AUC), and because ROCR cannot handle situations where all the labels are only one * category (i.e., all class 1 or all class 2), we force folds generated by this * method to exclude this situation. * * @param k Number of folds * @return An array where each element is the index of the fold to which the given instance * of the training set belongs. */ private int[] assignFolds(final int k) { final IntList indices = new IntArrayList(); do { indices.clear(); for (int i = 0; i < problem.getSize(); ++i) { indices.add(i % k); } Collections.shuffle(indices, randomAdapter); } while (invalidFold(indices, k)); final int[] splitIndex = new int[problem.getSize()]; indices.toArray(splitIndex); return splitIndex; } /** * Determines if a fold split is valid. See ROCR comment above. * * @param indices Training instance fold assignments. * @param k Number of folds in the split * @return True if the fold is invalid (does not have at least two labels represented) * @see #assignFolds */ private boolean invalidFold(final IntList indices, final int k) { problem.prepareNative(); for (int currentFoldInspected = 0; currentFoldInspected < k; currentFoldInspected++) { final DoubleSet labels = new DoubleArraySet(); int instanceIndex = 0; for (final int foldAssigment : indices) { if (foldAssigment == currentFoldInspected) { labels.add(problem.getLabel(instanceIndex)); } instanceIndex++; } if (labels.size() < 2) { return true; } } return false; } private static EvaluationMeasure convertToEvalMeasure(final ContingencyTable ctable) { return new EvaluationMeasure(ctable); } public ClassificationModel getModel() { return model; } public void evaluateMeasures(final CharSequence... names) { for (final CharSequence name : names) { evaluateMeasure(name); } } public void setScalerClass(final Class<? extends FeatureScaler> featureScalerClass) { this.featureScalerClass = featureScalerClass; } }