List of usage examples for weka.classifiers Evaluation Evaluation
public Evaluation(Instances data) throws Exception
From source file:com.actelion.research.orbit.imageAnalysis.tasks.ObjectTrainWorker.java
License:Open Source License
@Override protected void doWork() { if (dontRun) { dontRun = false;//from w ww.jav a2 s . c o m return; } trainSet = null; if (modelToBuild != null && modelToBuild.getClassifier() != null) modelToBuild.getClassifier().setBuild(false); List<double[]> trainData = new ArrayList<double[]>(); int mipLayer = -1; // used for checking if all iFrames (with trainData) have the same mapLayer (otherwise the model cannot be trained) for (ImageFrame iFrame : iFrames) { int sampleSize = Math.min(3, iFrame.recognitionFrame.bimg.getImage().getSampleModel().getNumBands()); // was always 1 before! (max 3 because alpha should be ignored) for (int i = 0; i < iFrame.recognitionFrame.getClassShapes().size(); i++) { checkPaused(); List<Shape> shapes = iFrame.recognitionFrame.getClassShapes().get(i).getShapeList(); if (shapes != null && shapes.size() > 0) { if (mipLayer < 0) { mipLayer = iFrame.getMipLayer(); logger.trace("iFrame candidate mipLayer {} from iFrame with width {}", mipLayer, iFrame.recognitionFrame.bimg.getWidth()); } else { if (mipLayer != iFrame.getMipLayer()) { logger.error( "Cell classifier cannot be trained on different image layers. Please use only training data of the same image layer."); return; } } if (mipLayer != modelToBuild.getMipLayer()) { // only same layer as segmentation allowed. Otherwise the cell features must be scaled, too (which is not yet the case). logger.error("Cell classifier must be trained on same layer as segmentation"); return; } } trainData.addAll(new ObjectFeatureBuilderTiled(null).buildFeatures(shapes, i + 1, iFrame.recognitionFrame, iFrame.recognitionFrame.getClassImage(), sampleSize, 0, 0)); // classes 1.0, 2.0, ... } } logger.trace("train levelNum: {}", mipLayer); if (trainData.size() == 0) { logger.error("trainset is empty, classifier cannot be trained."); trainSet = null; return; } if (isCancelled()) { cleanUp(); return; } timeEst = 1000 * 10L; setProgress(10); logger.debug("trainData contains " + trainData.size() + " samples"); Attribute classAttr = null; // create the first time a new trainSet. All further trainings will append new instances. if (trainSet == null) { // build traindata header double[] firstRowAll = trainData.get(0); double[] firstRow = Arrays.copyOfRange(firstRowAll, 0, firstRowAll.length - ObjectFeatureBuilderTiled.SkipTailForClassification); ArrayList<Attribute> attrInfo = new ArrayList<Attribute>(firstRow.length); for (int a = 0; a < firstRow.length - 1; a++) { Attribute attr = new Attribute("a" + a); // if (a<firstRow.length-2) attr.setWeight(0.1d); else attr.setWeight(1.0d); attrInfo.add(attr); } List<String> classValues = new ArrayList<String>( iFrames.get(0).recognitionFrame.getClassShapes().size()); for (int i = 0; i < iFrames.get(0).recognitionFrame.getClassShapes().size(); i++) { classValues.add((i + 1) + ".0"); // "1.0", "2.0", ... } classAttr = new Attribute("class", classValues); attrInfo.add(classAttr); trainSet = new Instances("trainSet pattern classes", attrInfo, trainData.size()); trainSet.setClassIndex(firstRow.length - 1); } else classAttr = trainSet.attribute("class"); // add instances for (double[] valsAll : trainData) { // skip some non relevant attributes like centerX/Y double[] vals = Arrays.copyOfRange(valsAll, 0, valsAll.length - ObjectFeatureBuilderTiled.SkipTailForClassification); vals[vals.length - 1] = valsAll[valsAll.length - 1]; // class value double classV = classAttr.indexOfValue(Double.toString(vals[vals.length - 1])); vals[vals.length - 1] = classV; Instance inst = new DenseInstance(1.0d, vals); trainSet.add(inst); } // trainSet = trainSet.resample(rand); logger.debug("trainSet contains " + trainSet.numInstances() + " instances"); if (logger.isTraceEnabled()) logger.trace(trainSet.toString()); // building classifier if (isCancelled()) { cleanUp(); return; } checkPaused(); timeEst = 1000 * 5L; setProgress(20); logger.info("Start training classifier... "); classifier = new ClassifierWrapper(new weka.classifiers.functions.SMO()); try { classifier.buildClassifier(trainSet); classifier.setBuild(true); modelToBuild.setClassifier(classifier); modelToBuild.setStructure(trainSet.stringFreeStructure()); modelToBuild.setCellClassification(true); modelToBuild.setMipLayer(mipLayer); setProgress(85); // evaluation StringBuilder cnamesInfo = new StringBuilder( "Evaluation for object classification model with classes: "); for (int i = 0; i < modelToBuild.getClassShapes().size(); i++) { cnamesInfo.append(modelToBuild.getClassShapes().get(i).getName()); if (i < modelToBuild.getClassShapes().size() - 1) cnamesInfo.append(", "); } logger.info(cnamesInfo.toString()); Evaluation evaluation = new Evaluation(trainSet); evaluation.evaluateModel(classifier.getClassifier(), trainSet); logger.info(evaluation.toSummaryString()); if (evaluation.pctCorrect() < OrbitUtils.ACCURACY_WARNING) { String w = "Warning: The model classifies the training objects only with an accuracy of " + evaluation.pctCorrect() + "%.\nThat means that the marked objects are not diverse enough.\nYou might want to remove some marked objects and mark some more representative ones.\nHowever, you can still use this model if you want (check the object classification)."; logger.warn(w); if (withGUI && !ScaleoutMode.SCALEOUTMODE.get()) { JOptionPane.showMessageDialog(null, w, "Warning: Low accuracy", JOptionPane.WARNING_MESSAGE); } } } catch (Exception e) { classifier = null; logger.error("error training classifier: ", e); } logger.info("training done."); timeEst = 0L; setProgress(100); }
From source file:com.actelion.research.orbit.imageAnalysis.tasks.TrainWorker.java
License:Open Source License
private void trainClassifier() throws OrbitImageServletException { logger.debug("start trainClassifier"); if (modelToBuild != null && modelToBuild.getClassifier() != null) modelToBuild.getClassifier().setBuild(false); trainSet = null;//from w ww . j a v a 2 s . c o m List<double[]> trainData = new ArrayList<double[]>(); int mipLayer = -1; // used for checking if all iFrames (with trainData) have the same mapLayer (otherwise the model cannot be trained) for (ImageFrame iFrame : iFrames) { if (logger.isTraceEnabled()) logger.trace( iFrame.getTitle() + ": #ClassShapes: " + iFrame.recognitionFrame.getClassShapes().size()); for (int i = 0; i < iFrame.recognitionFrame.getClassShapes().size(); i++) { // checkPaused(); if (iFrame.recognitionFrame.getClassShapes().get(i).getShapeList().size() > 0) { // set and check mip level only for iFrames with shapes (training data) if (mipLayer < 0) { mipLayer = iFrame.getMipLayer(); logger.trace("iFrame candidate mipLayer {} from iFrame with width {}", mipLayer, iFrame.recognitionFrame.bimg.getWidth()); } else { if (mipLayer != iFrame.getMipLayer()) { logger.error( "Model cannot be trained on different image layers. Please use only training data of the same image layer."); return; } } } List<Shape> shapes = iFrame.recognitionFrame.getClassShapes().get(i).getShapeList(); trainData.addAll(getFeatures(shapes, i + 1, iFrame.recognitionFrame.bimg)); // classes 1.0, 2.0, ... } } logger.trace("train levelNum: {}", mipLayer); if (trainData.size() == 0) { logger.error("trainset is empty, classifier cannot be trained."); trainSet = null; return; } if (isCancelled()) { logger.debug("canceled"); cleanUp(); return; } timeEst = 1000 * 10L; setProgress(10); logger.debug("trainData contains " + trainData.size() + " samples"); // limit training instances if (trainData.size() > MAXINST) { Collections.shuffle(trainData, rand); trainData = trainData.subList(0, MAXINST); logger.debug("trainSet shirked to " + trainData.size() + " instances"); } Attribute classAttr = null; // create the first time a new trainSet. All further trainings will append new instances. if (trainSet == null) { // build traindata header double[] firstRow = trainData.get(0); ArrayList<Attribute> attrInfo = new ArrayList<Attribute>(firstRow.length); for (int a = 0; a < firstRow.length - 1; a++) { Attribute attr = new Attribute("a" + a); // if (a<firstRow.length-2) attr.setWeight(0.1d); else attr.setWeight(1.0d); attrInfo.add(attr); } List<String> classValues = new ArrayList<String>( iFrames.get(0).recognitionFrame.getClassShapes().size()); for (int i = 0; i < iFrames.get(0).recognitionFrame.getClassShapes().size(); i++) { classValues.add((i + 1) + ".0"); // "1.0", "2.0", ... } classAttr = new Attribute("class", classValues); attrInfo.add(classAttr); trainSet = new Instances("trainSet pattern classes", attrInfo, trainData.size()); trainSet.setClassIndex(firstRow.length - 1); } else classAttr = trainSet.attribute("class"); // add instances for (double[] vals : trainData) { double classV = classAttr.indexOfValue(Double.toString(vals[vals.length - 1])); vals[vals.length - 1] = classV; //Instance inst = new Instance(1.0d, vals); Instance inst = new DenseInstance(1.0d, vals); trainSet.add(inst); } trainSet = trainSet.resample(rand); logger.debug("trainSet contains " + trainSet.numInstances() + " instances"); // building classifier if (isCancelled()) { cleanUp(); return; } checkPaused(); timeEst = 1000 * 5L; setProgress(20); logger.info("Start training classifier... "); Classifier c; /* // experiments with deep learning... do not use in production. if (AparUtils.DEEPORBIT) { FeatureDescription fd = modelToBuild!=null? modelToBuild.getFeatureDescription(): new FeatureDescription(); TissueFeatures tissueFeaturre = AparUtils.createTissueFeatures(fd, null); int numOutNeurons = modelToBuild.getClassShapes().size(); int numInNeurons = tissueFeaturre.prepareDoubleArray().length-1; logger.debug("numNeuronsIn:"+numInNeurons+" numNeuronsOut:"+numOutNeurons); MultiLayerPerceptron neuralNet = new MultiLayerPerceptron(numInNeurons,100, numOutNeurons); for (int a=0; a<numOutNeurons; a++) { neuralNet.getOutputNeurons()[a].setLabel("class"+a); } neuralNet.connectInputsToOutputs(); MomentumBackpropagation mb = new MomentumBackpropagation(); mb.setLearningRate(0.2d); mb.setMomentum(0.7d); //mb.setMaxIterations(20); mb.setMaxError(0.12); neuralNet.setLearningRule(mb); c = new WekaNeurophClassifier(neuralNet); } else { c = new weka.classifiers.functions.SMO(); } */ c = new weka.classifiers.functions.SMO(); //weka.classifiers.functions.LibSVM c = new weka.classifiers.functions.LibSVM(); //Classifier c = new weka.classifiers.trees.J48(); classifier = new ClassifierWrapper(c); //classifier = new weka.classifiers.bayes.BayesNet(); //classifier = new weka.classifiers.functions.MultilayerPerceptron(); //((weka.classifiers.functions.SMO)classifier).setKernel(new weka.classifiers.functions.supportVector.RBFKernel()); try { classifier.buildClassifier(trainSet); classifier.setBuild(true); modelToBuild.setClassifier(classifier); modelToBuild.setStructure(trainSet.stringFreeStructure()); modelToBuild.setCellClassification(false); modelToBuild.setMipLayer(mipLayer); logger.debug("training done"); // evaluation StringBuilder cnamesInfo = new StringBuilder("Evaluation for model with classes: "); for (int i = 0; i < modelToBuild.getClassShapes().size(); i++) { cnamesInfo.append(modelToBuild.getClassShapes().get(i).getName()); if (i < modelToBuild.getClassShapes().size() - 1) cnamesInfo.append(", "); } logger.info(cnamesInfo.toString()); Evaluation evaluation = new Evaluation(trainSet); evaluation.evaluateModel(classifier.getClassifier(), trainSet); logger.info(evaluation.toSummaryString()); if (evaluation.pctCorrect() < OrbitUtils.ACCURACY_WARNING) { final String w = "Warning: The model classifies the training shapes only with an accuracy of " + evaluation.pctCorrect() + "%.\nThat means that the drawn class shapes are not diverse enough.\nYou might want to remove some class shapes and mark some more representative regions.\nHowever, you can still use this model if you want (check the classification)."; logger.warn(w); if (withGUI && !ScaleoutMode.SCALEOUTMODE.get()) { SwingUtilities.invokeLater(new Runnable() { @Override public void run() { JOptionPane.showMessageDialog(null, w, "Warning: Low accuracy", JOptionPane.WARNING_MESSAGE); } }); } } } catch (Exception e) { classifier = null; logger.error("error training classifier", e); } // logger.trace(classifier.toString()); }
From source file:com.daniel.convert.IncrementalClassifier.java
License:Open Source License
/** * Expects an ARFF file as first argument (class attribute is assumed to be * the last attribute).//from www .ja v a2 s . c om * * @param args * the commandline arguments * @throws Exception * if something goes wrong */ public static BayesNet treinar(String[] args) throws Exception { // load data ArffLoader loader = new ArffLoader(); loader.setFile(new File(args[0])); Instances structure = loader.getStructure(); structure.setClassIndex(structure.numAttributes() - 1); // train NaiveBayes BayesNet BayesNet = new BayesNet(); Instance current; while ((current = loader.getNextInstance(structure)) != null) { structure.add(current); } BayesNet.buildClassifier(structure); // output generated model // System.out.println(nb); // test set BayesNet BayesNetTest = new BayesNet(); // test the model Evaluation eTest = new Evaluation(structure); // eTest.evaluateModel(nb, structure); eTest.crossValidateModel(BayesNetTest, structure, 15, new Random(1)); // Print the result la Weka explorer: String strSummary = eTest.toSummaryString(); System.out.println(strSummary); return BayesNet; }
From source file:com.deafgoat.ml.prognosticator.AppClassifier.java
License:Apache License
/** * Perform cross-validation on data set/builds model * /* w w w . j a va 2 s. co m*/ * @throws Exception */ public void crossValidate() throws Exception { // stratify nominal target class if (_trainInstances.classAttribute().isNominal()) { _trainInstances.stratify(_folds); } _eval = new Evaluation(_trainInstances); for (int n = 0; n < _folds; n++) { if (_logger.isDebugEnabled()) { _logger.debug("Cross validation fold: " + (n + 1)); } _train = _trainInstances.trainCV(_folds, n); _test = _trainInstances.testCV(_folds, n); _clsCopy = AbstractClassifier.makeCopy(_cls); try { _clsCopy.buildClassifier(_train); } catch (Exception e) { _logger.debug(_config._classifier + " can not handle " + getAttributeType(_test.classAttribute()) + " class attributes"); } try { _eval.evaluateModel(_clsCopy, _test); } catch (Exception e) { _logger.debug("Can not evaluate model"); } } if (_config._writeToMongoDB) { _logger.info("Writing model to mongoDB"); // save the trained model saveModel(); // save CV performance of trained model writeToMongoDB(_eval); } if (_config._writeToFile) { _logger.info("Writing model to file"); SerializationHelper.write(_config._modelFile, _clsCopy); } }
From source file:com.deafgoat.ml.prognosticator.AppClassifier.java
License:Apache License
/** * Gets details on classified instances according to supplied attribute * /* w w w . j a v a 2 s . co m*/ * @param attribute * The focal attribute for error analysis * @throws Exception * If model can not be evaluated */ public void errorAnalysis(String attribute) throws Exception { readModel(); _logger.info("Performing error analysis"); Evaluation eval = new Evaluation(_testInstances); eval.evaluateModel(_cls, _testInstances); _predictionList = new HashMap<String, List<Prediction>>(); String predicted, actual = null; double[] distribution = null; _predictionList.put(_config._truePositives, new ArrayList<Prediction>()); _predictionList.put(_config._trueNegatives, new ArrayList<Prediction>()); _predictionList.put(_config._falsePositives, new ArrayList<Prediction>()); _predictionList.put(_config._falseNegatives, new ArrayList<Prediction>()); for (int i = 0; i < _testInstances.numInstances(); i++) { distribution = _cls.distributionForInstance(_testInstances.instance(i)); actual = _testInstances.classAttribute().value((int) _testInstances.instance(i).classValue()); predicted = _testInstances.classAttribute() .value((int) _cls.classifyInstance(_testInstances.instance(i))); // 0 is negative, 1 is positive if (!predicted.equals(actual)) { if (actual.equals(_config._negativeClassValue)) { _predictionList.get(_config._falsePositives) .add(new Prediction(i + 1, predicted, distribution, _fullData.instance(i))); } else if (actual.equals(_config._positiveClassValue)) { _predictionList.get(_config._falseNegatives) .add(new Prediction(i + 1, predicted, distribution, _fullData.instance(i))); } } else if (predicted.equals(actual)) { if (actual.equals(_config._negativeClassValue)) { _predictionList.get(_config._trueNegatives) .add(new Prediction(i + 1, predicted, distribution, _fullData.instance(i))); } else if (actual.equals(_config._positiveClassValue)) { _predictionList.get(_config._truePositives) .add(new Prediction(i + 1, predicted, distribution, _fullData.instance(i))); } } } BufferedWriter writer = null; String name, prediction = null; for (Entry<String, List<Prediction>> entry : _predictionList.entrySet()) { name = entry.getKey(); Collections.sort(_predictionList.get(name), Collections.reverseOrder()); writer = new BufferedWriter(new FileWriter(name)); List<Prediction> predictions = _predictionList.get(name); for (int count = 0; count < predictions.size(); count++) { if (count < _config._maxCount) { prediction = predictions.get(count).attributeDistribution(attribute); if (Double.parseDouble(prediction.split(_delimeter)[1]) >= _config._minProb) { writer.write(prediction + "\n"); } } else { break; } } writer.close(); } }
From source file:com.deafgoat.ml.prognosticator.AppClassifier.java
License:Apache License
/** * Evaluates model performance on test instances * /* www. j av a2s . com*/ * @throws Exception * If model can not be evaluated. */ public void evaluate() throws Exception { readModel(); _logger.info("Classifying with " + _config._classifier); Evaluation eval = new Evaluation(_testInstances); eval.evaluateModel(_cls, _testInstances); _logger.info("\n" + eval.toSummaryString()); try { _logger.info("\n" + eval.toClassDetailsString()); } catch (Exception e) { _logger.info("Can not create class details" + _config._classifier); } try { _logger.info("\n" + _eval.toMatrixString()); } catch (Exception e) { _logger.info( "Can not create confusion matrix for " + _config._classifier + " using " + _config._classValue); } }
From source file:com.edwardraff.WekaMNIST.java
License:Open Source License
private static void evalModel(Classifier wekaModel, Instances train, Instances test) throws Exception { long start;/*from www . j a va 2 s.co m*/ long end; System.gc(); start = System.currentTimeMillis(); wekaModel.buildClassifier(train); end = System.currentTimeMillis(); System.out.println("\tTraining took: " + (end - start) / 1000.0); System.gc(); Evaluation eval = new Evaluation(train); start = System.currentTimeMillis(); eval.evaluateModel(wekaModel, test); end = System.currentTimeMillis(); System.out.println( "\tEvaluation took " + (end - start) / 1000.0 + " seconds with an error rate " + eval.errorRate()); System.gc(); }
From source file:com.github.r351574nc3.amex.assignment2.App.java
License:Open Source License
/** * Tests/evaluates the trained model. This method assumes that {@link #train()} was previously called to assign a {@link LinearRegression} * classifier. If it wasn't, an exception will be thrown. * * @throws Exception if train wasn't called prior. *//*from w w w . j a va 2s . co m*/ public void test() throws Exception { if (getClassifier() == null) { throw new RuntimeException("Make sure train was run prior to this method call"); } final Evaluation eval = new Evaluation(getTrained()); eval.evaluateModel(getClassifier(), getTest()); info("%s", eval.toSummaryString("Results\n\n", false)); info("Percent of correctly classified instances: %s", eval.pctCorrect()); }
From source file:com.guidefreitas.locator.services.PredictionService.java
public Evaluation train() { try {/*from w w w.ja v a 2 s. com*/ String arffData = this.generateTrainData(); InputStream stream = new ByteArrayInputStream(arffData.getBytes(StandardCharsets.UTF_8)); DataSource source = new DataSource(stream); Instances data = source.getDataSet(); data.setClassIndex(data.numAttributes() - 1); this.classifier = new LibSVM(); this.classifier.setKernelType(new SelectedTag(LibSVM.KERNELTYPE_POLYNOMIAL, LibSVM.TAGS_KERNELTYPE)); this.classifier.setSVMType(new SelectedTag(LibSVM.SVMTYPE_C_SVC, LibSVM.TAGS_SVMTYPE)); Evaluation eval = new Evaluation(data); eval.crossValidateModel(this.classifier, data, 10, new Random(1)); this.classifier.buildClassifier(data); return eval; } catch (Exception ex) { Logger.getLogger(PredictionService.class.getName()).log(Level.SEVERE, null, ex); } return null; }
From source file:com.ivanrf.smsspam.SpamClassifier.java
License:Apache License
public static void evaluate(int wordsToKeep, String tokenizerOp, boolean useAttributeSelection, String classifierOp, boolean boosting, JTextArea log) { try {/*from w w w . ja v a2s . c o m*/ long start = System.currentTimeMillis(); String modelName = getModelName(wordsToKeep, tokenizerOp, useAttributeSelection, classifierOp, boosting); showEstimatedTime(false, modelName, log); Instances trainData = loadDataset("SMSSpamCollection.arff", log); trainData.setClassIndex(0); FilteredClassifier classifier = initFilterClassifier(wordsToKeep, tokenizerOp, useAttributeSelection, classifierOp, boosting); publishEstado("=== Performing cross-validation ===", log); Evaluation eval = new Evaluation(trainData); // eval.evaluateModel(classifier, trainData); eval.crossValidateModel(classifier, trainData, 10, new Random(1)); publishEstado(eval.toSummaryString(), log); publishEstado(eval.toClassDetailsString(), log); publishEstado(eval.toMatrixString(), log); publishEstado("=== Evaluation finished ===", log); publishEstado("Elapsed time: " + Utils.getDateHsMinSegString(System.currentTimeMillis() - start), log); } catch (Exception e) { e.printStackTrace(); publishEstado("Error found when evaluating", log); } }