Example usage for weka.classifiers.evaluation ThresholdCurve getCurve

List of usage examples for weka.classifiers.evaluation ThresholdCurve getCurve

Introduction

In this page you can find the example usage for weka.classifiers.evaluation ThresholdCurve getCurve.

Prototype

public Instances getCurve(ArrayList<Prediction> predictions, int classIndex) 

Source Link

Document

Calculates the performance stats for the desired class and return results as a set of Instances.

Usage

From source file:cotraining.copy.Evaluation_D.java

License:Open Source License

/**
 * Returns the area under ROC for those predictions that have been collected
 * in the evaluateClassifier(Classifier, Instances) method. Returns 
 * Instance.missingValue() if the area is not available.
 *
 * @param classIndex the index of the class to consider as "positive"
 * @return the area under the ROC curve or not a number
 *//* w w  w  .  j av  a 2 s  . com*/
public double areaUnderROC(int classIndex) {

    // Check if any predictions have been collected
    if (m_Predictions == null) {
        return Instance.missingValue();
    } else {
        ThresholdCurve tc = new ThresholdCurve();
        Instances result = tc.getCurve(m_Predictions, classIndex);
        return ThresholdCurve.getROCArea(result);
    }
}

From source file:cotraining.copy.Evaluation_D.java

License:Open Source License

/**
 * Evaluates a classifier with the options given in an array of
 * strings. <p/>//from w w  w  .java2 s.c  o  m
 *
 * Valid options are: <p/>
 *
 * -t name of training file <br/>
 * Name of the file with the training data. (required) <p/>
 *
 * -T name of test file <br/>
 * Name of the file with the test data. If missing a cross-validation 
 * is performed. <p/>
 *
 * -c class index <br/>
 * Index of the class attribute (1, 2, ...; default: last). <p/>
 *
 * -x number of folds <br/>
 * The number of folds for the cross-validation (default: 10). <p/>
 *
 * -no-cv <br/>
 * No cross validation.  If no test file is provided, no evaluation
 * is done. <p/>
 * 
 * -split-percentage percentage <br/>
 * Sets the percentage for the train/test set split, e.g., 66. <p/>
 * 
 * -preserve-order <br/>
 * Preserves the order in the percentage split instead of randomizing
 * the data first with the seed value ('-s'). <p/>
 *
 * -s seed <br/>
 * Random number seed for the cross-validation and percentage split
 * (default: 1). <p/>
 *
 * -m file with cost matrix <br/>
 * The name of a file containing a cost matrix. <p/>
 *
 * -l filename <br/>
 * Loads classifier from the given file. In case the filename ends with
 * ".xml",a PMML file is loaded or, if that fails, options are loaded from XML. <p/>
 *
 * -d filename <br/>
 * Saves classifier built from the training data into the given file. In case 
 * the filename ends with ".xml" the options are saved XML, not the model. <p/>
 *
 * -v <br/>
 * Outputs no statistics for the training data. <p/>
 *
 * -o <br/>
 * Outputs statistics only, not the classifier. <p/>
 * 
 * -i <br/>
 * Outputs detailed information-retrieval statistics per class. <p/>
 *
 * -k <br/>
 * Outputs information-theoretic statistics. <p/>
 *
 * -p range <br/>
 * Outputs predictions for test instances (or the train instances if no test
 * instances provided and -no-cv is used), along with the attributes in the specified range 
 * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
 *
 * -distribution <br/>
 * Outputs the distribution instead of only the prediction
 * in conjunction with the '-p' option (only nominal classes). <p/>
 *
 * -r <br/>
 * Outputs cumulative margin distribution (and nothing else). <p/>
 *
 * -g <br/> 
 * Only for classifiers that implement "Graphable." Outputs
 * the graph representation of the classifier (and nothing
 * else). <p/>
 *
 * -xml filename | xml-string <br/>
 * Retrieves the options from the XML-data instead of the command line. <p/>
 *
 * @param classifier machine learning classifier
 * @param options the array of string containing the options
 * @throws Exception if model could not be evaluated successfully
 * @return a string describing the results 
 */
public static String evaluateModel(Classifier classifier, String[] options) throws Exception {

    Instances train = null, tempTrain, test = null, template = null;
    int seed = 1, folds = 10, classIndex = -1;
    boolean noCrossValidation = false;
    String trainFileName, testFileName, sourceClass, classIndexString, seedString, foldsString,
            objectInputFileName, objectOutputFileName, attributeRangeString;
    boolean noOutput = false, printClassifications = false, trainStatistics = true, printMargins = false,
            printComplexityStatistics = false, printGraph = false, classStatistics = false, printSource = false;
    StringBuffer text = new StringBuffer();
    DataSource trainSource = null, testSource = null;
    ObjectInputStream objectInputStream = null;
    BufferedInputStream xmlInputStream = null;
    CostMatrix costMatrix = null;
    StringBuffer schemeOptionsText = null;
    Range attributesToOutput = null;
    long trainTimeStart = 0, trainTimeElapsed = 0, testTimeStart = 0, testTimeElapsed = 0;
    String xml = "";
    String[] optionsTmp = null;
    Classifier classifierBackup;
    Classifier classifierClassifications = null;
    boolean printDistribution = false;
    int actualClassIndex = -1; // 0-based class index
    String splitPercentageString = "";
    int splitPercentage = -1;
    boolean preserveOrder = false;
    boolean trainSetPresent = false;
    boolean testSetPresent = false;
    String thresholdFile;
    String thresholdLabel;
    StringBuffer predsBuff = null; // predictions from cross-validation

    // help requested?
    if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) {

        // global info requested as well?
        boolean globalInfo = Utils.getFlag("synopsis", options) || Utils.getFlag("info", options);

        throw new Exception("\nHelp requested." + makeOptionString(classifier, globalInfo));
    }

    try {
        // do we get the input from XML instead of normal parameters?
        xml = Utils.getOption("xml", options);
        if (!xml.equals(""))
            options = new XMLOptions(xml).toArray();

        // is the input model only the XML-Options, i.e. w/o built model?
        optionsTmp = new String[options.length];
        for (int i = 0; i < options.length; i++)
            optionsTmp[i] = options[i];

        String tmpO = Utils.getOption('l', optionsTmp);
        //if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) {
        if (tmpO.endsWith(".xml")) {
            // try to load file as PMML first
            boolean success = false;
            try {
                //PMMLModel pmmlModel = PMMLFactory.getPMMLModel(tmpO);
                //if (pmmlModel instanceof PMMLClassifier) {
                //classifier = ((PMMLClassifier)pmmlModel);
                // success = true;
                //}
            } catch (IllegalArgumentException ex) {
                success = false;
            }
            if (!success) {
                // load options from serialized data  ('-l' is automatically erased!)
                XMLClassifier xmlserial = new XMLClassifier();
                Classifier cl = (Classifier) xmlserial.read(Utils.getOption('l', options));

                // merge options
                optionsTmp = new String[options.length + cl.getOptions().length];
                System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
                System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
                options = optionsTmp;
            }
        }

        noCrossValidation = Utils.getFlag("no-cv", options);
        // Get basic options (options the same for all schemes)
        classIndexString = Utils.getOption('c', options);
        if (classIndexString.length() != 0) {
            if (classIndexString.equals("first"))
                classIndex = 1;
            else if (classIndexString.equals("last"))
                classIndex = -1;
            else
                classIndex = Integer.parseInt(classIndexString);
        }
        trainFileName = Utils.getOption('t', options);
        objectInputFileName = Utils.getOption('l', options);
        objectOutputFileName = Utils.getOption('d', options);
        testFileName = Utils.getOption('T', options);
        foldsString = Utils.getOption('x', options);
        if (foldsString.length() != 0) {
            folds = Integer.parseInt(foldsString);
        }
        seedString = Utils.getOption('s', options);
        if (seedString.length() != 0) {
            seed = Integer.parseInt(seedString);
        }
        if (trainFileName.length() == 0) {
            if (objectInputFileName.length() == 0) {
                throw new Exception("No training file and no object " + "input file given.");
            }
            if (testFileName.length() == 0) {
                throw new Exception("No training file and no test " + "file given.");
            }
        } else if ((objectInputFileName.length() != 0)
                && ((!(classifier instanceof UpdateableClassifier)) || (testFileName.length() == 0))) {
            throw new Exception("Classifier not incremental, or no " + "test file provided: can't "
                    + "use both train and model file.");
        }
        try {
            if (trainFileName.length() != 0) {
                trainSetPresent = true;
                trainSource = new DataSource(trainFileName);
            }
            if (testFileName.length() != 0) {
                testSetPresent = true;
                testSource = new DataSource(testFileName);
            }
            if (objectInputFileName.length() != 0) {
                if (objectInputFileName.endsWith(".xml")) {
                    // if this is the case then it means that a PMML classifier was
                    // successfully loaded earlier in the code
                    objectInputStream = null;
                    xmlInputStream = null;
                } else {
                    InputStream is = new FileInputStream(objectInputFileName);
                    if (objectInputFileName.endsWith(".gz")) {
                        is = new GZIPInputStream(is);
                    }
                    // load from KOML?
                    if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent())) {
                        objectInputStream = new ObjectInputStream(is);
                        xmlInputStream = null;
                    } else {
                        objectInputStream = null;
                        xmlInputStream = new BufferedInputStream(is);
                    }
                }
            }
        } catch (Exception e) {
            throw new Exception("Can't open file " + e.getMessage() + '.');
        }
        if (testSetPresent) {
            template = test = testSource.getStructure();
            if (classIndex != -1) {
                test.setClassIndex(classIndex - 1);
            } else {
                if ((test.classIndex() == -1) || (classIndexString.length() != 0))
                    test.setClassIndex(test.numAttributes() - 1);
            }
            actualClassIndex = test.classIndex();
        } else {
            // percentage split
            splitPercentageString = Utils.getOption("split-percentage", options);
            if (splitPercentageString.length() != 0) {
                if (foldsString.length() != 0)
                    throw new Exception("Percentage split cannot be used in conjunction with "
                            + "cross-validation ('-x').");
                splitPercentage = Integer.parseInt(splitPercentageString);
                if ((splitPercentage <= 0) || (splitPercentage >= 100))
                    throw new Exception("Percentage split value needs be >0 and <100.");
            } else {
                splitPercentage = -1;
            }
            preserveOrder = Utils.getFlag("preserve-order", options);
            if (preserveOrder) {
                if (splitPercentage == -1)
                    throw new Exception("Percentage split ('-percentage-split') is missing.");
            }
            // create new train/test sources
            if (splitPercentage > 0) {
                testSetPresent = true;
                Instances tmpInst = trainSource.getDataSet(actualClassIndex);
                if (!preserveOrder)
                    tmpInst.randomize(new Random(seed));
                int trainSize = tmpInst.numInstances() * splitPercentage / 100;
                int testSize = tmpInst.numInstances() - trainSize;
                Instances trainInst = new Instances(tmpInst, 0, trainSize);
                Instances testInst = new Instances(tmpInst, trainSize, testSize);
                trainSource = new DataSource(trainInst);
                testSource = new DataSource(testInst);
                template = test = testSource.getStructure();
                if (classIndex != -1) {
                    test.setClassIndex(classIndex - 1);
                } else {
                    if ((test.classIndex() == -1) || (classIndexString.length() != 0))
                        test.setClassIndex(test.numAttributes() - 1);
                }
                actualClassIndex = test.classIndex();
            }
        }
        if (trainSetPresent) {
            template = train = trainSource.getStructure();
            if (classIndex != -1) {
                train.setClassIndex(classIndex - 1);
            } else {
                if ((train.classIndex() == -1) || (classIndexString.length() != 0))
                    train.setClassIndex(train.numAttributes() - 1);
            }
            actualClassIndex = train.classIndex();
            if ((testSetPresent) && !test.equalHeaders(train)) {
                throw new IllegalArgumentException("Train and test file not compatible!");
            }
        }
        if (template == null) {
            throw new Exception("No actual dataset provided to use as template");
        }
        costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses());

        classStatistics = Utils.getFlag('i', options);
        noOutput = Utils.getFlag('o', options);
        trainStatistics = !Utils.getFlag('v', options);
        printComplexityStatistics = Utils.getFlag('k', options);
        printMargins = Utils.getFlag('r', options);
        printGraph = Utils.getFlag('g', options);
        sourceClass = Utils.getOption('z', options);
        printSource = (sourceClass.length() != 0);
        printDistribution = Utils.getFlag("distribution", options);
        thresholdFile = Utils.getOption("threshold-file", options);
        thresholdLabel = Utils.getOption("threshold-label", options);

        // Check -p option
        try {
            attributeRangeString = Utils.getOption('p', options);
        } catch (Exception e) {
            throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. "
                    + "It now expects a parameter specifying a range of attributes "
                    + "to list with the predictions. Use '-p 0' for none.");
        }
        if (attributeRangeString.length() != 0) {
            printClassifications = true;
            noOutput = true;
            if (!attributeRangeString.equals("0"))
                attributesToOutput = new Range(attributeRangeString);
        }

        if (!printClassifications && printDistribution)
            throw new Exception("Cannot print distribution without '-p' option!");

        // if no training file given, we don't have any priors
        if ((!trainSetPresent) && (printComplexityStatistics))
            throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");

        // If a model file is given, we can't process 
        // scheme-specific options
        if (objectInputFileName.length() != 0) {
            Utils.checkForRemainingOptions(options);
        } else {

            // Set options for classifier
            if (classifier instanceof OptionHandler) {
                for (int i = 0; i < options.length; i++) {
                    if (options[i].length() != 0) {
                        if (schemeOptionsText == null) {
                            schemeOptionsText = new StringBuffer();
                        }
                        if (options[i].indexOf(' ') != -1) {
                            schemeOptionsText.append('"' + options[i] + "\" ");
                        } else {
                            schemeOptionsText.append(options[i] + " ");
                        }
                    }
                }
                ((OptionHandler) classifier).setOptions(options);
            }
        }
        Utils.checkForRemainingOptions(options);
    } catch (Exception e) {
        throw new Exception("\nWeka exception: " + e.getMessage() + makeOptionString(classifier, false));
    }

    // Setup up evaluation objects
    Evaluation_D trainingEvaluation = new Evaluation_D(new Instances(template, 0), costMatrix);
    Evaluation_D testingEvaluation = new Evaluation_D(new Instances(template, 0), costMatrix);

    // disable use of priors if no training file given
    if (!trainSetPresent)
        testingEvaluation.useNoPriors();

    if (objectInputFileName.length() != 0) {
        // Load classifier from file
        if (objectInputStream != null) {
            classifier = (Classifier) objectInputStream.readObject();
            // try and read a header (if present)
            Instances savedStructure = null;
            try {
                savedStructure = (Instances) objectInputStream.readObject();
            } catch (Exception ex) {
                // don't make a fuss
            }
            if (savedStructure != null) {
                // test for compatibility with template
                if (!template.equalHeaders(savedStructure)) {
                    throw new Exception("training and test set are not compatible");
                }
            }
            objectInputStream.close();
        } else if (xmlInputStream != null) {
            // whether KOML is available has already been checked (objectInputStream would null otherwise)!
            classifier = (Classifier) KOML.read(xmlInputStream);
            xmlInputStream.close();
        }
    }

    // backup of fully setup classifier for cross-validation
    classifierBackup = Classifier.makeCopy(classifier);

    // Build the classifier if no object file provided
    if ((classifier instanceof UpdateableClassifier) && (testSetPresent || noCrossValidation)
            && (costMatrix == null) && (trainSetPresent)) {
        // Build classifier incrementally
        trainingEvaluation.setPriors(train);
        testingEvaluation.setPriors(train);
        trainTimeStart = System.currentTimeMillis();
        if (objectInputFileName.length() == 0) {
            classifier.buildClassifier(train);
        }
        Instance trainInst;
        while (trainSource.hasMoreElements(train)) {
            trainInst = trainSource.nextElement(train);
            trainingEvaluation.updatePriors(trainInst);
            testingEvaluation.updatePriors(trainInst);
            ((UpdateableClassifier) classifier).updateClassifier(trainInst);
        }
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    } else if (objectInputFileName.length() == 0) {
        // Build classifier in one go
        tempTrain = trainSource.getDataSet(actualClassIndex);
        trainingEvaluation.setPriors(tempTrain);
        testingEvaluation.setPriors(tempTrain);
        trainTimeStart = System.currentTimeMillis();
        classifier.buildClassifier(tempTrain);
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    }

    // backup of fully trained classifier for printing the classifications
    if (printClassifications)
        classifierClassifications = Classifier.makeCopy(classifier);

    // Save the classifier if an object output file is provided
    if (objectOutputFileName.length() != 0) {
        OutputStream os = new FileOutputStream(objectOutputFileName);
        // binary
        if (!(objectOutputFileName.endsWith(".xml")
                || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
            if (objectOutputFileName.endsWith(".gz")) {
                os = new GZIPOutputStream(os);
            }
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
            objectOutputStream.writeObject(classifier);
            if (template != null) {
                objectOutputStream.writeObject(template);
            }
            objectOutputStream.flush();
            objectOutputStream.close();
        }
        // KOML/XML
        else {
            BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
            if (objectOutputFileName.endsWith(".xml")) {
                XMLSerialization xmlSerial = new XMLClassifier();
                xmlSerial.write(xmlOutputStream, classifier);
            } else
            // whether KOML is present has already been checked
            // if not present -> ".koml" is interpreted as binary - see above
            if (objectOutputFileName.endsWith(".koml")) {
                KOML.write(xmlOutputStream, classifier);
            }
            xmlOutputStream.close();
        }
    }

    // If classifier is drawable output string describing graph
    if ((classifier instanceof Drawable) && (printGraph)) {
        return ((Drawable) classifier).graph();
    }

    // Output the classifier as equivalent source
    if ((classifier instanceof Sourcable) && (printSource)) {
        return wekaStaticWrapper((Sourcable) classifier, sourceClass);
    }

    // Output model
    if (!(noOutput || printMargins)) {
        if (classifier instanceof OptionHandler) {
            if (schemeOptionsText != null) {
                text.append("\nOptions: " + schemeOptionsText);
                text.append("\n");
            }
        }
        text.append("\n" + classifier.toString() + "\n");
    }

    if (!printMargins && (costMatrix != null)) {
        text.append("\n=== Evaluation Cost Matrix ===\n\n");
        text.append(costMatrix.toString());
    }

    // Output test instance predictions only
    if (printClassifications) {
        DataSource source = testSource;
        predsBuff = new StringBuffer();
        // no test set -> use train set
        if (source == null && noCrossValidation) {
            source = trainSource;
            predsBuff.append("\n=== Predictions on training data ===\n\n");
        } else {
            predsBuff.append("\n=== Predictions on test data ===\n\n");
        }
        if (source != null) {
            /*      return printClassifications(classifierClassifications, new Instances(template, 0),
                    source, actualClassIndex + 1, attributesToOutput,
                    printDistribution); */
            printClassifications(classifierClassifications, new Instances(template, 0), source,
                    actualClassIndex + 1, attributesToOutput, printDistribution, predsBuff);
            //        return predsText.toString();
        }
    }

    // Compute error estimate from training data
    if ((trainStatistics) && (trainSetPresent)) {

        if ((classifier instanceof UpdateableClassifier) && (testSetPresent) && (costMatrix == null)) {

            // Classifier was trained incrementally, so we have to 
            // reset the source.
            trainSource.reset();

            // Incremental testing
            train = trainSource.getStructure(actualClassIndex);
            testTimeStart = System.currentTimeMillis();
            Instance trainInst;
            while (trainSource.hasMoreElements(train)) {
                trainInst = trainSource.nextElement(train);
                trainingEvaluation.evaluateModelOnce((Classifier) classifier, trainInst);
            }
            testTimeElapsed = System.currentTimeMillis() - testTimeStart;
        } else {
            testTimeStart = System.currentTimeMillis();
            trainingEvaluation.evaluateModel(classifier, trainSource.getDataSet(actualClassIndex));
            testTimeElapsed = System.currentTimeMillis() - testTimeStart;
        }

        // Print the results of the training evaluation
        if (printMargins) {
            return trainingEvaluation.toCumulativeMarginDistributionString();
        } else {
            if (!printClassifications) {
                text.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0, 2)
                        + " seconds");

                if (splitPercentage > 0)
                    text.append("\nTime taken to test model on training split: ");
                else
                    text.append("\nTime taken to test model on training data: ");
                text.append(Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds");

                if (splitPercentage > 0)
                    text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " split ===\n",
                            printComplexityStatistics));
                else
                    text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " data ===\n",
                            printComplexityStatistics));

                if (template.classAttribute().isNominal()) {
                    if (classStatistics) {
                        text.append("\n\n" + trainingEvaluation.toClassDetailsString());
                    }
                    if (!noCrossValidation)
                        text.append("\n\n" + trainingEvaluation.toMatrixString());
                }
            }
        }
    }

    // Compute proper error estimates
    if (testSource != null) {
        // Testing is on the supplied test data
        testSource.reset();
        test = testSource.getStructure(test.classIndex());
        Instance testInst;
        while (testSource.hasMoreElements(test)) {
            testInst = testSource.nextElement(test);
            testingEvaluation.evaluateModelOnceAndRecordPrediction((Classifier) classifier, testInst);
        }

        if (splitPercentage > 0) {
            if (!printClassifications) {
                text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test split ===\n",
                        printComplexityStatistics));
            }
        } else {
            if (!printClassifications) {
                text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test data ===\n",
                        printComplexityStatistics));
            }
        }

    } else if (trainSource != null) {
        if (!noCrossValidation) {
            // Testing is via cross-validation on training data
            Random random = new Random(seed);
            // use untrained (!) classifier for cross-validation
            classifier = Classifier.makeCopy(classifierBackup);
            if (!printClassifications) {
                testingEvaluation.crossValidateModel(classifier, trainSource.getDataSet(actualClassIndex),
                        folds, random);
                if (template.classAttribute().isNumeric()) {
                    text.append("\n\n\n" + testingEvaluation.toSummaryString("=== Cross-validation ===\n",
                            printComplexityStatistics));
                } else {
                    text.append("\n\n\n" + testingEvaluation.toSummaryString(
                            "=== Stratified " + "cross-validation ===\n", printComplexityStatistics));
                }
            } else {
                predsBuff = new StringBuffer();
                predsBuff.append("\n=== Predictions under cross-validation ===\n\n");
                testingEvaluation.crossValidateModel(classifier, trainSource.getDataSet(actualClassIndex),
                        folds, random, predsBuff, attributesToOutput, new Boolean(printDistribution));
                /*          if (template.classAttribute().isNumeric()) {
                            text.append("\n\n\n" + testingEvaluation.
                toSummaryString("=== Cross-validation ===\n",
                                printComplexityStatistics));
                          } else {
                            text.append("\n\n\n" + testingEvaluation.
                toSummaryString("=== Stratified " + 
                                "cross-validation ===\n",
                                printComplexityStatistics));
                          } */
            }
        }
    }
    if (template.classAttribute().isNominal()) {
        if (classStatistics && !noCrossValidation && !printClassifications) {
            text.append("\n\n" + testingEvaluation.toClassDetailsString());
        }
        if (!noCrossValidation && !printClassifications)
            text.append("\n\n" + testingEvaluation.toMatrixString());

    }

    // predictions from cross-validation?
    if (predsBuff != null) {
        text.append("\n" + predsBuff);
    }

    if ((thresholdFile.length() != 0) && template.classAttribute().isNominal()) {
        int labelIndex = 0;
        if (thresholdLabel.length() != 0)
            labelIndex = template.classAttribute().indexOfValue(thresholdLabel);
        if (labelIndex == -1)
            throw new IllegalArgumentException("Class label '" + thresholdLabel + "' is unknown!");
        ThresholdCurve tc = new ThresholdCurve();
        Instances result = tc.getCurve(testingEvaluation.predictions(), labelIndex);
        DataSink.write(thresholdFile, result);
    }

    return text.toString();
}

From source file:cotraining.copy.Evaluation_D.java

License:Open Source License

/**
 * Returns the area under PRC for those predictions that have been collected
 * in the evaluateClassifier(Classifier, Instances) method. Returns 
 * Instance.missingValue() if the area is not available.
 *
 * @param classIndex the index of the class to consider as "positive"
 * @return the area under the PRC curve or not a number
 * @author doina//from   w w w.  j av a2 s.com
 */
public double areaUnderPRC(int classIndex) {

    // Check if any predictions have been collected
    if (m_Predictions == null) {
        return Instance.missingValue();
    } else {
        ThresholdCurve tc = new ThresholdCurve();
        Instances result = tc.getCurve(m_Predictions, classIndex);
        return ThresholdCurve.getNPointPrecision(result, 11);
    }
}

From source file:cs.man.ac.uk.classifiers.GetAUC.java

License:Open Source License

/**
 * Computes the AUC for the supplied learner.
 * @return the AUC as a double value.//from  w  ww. j av a2  s.  c o m
 */
@SuppressWarnings("unused")
private static double validate5x2CV() {
    try {
        // other options
        int runs = 5;
        int folds = 2;
        double AUC_SUM = 0;

        // perform cross-validation
        for (int i = 0; i < runs; i++) {
            // randomize data
            int seed = i + 1;
            Random rand = new Random(seed);
            Instances randData = new Instances(data);
            randData.randomize(rand);

            if (randData.classAttribute().isNominal()) {
                System.out.println("Stratifying...");
                randData.stratify(folds);
            }

            Evaluation eval = new Evaluation(randData);

            for (int n = 0; n < folds; n++) {
                Instances train = randData.trainCV(folds, n);
                Instances test = randData.testCV(folds, n);

                // the above code is used by the StratifiedRemoveFolds filter, the
                // code below by the Explorer/Experimenter:
                // Instances train = randData.trainCV(folds, n, rand);

                // build and evaluate classifier
                String[] options = { "-U", "-A" };
                J48 classifier = new J48();
                //HTree classifier = new HTree();

                classifier.setOptions(options);
                classifier.buildClassifier(train);
                eval.evaluateModel(classifier, test);

                // generate curve
                ThresholdCurve tc = new ThresholdCurve();
                int classIndex = 0;
                Instances result = tc.getCurve(eval.predictions(), classIndex);

                // plot curve
                vmc = new ThresholdVisualizePanel();
                AUC_SUM += ThresholdCurve.getROCArea(result);
                System.out.println("AUC: " + ThresholdCurve.getROCArea(result) + " \tAUC SUM: " + AUC_SUM);
            }
        }

        return AUC_SUM / ((double) runs * (double) folds);
    } catch (Exception e) {
        System.out.println("Exception validating data!");
        return 0;
    }
}

From source file:cs.man.ac.uk.classifiers.GetAUC.java

License:Open Source License

/**
 * Computes the AUC for the supplied learner.
 * @param learner the learning algorithm to use.
 * @return the AUC as a double value.//  ww  w  .  j  a  v a  2s  .  c o  m
 */
@SuppressWarnings("unused")
private static double validate(Classifier learner) {
    try {

        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(learner, data, 2, new Random(1));

        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);

        // plot curve
        vmc = new ThresholdVisualizePanel();
        double AUC = ThresholdCurve.getROCArea(result);
        vmc.setROCString(
                "(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 9) + ")");
        vmc.setName(result.relationName());

        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();

        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++)
            cp[n] = true;

        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        return AUC;
    } catch (Exception e) {
        System.out.println("Exception validating data!");
        return 0;
    }
}

From source file:miRdup.WekaModule.java

License:Open Source License

public static void trainModel(File arff, String keyword) {
    dec.setMaximumFractionDigits(3);//from w ww .  j  ava2  s  . c o  m
    System.out.println("\nTraining model on file " + arff);
    try {
        // load data
        DataSource source = new DataSource(arff.toString());
        Instances data = source.getDataSet();
        if (data.classIndex() == -1) {
            data.setClassIndex(data.numAttributes() - 1);
        }

        PrintWriter pwout = new PrintWriter(new FileWriter(keyword + Main.modelExtension + "Output"));
        PrintWriter pwroc = new PrintWriter(new FileWriter(keyword + Main.modelExtension + "roc.arff"));

        //remove ID row
        Remove rm = new Remove();
        rm.setAttributeIndices("1");
        FilteredClassifier fc = new FilteredClassifier();
        fc.setFilter(rm);

        //            // train model svm
        //            weka.classifiers.functions.LibSVM model = new weka.classifiers.functions.LibSVM();
        //            model.setOptions(weka.core.Utils.splitOptions("-S 0 -K 2 -D 3 -G 0.0 -R 0.0 -N 0.5 -M 40.0 -C 1.0 -E 0.0010 -P 0.1 -B"));
        // train model MultilayerPerceptron
        //            weka.classifiers.functions.MultilayerPerceptron model = new weka.classifiers.functions.MultilayerPerceptron();
        //            model.setOptions(weka.core.Utils.splitOptions("-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a"));
        // train model Adaboost on RIPPER
        //            weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1();
        //            model.setOptions(weka.core.Utils.splitOptions("weka.classifiers.meta.AdaBoostM1 -P 100 -S 1 -I 10 -W weka.classifiers.rules.JRip -- -F 10 -N 2.0 -O 5 -S 1"));
        // train model Adaboost on FURIA
        //            weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1();
        //            model.setOptions(weka.core.Utils.splitOptions("weka.classifiers.meta.AdaBoostM1 -P 100 -S 1 -I 10 -W weka.classifiers.rules.FURIA -- -F 10 -N 2.0 -O 5 -S 1 -p 0 -s 0"));
        //train model Adaboot on J48 trees
        //             weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1();
        //             model.setOptions(
        //                     weka.core.Utils.splitOptions(
        //                     "-P 100 -S 1 -I 10 -W weka.classifiers.trees.J48 -- -C 0.25 -M 2"));
        //train model Adaboot on Random Forest trees
        weka.classifiers.meta.AdaBoostM1 model = new weka.classifiers.meta.AdaBoostM1();
        model.setOptions(weka.core.Utils
                .splitOptions("-P 100 -S 1 -I 10 -W weka.classifiers.trees.RandomForest -- -I 50 -K 0 -S 1"));

        if (Main.debug) {
            System.out.print("Model options: " + model.getClass().getName().trim() + " ");
        }
        System.out.print(model.getClass() + " ");
        for (String s : model.getOptions()) {
            System.out.print(s + " ");
        }

        pwout.print("Model options: " + model.getClass().getName().trim() + " ");
        for (String s : model.getOptions()) {
            pwout.print(s + " ");
        }

        //build model
        //            model.buildClassifier(data);
        fc.setClassifier(model);
        fc.buildClassifier(data);

        // cross validation 10 times on the model
        Evaluation eval = new Evaluation(data);
        //eval.crossValidateModel(model, data, 10, new Random(1));
        StringBuffer sb = new StringBuffer();
        eval.crossValidateModel(fc, data, 10, new Random(1), sb, new Range("first,last"), false);

        //System.out.println(sb);
        pwout.println(sb);
        pwout.flush();

        // output
        pwout.println("\n" + eval.toSummaryString());
        System.out.println(eval.toSummaryString());

        pwout.println(eval.toClassDetailsString());
        System.out.println(eval.toClassDetailsString());

        //calculate importants values
        String ev[] = eval.toClassDetailsString().split("\n");

        String ptmp[] = ev[3].trim().split(" ");
        String ntmp[] = ev[4].trim().split(" ");
        String avgtmp[] = ev[5].trim().split(" ");

        ArrayList<String> p = new ArrayList<String>();
        ArrayList<String> n = new ArrayList<String>();
        ArrayList<String> avg = new ArrayList<String>();

        for (String s : ptmp) {
            if (!s.trim().isEmpty()) {
                p.add(s);
            }
        }
        for (String s : ntmp) {
            if (!s.trim().isEmpty()) {
                n.add(s);
            }
        }
        for (String s : avgtmp) {
            if (!s.trim().isEmpty()) {
                avg.add(s);
            }
        }

        double tp = Double.parseDouble(p.get(0));
        double fp = Double.parseDouble(p.get(1));
        double tn = Double.parseDouble(n.get(0));
        double fn = Double.parseDouble(n.get(1));
        double auc = Double.parseDouble(avg.get(7));

        pwout.println("\nTP=" + tp + "\nFP=" + fp + "\nTN=" + tn + "\nFN=" + fn);
        System.out.println("\nTP=" + tp + "\nFP=" + fp + "\nTN=" + tn + "\nFN=" + fn);

        //specificity, sensitivity, Mathew's correlation, Prediction accuracy
        double sp = ((tn) / (tn + fp));
        double se = ((tp) / (tp + fn));
        double acc = ((tp + tn) / (tp + tn + fp + fn));
        double mcc = ((tp * tn) - (fp * fn)) / Math.sqrt((tp + fp) * (tn + fn) * (tp + fn) * tn + fp);

        String output = "\nse=" + dec.format(se).replace(",", ".") + "\nsp=" + dec.format(sp).replace(",", ".")
                + "\nACC=" + dec.format(acc).replace(",", ".") + "\nMCC=" + dec.format(mcc).replace(",", ".")
                + "\nAUC=" + dec.format(auc).replace(",", ".");

        pwout.println(output);
        System.out.println(output);

        pwout.println(eval.toMatrixString());
        System.out.println(eval.toMatrixString());

        pwout.flush();
        pwout.close();

        //Saving model
        System.out.println("Model saved: " + keyword + Main.modelExtension);
        weka.core.SerializationHelper.write(keyword + Main.modelExtension, fc.getClassifier() /*model*/);

        // get curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);
        pwroc.print(result.toString());
        pwroc.flush();
        pwroc.close();

        // draw curve
        //rocCurve(eval);
    } catch (Exception e) {
        e.printStackTrace();
    }
}

From source file:miRdup.WekaModule.java

License:Open Source License

public static void rocCurve(Evaluation eval) {
    try {/*w  ww .  j av a2  s  .  c o  m*/
        // generate curve
        ThresholdCurve tc = new ThresholdCurve();
        int classIndex = 0;
        Instances result = tc.getCurve(eval.predictions(), classIndex);
        result.toString();
        // plot curve
        ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
        vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")");
        vmc.setName(result.relationName());
        PlotData2D tempd = new PlotData2D(result);
        tempd.setPlotName(result.relationName());
        tempd.addInstanceNumberAttribute();
        // specify which points are connected
        boolean[] cp = new boolean[result.numInstances()];
        for (int n = 1; n < cp.length; n++) {
            cp[n] = true;
        }
        tempd.setConnectPoints(cp);
        // add plot
        vmc.addPlot(tempd);

        //
        result.toString();

        // display curve
        String plotName = vmc.getName();
        final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName);
        jf.setSize(500, 400);
        jf.getContentPane().setLayout(new BorderLayout());
        jf.getContentPane().add(vmc, BorderLayout.CENTER);
        jf.addWindowListener(new java.awt.event.WindowAdapter() {
            public void windowClosing(java.awt.event.WindowEvent e) {
                jf.dispose();
            }
        });

        jf.setVisible(true);
        System.out.println("");
    } catch (Exception e) {
        e.printStackTrace();
    }

}

From source file:mulan.evaluation.measure.MacroAUC.java

License:Open Source License

public double getValue() {
    double[] labelAUC = new double[numOfLabels];
    for (int i = 0; i < numOfLabels; i++) {
        ThresholdCurve tc = new ThresholdCurve();
        Instances result = tc.getCurve(m_Predictions[i], 1);
        labelAUC[i] = ThresholdCurve.getROCArea(result);
    }//from   w  ww .  j  a  va2  s .  c om
    return Utils.mean(labelAUC);
}

From source file:mulan.evaluation.measure.MicroAUC.java

License:Open Source License

public double getValue() {
    ThresholdCurve tc = new ThresholdCurve();
    Instances result = tc.getCurve(all_Predictions, 1);
    return ThresholdCurve.getROCArea(result);
}