Example usage for weka.classifiers.trees REPTree buildClassifier

List of usage examples for weka.classifiers.trees REPTree buildClassifier

Introduction

In this page you can find the example usage for weka.classifiers.trees REPTree buildClassifier.

Prototype

@Override
public void buildClassifier(Instances data) throws Exception 

Source Link

Document

Builds classifier.

Usage

From source file:de.ugoe.cs.cpdp.dataselection.DecisionTreeSelection.java

License:Apache License

@Override
public void apply(Instances testdata, SetUniqueList<Instances> traindataSet) {
    final Instances data = characteristicInstances(testdata, traindataSet);

    final ArrayList<String> attVals = new ArrayList<String>();
    attVals.add("same");
    attVals.add("more");
    attVals.add("less");
    final ArrayList<Attribute> atts = new ArrayList<Attribute>();
    for (int j = 0; j < data.numAttributes(); j++) {
        atts.add(new Attribute(data.attribute(j).name(), attVals));
    }/*from  w  ww  . j a  v  a  2 s  .c  om*/
    atts.add(new Attribute("score"));
    Instances similarityData = new Instances("similarity", atts, 0);
    similarityData.setClassIndex(similarityData.numAttributes() - 1);

    try {
        Classifier classifier = new J48();
        for (int i = 0; i < traindataSet.size(); i++) {
            classifier.buildClassifier(traindataSet.get(i));
            for (int j = 0; j < traindataSet.size(); j++) {
                if (i != j) {
                    double[] similarity = new double[data.numAttributes() + 1];
                    for (int k = 0; k < data.numAttributes(); k++) {
                        if (0.9 * data.get(i + 1).value(k) > data.get(j + 1).value(k)) {
                            similarity[k] = 2.0;
                        } else if (1.1 * data.get(i + 1).value(k) < data.get(j + 1).value(k)) {
                            similarity[k] = 1.0;
                        } else {
                            similarity[k] = 0.0;
                        }
                    }

                    Evaluation eval = new Evaluation(traindataSet.get(j));
                    eval.evaluateModel(classifier, traindataSet.get(j));
                    similarity[data.numAttributes()] = eval.fMeasure(1);
                    similarityData.add(new DenseInstance(1.0, similarity));
                }
            }
        }
        REPTree repTree = new REPTree();
        if (repTree.getNumFolds() > similarityData.size()) {
            repTree.setNumFolds(similarityData.size());
        }
        repTree.setNumFolds(2);
        repTree.buildClassifier(similarityData);

        Instances testTrainSimilarity = new Instances(similarityData);
        testTrainSimilarity.clear();
        for (int i = 0; i < traindataSet.size(); i++) {
            double[] similarity = new double[data.numAttributes() + 1];
            for (int k = 0; k < data.numAttributes(); k++) {
                if (0.9 * data.get(0).value(k) > data.get(i + 1).value(k)) {
                    similarity[k] = 2.0;
                } else if (1.1 * data.get(0).value(k) < data.get(i + 1).value(k)) {
                    similarity[k] = 1.0;
                } else {
                    similarity[k] = 0.0;
                }
            }
            testTrainSimilarity.add(new DenseInstance(1.0, similarity));
        }

        int bestScoringProductIndex = -1;
        double maxScore = Double.MIN_VALUE;
        for (int i = 0; i < traindataSet.size(); i++) {
            double score = repTree.classifyInstance(testTrainSimilarity.get(i));
            if (score > maxScore) {
                maxScore = score;
                bestScoringProductIndex = i;
            }
        }
        Instances bestScoringProduct = traindataSet.get(bestScoringProductIndex);
        traindataSet.clear();
        traindataSet.add(bestScoringProduct);
    } catch (Exception e) {
        Console.printerr("failure during DecisionTreeSelection: " + e.getMessage());
        throw new RuntimeException(e);
    }
}

From source file:lu.lippmann.cdb.datasetview.tabs.RegressionTreeTabView.java

License:Open Source License

/**
 * {@inheritDoc}/*  w w w.  j  a  va  2 s  .c  o m*/
 */
@SuppressWarnings("unchecked")
@Override
public void update0(final Instances dataSet) throws Exception {
    this.panel.removeAll();

    //final Object[] attrNames=WekaDataStatsUtil.getNumericAttributesNames(dataSet).toArray();
    final Object[] attrNames = WekaDataStatsUtil.getAttributeNames(dataSet).toArray();
    final JComboBox xCombo = new JComboBox(attrNames);
    xCombo.setBorder(new TitledBorder("Attribute to evaluate"));

    final JXPanel comboPanel = new JXPanel();
    comboPanel.setLayout(new GridLayout(1, 2));
    comboPanel.add(xCombo);
    final JXButton jxb = new JXButton("Compute");
    comboPanel.add(jxb);
    this.panel.add(comboPanel, BorderLayout.NORTH);

    jxb.addActionListener(new ActionListener() {
        @Override
        public void actionPerformed(ActionEvent e) {
            try {
                if (gv != null)
                    panel.remove((Component) gv);

                dataSet.setClassIndex(xCombo.getSelectedIndex());

                final REPTree rt = new REPTree();
                rt.setNoPruning(true);
                //rt.setMaxDepth(3);
                rt.buildClassifier(dataSet);

                /*final M5P rt=new M5P();
                rt.buildClassifier(dataSet);*/

                final Evaluation eval = new Evaluation(dataSet);
                double[] d = eval.evaluateModel(rt, dataSet);
                System.out.println("PREDICTED -> " + FormatterUtil.buildStringFromArrayOfDoubles(d));
                System.out.println(eval.errorRate());
                System.out.println(eval.sizeOfPredictedRegions());
                System.out.println(eval.toSummaryString("", true));

                final GraphWithOperations gwo = GraphUtil
                        .buildGraphWithOperationsFromWekaRegressionString(rt.graph());
                final DecisionTree dt = new DecisionTree(gwo, eval.errorRate());

                gv = DecisionTreeToGraphViewHelper.buildGraphView(dt, eventPublisher, commandDispatcher);
                gv.addMetaInfo("Size=" + dt.getSize(), "");
                gv.addMetaInfo("Depth=" + dt.getDepth(), "");

                gv.addMetaInfo("MAE=" + FormatterUtil.DECIMAL_FORMAT.format(eval.meanAbsoluteError()) + "", "");
                gv.addMetaInfo("RMSE=" + FormatterUtil.DECIMAL_FORMAT.format(eval.rootMeanSquaredError()) + "",
                        "");

                final JCheckBox toggleDecisionTreeDetails = new JCheckBox("Toggle details");
                toggleDecisionTreeDetails.addActionListener(new ActionListener() {
                    @Override
                    public void actionPerformed(ActionEvent e) {
                        if (!tweakedGraph) {
                            final Object[] mapRep = WekaDataStatsUtil
                                    .buildNodeAndEdgeRepartitionMap(dt.getGraphWithOperations(), dataSet);
                            gv.updateVertexShapeTransformer((Map<CNode, Map<Object, Integer>>) mapRep[0]);
                            gv.updateEdgeShapeRenderer((Map<CEdge, Float>) mapRep[1]);
                        } else {
                            gv.resetVertexAndEdgeShape();
                        }
                        tweakedGraph = !tweakedGraph;
                    }
                });
                gv.addMetaInfoComponent(toggleDecisionTreeDetails);

                /*final JButton openInEditorButton = new JButton("Open in editor");
                openInEditorButton.addActionListener(new ActionListener() {
                   @Override
                   public void actionPerformed(ActionEvent e) {
                       GraphUtil.importDecisionTreeInEditor(dtFactory, dataSet, applicationContext, eventPublisher, commandDispatcher);
                   }
                });
                this.gv.addMetaInfoComponent(openInEditorButton);*/

                final JButton showTextButton = new JButton("In text");
                showTextButton.addActionListener(new ActionListener() {
                    @Override
                    public void actionPerformed(ActionEvent e) {
                        JOptionPane.showMessageDialog(null, graphDsl.getDslString(dt.getGraphWithOperations()));
                    }
                });
                gv.addMetaInfoComponent(showTextButton);

                panel.add(gv.asComponent(), BorderLayout.CENTER);
            } catch (Exception e1) {
                e1.printStackTrace();
                panel.add(new JXLabel("Error during computation: " + e1.getMessage()), BorderLayout.CENTER);
            }

        }
    });
}

From source file:lu.lippmann.cdb.dt.RegressionTreeFactory.java

License:Open Source License

/**
 * Main method./*from  ww  w  . ja v  a  2  s.  co m*/
 * @param args command line arguments
 */
public static void main(final String[] args) {
    try {
        final String f = "./samples/csv/uci/winequality-red.csv";
        //final String f="./samples/arff/UCI/crimepredict.arff";
        final Instances dataSet = WekaDataAccessUtil.loadInstancesFromARFFOrCSVFile(new File(f));
        System.out.println(dataSet.classAttribute().isNumeric());

        final REPTree rt = new REPTree();
        rt.setMaxDepth(3);
        rt.buildClassifier(dataSet);

        System.out.println(rt);

        //System.out.println(rt.graph());

        final GraphWithOperations gwo = GraphUtil.buildGraphWithOperationsFromWekaRegressionString(rt.graph());
        System.out.println(gwo);
        System.out.println(new ASCIIGraphDsl().getDslString(gwo));

        final Evaluation eval = new Evaluation(dataSet);

        /*Field privateStringField = Evaluation.class.getDeclaredField("m_CoverageStatisticsAvailable");
        privateStringField.setAccessible(true);
        //privateStringField.get
        boolean fieldValue = privateStringField.getBoolean(eval);
        System.out.println("fieldValue = " + fieldValue);*/

        double[] d = eval.evaluateModel(rt, dataSet);
        System.out.println("PREDICTED -> " + FormatterUtil.buildStringFromArrayOfDoubles(d));

        System.out.println(eval.errorRate());
        System.out.println(eval.sizeOfPredictedRegions());

        System.out.println(eval.toSummaryString("", true));

        /*final String f2="./samples/csv/salary.csv";
        final Instances dataSet2=WekaDataAccessUtil.loadInstancesFromARFFOrCSVFile(new File(f2));
                
        final J48 j48=new J48();
        j48.buildClassifier(dataSet2);
        System.out.println(j48.graph());
        final GraphWithOperations gwo2=GraphUtil.buildGraphWithOperationsFromWekaString(j48.graph(),false);
        System.out.println(gwo2);*/

        System.out.println(new DecisionTree(gwo, eval.errorRate()));
    } catch (Exception e) {
        e.printStackTrace();
    }
}

From source file:org.uclab.mm.kcl.ddkat.modellearner.ModelLearner.java

License:Apache License

/**
* Method to compute the classification accuracy.
*
* @param algo the algorithm name/* w  ww .  j a  v a2 s  .c o  m*/
* @param data the data instances
* @param datanature the dataset nature (i.e. original or processed data)
* @throws Exception the exception
*/
protected String[] modelAccuracy(String algo, Instances data, String datanature) throws Exception {

    String modelResultSet[] = new String[4];
    String modelStr = "";
    Classifier classifier = null;

    // setting class attribute if the data format does not provide this information           
    if (data.classIndex() == -1)
        data.setClassIndex(data.numAttributes() - 1);

    String decisionAttribute = data.attribute(data.numAttributes() - 1).toString();
    String res[] = decisionAttribute.split("\\s+");
    decisionAttribute = res[1];

    if (algo.equals("BFTree")) {

        // Use BFTree classifiers
        BFTree BFTreeclassifier = new BFTree();
        BFTreeclassifier.buildClassifier(data);
        modelStr = BFTreeclassifier.toString();
        classifier = BFTreeclassifier;

    } else if (algo.equals("FT")) {

        // Use FT classifiers
        FT FTclassifier = new FT();
        FTclassifier.buildClassifier(data);
        modelStr = FTclassifier.toString();
        classifier = FTclassifier;

    } else if (algo.equals("J48")) {

        // Use J48 classifiers
        J48 J48classifier = new J48();
        J48classifier.buildClassifier(data);
        modelStr = J48classifier.toString();
        classifier = J48classifier;
        System.out.println("Model String: " + modelStr);

    } else if (algo.equals("J48graft")) {

        // Use J48graft classifiers
        J48graft J48graftclassifier = new J48graft();
        J48graftclassifier.buildClassifier(data);
        modelStr = J48graftclassifier.toString();
        classifier = J48graftclassifier;

    } else if (algo.equals("RandomTree")) {

        // Use RandomTree classifiers
        RandomTree RandomTreeclassifier = new RandomTree();
        RandomTreeclassifier.buildClassifier(data);
        modelStr = RandomTreeclassifier.toString();
        classifier = RandomTreeclassifier;

    } else if (algo.equals("REPTree")) {

        // Use REPTree classifiers
        REPTree REPTreeclassifier = new REPTree();
        REPTreeclassifier.buildClassifier(data);
        modelStr = REPTreeclassifier.toString();
        classifier = REPTreeclassifier;

    } else if (algo.equals("SimpleCart")) {

        // Use SimpleCart classifiers
        SimpleCart SimpleCartclassifier = new SimpleCart();
        SimpleCartclassifier.buildClassifier(data);
        modelStr = SimpleCartclassifier.toString();
        classifier = SimpleCartclassifier;

    }

    modelResultSet[0] = algo;
    modelResultSet[1] = decisionAttribute;
    modelResultSet[2] = modelStr;

    // Collect every group of predictions for J48 model in a FastVector
    FastVector predictions = new FastVector();

    Evaluation evaluation = new Evaluation(data);
    int folds = 10; // cross fold validation = 10
    evaluation.crossValidateModel(classifier, data, folds, new Random(1));
    // System.out.println("Evaluatuion"+evaluation.toSummaryString());
    System.out.println("\n\n" + datanature + " Evaluatuion " + evaluation.toMatrixString());

    // ArrayList<Prediction> predictions = evaluation.predictions();
    predictions.appendElements(evaluation.predictions());

    System.out.println("\n\n 11111");
    // Calculate overall accuracy of current classifier on all splits
    double correct = 0;

    for (int i = 0; i < predictions.size(); i++) {
        NominalPrediction np = (NominalPrediction) predictions.elementAt(i);
        if (np.predicted() == np.actual()) {
            correct++;
        }
    }

    System.out.println("\n\n 22222");
    double accuracy = 100 * correct / predictions.size();
    String accString = String.format("%.2f%%", accuracy);
    modelResultSet[3] = accString;
    System.out.println(datanature + " Accuracy " + accString);

    String modelFileName = algo + "-DDKA.model";

    System.out.println("\n\n 33333");

    ObjectOutputStream oos = new ObjectOutputStream(
            new FileOutputStream("D:\\DDKAResources\\" + modelFileName));
    oos.writeObject(classifier);
    oos.flush();
    oos.close();

    return modelResultSet;

}