Example usage for weka.classifiers Evaluation setPriors

List of usage examples for weka.classifiers Evaluation setPriors

Introduction

In this page you can find the example usage for weka.classifiers Evaluation setPriors.

Prototype

public void setPriors(Instances train) throws Exception 

Source Link

Document

Sets the class prior probabilities.

Usage

From source file:adams.multiprocess.WekaCrossValidationExecution.java

License:Open Source License

/**
 * Executes the flow item./*  w w w  . j  a v  a  2  s. co  m*/
 *
 * @return      null if everything is fine, otherwise error message
 */
public String execute() {
    MessageCollection result;
    Evaluation eval;
    AggregateEvaluations evalAgg;
    int folds;
    CrossValidationFoldGenerator generator;
    JobList<WekaCrossValidationJob> list;
    WekaCrossValidationJob job;
    WekaTrainTestSetContainer cont;
    int i;
    int current;
    int[] indices;
    Instances train;
    Instances test;
    Classifier cls;

    result = new MessageCollection();
    indices = null;
    m_Evaluation = null;
    m_Evaluations = null;

    try {
        // evaluate classifier
        if (m_Classifier == null)
            throw new IllegalStateException("Classifier '" + getClassifier() + "' not found!");
        if (isLoggingEnabled())
            getLogger().info(OptionUtils.getCommandLine(m_Classifier));

        m_ActualNumThreads = Performance.determineNumThreads(m_NumThreads);

        generator = (CrossValidationFoldGenerator) OptionUtils.shallowCopy(m_Generator);
        generator.setData(m_Data);
        generator.setNumFolds(m_Folds);
        generator.setSeed(m_Seed);
        generator.setStratify(true);
        generator.setUseViews(m_UseViews);
        generator.initializeIterator();
        folds = generator.getActualNumFolds();
        if ((m_ActualNumThreads == 1) && !m_SeparateFolds) {
            initOutputBuffer();
            if (m_Output != null) {
                m_Output.setHeader(m_Data);
                m_Output.printHeader();
            }
            eval = new Evaluation(m_Data);
            eval.setDiscardPredictions(m_DiscardPredictions);
            current = 0;
            while (generator.hasNext()) {
                if (isStopped())
                    break;
                if (m_StatusMessageHandler != null)
                    m_StatusMessageHandler.showStatus("Fold " + current + "/" + folds + ": '"
                            + m_Data.relationName() + "' using " + OptionUtils.getCommandLine(m_Classifier));
                cont = generator.next();
                train = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN);
                test = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST);
                cls = (Classifier) OptionUtils.shallowCopy(m_Classifier);
                cls.buildClassifier(train);
                eval.setPriors(train);
                eval.evaluateModel(cls, test, m_Output);
                current++;
            }
            if (m_Output != null)
                m_Output.printFooter();
            if (!isStopped())
                m_Evaluation = eval;
        } else {
            if (m_DiscardPredictions)
                throw new IllegalStateException(
                        "Cannot discard predictions in parallel mode, as they are used for aggregating the statistics!");
            if (m_JobRunnerSetup == null)
                m_JobRunner = new LocalJobRunner<WekaCrossValidationJob>();
            else
                m_JobRunner = m_JobRunnerSetup.newInstance();
            if (m_JobRunner instanceof ThreadLimiter)
                ((ThreadLimiter) m_JobRunner).setNumThreads(m_NumThreads);
            list = new JobList<>();
            while (generator.hasNext()) {
                cont = generator.next();
                job = new WekaCrossValidationJob((Classifier) OptionUtils.shallowCopy(m_Classifier),
                        (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN),
                        (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST),
                        (Integer) cont.getValue(WekaTrainTestSetContainer.VALUE_FOLD_NUMBER),
                        m_DiscardPredictions, m_StatusMessageHandler);
                list.add(job);
            }
            m_JobRunner.add(list);
            m_JobRunner.start();
            m_JobRunner.stop();
            // aggregate data
            if (!isStopped()) {
                evalAgg = new AggregateEvaluations();
                m_Evaluations = new Evaluation[m_JobRunner.getJobs().size()];
                for (i = 0; i < m_JobRunner.getJobs().size(); i++) {
                    job = (WekaCrossValidationJob) m_JobRunner.getJobs().get(i);
                    if (job.getEvaluation() == null) {
                        result.add("Fold #" + (i + 1) + " failed to evaluate"
                                + (job.hasExecutionError() ? job.getExecutionError() : "?"));
                        break;
                    }
                    evalAgg.add(job.getEvaluation());
                    m_Evaluations[i] = job.getEvaluation();
                    job.cleanUp();
                }
                m_Evaluation = evalAgg.aggregated();
                if (m_Evaluation == null) {
                    if (evalAgg.hasLastError())
                        result.add(evalAgg.getLastError());
                    else
                        result.add("Failed to aggregate evaluations!");
                }
            }
            list.cleanUp();
            m_JobRunner.cleanUp();
            m_JobRunner = null;
        }

        if (!m_DiscardPredictions)
            indices = generator.crossValidationIndices();
    } catch (Exception e) {
        result.add(Utils.handleException(this, "Failed to cross-validate classifier: ", e));
    }

    m_OriginalIndices = indices;

    if (result.isEmpty())
        return null;
    else
        return result.toString();
}