List of usage examples for weka.classifiers CostMatrix toString
@Override
public String toString()
From source file:cotraining.copy.Evaluation_D.java
License:Open Source License
/** * Evaluates a classifier with the options given in an array of * strings. <p/>/* w w w.j a va 2 s.c om*/ * * Valid options are: <p/> * * -t name of training file <br/> * Name of the file with the training data. (required) <p/> * * -T name of test file <br/> * Name of the file with the test data. If missing a cross-validation * is performed. <p/> * * -c class index <br/> * Index of the class attribute (1, 2, ...; default: last). <p/> * * -x number of folds <br/> * The number of folds for the cross-validation (default: 10). <p/> * * -no-cv <br/> * No cross validation. If no test file is provided, no evaluation * is done. <p/> * * -split-percentage percentage <br/> * Sets the percentage for the train/test set split, e.g., 66. <p/> * * -preserve-order <br/> * Preserves the order in the percentage split instead of randomizing * the data first with the seed value ('-s'). <p/> * * -s seed <br/> * Random number seed for the cross-validation and percentage split * (default: 1). <p/> * * -m file with cost matrix <br/> * The name of a file containing a cost matrix. <p/> * * -l filename <br/> * Loads classifier from the given file. In case the filename ends with * ".xml",a PMML file is loaded or, if that fails, options are loaded from XML. <p/> * * -d filename <br/> * Saves classifier built from the training data into the given file. In case * the filename ends with ".xml" the options are saved XML, not the model. <p/> * * -v <br/> * Outputs no statistics for the training data. <p/> * * -o <br/> * Outputs statistics only, not the classifier. <p/> * * -i <br/> * Outputs detailed information-retrieval statistics per class. <p/> * * -k <br/> * Outputs information-theoretic statistics. <p/> * * -p range <br/> * Outputs predictions for test instances (or the train instances if no test * instances provided and -no-cv is used), along with the attributes in the specified range * (and nothing else). Use '-p 0' if no attributes are desired. <p/> * * -distribution <br/> * Outputs the distribution instead of only the prediction * in conjunction with the '-p' option (only nominal classes). <p/> * * -r <br/> * Outputs cumulative margin distribution (and nothing else). <p/> * * -g <br/> * Only for classifiers that implement "Graphable." Outputs * the graph representation of the classifier (and nothing * else). <p/> * * -xml filename | xml-string <br/> * Retrieves the options from the XML-data instead of the command line. <p/> * * @param classifier machine learning classifier * @param options the array of string containing the options * @throws Exception if model could not be evaluated successfully * @return a string describing the results */ public static String evaluateModel(Classifier classifier, String[] options) throws Exception { Instances train = null, tempTrain, test = null, template = null; int seed = 1, folds = 10, classIndex = -1; boolean noCrossValidation = false; String trainFileName, testFileName, sourceClass, classIndexString, seedString, foldsString, objectInputFileName, objectOutputFileName, attributeRangeString; boolean noOutput = false, printClassifications = false, trainStatistics = true, printMargins = false, printComplexityStatistics = false, printGraph = false, classStatistics = false, printSource = false; StringBuffer text = new StringBuffer(); DataSource trainSource = null, testSource = null; ObjectInputStream objectInputStream = null; BufferedInputStream xmlInputStream = null; CostMatrix costMatrix = null; StringBuffer schemeOptionsText = null; Range attributesToOutput = null; long trainTimeStart = 0, trainTimeElapsed = 0, testTimeStart = 0, testTimeElapsed = 0; String xml = ""; String[] optionsTmp = null; Classifier classifierBackup; Classifier classifierClassifications = null; boolean printDistribution = false; int actualClassIndex = -1; // 0-based class index String splitPercentageString = ""; int splitPercentage = -1; boolean preserveOrder = false; boolean trainSetPresent = false; boolean testSetPresent = false; String thresholdFile; String thresholdLabel; StringBuffer predsBuff = null; // predictions from cross-validation // help requested? if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) { // global info requested as well? boolean globalInfo = Utils.getFlag("synopsis", options) || Utils.getFlag("info", options); throw new Exception("\nHelp requested." + makeOptionString(classifier, globalInfo)); } try { // do we get the input from XML instead of normal parameters? xml = Utils.getOption("xml", options); if (!xml.equals("")) options = new XMLOptions(xml).toArray(); // is the input model only the XML-Options, i.e. w/o built model? optionsTmp = new String[options.length]; for (int i = 0; i < options.length; i++) optionsTmp[i] = options[i]; String tmpO = Utils.getOption('l', optionsTmp); //if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) { if (tmpO.endsWith(".xml")) { // try to load file as PMML first boolean success = false; try { //PMMLModel pmmlModel = PMMLFactory.getPMMLModel(tmpO); //if (pmmlModel instanceof PMMLClassifier) { //classifier = ((PMMLClassifier)pmmlModel); // success = true; //} } catch (IllegalArgumentException ex) { success = false; } if (!success) { // load options from serialized data ('-l' is automatically erased!) XMLClassifier xmlserial = new XMLClassifier(); Classifier cl = (Classifier) xmlserial.read(Utils.getOption('l', options)); // merge options optionsTmp = new String[options.length + cl.getOptions().length]; System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length); System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length); options = optionsTmp; } } noCrossValidation = Utils.getFlag("no-cv", options); // Get basic options (options the same for all schemes) classIndexString = Utils.getOption('c', options); if (classIndexString.length() != 0) { if (classIndexString.equals("first")) classIndex = 1; else if (classIndexString.equals("last")) classIndex = -1; else classIndex = Integer.parseInt(classIndexString); } trainFileName = Utils.getOption('t', options); objectInputFileName = Utils.getOption('l', options); objectOutputFileName = Utils.getOption('d', options); testFileName = Utils.getOption('T', options); foldsString = Utils.getOption('x', options); if (foldsString.length() != 0) { folds = Integer.parseInt(foldsString); } seedString = Utils.getOption('s', options); if (seedString.length() != 0) { seed = Integer.parseInt(seedString); } if (trainFileName.length() == 0) { if (objectInputFileName.length() == 0) { throw new Exception("No training file and no object " + "input file given."); } if (testFileName.length() == 0) { throw new Exception("No training file and no test " + "file given."); } } else if ((objectInputFileName.length() != 0) && ((!(classifier instanceof UpdateableClassifier)) || (testFileName.length() == 0))) { throw new Exception("Classifier not incremental, or no " + "test file provided: can't " + "use both train and model file."); } try { if (trainFileName.length() != 0) { trainSetPresent = true; trainSource = new DataSource(trainFileName); } if (testFileName.length() != 0) { testSetPresent = true; testSource = new DataSource(testFileName); } if (objectInputFileName.length() != 0) { if (objectInputFileName.endsWith(".xml")) { // if this is the case then it means that a PMML classifier was // successfully loaded earlier in the code objectInputStream = null; xmlInputStream = null; } else { InputStream is = new FileInputStream(objectInputFileName); if (objectInputFileName.endsWith(".gz")) { is = new GZIPInputStream(is); } // load from KOML? if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent())) { objectInputStream = new ObjectInputStream(is); xmlInputStream = null; } else { objectInputStream = null; xmlInputStream = new BufferedInputStream(is); } } } } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } if (testSetPresent) { template = test = testSource.getStructure(); if (classIndex != -1) { test.setClassIndex(classIndex - 1); } else { if ((test.classIndex() == -1) || (classIndexString.length() != 0)) test.setClassIndex(test.numAttributes() - 1); } actualClassIndex = test.classIndex(); } else { // percentage split splitPercentageString = Utils.getOption("split-percentage", options); if (splitPercentageString.length() != 0) { if (foldsString.length() != 0) throw new Exception("Percentage split cannot be used in conjunction with " + "cross-validation ('-x')."); splitPercentage = Integer.parseInt(splitPercentageString); if ((splitPercentage <= 0) || (splitPercentage >= 100)) throw new Exception("Percentage split value needs be >0 and <100."); } else { splitPercentage = -1; } preserveOrder = Utils.getFlag("preserve-order", options); if (preserveOrder) { if (splitPercentage == -1) throw new Exception("Percentage split ('-percentage-split') is missing."); } // create new train/test sources if (splitPercentage > 0) { testSetPresent = true; Instances tmpInst = trainSource.getDataSet(actualClassIndex); if (!preserveOrder) tmpInst.randomize(new Random(seed)); int trainSize = tmpInst.numInstances() * splitPercentage / 100; int testSize = tmpInst.numInstances() - trainSize; Instances trainInst = new Instances(tmpInst, 0, trainSize); Instances testInst = new Instances(tmpInst, trainSize, testSize); trainSource = new DataSource(trainInst); testSource = new DataSource(testInst); template = test = testSource.getStructure(); if (classIndex != -1) { test.setClassIndex(classIndex - 1); } else { if ((test.classIndex() == -1) || (classIndexString.length() != 0)) test.setClassIndex(test.numAttributes() - 1); } actualClassIndex = test.classIndex(); } } if (trainSetPresent) { template = train = trainSource.getStructure(); if (classIndex != -1) { train.setClassIndex(classIndex - 1); } else { if ((train.classIndex() == -1) || (classIndexString.length() != 0)) train.setClassIndex(train.numAttributes() - 1); } actualClassIndex = train.classIndex(); if ((testSetPresent) && !test.equalHeaders(train)) { throw new IllegalArgumentException("Train and test file not compatible!"); } } if (template == null) { throw new Exception("No actual dataset provided to use as template"); } costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses()); classStatistics = Utils.getFlag('i', options); noOutput = Utils.getFlag('o', options); trainStatistics = !Utils.getFlag('v', options); printComplexityStatistics = Utils.getFlag('k', options); printMargins = Utils.getFlag('r', options); printGraph = Utils.getFlag('g', options); sourceClass = Utils.getOption('z', options); printSource = (sourceClass.length() != 0); printDistribution = Utils.getFlag("distribution", options); thresholdFile = Utils.getOption("threshold-file", options); thresholdLabel = Utils.getOption("threshold-label", options); // Check -p option try { attributeRangeString = Utils.getOption('p', options); } catch (Exception e) { throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " + "It now expects a parameter specifying a range of attributes " + "to list with the predictions. Use '-p 0' for none."); } if (attributeRangeString.length() != 0) { printClassifications = true; noOutput = true; if (!attributeRangeString.equals("0")) attributesToOutput = new Range(attributeRangeString); } if (!printClassifications && printDistribution) throw new Exception("Cannot print distribution without '-p' option!"); // if no training file given, we don't have any priors if ((!trainSetPresent) && (printComplexityStatistics)) throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!"); // If a model file is given, we can't process // scheme-specific options if (objectInputFileName.length() != 0) { Utils.checkForRemainingOptions(options); } else { // Set options for classifier if (classifier instanceof OptionHandler) { for (int i = 0; i < options.length; i++) { if (options[i].length() != 0) { if (schemeOptionsText == null) { schemeOptionsText = new StringBuffer(); } if (options[i].indexOf(' ') != -1) { schemeOptionsText.append('"' + options[i] + "\" "); } else { schemeOptionsText.append(options[i] + " "); } } } ((OptionHandler) classifier).setOptions(options); } } Utils.checkForRemainingOptions(options); } catch (Exception e) { throw new Exception("\nWeka exception: " + e.getMessage() + makeOptionString(classifier, false)); } // Setup up evaluation objects Evaluation_D trainingEvaluation = new Evaluation_D(new Instances(template, 0), costMatrix); Evaluation_D testingEvaluation = new Evaluation_D(new Instances(template, 0), costMatrix); // disable use of priors if no training file given if (!trainSetPresent) testingEvaluation.useNoPriors(); if (objectInputFileName.length() != 0) { // Load classifier from file if (objectInputStream != null) { classifier = (Classifier) objectInputStream.readObject(); // try and read a header (if present) Instances savedStructure = null; try { savedStructure = (Instances) objectInputStream.readObject(); } catch (Exception ex) { // don't make a fuss } if (savedStructure != null) { // test for compatibility with template if (!template.equalHeaders(savedStructure)) { throw new Exception("training and test set are not compatible"); } } objectInputStream.close(); } else if (xmlInputStream != null) { // whether KOML is available has already been checked (objectInputStream would null otherwise)! classifier = (Classifier) KOML.read(xmlInputStream); xmlInputStream.close(); } } // backup of fully setup classifier for cross-validation classifierBackup = Classifier.makeCopy(classifier); // Build the classifier if no object file provided if ((classifier instanceof UpdateableClassifier) && (testSetPresent || noCrossValidation) && (costMatrix == null) && (trainSetPresent)) { // Build classifier incrementally trainingEvaluation.setPriors(train); testingEvaluation.setPriors(train); trainTimeStart = System.currentTimeMillis(); if (objectInputFileName.length() == 0) { classifier.buildClassifier(train); } Instance trainInst; while (trainSource.hasMoreElements(train)) { trainInst = trainSource.nextElement(train); trainingEvaluation.updatePriors(trainInst); testingEvaluation.updatePriors(trainInst); ((UpdateableClassifier) classifier).updateClassifier(trainInst); } trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } else if (objectInputFileName.length() == 0) { // Build classifier in one go tempTrain = trainSource.getDataSet(actualClassIndex); trainingEvaluation.setPriors(tempTrain); testingEvaluation.setPriors(tempTrain); trainTimeStart = System.currentTimeMillis(); classifier.buildClassifier(tempTrain); trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } // backup of fully trained classifier for printing the classifications if (printClassifications) classifierClassifications = Classifier.makeCopy(classifier); // Save the classifier if an object output file is provided if (objectOutputFileName.length() != 0) { OutputStream os = new FileOutputStream(objectOutputFileName); // binary if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) { if (objectOutputFileName.endsWith(".gz")) { os = new GZIPOutputStream(os); } ObjectOutputStream objectOutputStream = new ObjectOutputStream(os); objectOutputStream.writeObject(classifier); if (template != null) { objectOutputStream.writeObject(template); } objectOutputStream.flush(); objectOutputStream.close(); } // KOML/XML else { BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os); if (objectOutputFileName.endsWith(".xml")) { XMLSerialization xmlSerial = new XMLClassifier(); xmlSerial.write(xmlOutputStream, classifier); } else // whether KOML is present has already been checked // if not present -> ".koml" is interpreted as binary - see above if (objectOutputFileName.endsWith(".koml")) { KOML.write(xmlOutputStream, classifier); } xmlOutputStream.close(); } } // If classifier is drawable output string describing graph if ((classifier instanceof Drawable) && (printGraph)) { return ((Drawable) classifier).graph(); } // Output the classifier as equivalent source if ((classifier instanceof Sourcable) && (printSource)) { return wekaStaticWrapper((Sourcable) classifier, sourceClass); } // Output model if (!(noOutput || printMargins)) { if (classifier instanceof OptionHandler) { if (schemeOptionsText != null) { text.append("\nOptions: " + schemeOptionsText); text.append("\n"); } } text.append("\n" + classifier.toString() + "\n"); } if (!printMargins && (costMatrix != null)) { text.append("\n=== Evaluation Cost Matrix ===\n\n"); text.append(costMatrix.toString()); } // Output test instance predictions only if (printClassifications) { DataSource source = testSource; predsBuff = new StringBuffer(); // no test set -> use train set if (source == null && noCrossValidation) { source = trainSource; predsBuff.append("\n=== Predictions on training data ===\n\n"); } else { predsBuff.append("\n=== Predictions on test data ===\n\n"); } if (source != null) { /* return printClassifications(classifierClassifications, new Instances(template, 0), source, actualClassIndex + 1, attributesToOutput, printDistribution); */ printClassifications(classifierClassifications, new Instances(template, 0), source, actualClassIndex + 1, attributesToOutput, printDistribution, predsBuff); // return predsText.toString(); } } // Compute error estimate from training data if ((trainStatistics) && (trainSetPresent)) { if ((classifier instanceof UpdateableClassifier) && (testSetPresent) && (costMatrix == null)) { // Classifier was trained incrementally, so we have to // reset the source. trainSource.reset(); // Incremental testing train = trainSource.getStructure(actualClassIndex); testTimeStart = System.currentTimeMillis(); Instance trainInst; while (trainSource.hasMoreElements(train)) { trainInst = trainSource.nextElement(train); trainingEvaluation.evaluateModelOnce((Classifier) classifier, trainInst); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; } else { testTimeStart = System.currentTimeMillis(); trainingEvaluation.evaluateModel(classifier, trainSource.getDataSet(actualClassIndex)); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } // Print the results of the training evaluation if (printMargins) { return trainingEvaluation.toCumulativeMarginDistributionString(); } else { if (!printClassifications) { text.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); if (splitPercentage > 0) text.append("\nTime taken to test model on training split: "); else text.append("\nTime taken to test model on training data: "); text.append(Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds"); if (splitPercentage > 0) text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " split ===\n", printComplexityStatistics)); else text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + trainingEvaluation.toClassDetailsString()); } if (!noCrossValidation) text.append("\n\n" + trainingEvaluation.toMatrixString()); } } } } // Compute proper error estimates if (testSource != null) { // Testing is on the supplied test data testSource.reset(); test = testSource.getStructure(test.classIndex()); Instance testInst; while (testSource.hasMoreElements(test)) { testInst = testSource.nextElement(test); testingEvaluation.evaluateModelOnceAndRecordPrediction((Classifier) classifier, testInst); } if (splitPercentage > 0) { if (!printClassifications) { text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test split ===\n", printComplexityStatistics)); } } else { if (!printClassifications) { text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test data ===\n", printComplexityStatistics)); } } } else if (trainSource != null) { if (!noCrossValidation) { // Testing is via cross-validation on training data Random random = new Random(seed); // use untrained (!) classifier for cross-validation classifier = Classifier.makeCopy(classifierBackup); if (!printClassifications) { testingEvaluation.crossValidateModel(classifier, trainSource.getDataSet(actualClassIndex), folds, random); if (template.classAttribute().isNumeric()) { text.append("\n\n\n" + testingEvaluation.toSummaryString("=== Cross-validation ===\n", printComplexityStatistics)); } else { text.append("\n\n\n" + testingEvaluation.toSummaryString( "=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } } else { predsBuff = new StringBuffer(); predsBuff.append("\n=== Predictions under cross-validation ===\n\n"); testingEvaluation.crossValidateModel(classifier, trainSource.getDataSet(actualClassIndex), folds, random, predsBuff, attributesToOutput, new Boolean(printDistribution)); /* if (template.classAttribute().isNumeric()) { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Cross-validation ===\n", printComplexityStatistics)); } else { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } */ } } } if (template.classAttribute().isNominal()) { if (classStatistics && !noCrossValidation && !printClassifications) { text.append("\n\n" + testingEvaluation.toClassDetailsString()); } if (!noCrossValidation && !printClassifications) text.append("\n\n" + testingEvaluation.toMatrixString()); } // predictions from cross-validation? if (predsBuff != null) { text.append("\n" + predsBuff); } if ((thresholdFile.length() != 0) && template.classAttribute().isNominal()) { int labelIndex = 0; if (thresholdLabel.length() != 0) labelIndex = template.classAttribute().indexOfValue(thresholdLabel); if (labelIndex == -1) throw new IllegalArgumentException("Class label '" + thresholdLabel + "' is unknown!"); ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(testingEvaluation.predictions(), labelIndex); DataSink.write(thresholdFile, result); } return text.toString(); }
From source file:GClass.EvaluationInternal.java
License:Open Source License
/** * Evaluates a classifier with the options given in an array of * strings. <p>/*from ww w . j av a 2 s .c om*/ * * Valid options are: <p> * * -t name of training file <br> * Name of the file with the training data. (required) <p> * * -T name of test file <br> * Name of the file with the test data. If missing a cross-validation * is performed. <p> * * -c class index <br> * Index of the class attribute (1, 2, ...; default: last). <p> * * -x number of folds <br> * The number of folds for the cross-validation (default: 10). <p> * * -s random number seed <br> * Random number seed for the cross-validation (default: 1). <p> * * -m file with cost matrix <br> * The name of a file containing a cost matrix. <p> * * -l name of model input file <br> * Loads classifier from the given file. <p> * * -d name of model output file <br> * Saves classifier built from the training data into the given file. <p> * * -v <br> * Outputs no statistics for the training data. <p> * * -o <br> * Outputs statistics only, not the classifier. <p> * * -i <br> * Outputs detailed information-retrieval statistics per class. <p> * * -k <br> * Outputs information-theoretic statistics. <p> * * -p <br> * Outputs predictions for test instances (and nothing else). <p> * * -r <br> * Outputs cumulative margin distribution (and nothing else). <p> * * -g <br> * Only for classifiers that implement "Graphable." Outputs * the graph representation of the classifier (and nothing * else). <p> * * @param classifier machine learning classifier * @param options the array of string containing the options * @exception Exception if model could not be evaluated successfully * @return a string describing the results */ public static String[] evaluateModel(Classifier classifier, String trainFileName, String objectOutputFileName) throws Exception { Instances train = null, tempTrain, test = null, template = null; int seed = 1, folds = 10, classIndex = -1; String testFileName, sourceClass, classIndexString, seedString, foldsString, objectInputFileName, attributeRangeString; boolean IRstatistics = false, noOutput = false, printClassifications = false, trainStatistics = true, printMargins = false, printComplexityStatistics = false, printGraph = false, classStatistics = false, printSource = false; StringBuffer text = new StringBuffer(); BufferedReader trainReader = null, testReader = null; ObjectInputStream objectInputStream = null; CostMatrix costMatrix = null; StringBuffer schemeOptionsText = null; Range attributesToOutput = null; long trainTimeStart = 0, trainTimeElapsed = 0, testTimeStart = 0, testTimeElapsed = 0; try { String[] options = null; // Get basic options (options the same for all schemes) classIndexString = Utils.getOption('c', options); if (classIndexString.length() != 0) { classIndex = Integer.parseInt(classIndexString); } // trainFileName = Utils.getOption('t', options); objectInputFileName = Utils.getOption('l', options); // objectOutputFileName = Utils.getOption('d', options); testFileName = Utils.getOption('T', options); if (trainFileName.length() == 0) { if (objectInputFileName.length() == 0) { throw new Exception("No training file and no object " + "input file given."); } if (testFileName.length() == 0) { throw new Exception("No training file and no test " + "file given."); } } else if ((objectInputFileName.length() != 0) && ((!(classifier instanceof UpdateableClassifier)) || (testFileName.length() == 0))) { throw new Exception("Classifier not incremental, or no " + "test file provided: can't " + "use both train and model file."); } try { if (trainFileName.length() != 0) { trainReader = new BufferedReader(new FileReader(trainFileName)); } if (testFileName.length() != 0) { testReader = new BufferedReader(new FileReader(testFileName)); } if (objectInputFileName.length() != 0) { InputStream is = new FileInputStream(objectInputFileName); if (objectInputFileName.endsWith(".gz")) { is = new GZIPInputStream(is); } objectInputStream = new ObjectInputStream(is); } } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } if (testFileName.length() != 0) { template = test = new Instances(testReader, 1); if (classIndex != -1) { test.setClassIndex(classIndex - 1); } else { test.setClassIndex(test.numAttributes() - 1); } if (classIndex > test.numAttributes()) { throw new Exception("Index of class attribute too large."); } } if (trainFileName.length() != 0) { if ((classifier instanceof UpdateableClassifier) && (testFileName.length() != 0)) { train = new Instances(trainReader, 1); } else { train = new Instances(trainReader); } template = train; if (classIndex != -1) { train.setClassIndex(classIndex - 1); } else { train.setClassIndex(train.numAttributes() - 1); } if ((testFileName.length() != 0) && !test.equalHeaders(train)) { throw new IllegalArgumentException("Train and test file not compatible!"); } if (classIndex > train.numAttributes()) { throw new Exception("Index of class attribute too large."); } //train = new Instances(train); } if (template == null) { throw new Exception("No actual dataset provided to use as template"); } seedString = Utils.getOption('s', options); if (seedString.length() != 0) { seed = Integer.parseInt(seedString); } foldsString = Utils.getOption('x', options); if (foldsString.length() != 0) { folds = Integer.parseInt(foldsString); } costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses()); classStatistics = Utils.getFlag('i', options); noOutput = Utils.getFlag('o', options); trainStatistics = !Utils.getFlag('v', options); printComplexityStatistics = Utils.getFlag('k', options); printMargins = Utils.getFlag('r', options); printGraph = Utils.getFlag('g', options); sourceClass = Utils.getOption('z', options); printSource = (sourceClass.length() != 0); // Check -p option try { attributeRangeString = Utils.getOption('p', options); } catch (Exception e) { throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " + "It now expects a parameter specifying a range of attributes " + "to list with the predictions. Use '-p 0' for none."); } if (attributeRangeString.length() != 0) { printClassifications = true; if (!attributeRangeString.equals("0")) { attributesToOutput = new Range(attributeRangeString); } } // If a model file is given, we can't process // scheme-specific options if (objectInputFileName.length() != 0) { Utils.checkForRemainingOptions(options); } else { // Set options for classifier if (classifier instanceof OptionHandler) { /* for (int i = 0; i < options.length; i++) { if (options[i].length() != 0) { if (schemeOptionsText == null) { schemeOptionsText = new StringBuffer(); } if (options[i].indexOf(' ') != -1) { schemeOptionsText.append('"' + options[i] + "\" "); } else { schemeOptionsText.append(options[i] + " "); } } } */ ((OptionHandler) classifier).setOptions(options); } } Utils.checkForRemainingOptions(options); } catch (Exception e) { throw new Exception("\nWeka exception: " + e.getMessage() + makeOptionString(classifier)); } // Setup up evaluation objects EvaluationInternal trainingEvaluation = new EvaluationInternal(new Instances(template, 0), costMatrix); EvaluationInternal testingEvaluation = new EvaluationInternal(new Instances(template, 0), costMatrix); if (objectInputFileName.length() != 0) { // Load classifier from file classifier = (Classifier) objectInputStream.readObject(); objectInputStream.close(); } // Build the classifier if no object file provided if ((classifier instanceof UpdateableClassifier) && (testFileName.length() != 0) && (costMatrix == null) && (trainFileName.length() != 0)) { // Build classifier incrementally trainingEvaluation.setPriors(train); testingEvaluation.setPriors(train); trainTimeStart = System.currentTimeMillis(); if (objectInputFileName.length() == 0) { classifier.buildClassifier(train); } while (train.readInstance(trainReader)) { trainingEvaluation.updatePriors(train.instance(0)); testingEvaluation.updatePriors(train.instance(0)); ((UpdateableClassifier) classifier).updateClassifier(train.instance(0)); train.delete(0); } trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; trainReader.close(); } else if (objectInputFileName.length() == 0) { // Build classifier in one go tempTrain = new Instances(train); trainingEvaluation.setPriors(tempTrain); testingEvaluation.setPriors(tempTrain); trainTimeStart = System.currentTimeMillis(); classifier.buildClassifier(tempTrain); trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } // Save the classifier if an object output file is provided if (objectOutputFileName.length() != 0) { OutputStream os = new FileOutputStream(objectOutputFileName); if (objectOutputFileName.endsWith(".gz")) { os = new GZIPOutputStream(os); } ObjectOutputStream objectOutputStream = new ObjectOutputStream(os); objectOutputStream.writeObject(classifier); objectOutputStream.flush(); objectOutputStream.close(); } /* // If classifier is drawable output string describing graph if ((classifier instanceof Drawable) && (printGraph)) { return ((Drawable) classifier).graph(); } // Output the classifier as equivalent source if ((classifier instanceof Sourcable) && (printSource)) { return wekaStaticWrapper((Sourcable) classifier, sourceClass); } // Output test instance predictions only if (printClassifications) { return printClassifications(classifier, new Instances(template, 0), testFileName, classIndex, attributesToOutput); } */ // Output model if (!(noOutput || printMargins)) { if (classifier instanceof OptionHandler) { if (schemeOptionsText != null) { text.append("\nOptions: " + schemeOptionsText); text.append("\n"); } } text.append("\n" + classifier.toString() + "\n"); } if (!printMargins && (costMatrix != null)) { text.append("\n=== Evaluation Cost Matrix ===\n\n").append(costMatrix.toString()); } // Compute error estimate from training data if ((trainStatistics) && (trainFileName.length() != 0)) { if ((classifier instanceof UpdateableClassifier) && (testFileName.length() != 0) && (costMatrix == null)) { // Classifier was trained incrementally, so we have to // reopen the training data in order to test on it. trainReader = new BufferedReader(new FileReader(trainFileName)); // Incremental testing train = new Instances(trainReader, 1); if (classIndex != -1) { train.setClassIndex(classIndex - 1); } else { train.setClassIndex(train.numAttributes() - 1); } testTimeStart = System.currentTimeMillis(); while (train.readInstance(trainReader)) { trainingEvaluation.evaluateModelOnce((Classifier) classifier, train.instance(0)); train.delete(0); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; trainReader.close(); } else { testTimeStart = System.currentTimeMillis(); trainingEvaluation.evaluateModel(classifier, train); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } // Print the results of the training evaluation // if (printMargins) { // return trainingEvaluation.toCumulativeMarginDistributionString(); // } else { text.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); text.append("\nTime taken to test model on training data: " + Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds"); text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + trainingEvaluation.toClassDetailsString()); } text.append("\n\n" + trainingEvaluation.toMatrixString()); } // } } // Compute proper error estimates if (testFileName.length() != 0) { // Testing is on the supplied test data while (test.readInstance(testReader)) { testingEvaluation.evaluateModelOnce((Classifier) classifier, test.instance(0)); test.delete(0); } testReader.close(); text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test data ===\n", printComplexityStatistics)); } else if (trainFileName.length() != 0) { // Testing is via cross-validation on training data Random random = new Random(seed); testingEvaluation.crossValidateModel(classifier, train, folds, random); if (template.classAttribute().isNumeric()) { text.append("\n\n\n" + testingEvaluation.toSummaryString("=== Cross-validation ===\n", printComplexityStatistics)); } else { text.append("\n\n\n" + testingEvaluation .toSummaryString("=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } } if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + testingEvaluation.toClassDetailsString()); } text.append("\n\n" + testingEvaluation.toMatrixString()); } String result = "\t" + Utils.doubleToString(trainingEvaluation.pctCorrect(), 12, 4) + " %"; result += " " + Utils.doubleToString(testingEvaluation.pctCorrect(), 12, 4) + " %"; String[] returnString = { text.toString(), result }; return returnString; }
From source file:milk.classifiers.MIEvaluation.java
License:Open Source License
/** * Evaluates a classifier with the options given in an array of * strings. <p>//from ww w . j a v a2 s . c o m * * Valid options are: <p> * * -t filename <br> * Name of the file with the training data. (required) <p> * * -T filename <br> * Name of the file with the test data. If missing a cross-validation * is performed. <p> * * -c index <br> * Index of the class attribute (1, 2, ...; default: last). <p> * * -I index <br> * Index of the ID attribute (0, 1, 2, ...; default: first). <p> * * -x number <br> * The number of folds for the cross-validation (default: 10). <p> * * -s seed <br> * Random number seed for the cross-validation (default: 1). <p> * * -m filename <br> * The name of a file containing a cost matrix. <p> * * -l filename <br> * Loads classifier from the given file. <p> * * -g <br> * Only for classifiers that implement "Graphable." Outputs * the graph representation of the classifier (and nothing * else). <p> * * -L <br> * Whether use "Leave-One-Out" cross-validation. <p> * * -d filename <br> * Saves classifier built from the training data into the given file. <p> * * -v <br> * Outputs no statistics for the training data. <p> * * -o <br> * Outputs statistics only, not the classifier. <p> * * @param classifier machine learning classifier * @param options the array of string containing the options * @exception Exception if model could not be evaluated successfully * @return a string describing the results */ public static String evaluateModel(MIClassifier classifier, String[] options) throws Exception { Exemplars train = null, tempTrain, test = null, template = null; int seed = 1, folds = 10, classIndex = -1, idIndex = -1; String trainFileName, testFileName, sourceClass, classIndexString, idIndexString, seedString, foldsString, objectInputFileName, objectOutputFileName, attributeRangeString; boolean IRstatistics = false, noOutput = false, leaveOneOut = false, printClassifications = false, trainStatistics = true, printMargins = false, printComplexityStatistics = false, classStatistics = true, printSource = false, printGraph = false; StringBuffer text = new StringBuffer(); BufferedReader trainReader = null, testReader = null; ObjectInputStream objectInputStream = null; Random random = null; CostMatrix costMatrix = null; StringBuffer schemeOptionsText = null; Range attributesToOutput = null; long trainTimeStart = 0, trainTimeElapsed = 0, testTimeStart = 0, testTimeElapsed = 0; Instances data = null; try { // Get basic options (options the same for all schemes) classIndexString = Utils.getOption('c', options); if (classIndexString.length() != 0) classIndex = Integer.parseInt(classIndexString); idIndexString = Utils.getOption('I', options); if (idIndexString.length() != 0) idIndex = Integer.parseInt(idIndexString); trainFileName = Utils.getOption('t', options); objectInputFileName = Utils.getOption('l', options); objectOutputFileName = Utils.getOption('d', options); testFileName = Utils.getOption('T', options); if (trainFileName.length() == 0) { if (objectInputFileName.length() == 0) { throw new Exception("No training file and no object " + "input file given."); } if (testFileName.length() == 0) { throw new Exception("No training file and no test " + "file given."); } } else if ((objectInputFileName.length() != 0) && ((!(classifier instanceof MIUpdateableClassifier)) || (testFileName.length() == 0))) { throw new Exception("Classifier not incremental, or no " + "test file provided: can't " + "use both train and model file."); } try { if (trainFileName.length() != 0) { trainReader = new BufferedReader(new FileReader(trainFileName)); } if (testFileName.length() != 0) testReader = new BufferedReader(new FileReader(testFileName)); if (objectInputFileName.length() != 0) { InputStream is = new FileInputStream(objectInputFileName); if (objectInputFileName.endsWith(".gz")) { is = new GZIPInputStream(is); } objectInputStream = new ObjectInputStream(is); } } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } if (testFileName.length() != 0) { Instances insts = new Instances(testReader); if (classIndex != -1) insts.setClassIndex(classIndex - 1); else insts.setClassIndex(insts.numAttributes() - 1); if (classIndex > insts.numAttributes()) throw new Exception("Index of class attribute too large."); if (idIndex != -1) test = new Exemplars(insts, idIndex); else test = new Exemplars(insts, 0); template = test; testReader.close(); } if (trainFileName.length() != 0) { data = new Instances(trainReader); if (classIndex != -1) data.setClassIndex(classIndex - 1); else data.setClassIndex(data.numAttributes() - 1); if (classIndex > data.numAttributes()) throw new Exception("Index of class attribute too large."); Instances tmp = new Instances(data); if (idIndex != -1) train = new Exemplars(tmp, idIndex); else train = new Exemplars(tmp, 0); template = train; trainReader.close(); } if (template == null) throw new Exception("No actual dataset provided to use as template"); seedString = Utils.getOption('s', options); if (seedString.length() != 0) { seed = Integer.parseInt(seedString); } foldsString = Utils.getOption('x', options); if (foldsString.length() != 0) { folds = Integer.parseInt(foldsString); } costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses()); printGraph = Utils.getFlag('g', options); sourceClass = Utils.getOption('z', options); printMargins = Utils.getFlag('r', options); printSource = (sourceClass.length() != 0); classStatistics = Utils.getFlag('i', options); leaveOneOut = Utils.getFlag('L', options); if (leaveOneOut) // Leave-one-out folds = template.numExemplars(); // If a model file is given, we can't process // scheme-specific options if (objectInputFileName.length() != 0) { Utils.checkForRemainingOptions(options); } else { // Set options for classifier if (classifier instanceof OptionHandler) { for (int i = 0; i < options.length; i++) { if (options[i].length() != 0) { if (schemeOptionsText == null) { schemeOptionsText = new StringBuffer(); } if (options[i].indexOf(' ') != -1) { schemeOptionsText.append('"' + options[i] + "\" "); } else { schemeOptionsText.append(options[i] + " "); } } } ((OptionHandler) classifier).setOptions(options); } } Utils.checkForRemainingOptions(options); } catch (Exception e) { e.printStackTrace(); throw new Exception("\nWeka exception: " + e.getMessage() + makeOptionString(classifier)); } // Setup up evaluation objects MIEvaluation trainingEvaluation = new MIEvaluation(new Exemplars(template), costMatrix); MIEvaluation testingEvaluation = new MIEvaluation(new Exemplars(template), costMatrix); if (objectInputFileName.length() != 0) { // Load classifier from file classifier = (MIClassifier) objectInputStream.readObject(); objectInputStream.close(); } // Build the classifier if no object file provided if ((classifier instanceof MIUpdateableClassifier) && (testFileName.length() != 0) && (costMatrix == null) && (trainFileName.length() != 0)) { // Build classifier incrementally int x = 0; Exemplars traineg = new Exemplars(train.exemplar(x++).getInstances(), train.idIndex()); trainingEvaluation.setPriors(traineg); testingEvaluation.setPriors(traineg); trainTimeStart = System.currentTimeMillis(); if (objectInputFileName.length() == 0) { classifier.buildClassifier(traineg); } while (x < train.numExemplars()) { trainingEvaluation.updatePriors(train.exemplar(x)); testingEvaluation.updatePriors(train.exemplar(x)); ((MIUpdateableClassifier) classifier).updateClassifier(train.exemplar(x)); x++; } trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } else if (objectInputFileName.length() == 0) { // Build classifier in one go tempTrain = new Exemplars(train); trainingEvaluation.setPriors(tempTrain); testingEvaluation.setPriors(tempTrain); trainTimeStart = System.currentTimeMillis(); classifier.buildClassifier(tempTrain); trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } // Save the classifier if an object output file is provided if (objectOutputFileName.length() != 0) { OutputStream os = new FileOutputStream(objectOutputFileName); if (objectOutputFileName.endsWith(".gz")) { os = new GZIPOutputStream(os); } ObjectOutputStream objectOutputStream = new ObjectOutputStream(os); objectOutputStream.writeObject(classifier); objectOutputStream.flush(); objectOutputStream.close(); } // If classifier is drawable output string describing graph if ((classifier instanceof Drawable) && (printGraph)) { return ((Drawable) classifier).graph(); } // Output the classifier as equivalent source if ((classifier instanceof Sourcable) && (printSource)) { return wekaStaticWrapper((Sourcable) classifier, sourceClass); } // Output model if (classifier instanceof OptionHandler) { if (schemeOptionsText != null) { text.append("\nOptions: " + schemeOptionsText); text.append("\n"); } } text.append("\n" + classifier.toString() + "\n"); if (costMatrix != null) { text.append("\n=== Evaluation Cost Matrix ===\n\n").append(costMatrix.toString()); } // Compute error estimate from training data if (trainFileName.length() != 0) { if ((classifier instanceof MIUpdateableClassifier) && (testFileName.length() != 0) && (costMatrix == null)) { // Classifier was trained incrementally, so we have to // reopen the training data in order to test on it. trainReader = new BufferedReader(new FileReader(trainFileName)); // Incremental testing Instances trn = new Instances(trainReader); if (classIndex != -1) { trn.setClassIndex(classIndex - 1); } else { trn.setClassIndex(trn.numAttributes() - 1); } testTimeStart = System.currentTimeMillis(); if (idIndex != -1) train = new Exemplars(trn, idIndex); else train = new Exemplars(trn, 0); for (int y = 0; y < train.numExemplars(); y++) { trainingEvaluation.evaluateModelOnce((MIClassifier) classifier, train.exemplar(y)); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; trainReader.close(); } else { testTimeStart = System.currentTimeMillis(); trainingEvaluation.evaluateModel(classifier, train); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } // Print the results of the training evaluation text.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); text.append("\nTime taken to test model on training data: " + Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds"); text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + trainingEvaluation.toClassDetailsString()); } text.append("\n\n" + trainingEvaluation.toMatrixString()); } } // Compute proper error estimates if (testFileName.length() != 0) { // Testing is on the supplied test data for (int z = 0; z < test.numExemplars(); z++) testingEvaluation.evaluateModelOnce((MIClassifier) classifier, test.exemplar(z)); text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test data ===\n", printComplexityStatistics)); } else if (trainFileName.length() != 0) { // Testing is via cross-validation on training data if (random == null) random = new Random(seed); random.setSeed(seed); // In case exemplars are changed by classifier if (idIndex != -1) train = new Exemplars(data, idIndex); else train = new Exemplars(data, 0); train.randomize(random); testingEvaluation.crossValidateModel(classifier, train, folds); if (leaveOneOut) text.append("\n\n\n" + testingEvaluation.toSummaryString("=== Leave One Out Error ===\n", printComplexityStatistics)); else text.append("\n\n\n" + testingEvaluation .toSummaryString("=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + testingEvaluation.toClassDetailsString()); } text.append("\n\n" + testingEvaluation.toMatrixString()); } return text.toString(); }