List of usage examples for weka.classifiers.trees J48graft J48graft
J48graft
From source file:org.uclab.mm.kcl.ddkat.modellearner.ModelLearner.java
License:Apache License
/** * Method to compute the classification accuracy. * * @param algo the algorithm name/*from w ww.ja v a 2 s . c om*/ * @param data the data instances * @param datanature the dataset nature (i.e. original or processed data) * @throws Exception the exception */ protected String[] modelAccuracy(String algo, Instances data, String datanature) throws Exception { String modelResultSet[] = new String[4]; String modelStr = ""; Classifier classifier = null; // setting class attribute if the data format does not provide this information if (data.classIndex() == -1) data.setClassIndex(data.numAttributes() - 1); String decisionAttribute = data.attribute(data.numAttributes() - 1).toString(); String res[] = decisionAttribute.split("\\s+"); decisionAttribute = res[1]; if (algo.equals("BFTree")) { // Use BFTree classifiers BFTree BFTreeclassifier = new BFTree(); BFTreeclassifier.buildClassifier(data); modelStr = BFTreeclassifier.toString(); classifier = BFTreeclassifier; } else if (algo.equals("FT")) { // Use FT classifiers FT FTclassifier = new FT(); FTclassifier.buildClassifier(data); modelStr = FTclassifier.toString(); classifier = FTclassifier; } else if (algo.equals("J48")) { // Use J48 classifiers J48 J48classifier = new J48(); J48classifier.buildClassifier(data); modelStr = J48classifier.toString(); classifier = J48classifier; System.out.println("Model String: " + modelStr); } else if (algo.equals("J48graft")) { // Use J48graft classifiers J48graft J48graftclassifier = new J48graft(); J48graftclassifier.buildClassifier(data); modelStr = J48graftclassifier.toString(); classifier = J48graftclassifier; } else if (algo.equals("RandomTree")) { // Use RandomTree classifiers RandomTree RandomTreeclassifier = new RandomTree(); RandomTreeclassifier.buildClassifier(data); modelStr = RandomTreeclassifier.toString(); classifier = RandomTreeclassifier; } else if (algo.equals("REPTree")) { // Use REPTree classifiers REPTree REPTreeclassifier = new REPTree(); REPTreeclassifier.buildClassifier(data); modelStr = REPTreeclassifier.toString(); classifier = REPTreeclassifier; } else if (algo.equals("SimpleCart")) { // Use SimpleCart classifiers SimpleCart SimpleCartclassifier = new SimpleCart(); SimpleCartclassifier.buildClassifier(data); modelStr = SimpleCartclassifier.toString(); classifier = SimpleCartclassifier; } modelResultSet[0] = algo; modelResultSet[1] = decisionAttribute; modelResultSet[2] = modelStr; // Collect every group of predictions for J48 model in a FastVector FastVector predictions = new FastVector(); Evaluation evaluation = new Evaluation(data); int folds = 10; // cross fold validation = 10 evaluation.crossValidateModel(classifier, data, folds, new Random(1)); // System.out.println("Evaluatuion"+evaluation.toSummaryString()); System.out.println("\n\n" + datanature + " Evaluatuion " + evaluation.toMatrixString()); // ArrayList<Prediction> predictions = evaluation.predictions(); predictions.appendElements(evaluation.predictions()); System.out.println("\n\n 11111"); // Calculate overall accuracy of current classifier on all splits double correct = 0; for (int i = 0; i < predictions.size(); i++) { NominalPrediction np = (NominalPrediction) predictions.elementAt(i); if (np.predicted() == np.actual()) { correct++; } } System.out.println("\n\n 22222"); double accuracy = 100 * correct / predictions.size(); String accString = String.format("%.2f%%", accuracy); modelResultSet[3] = accString; System.out.println(datanature + " Accuracy " + accString); String modelFileName = algo + "-DDKA.model"; System.out.println("\n\n 33333"); ObjectOutputStream oos = new ObjectOutputStream( new FileOutputStream("D:\\DDKAResources\\" + modelFileName)); oos.writeObject(classifier); oos.flush(); oos.close(); return modelResultSet; }