Implement different classifiers in order to get statistical summaries in weka - Java Machine Learning AI

Java examples for Machine Learning AI:weka

Description

Implement different classifiers in order to get statistical summaries in weka

Demo Code


import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.PrintWriter;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.NominalPrediction;
import weka.classifiers.rules.*;
import weka.classifiers.lazy.*;
import weka.classifiers.functions.*;
import weka.classifiers.meta.*;
//import weka.classifiers.trees.DecisionStump;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.trees.*;
import weka.core.FastVector;
import weka.core.Instances;
//import libsvm.*;

public class algorithms {
  public static BufferedReader readDataFile(String filename) {
    BufferedReader inputReader = null;

    try {/*from w  w w .j  a  va2 s. co  m*/
      inputReader = new BufferedReader(new FileReader(filename));
    } catch (FileNotFoundException ex) {
      System.err.println("File not found: " + filename);
    }

    return inputReader;
  }

  public static Evaluation classify(Classifier model, Instances trainingSet, Instances testingSet) throws Exception {
    Evaluation evaluation = new Evaluation(trainingSet);

    model.buildClassifier(trainingSet);
    evaluation.evaluateModel(model, testingSet);

    return evaluation;
  }

  public static double calculateAccuracy(FastVector predictions) {
    double correct = 0;

    for (int i = 0; i < predictions.size(); i++) {
      NominalPrediction np = (NominalPrediction) predictions.elementAt(i);
      if (np.predicted() == np.actual()) {
        correct++;
      }
    }

    return 100 * correct / predictions.size();
  }

  public static Instances[][] crossValidationSplit(Instances data, int numberOfFolds) {
    Instances[][] split = new Instances[2][numberOfFolds];

    for (int i = 0; i < numberOfFolds; i++) {
      split[0][i] = data.trainCV(numberOfFolds, i);
      split[1][i] = data.testCV(numberOfFolds, i);
    }

    return split;
  }

  public static void main(String[] args) throws Exception {
    String outFile = "algorithmOutput.csv";
    String outPath = "c:\\data";

    String inFile = "auto93.arff";
    String inPath = "c:\\data\\numeric";

    String inStore = inPath + "\\" + inFile;
    String outStore = outPath + "\\" + outFile;

    BufferedReader datafile = readDataFile(inStore);
    PrintWriter out = new PrintWriter(new FileWriter(outStore));

    Instances data = new Instances(datafile);
    data.setClassIndex(data.numAttributes() - 1);

    Instances[][] split = crossValidationSplit(data, 10);

    Instances[] trainingSplits = split[0];
    Instances[] testingSplits = split[1];

    Classifier[] models = { new J48(), // a decision trees
        new DecisionStump(), // one-level decision tree
        new RandomForest(), new PART(), new DecisionTable(), // decision table majority classifier
        new JRip(), new ZeroR(),

        new IBk(), // instance based classifier used K nearest neighbor
        new KStar(), // instance based classifier using entropy based distance

        new NaiveBayes(),

        new AdaBoostM1(), new Bagging(), new Stacking(), new LogitBoost(), new RandomCommittee(),

        new Logistic(), new MultilayerPerceptron(), new SimpleLogistic(), // linear logistic regression models.
        new SMO(), new SMOreg(), // SMOreg implements the support vector machine for regression.
    };

    // Print header
    out.println("Accuracy,RMSE,Fscore,Kappa,PRC,AUC,Dataset,Algorithm");
    // System.out.println(" Number of instances = "+String.format("%.2f",
    // data.size()));
    // Run for each model
    for (int j = 0; j < models.length; j++) {

      // Collect every group of predictions for current model in a FastVector
      FastVector predictions = new FastVector();
      float avgRMSE = 0, avgKappa = 0, avgFscore = 0, avgPRC = 0, avgAUC = 0;
      float avgAcc = 0;
      // For each training-testing split pair, train and test the classifier
      for (int i = 0; i < trainingSplits.length; i++) {
        Evaluation validation = classify(models[j], trainingSplits[i], testingSplits[i]);

        predictions.appendElements(validation.predictions());
        // avgFscore += validation.weightedFMeasure();
        avgRMSE += validation.rootMeanSquaredError();
        System.out.println(validation.toClassDetailsString());
      }

      System.out.println(" RMSE = " + String.format("%.2f", avgRMSE / trainingSplits.length));
      double accuracy = calculateAccuracy(predictions);
      avgRMSE = avgRMSE / trainingSplits.length;
      avgFscore = avgFscore / trainingSplits.length;
      avgKappa = avgKappa / trainingSplits.length;
      avgPRC = avgPRC / trainingSplits.length;
      avgAUC = avgAUC / trainingSplits.length;
      avgAcc = avgAcc / trainingSplits.length;
      out.print(String.format("%.2f", accuracy) + "," + String.format("%.2f", avgRMSE) + ","
          + String.format("%.2f", avgFscore) + ",");
      out.print(String.format("%.2f", avgKappa) + "," + String.format("%.2f", avgPRC) + ","
          + String.format("%.2f", avgAUC) + ",");
      out.println(inFile + "," + models[j].getClass().getSimpleName());

      System.out.println(j);
      System.out.println(" copy " + String.format("%.2f%%", avgAcc));
      System.out.println("Accuracy of " + models[j].getClass().getSimpleName() + ": "
          + String.format("%.2f%%", accuracy) + "\n---------------------------------");

    }
    out.close();
    System.out.println(" Number of Classes = " + String.format("%d", data.numClasses()));
    System.out.printf(" Number of Attributes = " + String.format("%d", data.numAttributes() - 1));

  }
}

Related Tutorials