Example usage for weka.classifiers.trees REPTree setMaxDepth

List of usage examples for weka.classifiers.trees REPTree setMaxDepth

Introduction

In this page you can find the example usage for weka.classifiers.trees REPTree setMaxDepth.

Prototype

public void setMaxDepth(int newMaxDepth) 

Source Link

Document

Set the value of MaxDepth.

Usage

From source file:lu.lippmann.cdb.dt.RegressionTreeFactory.java

License:Open Source License

/**
 * Main method.//  ww w  . java 2 s .  co m
 * @param args command line arguments
 */
public static void main(final String[] args) {
    try {
        final String f = "./samples/csv/uci/winequality-red.csv";
        //final String f="./samples/arff/UCI/crimepredict.arff";
        final Instances dataSet = WekaDataAccessUtil.loadInstancesFromARFFOrCSVFile(new File(f));
        System.out.println(dataSet.classAttribute().isNumeric());

        final REPTree rt = new REPTree();
        rt.setMaxDepth(3);
        rt.buildClassifier(dataSet);

        System.out.println(rt);

        //System.out.println(rt.graph());

        final GraphWithOperations gwo = GraphUtil.buildGraphWithOperationsFromWekaRegressionString(rt.graph());
        System.out.println(gwo);
        System.out.println(new ASCIIGraphDsl().getDslString(gwo));

        final Evaluation eval = new Evaluation(dataSet);

        /*Field privateStringField = Evaluation.class.getDeclaredField("m_CoverageStatisticsAvailable");
        privateStringField.setAccessible(true);
        //privateStringField.get
        boolean fieldValue = privateStringField.getBoolean(eval);
        System.out.println("fieldValue = " + fieldValue);*/

        double[] d = eval.evaluateModel(rt, dataSet);
        System.out.println("PREDICTED -> " + FormatterUtil.buildStringFromArrayOfDoubles(d));

        System.out.println(eval.errorRate());
        System.out.println(eval.sizeOfPredictedRegions());

        System.out.println(eval.toSummaryString("", true));

        /*final String f2="./samples/csv/salary.csv";
        final Instances dataSet2=WekaDataAccessUtil.loadInstancesFromARFFOrCSVFile(new File(f2));
                
        final J48 j48=new J48();
        j48.buildClassifier(dataSet2);
        System.out.println(j48.graph());
        final GraphWithOperations gwo2=GraphUtil.buildGraphWithOperationsFromWekaString(j48.graph(),false);
        System.out.println(gwo2);*/

        System.out.println(new DecisionTree(gwo, eval.errorRate()));
    } catch (Exception e) {
        e.printStackTrace();
    }
}

From source file:org.openml.webapplication.fantail.dc.landmarking.REPTreeBasedLandmarker.java

License:Open Source License

public Map<String, Double> characterize(Instances data) {

    int numFolds = m_NumFolds;

    double score1 = 0.5;
    double score2 = 0.5;
    // double score3 = 0.5;

    double score3 = 0.5;
    double score4 = 0.5;
    // double score3 = 0.5;

    double score5 = 0.5;
    double score6 = 0.5;

    double score7 = 0.5;
    double score8 = 0.5;
    double score9 = 0.5;

    weka.classifiers.trees.REPTree cls = new weka.classifiers.trees.REPTree();
    cls.setMaxDepth(1);

    try {//from www.jav  a 2  s.co  m

        weka.classifiers.Evaluation eval = new weka.classifiers.Evaluation(data);
        eval.crossValidateModel(cls, data, numFolds, new java.util.Random(1));

        score1 = eval.pctIncorrect();
        score2 = eval.weightedAreaUnderROC();

        score7 = eval.kappa();

    } catch (Exception e) {
        e.printStackTrace();
    }

    //
    cls = new weka.classifiers.trees.REPTree();
    cls.setMaxDepth(2);

    try {

        weka.classifiers.Evaluation eval = new weka.classifiers.Evaluation(data);
        eval.crossValidateModel(cls, data, numFolds, new java.util.Random(1));

        score3 = eval.pctIncorrect();
        score4 = eval.weightedAreaUnderROC();

        score8 = eval.kappa();

    } catch (Exception e) {
        e.printStackTrace();
    }

    //
    cls = new weka.classifiers.trees.REPTree();
    cls.setMaxDepth(3);

    try {

        weka.classifiers.Evaluation eval = new weka.classifiers.Evaluation(data);
        eval.crossValidateModel(cls, data, numFolds, new java.util.Random(1));

        score5 = eval.pctIncorrect();
        score6 = eval.weightedAreaUnderROC();

        score9 = eval.kappa();

    } catch (Exception e) {
        e.printStackTrace();
    }

    Map<String, Double> qualities = new HashMap<String, Double>();
    qualities.put(ids[0], score1);
    qualities.put(ids[1], score2);
    qualities.put(ids[2], score3);
    qualities.put(ids[3], score4);
    qualities.put(ids[4], score5);
    qualities.put(ids[5], score6);
    qualities.put(ids[6], score7);
    qualities.put(ids[7], score8);
    qualities.put(ids[8], score9);
    return qualities;
}

From source file:statechum.analysis.learning.experiments.PairSelection.PairQualityLearner.java

License:Open Source License

@SuppressWarnings("null")
public static void runExperiment() throws Exception {
    DrawGraphs gr = new DrawGraphs();
    Configuration config = Configuration.getDefaultConfiguration().copy();
    config.setAskQuestions(false);/*ww w . j  a va  2s .  c  om*/
    config.setDebugMode(false);
    config.setGdLowToHighRatio(0.7);
    config.setRandomPathAttemptFudgeThreshold(1000);
    config.setTransitionMatrixImplType(STATETREE.STATETREE_LINKEDHASH);
    ConvertALabel converter = new Transform.InternStringLabel();
    //gr_NewToOrig.setLimit(7000);
    GlobalConfiguration.getConfiguration().setProperty(G_PROPERTIES.LINEARWARNINGS, "false");
    final int ThreadNumber = ExperimentRunner.getCpuNumber();

    ExecutorService executorService = Executors.newFixedThreadPool(ThreadNumber);
    final int minStateNumber = 20;
    final int samplesPerFSM = 4;
    final int rangeOfStateNumbers = 4;
    final int stateNumberIncrement = 4;
    final double trainingDataMultiplier = 2;
    // Stores tasks to complete.
    CompletionService<ThreadResult> runner = new ExecutorCompletionService<ThreadResult>(executorService);
    for (final int lengthMultiplier : new int[] { 50 })
        for (final int ifDepth : new int[] { 1 })
            for (final boolean onlyPositives : new boolean[] { true }) {
                final int traceQuantity = 1;
                for (final boolean useUnique : new boolean[] { false }) {
                    String selection = "TRUNK;TRAINING;" + "ifDepth=" + ifDepth + ";onlyPositives="
                            + onlyPositives + ";useUnique=" + useUnique + ";traceQuantity=" + traceQuantity
                            + ";lengthMultiplier=" + lengthMultiplier + ";trainingDataMultiplier="
                            + trainingDataMultiplier + ";";

                    WekaDataCollector dataCollector = createDataCollector(ifDepth);
                    List<SampleData> samples = new LinkedList<SampleData>();
                    try {
                        int numberOfTasks = 0;
                        for (int states = minStateNumber; states < minStateNumber
                                + rangeOfStateNumbers; states += stateNumberIncrement)
                            for (int sample = 0; sample < Math
                                    .round(samplesPerFSM * trainingDataMultiplier); ++sample) {
                                LearnerRunner learnerRunner = new LearnerRunner(dataCollector, states, sample,
                                        1 + numberOfTasks, traceQuantity, config, converter) {
                                    @Override
                                    public LearnerThatCanClassifyPairs createLearner(
                                            LearnerEvaluationConfiguration evalCnf,
                                            LearnerGraph argReferenceGraph, WekaDataCollector argDataCollector,
                                            LearnerGraph argInitialPTA) {
                                        return new LearnerThatUpdatesWekaResults(evalCnf, argReferenceGraph,
                                                argDataCollector, argInitialPTA);
                                    }
                                };
                                learnerRunner.setPickUniqueFromInitial(useUnique);
                                learnerRunner.setOnlyUsePositives(onlyPositives);
                                learnerRunner.setIfdepth(ifDepth);
                                learnerRunner.setLengthMultiplier(lengthMultiplier);
                                learnerRunner
                                        .setSelectionID(selection + "_states" + states + "_sample" + sample);
                                runner.submit(learnerRunner);
                                ++numberOfTasks;
                            }
                        ProgressIndicator progress = new ProgressIndicator(
                                "running " + numberOfTasks + " tasks for " + selection, numberOfTasks);
                        for (int count = 0; count < numberOfTasks; ++count) {
                            ThreadResult result = runner.take().get();// this will throw an exception if any of the tasks failed.
                            samples.addAll(result.samples);
                            progress.next();
                        }
                    } catch (Exception ex) {
                        IllegalArgumentException e = new IllegalArgumentException(
                                "failed to compute, the problem is: " + ex);
                        e.initCause(ex);
                        if (executorService != null) {
                            executorService.shutdown();
                            executorService = null;
                        }
                        throw e;
                    }

                    int nonZeroes = 0;
                    long numberOfValues = 0;
                    System.out.println("number of instances: " + dataCollector.trainingData.numInstances());
                    int freqData[] = new int[dataCollector.attributesOfAnInstance.length];
                    for (int i = 0; i < dataCollector.trainingData.numInstances(); ++i)
                        for (int attrNum = 0; attrNum < dataCollector.attributesOfAnInstance.length; ++attrNum) {
                            assert dataCollector.attributesOfAnInstance[attrNum].index() == attrNum;
                            if (dataCollector.trainingData.instance(i)
                                    .stringValue(attrNum) != WekaDataCollector.ZERO) {
                                ++freqData[attrNum];
                                ++numberOfValues;
                            }
                        }
                    for (int attrNum = 0; attrNum < dataCollector.attributesOfAnInstance.length; ++attrNum)
                        if (freqData[attrNum] > 0)
                            ++nonZeroes;

                    System.out.println("Total instances: " + dataCollector.trainingData.numInstances()
                            + " with " + dataCollector.attributesOfAnInstance.length
                            + " attributes, non-zeroes are " + nonZeroes + " with average of "
                            + ((double) numberOfValues) / nonZeroes);
                    Arrays.sort(freqData);
                    int numOfcolumns = 20;
                    int stepWidth = dataCollector.attributesOfAnInstance.length / numOfcolumns;

                    final RBoxPlot<Long> gr_HistogramOfAttributeValues = new RBoxPlot<Long>("Attributes",
                            "Number of values", new File("attributes_use" + selection + ".pdf"));
                    for (int i = 0; i < numOfcolumns; ++i) {
                        int columnData = 0;
                        for (int j = i * stepWidth; j < (i + 1) * stepWidth; ++j)
                            if (j < dataCollector.attributesOfAnInstance.length)
                                columnData += freqData[j];

                        gr_HistogramOfAttributeValues.add(new Long(numOfcolumns - i),
                                new Double(columnData > 0 ? Math.log10(columnData) : 0));
                    }
                    //gr_HistogramOfAttributeValues.drawInteractive(gr);
                    gr_HistogramOfAttributeValues.drawPdf(gr);
                    /*
                    // write arff
                    FileWriter wekaInstances = null;
                    String whereToWrite = "qualityLearner_"+selection+".arff";
                    try
                    {
                       wekaInstances = new FileWriter(whereToWrite);
                       // This chunk is almost verbatim from Weka's Instances.toString()
                       wekaInstances.append(Instances.ARFF_RELATION).append(" ").append(Utils.quote(dataCollector.trainingData.relationName())).append("\n\n");
                        for (int i = 0; i < dataCollector.trainingData.numAttributes(); i++) {
                           wekaInstances.append(dataCollector.trainingData.attribute(i).toString()).append("\n");
                        }
                        wekaInstances.append("\n").append(Instances.ARFF_DATA).append("\n");
                        for (int i = 0; i < dataCollector.trainingData.numInstances(); i++) {
                           wekaInstances.append(dataCollector.trainingData.instance(i).toString());
                            if (i < dataCollector.trainingData.numInstances() - 1) {
                               wekaInstances.append('\n');
                            }
                          }
                    }
                    catch(Exception ex)
                    {
                       Helper.throwUnchecked("failed to create a file with training data for "+whereToWrite, ex);
                    }
                    finally
                    {
                       if (wekaInstances != null)
                          try {
                             wekaInstances.close();
                          } catch (IOException e) {
                             // ignore this, we are not proceeding anyway due to an earlier exception so whether the file was actually written does not matter
                          }
                    }
                    */
                    // Run the evaluation
                    final weka.classifiers.trees.REPTree repTree = new weka.classifiers.trees.REPTree();
                    repTree.setMaxDepth(4);
                    //repTree.setNoPruning(true);// since we only use the tree as a classifier (as a conservative extension of what is currently done) and do not actually look at it, elimination of pruning is not a problem. 
                    // As part of learning, we also prune some of the nodes where the ratio of correctly-classified pairs to those incorrectly classified is comparable.
                    // The significant advantage of not pruning is that the result is no longer sensitive to the order of elements in the tree and hence does not depend on the order in which elements have been obtained by concurrent threads.
                    //final weka.classifiers.lazy.IB1 ib1 = new weka.classifiers.lazy.IB1();
                    //final weka.classifiers.trees.J48 classifier = new weka.classifiers.trees.J48();
                    final Classifier classifier = repTree;
                    classifier.buildClassifier(dataCollector.trainingData);
                    System.out
                            .println("Entries in the classifier: " + dataCollector.trainingData.numInstances());
                    System.out.println(classifier);
                    dataCollector = null;// throw all the training data away.

                    {// serialise the classifier, this is the only way to store it.
                        OutputStream os = new FileOutputStream(selection + ".ser");
                        ObjectOutputStream oo = new ObjectOutputStream(os);
                        oo.writeObject(classifier);
                        os.close();
                    }

                    for (final boolean selectingRed : new boolean[] { false })
                        for (final boolean classifierToBlockAllMergers : new boolean[] { true })
                            //for(final boolean zeroScoringAsRed:(classifierToBlockAllMergers?new boolean[]{true,false}:new boolean[]{false}))// where we are not using classifier to rule out all mergers proposed by pair selection, it does not make sense to use two values configuring this classifier.
                            for (final double threshold : new double[] { 1 }) {
                                final boolean zeroScoringAsRed = false;
                                selection = "TRUNK;EVALUATION;" + "ifDepth=" + ifDepth + ";threshold="
                                        + threshold + // ";useUnique="+useUnique+";onlyPositives="+onlyPositives+
                                        ";selectingRed=" + selectingRed + ";classifierToBlockAllMergers="
                                        + classifierToBlockAllMergers + ";zeroScoringAsRed=" + zeroScoringAsRed
                                        + ";traceQuantity=" + traceQuantity + ";lengthMultiplier="
                                        + lengthMultiplier + ";trainingDataMultiplier=" + trainingDataMultiplier
                                        + ";";

                                final int totalTaskNumber = traceQuantity;
                                final RBoxPlot<Long> gr_PairQuality = new RBoxPlot<Long>("Correct v.s. wrong",
                                        "%%", new File("percentage_score" + selection + ".pdf"));
                                final RBoxPlot<String> gr_QualityForNumberOfTraces = new RBoxPlot<String>(
                                        "traces", "%%", new File("quality_traces" + selection + ".pdf"));
                                SquareBagPlot gr_NewToOrig = new SquareBagPlot("orig score",
                                        "score with learnt selection",
                                        new File("new_to_orig" + selection + ".pdf"), 0, 1, true);
                                final Map<Long, TrueFalseCounter> pairQualityCounter = new TreeMap<Long, TrueFalseCounter>();
                                try {
                                    int numberOfTasks = 0;
                                    for (int states = minStateNumber; states < minStateNumber
                                            + rangeOfStateNumbers; states += stateNumberIncrement)
                                        for (int sample = 0; sample < samplesPerFSM; ++sample) {
                                            LearnerRunner learnerRunner = new LearnerRunner(dataCollector,
                                                    states, sample, totalTaskNumber + numberOfTasks,
                                                    traceQuantity, config, converter) {
                                                @Override
                                                public LearnerThatCanClassifyPairs createLearner(
                                                        LearnerEvaluationConfiguration evalCnf,
                                                        LearnerGraph argReferenceGraph,
                                                        @SuppressWarnings("unused") WekaDataCollector argDataCollector,
                                                        LearnerGraph argInitialPTA) {
                                                    LearnerThatUsesWekaResults l = new LearnerThatUsesWekaResults(
                                                            ifDepth, evalCnf, argReferenceGraph, classifier,
                                                            argInitialPTA);
                                                    if (gr_PairQuality != null)
                                                        l.setPairQualityCounter(pairQualityCounter);

                                                    l.setUseClassifierForRed(selectingRed);
                                                    l.setUseClassifierToChooseNextRed(
                                                            classifierToBlockAllMergers);
                                                    l.setBlacklistZeroScoringPairs(zeroScoringAsRed);
                                                    l.setThreshold(threshold);
                                                    return l;
                                                }

                                            };
                                            learnerRunner.setPickUniqueFromInitial(useUnique);
                                            learnerRunner.setEvaluateAlsoUsingReferenceLearner(true);
                                            learnerRunner.setOnlyUsePositives(onlyPositives);
                                            learnerRunner.setIfdepth(ifDepth);
                                            learnerRunner.setLengthMultiplier(lengthMultiplier);
                                            learnerRunner.setSelectionID(
                                                    selection + "_states" + states + "_sample" + sample);
                                            runner.submit(learnerRunner);
                                            ++numberOfTasks;
                                        }
                                    ProgressIndicator progress = new ProgressIndicator(new Date()
                                            + " evaluating " + numberOfTasks + " tasks for " + selection,
                                            numberOfTasks);
                                    for (int count = 0; count < numberOfTasks; ++count) {
                                        ThreadResult result = runner.take().get();// this will throw an exception if any of the tasks failed.
                                        if (gr_NewToOrig != null) {
                                            for (SampleData sample : result.samples)
                                                gr_NewToOrig.add(sample.referenceLearner.getValue(),
                                                        sample.actualLearner.getValue());
                                        }

                                        for (SampleData sample : result.samples)
                                            if (sample.referenceLearner.getValue() > 0)
                                                gr_QualityForNumberOfTraces.add(traceQuantity + "",
                                                        sample.actualLearner.getValue()
                                                                / sample.referenceLearner.getValue());
                                        progress.next();
                                    }
                                    if (gr_PairQuality != null) {
                                        synchronized (pairQualityCounter) {
                                            updateGraph(gr_PairQuality, pairQualityCounter);
                                            //gr_PairQuality.drawInteractive(gr);
                                            //gr_NewToOrig.drawInteractive(gr);
                                            //if (gr_QualityForNumberOfTraces.size() > 0)
                                            //   gr_QualityForNumberOfTraces.drawInteractive(gr);
                                        }
                                    }
                                    if (gr_PairQuality != null)
                                        gr_PairQuality.drawPdf(gr);
                                } catch (Exception ex) {
                                    IllegalArgumentException e = new IllegalArgumentException(
                                            "failed to compute, the problem is: " + ex);
                                    e.initCause(ex);
                                    if (executorService != null) {
                                        executorService.shutdownNow();
                                        executorService = null;
                                    }
                                    throw e;
                                }
                                if (gr_NewToOrig != null)
                                    gr_NewToOrig.drawPdf(gr);
                                if (gr_QualityForNumberOfTraces != null)
                                    gr_QualityForNumberOfTraces.drawPdf(gr);
                            }
                }
            }
    if (executorService != null) {
        executorService.shutdown();
        executorService = null;
    }
}

From source file:statechum.analysis.learning.experiments.PaperUAS.java

License:Open Source License

/** Used to training a few different classifiers from a full PTA by comparing metrics on pairs considered by QSM and checking them against the reference solution. */
protected Classifier[] loadClassifierFromArff(String arffWithTrainingData) {
    weka.classifiers.trees.REPTree tree = new weka.classifiers.trees.REPTree();
    tree.setMaxDepth(3);
    tree.setNoPruning(true);// since we only use the tree as a classifier (as a conservative extension of what is currently done) and do not actually look at it, elimination of pruning is not a problem. 
    // As part of learning, we also prune some of the nodes where the ratio of correctly-classified pairs to those incorrectly classified is comparable.
    // The significant advantage of not pruning is that the result is no longer sensitive to the order of elements in the tree and hence does not depend on the order in which elements have been obtained by concurrent threads.
    weka.classifiers.trees.J48 tree48 = new weka.classifiers.trees.J48();
    tree48.setUnpruned(true);// since we only use the tree as a classifier (as a conservative extension of what is currently done) and do not actually look at it, elimination of pruning is not a problem. 
    // As part of learning, we also prune some of the nodes where the ratio of correctly-classified pairs to those incorrectly classified is comparable.
    // The significant advantage of not pruning is that the result is no longer sensitive to the order of elements in the tree and hence does not depend on the order in which elements have been obtained by concurrent threads.
    weka.classifiers.lazy.IBk ibk = new weka.classifiers.lazy.IBk(1);
    weka.classifiers.lazy.IB1 ib1 = new weka.classifiers.lazy.IB1();
    weka.classifiers.functions.MultilayerPerceptron perceptron = new weka.classifiers.functions.MultilayerPerceptron();
    Classifier[] outcome = new Classifier[] { ib1 };//tree};//,tree48,ibk};//,perceptron};
    for (Classifier c : outcome)
        trainClassifierFromArff(c, arffWithTrainingData);
    return outcome;
}