Example usage for weka.classifiers.trees SimpleCart buildClassifier

List of usage examples for weka.classifiers.trees SimpleCart buildClassifier

Introduction

In this page you can find the example usage for weka.classifiers.trees SimpleCart buildClassifier.

Prototype

@Override
public void buildClassifier(Instances data) throws Exception 

Source Link

Document

Build the classifier.

Usage

From source file:org.uclab.mm.kcl.ddkat.modellearner.ModelLearner.java

License:Apache License

/**
* Method to compute the classification accuracy.
*
* @param algo the algorithm name/* w  w  w  .  j  a  v  a 2  s .  co m*/
* @param data the data instances
* @param datanature the dataset nature (i.e. original or processed data)
* @throws Exception the exception
*/
protected String[] modelAccuracy(String algo, Instances data, String datanature) throws Exception {

    String modelResultSet[] = new String[4];
    String modelStr = "";
    Classifier classifier = null;

    // setting class attribute if the data format does not provide this information           
    if (data.classIndex() == -1)
        data.setClassIndex(data.numAttributes() - 1);

    String decisionAttribute = data.attribute(data.numAttributes() - 1).toString();
    String res[] = decisionAttribute.split("\\s+");
    decisionAttribute = res[1];

    if (algo.equals("BFTree")) {

        // Use BFTree classifiers
        BFTree BFTreeclassifier = new BFTree();
        BFTreeclassifier.buildClassifier(data);
        modelStr = BFTreeclassifier.toString();
        classifier = BFTreeclassifier;

    } else if (algo.equals("FT")) {

        // Use FT classifiers
        FT FTclassifier = new FT();
        FTclassifier.buildClassifier(data);
        modelStr = FTclassifier.toString();
        classifier = FTclassifier;

    } else if (algo.equals("J48")) {

        // Use J48 classifiers
        J48 J48classifier = new J48();
        J48classifier.buildClassifier(data);
        modelStr = J48classifier.toString();
        classifier = J48classifier;
        System.out.println("Model String: " + modelStr);

    } else if (algo.equals("J48graft")) {

        // Use J48graft classifiers
        J48graft J48graftclassifier = new J48graft();
        J48graftclassifier.buildClassifier(data);
        modelStr = J48graftclassifier.toString();
        classifier = J48graftclassifier;

    } else if (algo.equals("RandomTree")) {

        // Use RandomTree classifiers
        RandomTree RandomTreeclassifier = new RandomTree();
        RandomTreeclassifier.buildClassifier(data);
        modelStr = RandomTreeclassifier.toString();
        classifier = RandomTreeclassifier;

    } else if (algo.equals("REPTree")) {

        // Use REPTree classifiers
        REPTree REPTreeclassifier = new REPTree();
        REPTreeclassifier.buildClassifier(data);
        modelStr = REPTreeclassifier.toString();
        classifier = REPTreeclassifier;

    } else if (algo.equals("SimpleCart")) {

        // Use SimpleCart classifiers
        SimpleCart SimpleCartclassifier = new SimpleCart();
        SimpleCartclassifier.buildClassifier(data);
        modelStr = SimpleCartclassifier.toString();
        classifier = SimpleCartclassifier;

    }

    modelResultSet[0] = algo;
    modelResultSet[1] = decisionAttribute;
    modelResultSet[2] = modelStr;

    // Collect every group of predictions for J48 model in a FastVector
    FastVector predictions = new FastVector();

    Evaluation evaluation = new Evaluation(data);
    int folds = 10; // cross fold validation = 10
    evaluation.crossValidateModel(classifier, data, folds, new Random(1));
    // System.out.println("Evaluatuion"+evaluation.toSummaryString());
    System.out.println("\n\n" + datanature + " Evaluatuion " + evaluation.toMatrixString());

    // ArrayList<Prediction> predictions = evaluation.predictions();
    predictions.appendElements(evaluation.predictions());

    System.out.println("\n\n 11111");
    // Calculate overall accuracy of current classifier on all splits
    double correct = 0;

    for (int i = 0; i < predictions.size(); i++) {
        NominalPrediction np = (NominalPrediction) predictions.elementAt(i);
        if (np.predicted() == np.actual()) {
            correct++;
        }
    }

    System.out.println("\n\n 22222");
    double accuracy = 100 * correct / predictions.size();
    String accString = String.format("%.2f%%", accuracy);
    modelResultSet[3] = accString;
    System.out.println(datanature + " Accuracy " + accString);

    String modelFileName = algo + "-DDKA.model";

    System.out.println("\n\n 33333");

    ObjectOutputStream oos = new ObjectOutputStream(
            new FileOutputStream("D:\\DDKAResources\\" + modelFileName));
    oos.writeObject(classifier);
    oos.flush();
    oos.close();

    return modelResultSet;

}

From source file:SpamDetector.SpamDetector.java

/**
 * @param args the command line arguments
 *//*from   w  w  w .  j  a  v a2  s.  c  om*/
public static void main(String[] args) throws IOException, Exception {
    ArrayList<ArrayList<String>> notSpam = processCSV("notspam.csv");
    ArrayList<ArrayList<String>> spam = processCSV("spam.csv");

    // Cobain generate attribute & data
    FeatureExtraction fe = new FeatureExtraction();
    fe.generateArff(spam, notSpam);

    // Cobain CART
    BufferedReader br = new BufferedReader(new FileReader("data.arff"));

    ArffReader arff = new ArffReader(br);
    Instances data = arff.getData();
    data.setClassIndex(data.numAttributes() - 1);

    SimpleCart tree = new SimpleCart();
    tree.buildClassifier(data);
    System.out.println(tree.toString());

    Evaluation eval = new Evaluation(data);
    eval.evaluateModel(tree, data);
    System.out.println(eval.toSummaryString("\n\n\n\nResults\n======\n", false));
    eval.crossValidateModel(tree, data, 10, new Random());
    System.out.println(eval.toSummaryString("\n\n\n\n10-Fold\n======\n", false));

}

From source file:util.FeatureExtract.java

public static void createArff(String directory) {
    TextDirectoryLoader loader = new TextDirectoryLoader();
    try {/* ww  w.ja  v a  2 s .  c o  m*/
        // convert the directory into a dataset
        loader.setDirectory(new File(directory));
        Instances dataRaw = loader.getDataSet();

        // apply the StringToWordVector and tf-idf weighting
        StringToWordVector filter = new StringToWordVector();
        filter.setIDFTransform(true);
        filter.setInputFormat(dataRaw);
        Instances dataFiltered = Filter.useFilter(dataRaw, filter);

        // output the arff file
        ArffSaver saver = new ArffSaver();
        saver.setInstances(dataFiltered);
        saver.setFile(new File(SpamFilterConfig.getArffFilePath()));
        saver.writeBatch();

        // train with simple cart
        SimpleCart classifier = new SimpleCart();
        classifier.buildClassifier(dataFiltered);
        System.out.println("\n\nClassifier model:\n\n" + classifier.toString());

        // using 10 cross validation
        Evaluation eval = new Evaluation(dataFiltered);
        eval.crossValidateModel(classifier, dataFiltered, 10, new Random(1));

        System.out.println("\n\nCross fold:\n\n" + eval.toSummaryString());
    } catch (Exception ex) {
        Logger.getLogger(FeatureExtract.class.getName()).log(Level.SEVERE, null, ex);
    }
}