List of usage examples for weka.classifiers.lazy IBk setNearestNeighbourSearchAlgorithm
public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch nearestNeighbourSearchAlgorithm)
From source file:au.edu.usyd.it.yangpy.snp.Ensemble.java
License:Open Source License
public void ensemble(String mode) throws Exception { numInstances = test.numInstances();/*from ww w .ja v a 2 s. c o m*/ numClasses = test.numClasses(); givenValue = new double[numInstances]; predictDistribution = new double[numClassifiers][numInstances][numClasses]; predictValue = new double[numClassifiers][numInstances]; voteValue = new double[numInstances][numClasses]; // Setting the given class values of the test instances. for (int i = 0; i < numInstances; i++) { givenValue[i] = test.instance(i).classValue(); } // Calculating the predicted class values using each classifier respectively. // J48 coverTree1NN KStar coverTree3NN coverTree5NN J48 tree = new J48(); tree.setUnpruned(true); aucClassifiers[0] = classify(tree, 0); KStar kstar = new KStar(); aucClassifiers[1] = classify(kstar, 1); IBk ctnn1 = new IBk(1); CoverTree search = new CoverTree(); ctnn1.setNearestNeighbourSearchAlgorithm(search); aucClassifiers[2] = classify(ctnn1, 2); IBk ctnn3 = new IBk(3); ctnn3.setNearestNeighbourSearchAlgorithm(search); aucClassifiers[3] = classify(ctnn3, 3); IBk ctnn5 = new IBk(5); ctnn5.setNearestNeighbourSearchAlgorithm(search); aucClassifiers[4] = classify(ctnn5, 4); // Print the classification results if in print mode. if (mode.equals("v")) { System.out.println("J48 AUC: " + aucClassifiers[0]); System.out.println("KStar AUC: " + aucClassifiers[1]); System.out.println("CTNN1 AUC: " + aucClassifiers[2]); System.out.println("CTNN3 AUC: " + aucClassifiers[3]); System.out.println("CTNN5 AUC: " + aucClassifiers[4]); System.out.println(" - - "); } }
From source file:com.edwardraff.WekaMNIST.java
License:Open Source License
public static void main(String[] args) throws IOException, Exception { String folder = args[0];/* w w w .j a v a 2 s . com*/ String trainPath = folder + "MNISTtrain.arff"; String testPath = folder + "MNISTtest.arff"; System.out.println("Weka Timings"); Instances mnistTrainWeka = new Instances(new BufferedReader(new FileReader(new File(trainPath)))); mnistTrainWeka.setClassIndex(mnistTrainWeka.numAttributes() - 1); Instances mnistTestWeka = new Instances(new BufferedReader(new FileReader(new File(testPath)))); mnistTestWeka.setClassIndex(mnistTestWeka.numAttributes() - 1); //normalize range like into [0, 1] Normalize normalizeFilter = new Normalize(); normalizeFilter.setInputFormat(mnistTrainWeka); mnistTestWeka = Normalize.useFilter(mnistTestWeka, normalizeFilter); mnistTrainWeka = Normalize.useFilter(mnistTrainWeka, normalizeFilter); long start, end; System.out.println("RBF SVM (Full Cache)"); SMO smo = new SMO(); smo.setKernel(new RBFKernel(mnistTrainWeka, 0/*0 causes Weka to cache the whole matrix...*/, 0.015625)); smo.setC(8.0); smo.setBuildLogisticModels(false); evalModel(smo, mnistTrainWeka, mnistTestWeka); System.out.println("RBF SVM (No Cache)"); smo = new SMO(); smo.setKernel(new RBFKernel(mnistTrainWeka, 1, 0.015625)); smo.setC(8.0); smo.setBuildLogisticModels(false); evalModel(smo, mnistTrainWeka, mnistTestWeka); System.out.println("Decision Tree C45"); J48 wekaC45 = new J48(); wekaC45.setUseLaplace(false); wekaC45.setCollapseTree(false); wekaC45.setUnpruned(true); wekaC45.setMinNumObj(2); wekaC45.setUseMDLcorrection(true); evalModel(wekaC45, mnistTrainWeka, mnistTestWeka); System.out.println("Random Forest 50 trees"); int featuresToUse = (int) Math.sqrt(28 * 28);//Weka uses different defaults, so lets make sure they both use the published way RandomForest wekaRF = new RandomForest(); wekaRF.setNumExecutionSlots(1); wekaRF.setMaxDepth(0/*0 for unlimited*/); wekaRF.setNumFeatures(featuresToUse); wekaRF.setNumTrees(50); evalModel(wekaRF, mnistTrainWeka, mnistTestWeka); System.out.println("1-NN (brute)"); IBk wekaNN = new IBk(1); wekaNN.setNearestNeighbourSearchAlgorithm(new LinearNNSearch()); wekaNN.setCrossValidate(false); evalModel(wekaNN, mnistTrainWeka, mnistTestWeka); System.out.println("1-NN (Ball Tree)"); wekaNN = new IBk(1); wekaNN.setNearestNeighbourSearchAlgorithm(new BallTree()); wekaNN.setCrossValidate(false); evalModel(wekaNN, mnistTrainWeka, mnistTestWeka); System.out.println("1-NN (Cover Tree)"); wekaNN = new IBk(1); wekaNN.setNearestNeighbourSearchAlgorithm(new CoverTree()); wekaNN.setCrossValidate(false); evalModel(wekaNN, mnistTrainWeka, mnistTestWeka); System.out.println("Logistic Regression LBFGS lambda = 1e-4"); Logistic logisticLBFGS = new Logistic(); logisticLBFGS.setRidge(1e-4); logisticLBFGS.setMaxIts(500); evalModel(logisticLBFGS, mnistTrainWeka, mnistTestWeka); System.out.println("k-means (Loyd)"); int origClassIndex = mnistTrainWeka.classIndex(); mnistTrainWeka.setClassIndex(-1); mnistTrainWeka.deleteAttributeAt(origClassIndex); { long totalTime = 0; for (int i = 0; i < 10; i++) { SimpleKMeans wekaKMeans = new SimpleKMeans(); wekaKMeans.setNumClusters(10); wekaKMeans.setNumExecutionSlots(1); wekaKMeans.setFastDistanceCalc(true); start = System.currentTimeMillis(); wekaKMeans.buildClusterer(mnistTrainWeka); end = System.currentTimeMillis(); totalTime += (end - start); } System.out.println("\tClustering took: " + (totalTime / 10.0) / 1000.0 + " on average"); } }
From source file:hurtowniedanych.FXMLController.java
public void trainAndTestKNN() throws FileNotFoundException, IOException, Exception { InstanceQuery instanceQuery = new InstanceQuery(); instanceQuery.setUsername("postgres"); instanceQuery.setPassword("szupek"); instanceQuery.setCustomPropsFile(new File("./src/data/DatabaseUtils.props")); // Wskazanie pliku z ustawieniami dla PostgreSQL String query = "select ks.wydawnictwo,ks.gatunek, kl.mia-sto\n" + "from zakupy z,ksiazki ks,klienci kl\n" + "where ks.id_ksiazka=z.id_ksiazka and kl.id_klient=z.id_klient"; instanceQuery.setQuery(query);/* w w w. j a va 2 s. c om*/ Instances data = instanceQuery.retrieveInstances(); data.setClassIndex(data.numAttributes() - 1); data.randomize(new Random()); double percent = 70.0; int trainSize = (int) Math.round(data.numInstances() * percent / 100); int testSize = data.numInstances() - trainSize; Instances trainData = new Instances(data, 0, trainSize); Instances testData = new Instances(data, trainSize, testSize); int lSasiadow = Integer.parseInt(textFieldKnn.getText()); System.out.println(lSasiadow); IBk ibk = new IBk(lSasiadow); // Ustawienie odleglosci EuclideanDistance euclidean = new EuclideanDistance(); // euklidesowej ManhattanDistance manhatan = new ManhattanDistance(); // miejska LinearNNSearch linearNN = new LinearNNSearch(); if (comboboxOdleglosc.getSelectionModel().getSelectedItem().equals("Manhatan")) { linearNN.setDistanceFunction(manhatan); } else { linearNN.setDistanceFunction(euclidean); } ibk.setNearestNeighbourSearchAlgorithm(linearNN); // ustawienie sposobu szukania sasiadow // Tworzenie klasyfikatora ibk.buildClassifier(trainData); Evaluation eval = new Evaluation(trainData); eval.evaluateModel(ibk, testData); spr.setVisible(true); labelKnn.setVisible(true); labelOdleglosc.setVisible(true); labelKnn.setText(textFieldKnn.getText()); labelOdleglosc.setText(comboboxOdleglosc.getSelectionModel().getSelectedItem().toString()); spr.setText(eval.toSummaryString("Wynik:", true)); }
From source file:jjj.asap.sas.models1.job.BuildCosineModels.java
License:Open Source License
@Override protected void run() throws Exception { // validate args if (!Bucket.isBucket("datasets", inputBucket)) { throw new FileNotFoundException(inputBucket); }// ww w . j a v a2 s. c o m if (!Bucket.isBucket("models", outputBucket)) { throw new FileNotFoundException(outputBucket); } // init multi-threading Job.startService(); final Queue<Future<Object>> queue = new LinkedList<Future<Object>>(); // get the input from the bucket List<String> names = Bucket.getBucketItems("datasets", this.inputBucket); for (String dsn : names) { int essaySet = Contest.getEssaySet(dsn); int k = -1; switch (essaySet) { case 3: k = 13; break; case 5: case 7: k = 55; break; case 2: case 6: case 10: k = 21; break; case 1: case 4: case 8: case 9: k = 34; break; } if (k == -1) { throw new IllegalArgumentException("not k defined for " + essaySet); } LinearNNSearch search = new LinearNNSearch(); search.setDistanceFunction(new CosineDistance()); search.setSkipIdentical(false); IBk knn = new IBk(); knn.setKNN(k); knn.setDistanceWeighting(INVERSE); knn.setNearestNeighbourSearchAlgorithm(search); queue.add(Job.submit(new ModelBuilder(dsn, "KNN-" + k, knn, this.outputBucket))); } // wait on complete Progress progress = new Progress(queue.size(), this.getClass().getSimpleName()); while (!queue.isEmpty()) { try { queue.remove().get(); } catch (Exception e) { Job.log("ERROR", e.toString()); e.printStackTrace(System.err); } progress.tick(); } progress.done(); Job.stopService(); }