List of usage examples for weka.classifiers Evaluation setPriors
public void setPriors(Instances train) throws Exception
From source file:adams.multiprocess.WekaCrossValidationExecution.java
License:Open Source License
/** * Executes the flow item./* w w w . j a v a 2 s. co m*/ * * @return null if everything is fine, otherwise error message */ public String execute() { MessageCollection result; Evaluation eval; AggregateEvaluations evalAgg; int folds; CrossValidationFoldGenerator generator; JobList<WekaCrossValidationJob> list; WekaCrossValidationJob job; WekaTrainTestSetContainer cont; int i; int current; int[] indices; Instances train; Instances test; Classifier cls; result = new MessageCollection(); indices = null; m_Evaluation = null; m_Evaluations = null; try { // evaluate classifier if (m_Classifier == null) throw new IllegalStateException("Classifier '" + getClassifier() + "' not found!"); if (isLoggingEnabled()) getLogger().info(OptionUtils.getCommandLine(m_Classifier)); m_ActualNumThreads = Performance.determineNumThreads(m_NumThreads); generator = (CrossValidationFoldGenerator) OptionUtils.shallowCopy(m_Generator); generator.setData(m_Data); generator.setNumFolds(m_Folds); generator.setSeed(m_Seed); generator.setStratify(true); generator.setUseViews(m_UseViews); generator.initializeIterator(); folds = generator.getActualNumFolds(); if ((m_ActualNumThreads == 1) && !m_SeparateFolds) { initOutputBuffer(); if (m_Output != null) { m_Output.setHeader(m_Data); m_Output.printHeader(); } eval = new Evaluation(m_Data); eval.setDiscardPredictions(m_DiscardPredictions); current = 0; while (generator.hasNext()) { if (isStopped()) break; if (m_StatusMessageHandler != null) m_StatusMessageHandler.showStatus("Fold " + current + "/" + folds + ": '" + m_Data.relationName() + "' using " + OptionUtils.getCommandLine(m_Classifier)); cont = generator.next(); train = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN); test = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST); cls = (Classifier) OptionUtils.shallowCopy(m_Classifier); cls.buildClassifier(train); eval.setPriors(train); eval.evaluateModel(cls, test, m_Output); current++; } if (m_Output != null) m_Output.printFooter(); if (!isStopped()) m_Evaluation = eval; } else { if (m_DiscardPredictions) throw new IllegalStateException( "Cannot discard predictions in parallel mode, as they are used for aggregating the statistics!"); if (m_JobRunnerSetup == null) m_JobRunner = new LocalJobRunner<WekaCrossValidationJob>(); else m_JobRunner = m_JobRunnerSetup.newInstance(); if (m_JobRunner instanceof ThreadLimiter) ((ThreadLimiter) m_JobRunner).setNumThreads(m_NumThreads); list = new JobList<>(); while (generator.hasNext()) { cont = generator.next(); job = new WekaCrossValidationJob((Classifier) OptionUtils.shallowCopy(m_Classifier), (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN), (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST), (Integer) cont.getValue(WekaTrainTestSetContainer.VALUE_FOLD_NUMBER), m_DiscardPredictions, m_StatusMessageHandler); list.add(job); } m_JobRunner.add(list); m_JobRunner.start(); m_JobRunner.stop(); // aggregate data if (!isStopped()) { evalAgg = new AggregateEvaluations(); m_Evaluations = new Evaluation[m_JobRunner.getJobs().size()]; for (i = 0; i < m_JobRunner.getJobs().size(); i++) { job = (WekaCrossValidationJob) m_JobRunner.getJobs().get(i); if (job.getEvaluation() == null) { result.add("Fold #" + (i + 1) + " failed to evaluate" + (job.hasExecutionError() ? job.getExecutionError() : "?")); break; } evalAgg.add(job.getEvaluation()); m_Evaluations[i] = job.getEvaluation(); job.cleanUp(); } m_Evaluation = evalAgg.aggregated(); if (m_Evaluation == null) { if (evalAgg.hasLastError()) result.add(evalAgg.getLastError()); else result.add("Failed to aggregate evaluations!"); } } list.cleanUp(); m_JobRunner.cleanUp(); m_JobRunner = null; } if (!m_DiscardPredictions) indices = generator.crossValidationIndices(); } catch (Exception e) { result.add(Utils.handleException(this, "Failed to cross-validate classifier: ", e)); } m_OriginalIndices = indices; if (result.isEmpty()) return null; else return result.toString(); }