List of usage examples for weka.classifiers.lazy IBk setCrossValidate
public void setCrossValidate(boolean newCrossValidate)
From source file:com.edwardraff.WekaMNIST.java
License:Open Source License
public static void main(String[] args) throws IOException, Exception { String folder = args[0];// ww w. j av a2 s. c o m String trainPath = folder + "MNISTtrain.arff"; String testPath = folder + "MNISTtest.arff"; System.out.println("Weka Timings"); Instances mnistTrainWeka = new Instances(new BufferedReader(new FileReader(new File(trainPath)))); mnistTrainWeka.setClassIndex(mnistTrainWeka.numAttributes() - 1); Instances mnistTestWeka = new Instances(new BufferedReader(new FileReader(new File(testPath)))); mnistTestWeka.setClassIndex(mnistTestWeka.numAttributes() - 1); //normalize range like into [0, 1] Normalize normalizeFilter = new Normalize(); normalizeFilter.setInputFormat(mnistTrainWeka); mnistTestWeka = Normalize.useFilter(mnistTestWeka, normalizeFilter); mnistTrainWeka = Normalize.useFilter(mnistTrainWeka, normalizeFilter); long start, end; System.out.println("RBF SVM (Full Cache)"); SMO smo = new SMO(); smo.setKernel(new RBFKernel(mnistTrainWeka, 0/*0 causes Weka to cache the whole matrix...*/, 0.015625)); smo.setC(8.0); smo.setBuildLogisticModels(false); evalModel(smo, mnistTrainWeka, mnistTestWeka); System.out.println("RBF SVM (No Cache)"); smo = new SMO(); smo.setKernel(new RBFKernel(mnistTrainWeka, 1, 0.015625)); smo.setC(8.0); smo.setBuildLogisticModels(false); evalModel(smo, mnistTrainWeka, mnistTestWeka); System.out.println("Decision Tree C45"); J48 wekaC45 = new J48(); wekaC45.setUseLaplace(false); wekaC45.setCollapseTree(false); wekaC45.setUnpruned(true); wekaC45.setMinNumObj(2); wekaC45.setUseMDLcorrection(true); evalModel(wekaC45, mnistTrainWeka, mnistTestWeka); System.out.println("Random Forest 50 trees"); int featuresToUse = (int) Math.sqrt(28 * 28);//Weka uses different defaults, so lets make sure they both use the published way RandomForest wekaRF = new RandomForest(); wekaRF.setNumExecutionSlots(1); wekaRF.setMaxDepth(0/*0 for unlimited*/); wekaRF.setNumFeatures(featuresToUse); wekaRF.setNumTrees(50); evalModel(wekaRF, mnistTrainWeka, mnistTestWeka); System.out.println("1-NN (brute)"); IBk wekaNN = new IBk(1); wekaNN.setNearestNeighbourSearchAlgorithm(new LinearNNSearch()); wekaNN.setCrossValidate(false); evalModel(wekaNN, mnistTrainWeka, mnistTestWeka); System.out.println("1-NN (Ball Tree)"); wekaNN = new IBk(1); wekaNN.setNearestNeighbourSearchAlgorithm(new BallTree()); wekaNN.setCrossValidate(false); evalModel(wekaNN, mnistTrainWeka, mnistTestWeka); System.out.println("1-NN (Cover Tree)"); wekaNN = new IBk(1); wekaNN.setNearestNeighbourSearchAlgorithm(new CoverTree()); wekaNN.setCrossValidate(false); evalModel(wekaNN, mnistTrainWeka, mnistTestWeka); System.out.println("Logistic Regression LBFGS lambda = 1e-4"); Logistic logisticLBFGS = new Logistic(); logisticLBFGS.setRidge(1e-4); logisticLBFGS.setMaxIts(500); evalModel(logisticLBFGS, mnistTrainWeka, mnistTestWeka); System.out.println("k-means (Loyd)"); int origClassIndex = mnistTrainWeka.classIndex(); mnistTrainWeka.setClassIndex(-1); mnistTrainWeka.deleteAttributeAt(origClassIndex); { long totalTime = 0; for (int i = 0; i < 10; i++) { SimpleKMeans wekaKMeans = new SimpleKMeans(); wekaKMeans.setNumClusters(10); wekaKMeans.setNumExecutionSlots(1); wekaKMeans.setFastDistanceCalc(true); start = System.currentTimeMillis(); wekaKMeans.buildClusterer(mnistTrainWeka); end = System.currentTimeMillis(); totalTime += (end - start); } System.out.println("\tClustering took: " + (totalTime / 10.0) / 1000.0 + " on average"); } }
From source file:development.GoodHonoursPrediction.java
public static void main(String[] args) { Instances data = ClassifierTools.loadData("C:\\Admin\\Perfomance Analysis\\GoodHonsClassification"); RandomForest rf = new RandomForest(); double[][] a = ClassifierTools.crossValidationWithStats(rf, data, data.numInstances()); System.out.println(" Random forest LOOCV accuracy =" + a[0][0]); J48 tree = new J48(); a = ClassifierTools.crossValidationWithStats(tree, data, data.numInstances()); System.out.println(" C4.5 LOOCV accuracy =" + a[0][0]); IBk knn = new IBk(11); knn.setCrossValidate(true); a = ClassifierTools.crossValidationWithStats(knn, data, data.numInstances()); System.out.println(" KNN LOOCV accuracy =" + a[0][0]); NaiveBayes nb = new NaiveBayes(); a = ClassifierTools.crossValidationWithStats(nb, data, data.numInstances()); System.out.println(" Naive Bayes LOOCV accuracy =" + a[0][0]); /* try {//from w ww.j ava 2 s .c o m tree.buildClassifier(data); System.out.println(" Tree ="+tree); Classifier cls = new J48(); Evaluation eval = new Evaluation(data); Random rand = new Random(1); // using seed = 1 int folds = data.numInstances(); eval.crossValidateModel(cls, data, folds, rand); System.out.println(eval.toSummaryString()); tree.getTechnicalInformation(); } catch (Exception ex) { Logger.getLogger(GoodHonoursPrediction.class.getName()).log(Level.SEVERE, null, ex); } */ }