List of usage examples for weka.classifiers Evaluation crossValidateModel
public void crossValidateModel(Classifier classifier, Instances data, int numFolds, Random random) throws Exception
From source file:binarizer.LayoutAnalysis.java
public double crossValidation(String arffFile) throws Exception { DataSource source = new DataSource(arffFile); Instances trainingData = source.getDataSet(); if (trainingData.classIndex() == -1) trainingData.setClassIndex(trainingData.numAttributes() - 1); NaiveBayes nb = new NaiveBayes(); nb.setUseSupervisedDiscretization(true); Evaluation evaluation = new Evaluation(trainingData); evaluation.crossValidateModel(nb, trainingData, 10, new Random(1)); System.out.println(evaluation.toSummaryString()); return evaluation.errorRate(); }
From source file:c4.pkg5crossv.Classifier.java
public static void trainAndTest() throws FileNotFoundException, IOException, Exception { Instances data = DataLoad.loadData("./src/data/irysy.arff"); data.setClassIndex(data.numAttributes() - 1); //Losowy podzial tablicy data.randomize(new Random()); double percent = 60.0; int trainSize = (int) Math.round(data.numInstances() * percent / 100); int testSize = data.numInstances() - trainSize; Instances trainData = new Instances(data, 0, trainSize); Instances testData = new Instances(data, trainSize, testSize); String[] options = Utils.splitOptions("-U -M 10"); J48 tree = new J48(); tree.setOptions(options);//from w w w . j a va 2s . c o m tree.buildClassifier(trainData); Evaluation eval2 = new Evaluation(trainData); eval2.crossValidateModel(tree, testData, 10, new Random(1)); // 5 - fold System.out.println(eval2.toSummaryString("Wyniki:", false)); //Wypisanie testovania cross validation }
From source file:ca.uottawa.balie.WekaLearner.java
License:Open Source License
/** * Approximate training set error.//from w ww. j a va 2 s . c om * * @return evaluation module from which many types of errors are exposed (e.g.: mean absolute error) */ public Evaluation EstimateConfidence() { Evaluation evaluation = null; try { evaluation = new Evaluation(m_TrainingSet); evaluation.crossValidateModel(m_Scheme, m_TrainingSet, 10, new Random()); } catch (Exception e) { System.out.println(e.getMessage()); } // which error is the best? depends on the application. return evaluation; }
From source file:ca.uqac.florentinth.speakerauthentication.Learning.Learning.java
License:Apache License
public void trainClassifier(Classifier classifier, FileReader trainingDataset, FileOutputStream trainingModel, Integer crossValidationFoldNumber) throws Exception { Instances instances = new Instances(new BufferedReader(trainingDataset)); switch (classifier) { case KNN://from w ww .j a va 2s . co m int K = (int) Math.ceil(Math.sqrt(instances.numInstances())); this.classifier = new IBk(K); break; case NB: this.classifier = new NaiveBayes(); } if (instances.classIndex() == -1) { instances.setClassIndex(instances.numAttributes() - 1); } this.classifier.buildClassifier(instances); if (crossValidationFoldNumber > 0) { Evaluation evaluation = new Evaluation(instances); evaluation.crossValidateModel(this.classifier, instances, crossValidationFoldNumber, new Random(1)); kappa = evaluation.kappa(); fMeasure = evaluation.weightedFMeasure(); confusionMatrix = evaluation.toMatrixString("Confusion matrix: "); } ObjectOutputStream outputStream = new ObjectOutputStream(trainingModel); outputStream.writeObject(this.classifier); outputStream.flush(); outputStream.close(); }
From source file:classif.ExperimentsLauncher.java
License:Open Source License
public void launchseq() { try {// ww w. ja v a 2 s . co m nbPrototypesMax = 10; int[] bestprototypes = new int[train.numClasses()]; double lowerror = 1.0; for (int j = 1; j <= nbPrototypesMax; j++) { int[] nbPrototypesPerClass = new int[train.numClasses()]; for (int i = 0; i < train.numClasses(); i++) { nbPrototypesPerClass[i] = j; } double errorBefore = 1; double errorNow = 1; int flag = 0; do { Unbalancecluster classifierseq = new Unbalancecluster(); classifierseq.setNbPrototypesPerClass(nbPrototypesPerClass); System.out.println(Arrays.toString(nbPrototypesPerClass)); // classifierseq.buildClassifier(train); Evaluation evalcv = new Evaluation(train); Random rand = new Random(1); evalcv.crossValidateModel(classifierseq, train, 10, rand); // errorNow = classifierseq.predictAccuracyXVal(10); errorNow = evalcv.errorRate(); System.out.println("errorBefore " + errorBefore); System.out.println("errorNow " + errorNow); if (errorNow < errorBefore) { nbPrototypesPerClass[flag]++; errorBefore = errorNow; } else { nbPrototypesPerClass[flag]--; flag++; if (flag >= nbPrototypesPerClass.length) break; nbPrototypesPerClass[flag]++; } } while (flag < nbPrototypesPerClass.length); // System.out.println("\nbest nbPrototypesPerClass " + Arrays.toString(nbPrototypesPerClass)); double testError = 0; for (int n = 0; n < nbExp; n++) { Unbalancecluster classifier = new Unbalancecluster(); classifier.setNbPrototypesPerClass(nbPrototypesPerClass); classifier.buildClassifier(train); Evaluation evaltest = new Evaluation(train); evaltest.evaluateModel(classifier, test); testError += evaltest.errorRate(); } double avgTestError = testError / nbExp; System.out.println(avgTestError); if (avgTestError < lowerror) { bestprototypes = nbPrototypesPerClass; lowerror = avgTestError; } } System.out.println("Best prototypes:" + Arrays.toString(bestprototypes) + "\n"); System.out.println("Best errorRate:" + lowerror + "\n"); } catch (Exception e) { e.printStackTrace(); } }
From source file:classify.Classifier.java
/** * @param args the command line arguments */// w ww . ja va 2 s . c om public static void main(String[] args) { //read in data try { DataSource input = new DataSource("no_missing_values.csv"); Instances data = input.getDataSet(); //Instances data = readFile("newfixed.txt"); missingValuesRows(data); setAttributeValues(data); data.setClassIndex(data.numAttributes() - 1); //boosting AdaBoostM1 boosting = new AdaBoostM1(); boosting.setNumIterations(25); boosting.setClassifier(new DecisionStump()); //build the classifier boosting.buildClassifier(data); //evaluate using 10-fold cross validation Evaluation e1 = new Evaluation(data); e1.crossValidateModel(boosting, data, 10, new Random(1)); DecimalFormat nf = new DecimalFormat("0.000"); System.out.println("Results of Boosting with Decision Stumps:"); System.out.println(boosting.toString()); System.out.println("Results of Cross Validation:"); System.out.println("Number of correctly classified instances: " + e1.correct() + " (" + nf.format(e1.pctCorrect()) + "%)"); System.out.println("Number of incorrectly classified instances: " + e1.incorrect() + " (" + nf.format(e1.pctIncorrect()) + "%)"); System.out.println("TP Rate: " + nf.format(e1.weightedTruePositiveRate() * 100) + "%"); System.out.println("FP Rate: " + nf.format(e1.weightedFalsePositiveRate() * 100) + "%"); System.out.println("Precision: " + nf.format(e1.weightedPrecision() * 100) + "%"); System.out.println("Recall: " + nf.format(e1.weightedRecall() * 100) + "%"); System.out.println(); System.out.println("Confusion Matrix:"); for (int i = 0; i < e1.confusionMatrix().length; i++) { for (int j = 0; j < e1.confusionMatrix()[0].length; j++) { System.out.print(e1.confusionMatrix()[i][j] + " "); } System.out.println(); } System.out.println(); System.out.println(); System.out.println(); //logistic regression Logistic l = new Logistic(); l.buildClassifier(data); e1 = new Evaluation(data); e1.crossValidateModel(l, data, 10, new Random(1)); System.out.println("Results of Logistic Regression:"); System.out.println(l.toString()); System.out.println("Results of Cross Validation:"); System.out.println("Number of correctly classified instances: " + e1.correct() + " (" + nf.format(e1.pctCorrect()) + "%)"); System.out.println("Number of incorrectly classified instances: " + e1.incorrect() + " (" + nf.format(e1.pctIncorrect()) + "%)"); System.out.println("TP Rate: " + nf.format(e1.weightedTruePositiveRate() * 100) + "%"); System.out.println("FP Rate: " + nf.format(e1.weightedFalsePositiveRate() * 100) + "%"); System.out.println("Precision: " + nf.format(e1.weightedPrecision() * 100) + "%"); System.out.println("Recall: " + nf.format(e1.weightedRecall() * 100) + "%"); System.out.println(); System.out.println("Confusion Matrix:"); for (int i = 0; i < e1.confusionMatrix().length; i++) { for (int j = 0; j < e1.confusionMatrix()[0].length; j++) { System.out.print(e1.confusionMatrix()[i][j] + " "); } System.out.println(); } } catch (Exception ex) { //data couldn't be read, so end program System.out.println("Exception thrown, program ending."); } }
From source file:com.daniel.convert.IncrementalClassifier.java
License:Open Source License
/** * Expects an ARFF file as first argument (class attribute is assumed to be * the last attribute).//from www .j a v a 2 s. co m * * @param args * the commandline arguments * @throws Exception * if something goes wrong */ public static BayesNet treinar(String[] args) throws Exception { // load data ArffLoader loader = new ArffLoader(); loader.setFile(new File(args[0])); Instances structure = loader.getStructure(); structure.setClassIndex(structure.numAttributes() - 1); // train NaiveBayes BayesNet BayesNet = new BayesNet(); Instance current; while ((current = loader.getNextInstance(structure)) != null) { structure.add(current); } BayesNet.buildClassifier(structure); // output generated model // System.out.println(nb); // test set BayesNet BayesNetTest = new BayesNet(); // test the model Evaluation eTest = new Evaluation(structure); // eTest.evaluateModel(nb, structure); eTest.crossValidateModel(BayesNetTest, structure, 15, new Random(1)); // Print the result la Weka explorer: String strSummary = eTest.toSummaryString(); System.out.println(strSummary); return BayesNet; }
From source file:com.guidefreitas.locator.services.PredictionService.java
public Evaluation train() { try {//w w w . java2 s . co m String arffData = this.generateTrainData(); InputStream stream = new ByteArrayInputStream(arffData.getBytes(StandardCharsets.UTF_8)); DataSource source = new DataSource(stream); Instances data = source.getDataSet(); data.setClassIndex(data.numAttributes() - 1); this.classifier = new LibSVM(); this.classifier.setKernelType(new SelectedTag(LibSVM.KERNELTYPE_POLYNOMIAL, LibSVM.TAGS_KERNELTYPE)); this.classifier.setSVMType(new SelectedTag(LibSVM.SVMTYPE_C_SVC, LibSVM.TAGS_SVMTYPE)); Evaluation eval = new Evaluation(data); eval.crossValidateModel(this.classifier, data, 10, new Random(1)); this.classifier.buildClassifier(data); return eval; } catch (Exception ex) { Logger.getLogger(PredictionService.class.getName()).log(Level.SEVERE, null, ex); } return null; }
From source file:com.ivanrf.smsspam.SpamClassifier.java
License:Apache License
public static void evaluate(int wordsToKeep, String tokenizerOp, boolean useAttributeSelection, String classifierOp, boolean boosting, JTextArea log) { try {//from ww w. j a va 2 s .com long start = System.currentTimeMillis(); String modelName = getModelName(wordsToKeep, tokenizerOp, useAttributeSelection, classifierOp, boosting); showEstimatedTime(false, modelName, log); Instances trainData = loadDataset("SMSSpamCollection.arff", log); trainData.setClassIndex(0); FilteredClassifier classifier = initFilterClassifier(wordsToKeep, tokenizerOp, useAttributeSelection, classifierOp, boosting); publishEstado("=== Performing cross-validation ===", log); Evaluation eval = new Evaluation(trainData); // eval.evaluateModel(classifier, trainData); eval.crossValidateModel(classifier, trainData, 10, new Random(1)); publishEstado(eval.toSummaryString(), log); publishEstado(eval.toClassDetailsString(), log); publishEstado(eval.toMatrixString(), log); publishEstado("=== Evaluation finished ===", log); publishEstado("Elapsed time: " + Utils.getDateHsMinSegString(System.currentTimeMillis() - start), log); } catch (Exception e) { e.printStackTrace(); publishEstado("Error found when evaluating", log); } }
From source file:com.sliit.views.DataVisualizerPanel.java
void getRocCurve() { try {// w ww. j a v a 2 s . c o m Instances data; data = new Instances(new BufferedReader(new FileReader(datasetPathText.getText()))); data.setClassIndex(data.numAttributes() - 1); // train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // display curve String plotName = vmc.getName(); final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName); jf.setSize(500, 400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(vmc, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }