Example usage for weka.classifiers.lazy IBk setCrossValidate

List of usage examples for weka.classifiers.lazy IBk setCrossValidate

Introduction

In this page you can find the example usage for weka.classifiers.lazy IBk setCrossValidate.

Prototype

public void setCrossValidate(boolean newCrossValidate) 

Source Link

Document

Sets whether hold-one-out cross-validation will be used to select the best k value.

Usage

From source file:com.edwardraff.WekaMNIST.java

License:Open Source License

public static void main(String[] args) throws IOException, Exception {
    String folder = args[0];// ww w. j  av  a2 s. c  o  m
    String trainPath = folder + "MNISTtrain.arff";
    String testPath = folder + "MNISTtest.arff";

    System.out.println("Weka Timings");
    Instances mnistTrainWeka = new Instances(new BufferedReader(new FileReader(new File(trainPath))));
    mnistTrainWeka.setClassIndex(mnistTrainWeka.numAttributes() - 1);
    Instances mnistTestWeka = new Instances(new BufferedReader(new FileReader(new File(testPath))));
    mnistTestWeka.setClassIndex(mnistTestWeka.numAttributes() - 1);

    //normalize range like into [0, 1]
    Normalize normalizeFilter = new Normalize();
    normalizeFilter.setInputFormat(mnistTrainWeka);

    mnistTestWeka = Normalize.useFilter(mnistTestWeka, normalizeFilter);
    mnistTrainWeka = Normalize.useFilter(mnistTrainWeka, normalizeFilter);

    long start, end;

    System.out.println("RBF SVM (Full Cache)");
    SMO smo = new SMO();
    smo.setKernel(new RBFKernel(mnistTrainWeka, 0/*0 causes Weka to cache the whole matrix...*/, 0.015625));
    smo.setC(8.0);
    smo.setBuildLogisticModels(false);
    evalModel(smo, mnistTrainWeka, mnistTestWeka);

    System.out.println("RBF SVM (No Cache)");
    smo = new SMO();
    smo.setKernel(new RBFKernel(mnistTrainWeka, 1, 0.015625));
    smo.setC(8.0);
    smo.setBuildLogisticModels(false);
    evalModel(smo, mnistTrainWeka, mnistTestWeka);

    System.out.println("Decision Tree C45");
    J48 wekaC45 = new J48();
    wekaC45.setUseLaplace(false);
    wekaC45.setCollapseTree(false);
    wekaC45.setUnpruned(true);
    wekaC45.setMinNumObj(2);
    wekaC45.setUseMDLcorrection(true);

    evalModel(wekaC45, mnistTrainWeka, mnistTestWeka);

    System.out.println("Random Forest 50 trees");
    int featuresToUse = (int) Math.sqrt(28 * 28);//Weka uses different defaults, so lets make sure they both use the published way

    RandomForest wekaRF = new RandomForest();
    wekaRF.setNumExecutionSlots(1);
    wekaRF.setMaxDepth(0/*0 for unlimited*/);
    wekaRF.setNumFeatures(featuresToUse);
    wekaRF.setNumTrees(50);

    evalModel(wekaRF, mnistTrainWeka, mnistTestWeka);

    System.out.println("1-NN (brute)");
    IBk wekaNN = new IBk(1);
    wekaNN.setNearestNeighbourSearchAlgorithm(new LinearNNSearch());
    wekaNN.setCrossValidate(false);

    evalModel(wekaNN, mnistTrainWeka, mnistTestWeka);

    System.out.println("1-NN (Ball Tree)");
    wekaNN = new IBk(1);
    wekaNN.setNearestNeighbourSearchAlgorithm(new BallTree());
    wekaNN.setCrossValidate(false);

    evalModel(wekaNN, mnistTrainWeka, mnistTestWeka);

    System.out.println("1-NN (Cover Tree)");
    wekaNN = new IBk(1);
    wekaNN.setNearestNeighbourSearchAlgorithm(new CoverTree());
    wekaNN.setCrossValidate(false);

    evalModel(wekaNN, mnistTrainWeka, mnistTestWeka);

    System.out.println("Logistic Regression LBFGS lambda = 1e-4");
    Logistic logisticLBFGS = new Logistic();
    logisticLBFGS.setRidge(1e-4);
    logisticLBFGS.setMaxIts(500);

    evalModel(logisticLBFGS, mnistTrainWeka, mnistTestWeka);

    System.out.println("k-means (Loyd)");
    int origClassIndex = mnistTrainWeka.classIndex();
    mnistTrainWeka.setClassIndex(-1);
    mnistTrainWeka.deleteAttributeAt(origClassIndex);
    {
        long totalTime = 0;
        for (int i = 0; i < 10; i++) {
            SimpleKMeans wekaKMeans = new SimpleKMeans();
            wekaKMeans.setNumClusters(10);
            wekaKMeans.setNumExecutionSlots(1);
            wekaKMeans.setFastDistanceCalc(true);

            start = System.currentTimeMillis();
            wekaKMeans.buildClusterer(mnistTrainWeka);
            end = System.currentTimeMillis();
            totalTime += (end - start);
        }
        System.out.println("\tClustering took: " + (totalTime / 10.0) / 1000.0 + " on average");
    }
}

From source file:development.GoodHonoursPrediction.java

public static void main(String[] args) {
    Instances data = ClassifierTools.loadData("C:\\Admin\\Perfomance Analysis\\GoodHonsClassification");
    RandomForest rf = new RandomForest();
    double[][] a = ClassifierTools.crossValidationWithStats(rf, data, data.numInstances());
    System.out.println(" Random forest LOOCV accuracy =" + a[0][0]);
    J48 tree = new J48();
    a = ClassifierTools.crossValidationWithStats(tree, data, data.numInstances());
    System.out.println(" C4.5 LOOCV accuracy =" + a[0][0]);
    IBk knn = new IBk(11);
    knn.setCrossValidate(true);
    a = ClassifierTools.crossValidationWithStats(knn, data, data.numInstances());
    System.out.println(" KNN LOOCV accuracy =" + a[0][0]);
    NaiveBayes nb = new NaiveBayes();
    a = ClassifierTools.crossValidationWithStats(nb, data, data.numInstances());
    System.out.println(" Naive Bayes LOOCV accuracy =" + a[0][0]);

    /*       try {//from  w  ww.j ava  2  s  .c  o  m
    tree.buildClassifier(data);
           System.out.println(" Tree ="+tree);
           Classifier cls = new J48();
            Evaluation eval = new Evaluation(data);
            Random rand = new Random(1);  // using seed = 1
            int folds = data.numInstances();
            eval.crossValidateModel(cls, data, folds, rand);
            System.out.println(eval.toSummaryString());        
                   
           tree.getTechnicalInformation();
           } catch (Exception ex) {
    Logger.getLogger(GoodHonoursPrediction.class.getName()).log(Level.SEVERE, null, ex);
           }
           */

}