List of usage examples for weka.classifiers Evaluation evaluateModel
public double[] evaluateModel(Classifier classifier, Instances data, Object... forPredictionsPrinting) throws Exception
From source file:FlexDMThread.java
License:Open Source License
public void run() { try {/*from w w w . j av a 2 s .co m*/ //Get the data from the source FlexDM.getMainData.acquire(); Instances data = dataset.getSource().getDataSet(); FlexDM.getMainData.release(); //Set class attribute if undefined if (data.classIndex() == -1) { data.setClassIndex(data.numAttributes() - 1); } //Process hyperparameters for classifier String temp = ""; for (int i = 0; i < classifier.getNumParams(); i++) { temp += classifier.getParameter(i).getName(); temp += " "; if (classifier.getParameter(i).getValue() != null) { temp += classifier.getParameter(i).getValue(); temp += " "; } } String[] options = weka.core.Utils.splitOptions(temp); //Print to console- experiment is starting if (temp.equals("")) { //no parameters temp = "results_no_parameters"; try { System.out.println("STARTING CLASSIFIER " + cNum + " - " + classifier.getName() + " on dataset " + dataset.getName().substring(dataset.getName().lastIndexOf("\\") + 1) + " with no parameters"); } catch (Exception e) { System.out.println("STARTING CLASSIFIER " + cNum + " - " + classifier.getName() + " on dataset " + dataset.getName() + " with no parameters"); } } else { //parameters try { System.out.println("STARTING CLASSIFIER " + cNum + " - " + classifier.getName() + " on dataset " + dataset.getName().substring(dataset.getName().lastIndexOf("\\") + 1) + " with parameters " + temp); } catch (Exception e) { System.out.println("STARTING CLASSIFIER " + cNum + " - " + classifier.getName() + " on dataset " + dataset.getName() + " with parameters " + temp); } } //Create classifier, setting parameters weka.classifiers.Classifier x = createObject(classifier.getName()); x.setOptions(options); x.buildClassifier(data); //Process the test selection String[] tempTest = dataset.getTest().split("\\s"); //Create evaluation object for training and testing classifiers Evaluation eval = new Evaluation(data); StringBuffer predictions = new StringBuffer(); //Train and evaluate classifier if (tempTest[0].equals("testset")) { //specified test file //Build classifier x.buildClassifier(data); //Open test file, load data //DataSource testFile = new DataSource(dataset.getTest().substring(7).trim()); // Instances testSet = testFile.getDataSet(); FlexDM.getTestData.acquire(); Instances testSet = dataset.getTestFile().getDataSet(); FlexDM.getTestData.release(); //Set class attribute if undefined if (testSet.classIndex() == -1) { testSet.setClassIndex(testSet.numAttributes() - 1); } //Evaluate model Object[] array = { predictions, new Range(), new Boolean(true) }; eval.evaluateModel(x, testSet, array); } else if (tempTest[0].equals("xval")) { //Cross validation //Build classifier x.buildClassifier(data); //Cross validate eval.crossValidateModel(x, data, Integer.parseInt(tempTest[1]), new Random(1), predictions, new Range(), true); } else if (tempTest[0].equals("leavexval")) { //Leave one out cross validation //Build classifier x.buildClassifier(data); //Cross validate eval.crossValidateModel(x, data, data.numInstances() - 1, new Random(1), predictions, new Range(), true); } else if (tempTest[0].equals("percent")) { //Percentage split of single data set //Set training and test sizes from percentage int trainSize = (int) Math.round(data.numInstances() * Double.parseDouble(tempTest[1])); int testSize = data.numInstances() - trainSize; //Load specified data Instances train = new Instances(data, 0, trainSize); Instances testSet = new Instances(data, trainSize, testSize); //Build classifier x.buildClassifier(train); //Train and evaluate model Object[] array = { predictions, new Range(), new Boolean(true) }; eval.evaluateModel(x, testSet, array); } else { //Evaluate on training data //Test and evaluate model Object[] array = { predictions, new Range(), new Boolean(true) }; eval.evaluateModel(x, data, array); } //create datafile for results String filename = dataset.getDir() + "/" + classifier.getDirName() + "/" + temp + ".txt"; PrintWriter writer = new PrintWriter(filename, "UTF-8"); //Print classifier, dataset, parameters info to file try { writer.println("CLASSIFIER: " + classifier.getName() + "\n DATASET: " + dataset.getName() + "\n PARAMETERS: " + temp); } catch (Exception e) { writer.println("CLASSIFIER: " + classifier.getName() + "\n DATASET: " + dataset.getName() + "\n PARAMETERS: " + temp); } //Add evaluation string to file writer.println(eval.toSummaryString()); //Process result options if (checkResults("stats")) { //Classifier statistics writer.println(eval.toClassDetailsString()); } if (checkResults("model")) { //The model writer.println(x.toString()); } if (checkResults("matrix")) { //Confusion matrix writer.println(eval.toMatrixString()); } if (checkResults("entropy")) { //Entropy statistics //Set options req'd to get the entropy stats String[] opt = new String[4]; opt[0] = "-t"; opt[1] = dataset.getName(); opt[2] = "-k"; opt[3] = "-v"; //Evaluate model String entropy = Evaluation.evaluateModel(x, opt); //Grab the relevant info from the results, print to file entropy = entropy.substring(entropy.indexOf("=== Stratified cross-validation ===") + 35, entropy.indexOf("=== Confusion Matrix ===")); writer.println("=== Entropy Statistics ==="); writer.println(entropy); } if (checkResults("predictions")) { //The models predictions writer.println("=== Predictions ===\n"); if (!dataset.getTest().contains("xval")) { //print header of predictions table if req'd writer.println(" inst# actual predicted error distribution ()"); } writer.println(predictions.toString()); //print predictions to file } writer.close(); //Summary file is semaphore controlled to ensure quality try { //get a permit //grab the summary file, write the classifiers details to it FlexDM.writeFile.acquire(); PrintWriter p = new PrintWriter(new FileWriter(summary, true)); if (temp.equals("results_no_parameters")) { //change output based on parameters temp = temp.substring(8); } //write percent correct, classifier name, dataset name to summary file p.write(dataset.getName() + ", " + classifier.getName() + ", " + temp + ", " + eval.correct() + ", " + eval.incorrect() + ", " + eval.unclassified() + ", " + eval.pctCorrect() + ", " + eval.pctIncorrect() + ", " + eval.pctUnclassified() + ", " + eval.kappa() + ", " + eval.meanAbsoluteError() + ", " + eval.rootMeanSquaredError() + ", " + eval.relativeAbsoluteError() + ", " + eval.rootRelativeSquaredError() + ", " + eval.SFPriorEntropy() + ", " + eval.SFSchemeEntropy() + ", " + eval.SFEntropyGain() + ", " + eval.SFMeanPriorEntropy() + ", " + eval.SFMeanSchemeEntropy() + ", " + eval.SFMeanEntropyGain() + ", " + eval.KBInformation() + ", " + eval.KBMeanInformation() + ", " + eval.KBRelativeInformation() + ", " + eval.weightedTruePositiveRate() + ", " + eval.weightedFalsePositiveRate() + ", " + eval.weightedTrueNegativeRate() + ", " + eval.weightedFalseNegativeRate() + ", " + eval.weightedPrecision() + ", " + eval.weightedRecall() + ", " + eval.weightedFMeasure() + ", " + eval.weightedAreaUnderROC() + "\n"); p.close(); //release semaphore FlexDM.writeFile.release(); } catch (InterruptedException e) { //bad things happened System.err.println("FATAL ERROR OCCURRED: Classifier: " + cNum + " - " + classifier.getName() + " on dataset " + dataset.getName()); } //output we have successfully finished processing classifier if (temp.equals("no_parameters")) { //no parameters try { System.out.println("FINISHED CLASSIFIER " + cNum + " - " + classifier.getName() + " on dataset " + dataset.getName().substring(dataset.getName().lastIndexOf("\\") + 1) + " with no parameters"); } catch (Exception e) { System.out.println("FINISHED CLASSIFIER " + cNum + " - " + classifier.getName() + " on dataset " + dataset.getName() + " with no parameters"); } } else { //with parameters try { System.out.println("FINISHED CLASSIFIER " + cNum + " - " + classifier.getName() + " on dataset " + dataset.getName().substring(dataset.getName().lastIndexOf("\\") + 1) + " with parameters " + temp); } catch (Exception e) { System.out.println("FINISHED CLASSIFIER " + cNum + " - " + classifier.getName() + " on dataset " + dataset.getName() + " with parameters " + temp); } } try { //get a permit //grab the log file, write the classifiers details to it FlexDM.writeLog.acquire(); PrintWriter p = new PrintWriter(new FileWriter(log, true)); Date date = new Date(); Format formatter = new SimpleDateFormat("dd/MM/YYYY HH:mm:ss"); //formatter.format(date) if (temp.equals("results_no_parameters")) { //change output based on parameters temp = temp.substring(8); } //write details to log file p.write(dataset.getName() + ", " + dataset.getTest() + ", \"" + dataset.getResult_string() + "\", " + classifier.getName() + ", " + temp + ", " + formatter.format(date) + "\n"); p.close(); //release semaphore FlexDM.writeLog.release(); } catch (InterruptedException e) { //bad things happened System.err.println("FATAL ERROR OCCURRED: Classifier: " + cNum + " - " + classifier.getName() + " on dataset " + dataset.getName()); } s.release(); } catch (Exception e) { //an error occurred System.err.println("FATAL ERROR OCCURRED: " + e.toString() + "\nClassifier: " + cNum + " - " + classifier.getName() + " on dataset " + dataset.getName()); s.release(); } }
From source file:adams.flow.transformer.WekaTestSetEvaluator.java
License:Open Source License
/** * Executes the flow item./* w w w . ja v a2 s. c o m*/ * * @return null if everything is fine, otherwise error message */ @Override protected String doExecute() { String result; Instances test; Evaluation eval; weka.classifiers.Classifier cls; CallableSource gs; Token output; result = null; test = null; try { // get test set test = null; gs = new CallableSource(); gs.setCallableName(m_Testset); gs.setParent(getParent()); gs.setUp(); gs.execute(); output = gs.output(); if (output != null) test = (Instances) output.getPayload(); else result = "No test set available!"; gs.wrapUp(); // evaluate classifier if (result == null) { if (m_InputToken.getPayload() instanceof weka.classifiers.Classifier) cls = (weka.classifiers.Classifier) m_InputToken.getPayload(); else cls = (weka.classifiers.Classifier) ((WekaModelContainer) m_InputToken.getPayload()) .getValue(WekaModelContainer.VALUE_MODEL); initOutputBuffer(); m_Output.setHeader(test); eval = new Evaluation(test); eval.setDiscardPredictions(m_DiscardPredictions); eval.evaluateModel(cls, test, m_Output); // broadcast result if (m_Output instanceof Null) { m_OutputToken = new Token(new WekaEvaluationContainer(eval, cls)); } else { if (m_AlwaysUseContainer) m_OutputToken = new Token( new WekaEvaluationContainer(eval, cls, m_Output.getBuffer().toString())); else m_OutputToken = new Token(m_Output.getBuffer().toString()); } } } catch (Exception e) { m_OutputToken = null; result = handleException("Failed to evaluate: ", e); } if (m_OutputToken != null) { if (m_OutputToken.getPayload() instanceof WekaEvaluationContainer) { if (test != null) ((WekaEvaluationContainer) m_OutputToken.getPayload()) .setValue(WekaEvaluationContainer.VALUE_TESTDATA, test); } updateProvenance(m_OutputToken); } return result; }
From source file:adams.flow.transformer.WekaTrainTestSetEvaluator.java
License:Open Source License
/** * Executes the flow item./*from w ww . j a v a 2 s .c om*/ * * @return null if everything is fine, otherwise error message */ @Override protected String doExecute() { String result; Instances train; Instances test; weka.classifiers.Classifier cls; Evaluation eval; WekaTrainTestSetContainer cont; result = null; test = null; try { // cross-validate classifier cls = getClassifierInstance(); if (cls == null) throw new IllegalStateException("Classifier '" + getClassifier() + "' not found!"); cont = (WekaTrainTestSetContainer) m_InputToken.getPayload(); train = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN); test = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST); cls.buildClassifier(train); initOutputBuffer(); m_Output.setHeader(train); eval = new Evaluation(train); eval.setDiscardPredictions(m_DiscardPredictions); eval.evaluateModel(cls, test, m_Output); // broadcast result if (m_Output instanceof Null) { m_OutputToken = new Token(new WekaEvaluationContainer(eval, cls)); } else { if (m_AlwaysUseContainer) m_OutputToken = new Token( new WekaEvaluationContainer(eval, cls, m_Output.getBuffer().toString())); else m_OutputToken = new Token(m_Output.getBuffer().toString()); } } catch (Exception e) { m_OutputToken = null; result = handleException("Failed to evaluate: ", e); } if (m_OutputToken != null) { if (m_OutputToken.getPayload() instanceof WekaEvaluationContainer) { if (test != null) ((WekaEvaluationContainer) m_OutputToken.getPayload()) .setValue(WekaEvaluationContainer.VALUE_TESTDATA, test); } updateProvenance(m_OutputToken); } return result; }
From source file:adams.multiprocess.WekaCrossValidationExecution.java
License:Open Source License
/** * Executes the flow item./*from w ww . ja v a2 s. co m*/ * * @return null if everything is fine, otherwise error message */ public String execute() { MessageCollection result; Evaluation eval; AggregateEvaluations evalAgg; int folds; CrossValidationFoldGenerator generator; JobList<WekaCrossValidationJob> list; WekaCrossValidationJob job; WekaTrainTestSetContainer cont; int i; int current; int[] indices; Instances train; Instances test; Classifier cls; result = new MessageCollection(); indices = null; m_Evaluation = null; m_Evaluations = null; try { // evaluate classifier if (m_Classifier == null) throw new IllegalStateException("Classifier '" + getClassifier() + "' not found!"); if (isLoggingEnabled()) getLogger().info(OptionUtils.getCommandLine(m_Classifier)); m_ActualNumThreads = Performance.determineNumThreads(m_NumThreads); generator = (CrossValidationFoldGenerator) OptionUtils.shallowCopy(m_Generator); generator.setData(m_Data); generator.setNumFolds(m_Folds); generator.setSeed(m_Seed); generator.setStratify(true); generator.setUseViews(m_UseViews); generator.initializeIterator(); folds = generator.getActualNumFolds(); if ((m_ActualNumThreads == 1) && !m_SeparateFolds) { initOutputBuffer(); if (m_Output != null) { m_Output.setHeader(m_Data); m_Output.printHeader(); } eval = new Evaluation(m_Data); eval.setDiscardPredictions(m_DiscardPredictions); current = 0; while (generator.hasNext()) { if (isStopped()) break; if (m_StatusMessageHandler != null) m_StatusMessageHandler.showStatus("Fold " + current + "/" + folds + ": '" + m_Data.relationName() + "' using " + OptionUtils.getCommandLine(m_Classifier)); cont = generator.next(); train = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN); test = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST); cls = (Classifier) OptionUtils.shallowCopy(m_Classifier); cls.buildClassifier(train); eval.setPriors(train); eval.evaluateModel(cls, test, m_Output); current++; } if (m_Output != null) m_Output.printFooter(); if (!isStopped()) m_Evaluation = eval; } else { if (m_DiscardPredictions) throw new IllegalStateException( "Cannot discard predictions in parallel mode, as they are used for aggregating the statistics!"); if (m_JobRunnerSetup == null) m_JobRunner = new LocalJobRunner<WekaCrossValidationJob>(); else m_JobRunner = m_JobRunnerSetup.newInstance(); if (m_JobRunner instanceof ThreadLimiter) ((ThreadLimiter) m_JobRunner).setNumThreads(m_NumThreads); list = new JobList<>(); while (generator.hasNext()) { cont = generator.next(); job = new WekaCrossValidationJob((Classifier) OptionUtils.shallowCopy(m_Classifier), (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN), (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST), (Integer) cont.getValue(WekaTrainTestSetContainer.VALUE_FOLD_NUMBER), m_DiscardPredictions, m_StatusMessageHandler); list.add(job); } m_JobRunner.add(list); m_JobRunner.start(); m_JobRunner.stop(); // aggregate data if (!isStopped()) { evalAgg = new AggregateEvaluations(); m_Evaluations = new Evaluation[m_JobRunner.getJobs().size()]; for (i = 0; i < m_JobRunner.getJobs().size(); i++) { job = (WekaCrossValidationJob) m_JobRunner.getJobs().get(i); if (job.getEvaluation() == null) { result.add("Fold #" + (i + 1) + " failed to evaluate" + (job.hasExecutionError() ? job.getExecutionError() : "?")); break; } evalAgg.add(job.getEvaluation()); m_Evaluations[i] = job.getEvaluation(); job.cleanUp(); } m_Evaluation = evalAgg.aggregated(); if (m_Evaluation == null) { if (evalAgg.hasLastError()) result.add(evalAgg.getLastError()); else result.add("Failed to aggregate evaluations!"); } } list.cleanUp(); m_JobRunner.cleanUp(); m_JobRunner = null; } if (!m_DiscardPredictions) indices = generator.crossValidationIndices(); } catch (Exception e) { result.add(Utils.handleException(this, "Failed to cross-validate classifier: ", e)); } m_OriginalIndices = indices; if (result.isEmpty()) return null; else return result.toString(); }
From source file:dkpro.similarity.experiments.rte.util.Evaluator.java
License:Open Source License
public static void runClassifier(WekaClassifier wekaClassifier, Dataset trainDataset, Dataset testDataset) throws Exception { Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the train instances and get the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + trainDataset.toString() + ".arff", "-o", MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff" }); Instances train = DataSource.read(MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff"); train.setClassIndex(train.numAttributes() - 1); // Add IDs to the test instances and get the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + testDataset.toString() + ".arff", "-o", MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff" }); Instances test = DataSource.read(MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff"); test.setClassIndex(test.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data test.randomize(random);//from w w w .j a v a2 s . co m // Apply log filter // Filter logFilter = new LogFilter(); // logFilter.setInputFormat(train); // train = Filter.useFilter(train, logFilter); // logFilter.setInputFormat(test); // test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Prepare the output buffer AbstractOutput output = new PlainText(); output.setBuffer(new StringBuffer()); output.setHeader(test); output.setAttributes("first"); Evaluation eval = new Evaluation(train); eval.evaluateModel(filteredClassifier, test, output); // Convert predictions to CSV // Format: inst#, actual, predicted, error, probability, (ID) String[] scores = new String[new Double(eval.numInstances()).intValue()]; double[] probabilities = new double[new Double(eval.numInstances()).intValue()]; for (String line : output.getBuffer().toString().split("\n")) { String[] linesplit = line.split("\\s+"); // If there's been an error, the length of linesplit is 6, otherwise 5, // due to the error flag "+" int id; String expectedValue, classification; double probability; if (line.contains("+")) { id = Integer.parseInt(linesplit[6].substring(1, linesplit[6].length() - 1)); expectedValue = linesplit[2].substring(2); classification = linesplit[3].substring(2); probability = Double.parseDouble(linesplit[5]); } else { id = Integer.parseInt(linesplit[5].substring(1, linesplit[5].length() - 1)); expectedValue = linesplit[2].substring(2); classification = linesplit[3].substring(2); probability = Double.parseDouble(linesplit[4]); } scores[id - 1] = classification; probabilities[id - 1] = probability; } System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); // Output classifications StringBuilder sb = new StringBuilder(); for (String score : scores) sb.append(score.toString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".csv"), sb.toString()); // Output probabilities sb = new StringBuilder(); for (Double probability : probabilities) sb.append(probability.toString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".probabilities.csv"), sb.toString()); // Output predictions FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".predictions.txt"), output.getBuffer().toString()); // Output meta information sb = new StringBuilder(); sb.append(classifier.toString() + LF); sb.append(eval.toSummaryString() + LF); sb.append(eval.toMatrixString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".meta.txt"), sb.toString()); }