List of usage examples for weka.classifiers.trees SimpleCart SimpleCart
SimpleCart
From source file:org.uclab.mm.kcl.ddkat.modellearner.ModelLearner.java
License:Apache License
/** * Method to compute the classification accuracy. * * @param algo the algorithm name/*from w ww. j av a 2 s.co m*/ * @param data the data instances * @param datanature the dataset nature (i.e. original or processed data) * @throws Exception the exception */ protected String[] modelAccuracy(String algo, Instances data, String datanature) throws Exception { String modelResultSet[] = new String[4]; String modelStr = ""; Classifier classifier = null; // setting class attribute if the data format does not provide this information if (data.classIndex() == -1) data.setClassIndex(data.numAttributes() - 1); String decisionAttribute = data.attribute(data.numAttributes() - 1).toString(); String res[] = decisionAttribute.split("\\s+"); decisionAttribute = res[1]; if (algo.equals("BFTree")) { // Use BFTree classifiers BFTree BFTreeclassifier = new BFTree(); BFTreeclassifier.buildClassifier(data); modelStr = BFTreeclassifier.toString(); classifier = BFTreeclassifier; } else if (algo.equals("FT")) { // Use FT classifiers FT FTclassifier = new FT(); FTclassifier.buildClassifier(data); modelStr = FTclassifier.toString(); classifier = FTclassifier; } else if (algo.equals("J48")) { // Use J48 classifiers J48 J48classifier = new J48(); J48classifier.buildClassifier(data); modelStr = J48classifier.toString(); classifier = J48classifier; System.out.println("Model String: " + modelStr); } else if (algo.equals("J48graft")) { // Use J48graft classifiers J48graft J48graftclassifier = new J48graft(); J48graftclassifier.buildClassifier(data); modelStr = J48graftclassifier.toString(); classifier = J48graftclassifier; } else if (algo.equals("RandomTree")) { // Use RandomTree classifiers RandomTree RandomTreeclassifier = new RandomTree(); RandomTreeclassifier.buildClassifier(data); modelStr = RandomTreeclassifier.toString(); classifier = RandomTreeclassifier; } else if (algo.equals("REPTree")) { // Use REPTree classifiers REPTree REPTreeclassifier = new REPTree(); REPTreeclassifier.buildClassifier(data); modelStr = REPTreeclassifier.toString(); classifier = REPTreeclassifier; } else if (algo.equals("SimpleCart")) { // Use SimpleCart classifiers SimpleCart SimpleCartclassifier = new SimpleCart(); SimpleCartclassifier.buildClassifier(data); modelStr = SimpleCartclassifier.toString(); classifier = SimpleCartclassifier; } modelResultSet[0] = algo; modelResultSet[1] = decisionAttribute; modelResultSet[2] = modelStr; // Collect every group of predictions for J48 model in a FastVector FastVector predictions = new FastVector(); Evaluation evaluation = new Evaluation(data); int folds = 10; // cross fold validation = 10 evaluation.crossValidateModel(classifier, data, folds, new Random(1)); // System.out.println("Evaluatuion"+evaluation.toSummaryString()); System.out.println("\n\n" + datanature + " Evaluatuion " + evaluation.toMatrixString()); // ArrayList<Prediction> predictions = evaluation.predictions(); predictions.appendElements(evaluation.predictions()); System.out.println("\n\n 11111"); // Calculate overall accuracy of current classifier on all splits double correct = 0; for (int i = 0; i < predictions.size(); i++) { NominalPrediction np = (NominalPrediction) predictions.elementAt(i); if (np.predicted() == np.actual()) { correct++; } } System.out.println("\n\n 22222"); double accuracy = 100 * correct / predictions.size(); String accString = String.format("%.2f%%", accuracy); modelResultSet[3] = accString; System.out.println(datanature + " Accuracy " + accString); String modelFileName = algo + "-DDKA.model"; System.out.println("\n\n 33333"); ObjectOutputStream oos = new ObjectOutputStream( new FileOutputStream("D:\\DDKAResources\\" + modelFileName)); oos.writeObject(classifier); oos.flush(); oos.close(); return modelResultSet; }
From source file:SpamDetector.SpamDetector.java
/** * @param args the command line arguments *//* ww w . ja v a 2 s .co m*/ public static void main(String[] args) throws IOException, Exception { ArrayList<ArrayList<String>> notSpam = processCSV("notspam.csv"); ArrayList<ArrayList<String>> spam = processCSV("spam.csv"); // Cobain generate attribute & data FeatureExtraction fe = new FeatureExtraction(); fe.generateArff(spam, notSpam); // Cobain CART BufferedReader br = new BufferedReader(new FileReader("data.arff")); ArffReader arff = new ArffReader(br); Instances data = arff.getData(); data.setClassIndex(data.numAttributes() - 1); SimpleCart tree = new SimpleCart(); tree.buildClassifier(data); System.out.println(tree.toString()); Evaluation eval = new Evaluation(data); eval.evaluateModel(tree, data); System.out.println(eval.toSummaryString("\n\n\n\nResults\n======\n", false)); eval.crossValidateModel(tree, data, 10, new Random()); System.out.println(eval.toSummaryString("\n\n\n\n10-Fold\n======\n", false)); }
From source file:util.FeatureExtract.java
public static void createArff(String directory) { TextDirectoryLoader loader = new TextDirectoryLoader(); try {/*www . j a va2 s . c om*/ // convert the directory into a dataset loader.setDirectory(new File(directory)); Instances dataRaw = loader.getDataSet(); // apply the StringToWordVector and tf-idf weighting StringToWordVector filter = new StringToWordVector(); filter.setIDFTransform(true); filter.setInputFormat(dataRaw); Instances dataFiltered = Filter.useFilter(dataRaw, filter); // output the arff file ArffSaver saver = new ArffSaver(); saver.setInstances(dataFiltered); saver.setFile(new File(SpamFilterConfig.getArffFilePath())); saver.writeBatch(); // train with simple cart SimpleCart classifier = new SimpleCart(); classifier.buildClassifier(dataFiltered); System.out.println("\n\nClassifier model:\n\n" + classifier.toString()); // using 10 cross validation Evaluation eval = new Evaluation(dataFiltered); eval.crossValidateModel(classifier, dataFiltered, 10, new Random(1)); System.out.println("\n\nCross fold:\n\n" + eval.toSummaryString()); } catch (Exception ex) { Logger.getLogger(FeatureExtract.class.getName()).log(Level.SEVERE, null, ex); } }