Example usage for weka.classifiers Evaluation setDiscardPredictions

List of usage examples for weka.classifiers Evaluation setDiscardPredictions

Introduction

In this page you can find the example usage for weka.classifiers Evaluation setDiscardPredictions.

Prototype

public void setDiscardPredictions(boolean value) 

Source Link

Document

Sets whether to discard predictions, ie, not storing them for future reference via predictions() method in order to conserve memory.

Usage

From source file:adams.flow.transformer.WekaTestSetEvaluator.java

License:Open Source License

/**
 * Executes the flow item./*from   w  ww  . ja  va2 s .co m*/
 *
 * @return      null if everything is fine, otherwise error message
 */
@Override
protected String doExecute() {
    String result;
    Instances test;
    Evaluation eval;
    weka.classifiers.Classifier cls;
    CallableSource gs;
    Token output;

    result = null;
    test = null;

    try {
        // get test set
        test = null;
        gs = new CallableSource();
        gs.setCallableName(m_Testset);
        gs.setParent(getParent());
        gs.setUp();
        gs.execute();
        output = gs.output();
        if (output != null)
            test = (Instances) output.getPayload();
        else
            result = "No test set available!";
        gs.wrapUp();

        // evaluate classifier
        if (result == null) {
            if (m_InputToken.getPayload() instanceof weka.classifiers.Classifier)
                cls = (weka.classifiers.Classifier) m_InputToken.getPayload();
            else
                cls = (weka.classifiers.Classifier) ((WekaModelContainer) m_InputToken.getPayload())
                        .getValue(WekaModelContainer.VALUE_MODEL);
            initOutputBuffer();
            m_Output.setHeader(test);
            eval = new Evaluation(test);
            eval.setDiscardPredictions(m_DiscardPredictions);
            eval.evaluateModel(cls, test, m_Output);

            // broadcast result
            if (m_Output instanceof Null) {
                m_OutputToken = new Token(new WekaEvaluationContainer(eval, cls));
            } else {
                if (m_AlwaysUseContainer)
                    m_OutputToken = new Token(
                            new WekaEvaluationContainer(eval, cls, m_Output.getBuffer().toString()));
                else
                    m_OutputToken = new Token(m_Output.getBuffer().toString());
            }
        }
    } catch (Exception e) {
        m_OutputToken = null;
        result = handleException("Failed to evaluate: ", e);
    }

    if (m_OutputToken != null) {
        if (m_OutputToken.getPayload() instanceof WekaEvaluationContainer) {
            if (test != null)
                ((WekaEvaluationContainer) m_OutputToken.getPayload())
                        .setValue(WekaEvaluationContainer.VALUE_TESTDATA, test);
        }
        updateProvenance(m_OutputToken);
    }

    return result;
}

From source file:adams.flow.transformer.WekaTrainTestSetEvaluator.java

License:Open Source License

/**
 * Executes the flow item./*from ww w.  j  a va 2 s .  co m*/
 *
 * @return      null if everything is fine, otherwise error message
 */
@Override
protected String doExecute() {
    String result;
    Instances train;
    Instances test;
    weka.classifiers.Classifier cls;
    Evaluation eval;
    WekaTrainTestSetContainer cont;

    result = null;
    test = null;

    try {
        // cross-validate classifier
        cls = getClassifierInstance();
        if (cls == null)
            throw new IllegalStateException("Classifier '" + getClassifier() + "' not found!");

        cont = (WekaTrainTestSetContainer) m_InputToken.getPayload();
        train = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN);
        test = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST);
        cls.buildClassifier(train);
        initOutputBuffer();
        m_Output.setHeader(train);
        eval = new Evaluation(train);
        eval.setDiscardPredictions(m_DiscardPredictions);
        eval.evaluateModel(cls, test, m_Output);

        // broadcast result
        if (m_Output instanceof Null) {
            m_OutputToken = new Token(new WekaEvaluationContainer(eval, cls));
        } else {
            if (m_AlwaysUseContainer)
                m_OutputToken = new Token(
                        new WekaEvaluationContainer(eval, cls, m_Output.getBuffer().toString()));
            else
                m_OutputToken = new Token(m_Output.getBuffer().toString());
        }
    } catch (Exception e) {
        m_OutputToken = null;
        result = handleException("Failed to evaluate: ", e);
    }

    if (m_OutputToken != null) {
        if (m_OutputToken.getPayload() instanceof WekaEvaluationContainer) {
            if (test != null)
                ((WekaEvaluationContainer) m_OutputToken.getPayload())
                        .setValue(WekaEvaluationContainer.VALUE_TESTDATA, test);
        }
        updateProvenance(m_OutputToken);
    }

    return result;
}

From source file:adams.multiprocess.WekaCrossValidationExecution.java

License:Open Source License

/**
 * Executes the flow item./*from  w w w  . ja va 2  s  . co m*/
 *
 * @return      null if everything is fine, otherwise error message
 */
public String execute() {
    MessageCollection result;
    Evaluation eval;
    AggregateEvaluations evalAgg;
    int folds;
    CrossValidationFoldGenerator generator;
    JobList<WekaCrossValidationJob> list;
    WekaCrossValidationJob job;
    WekaTrainTestSetContainer cont;
    int i;
    int current;
    int[] indices;
    Instances train;
    Instances test;
    Classifier cls;

    result = new MessageCollection();
    indices = null;
    m_Evaluation = null;
    m_Evaluations = null;

    try {
        // evaluate classifier
        if (m_Classifier == null)
            throw new IllegalStateException("Classifier '" + getClassifier() + "' not found!");
        if (isLoggingEnabled())
            getLogger().info(OptionUtils.getCommandLine(m_Classifier));

        m_ActualNumThreads = Performance.determineNumThreads(m_NumThreads);

        generator = (CrossValidationFoldGenerator) OptionUtils.shallowCopy(m_Generator);
        generator.setData(m_Data);
        generator.setNumFolds(m_Folds);
        generator.setSeed(m_Seed);
        generator.setStratify(true);
        generator.setUseViews(m_UseViews);
        generator.initializeIterator();
        folds = generator.getActualNumFolds();
        if ((m_ActualNumThreads == 1) && !m_SeparateFolds) {
            initOutputBuffer();
            if (m_Output != null) {
                m_Output.setHeader(m_Data);
                m_Output.printHeader();
            }
            eval = new Evaluation(m_Data);
            eval.setDiscardPredictions(m_DiscardPredictions);
            current = 0;
            while (generator.hasNext()) {
                if (isStopped())
                    break;
                if (m_StatusMessageHandler != null)
                    m_StatusMessageHandler.showStatus("Fold " + current + "/" + folds + ": '"
                            + m_Data.relationName() + "' using " + OptionUtils.getCommandLine(m_Classifier));
                cont = generator.next();
                train = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN);
                test = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST);
                cls = (Classifier) OptionUtils.shallowCopy(m_Classifier);
                cls.buildClassifier(train);
                eval.setPriors(train);
                eval.evaluateModel(cls, test, m_Output);
                current++;
            }
            if (m_Output != null)
                m_Output.printFooter();
            if (!isStopped())
                m_Evaluation = eval;
        } else {
            if (m_DiscardPredictions)
                throw new IllegalStateException(
                        "Cannot discard predictions in parallel mode, as they are used for aggregating the statistics!");
            if (m_JobRunnerSetup == null)
                m_JobRunner = new LocalJobRunner<WekaCrossValidationJob>();
            else
                m_JobRunner = m_JobRunnerSetup.newInstance();
            if (m_JobRunner instanceof ThreadLimiter)
                ((ThreadLimiter) m_JobRunner).setNumThreads(m_NumThreads);
            list = new JobList<>();
            while (generator.hasNext()) {
                cont = generator.next();
                job = new WekaCrossValidationJob((Classifier) OptionUtils.shallowCopy(m_Classifier),
                        (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN),
                        (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST),
                        (Integer) cont.getValue(WekaTrainTestSetContainer.VALUE_FOLD_NUMBER),
                        m_DiscardPredictions, m_StatusMessageHandler);
                list.add(job);
            }
            m_JobRunner.add(list);
            m_JobRunner.start();
            m_JobRunner.stop();
            // aggregate data
            if (!isStopped()) {
                evalAgg = new AggregateEvaluations();
                m_Evaluations = new Evaluation[m_JobRunner.getJobs().size()];
                for (i = 0; i < m_JobRunner.getJobs().size(); i++) {
                    job = (WekaCrossValidationJob) m_JobRunner.getJobs().get(i);
                    if (job.getEvaluation() == null) {
                        result.add("Fold #" + (i + 1) + " failed to evaluate"
                                + (job.hasExecutionError() ? job.getExecutionError() : "?"));
                        break;
                    }
                    evalAgg.add(job.getEvaluation());
                    m_Evaluations[i] = job.getEvaluation();
                    job.cleanUp();
                }
                m_Evaluation = evalAgg.aggregated();
                if (m_Evaluation == null) {
                    if (evalAgg.hasLastError())
                        result.add(evalAgg.getLastError());
                    else
                        result.add("Failed to aggregate evaluations!");
                }
            }
            list.cleanUp();
            m_JobRunner.cleanUp();
            m_JobRunner = null;
        }

        if (!m_DiscardPredictions)
            indices = generator.crossValidationIndices();
    } catch (Exception e) {
        result.add(Utils.handleException(this, "Failed to cross-validate classifier: ", e));
    }

    m_OriginalIndices = indices;

    if (result.isEmpty())
        return null;
    else
        return result.toString();
}