Example usage for weka.classifiers.lazy IBk setNearestNeighbourSearchAlgorithm

List of usage examples for weka.classifiers.lazy IBk setNearestNeighbourSearchAlgorithm

Introduction

In this page you can find the example usage for weka.classifiers.lazy IBk setNearestNeighbourSearchAlgorithm.

Prototype

public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch nearestNeighbourSearchAlgorithm) 

Source Link

Document

Sets the nearestNeighbourSearch algorithm to be used for finding nearest neighbour(s).

Usage

From source file:au.edu.usyd.it.yangpy.snp.Ensemble.java

License:Open Source License

public void ensemble(String mode) throws Exception {

    numInstances = test.numInstances();/*from  ww w .ja  v  a 2 s.  c o m*/
    numClasses = test.numClasses();
    givenValue = new double[numInstances];
    predictDistribution = new double[numClassifiers][numInstances][numClasses];
    predictValue = new double[numClassifiers][numInstances];
    voteValue = new double[numInstances][numClasses];

    // Setting the given class values of the test instances.
    for (int i = 0; i < numInstances; i++) {
        givenValue[i] = test.instance(i).classValue();
    }

    // Calculating the predicted class values using each classifier respectively.
    // J48 coverTree1NN KStar coverTree3NN coverTree5NN

    J48 tree = new J48();
    tree.setUnpruned(true);
    aucClassifiers[0] = classify(tree, 0);

    KStar kstar = new KStar();
    aucClassifiers[1] = classify(kstar, 1);

    IBk ctnn1 = new IBk(1);
    CoverTree search = new CoverTree();
    ctnn1.setNearestNeighbourSearchAlgorithm(search);
    aucClassifiers[2] = classify(ctnn1, 2);

    IBk ctnn3 = new IBk(3);
    ctnn3.setNearestNeighbourSearchAlgorithm(search);
    aucClassifiers[3] = classify(ctnn3, 3);

    IBk ctnn5 = new IBk(5);
    ctnn5.setNearestNeighbourSearchAlgorithm(search);
    aucClassifiers[4] = classify(ctnn5, 4);

    // Print the classification results if in print mode.
    if (mode.equals("v")) {
        System.out.println("J48   AUC: " + aucClassifiers[0]);
        System.out.println("KStar AUC: " + aucClassifiers[1]);
        System.out.println("CTNN1 AUC: " + aucClassifiers[2]);
        System.out.println("CTNN3 AUC: " + aucClassifiers[3]);
        System.out.println("CTNN5 AUC: " + aucClassifiers[4]);
        System.out.println("   -         -   ");
    }
}

From source file:com.edwardraff.WekaMNIST.java

License:Open Source License

public static void main(String[] args) throws IOException, Exception {
    String folder = args[0];/*  w w w  .j  a  v  a 2  s . com*/
    String trainPath = folder + "MNISTtrain.arff";
    String testPath = folder + "MNISTtest.arff";

    System.out.println("Weka Timings");
    Instances mnistTrainWeka = new Instances(new BufferedReader(new FileReader(new File(trainPath))));
    mnistTrainWeka.setClassIndex(mnistTrainWeka.numAttributes() - 1);
    Instances mnistTestWeka = new Instances(new BufferedReader(new FileReader(new File(testPath))));
    mnistTestWeka.setClassIndex(mnistTestWeka.numAttributes() - 1);

    //normalize range like into [0, 1]
    Normalize normalizeFilter = new Normalize();
    normalizeFilter.setInputFormat(mnistTrainWeka);

    mnistTestWeka = Normalize.useFilter(mnistTestWeka, normalizeFilter);
    mnistTrainWeka = Normalize.useFilter(mnistTrainWeka, normalizeFilter);

    long start, end;

    System.out.println("RBF SVM (Full Cache)");
    SMO smo = new SMO();
    smo.setKernel(new RBFKernel(mnistTrainWeka, 0/*0 causes Weka to cache the whole matrix...*/, 0.015625));
    smo.setC(8.0);
    smo.setBuildLogisticModels(false);
    evalModel(smo, mnistTrainWeka, mnistTestWeka);

    System.out.println("RBF SVM (No Cache)");
    smo = new SMO();
    smo.setKernel(new RBFKernel(mnistTrainWeka, 1, 0.015625));
    smo.setC(8.0);
    smo.setBuildLogisticModels(false);
    evalModel(smo, mnistTrainWeka, mnistTestWeka);

    System.out.println("Decision Tree C45");
    J48 wekaC45 = new J48();
    wekaC45.setUseLaplace(false);
    wekaC45.setCollapseTree(false);
    wekaC45.setUnpruned(true);
    wekaC45.setMinNumObj(2);
    wekaC45.setUseMDLcorrection(true);

    evalModel(wekaC45, mnistTrainWeka, mnistTestWeka);

    System.out.println("Random Forest 50 trees");
    int featuresToUse = (int) Math.sqrt(28 * 28);//Weka uses different defaults, so lets make sure they both use the published way

    RandomForest wekaRF = new RandomForest();
    wekaRF.setNumExecutionSlots(1);
    wekaRF.setMaxDepth(0/*0 for unlimited*/);
    wekaRF.setNumFeatures(featuresToUse);
    wekaRF.setNumTrees(50);

    evalModel(wekaRF, mnistTrainWeka, mnistTestWeka);

    System.out.println("1-NN (brute)");
    IBk wekaNN = new IBk(1);
    wekaNN.setNearestNeighbourSearchAlgorithm(new LinearNNSearch());
    wekaNN.setCrossValidate(false);

    evalModel(wekaNN, mnistTrainWeka, mnistTestWeka);

    System.out.println("1-NN (Ball Tree)");
    wekaNN = new IBk(1);
    wekaNN.setNearestNeighbourSearchAlgorithm(new BallTree());
    wekaNN.setCrossValidate(false);

    evalModel(wekaNN, mnistTrainWeka, mnistTestWeka);

    System.out.println("1-NN (Cover Tree)");
    wekaNN = new IBk(1);
    wekaNN.setNearestNeighbourSearchAlgorithm(new CoverTree());
    wekaNN.setCrossValidate(false);

    evalModel(wekaNN, mnistTrainWeka, mnistTestWeka);

    System.out.println("Logistic Regression LBFGS lambda = 1e-4");
    Logistic logisticLBFGS = new Logistic();
    logisticLBFGS.setRidge(1e-4);
    logisticLBFGS.setMaxIts(500);

    evalModel(logisticLBFGS, mnistTrainWeka, mnistTestWeka);

    System.out.println("k-means (Loyd)");
    int origClassIndex = mnistTrainWeka.classIndex();
    mnistTrainWeka.setClassIndex(-1);
    mnistTrainWeka.deleteAttributeAt(origClassIndex);
    {
        long totalTime = 0;
        for (int i = 0; i < 10; i++) {
            SimpleKMeans wekaKMeans = new SimpleKMeans();
            wekaKMeans.setNumClusters(10);
            wekaKMeans.setNumExecutionSlots(1);
            wekaKMeans.setFastDistanceCalc(true);

            start = System.currentTimeMillis();
            wekaKMeans.buildClusterer(mnistTrainWeka);
            end = System.currentTimeMillis();
            totalTime += (end - start);
        }
        System.out.println("\tClustering took: " + (totalTime / 10.0) / 1000.0 + " on average");
    }
}

From source file:hurtowniedanych.FXMLController.java

public void trainAndTestKNN() throws FileNotFoundException, IOException, Exception {

    InstanceQuery instanceQuery = new InstanceQuery();
    instanceQuery.setUsername("postgres");
    instanceQuery.setPassword("szupek");
    instanceQuery.setCustomPropsFile(new File("./src/data/DatabaseUtils.props")); // Wskazanie pliku z ustawieniami dla PostgreSQL

    String query = "select ks.wydawnictwo,ks.gatunek, kl.mia-sto\n" + "from zakupy z,ksiazki ks,klienci kl\n"
            + "where ks.id_ksiazka=z.id_ksiazka and kl.id_klient=z.id_klient";

    instanceQuery.setQuery(query);/* w w w. j a  va  2 s. c  om*/
    Instances data = instanceQuery.retrieveInstances();
    data.setClassIndex(data.numAttributes() - 1);

    data.randomize(new Random());
    double percent = 70.0;
    int trainSize = (int) Math.round(data.numInstances() * percent / 100);
    int testSize = data.numInstances() - trainSize;
    Instances trainData = new Instances(data, 0, trainSize);
    Instances testData = new Instances(data, trainSize, testSize);

    int lSasiadow = Integer.parseInt(textFieldKnn.getText());
    System.out.println(lSasiadow);

    IBk ibk = new IBk(lSasiadow);

    // Ustawienie odleglosci
    EuclideanDistance euclidean = new EuclideanDistance(); // euklidesowej
    ManhattanDistance manhatan = new ManhattanDistance(); // miejska  

    LinearNNSearch linearNN = new LinearNNSearch();

    if (comboboxOdleglosc.getSelectionModel().getSelectedItem().equals("Manhatan")) {
        linearNN.setDistanceFunction(manhatan);
    } else {
        linearNN.setDistanceFunction(euclidean);
    }

    ibk.setNearestNeighbourSearchAlgorithm(linearNN); // ustawienie sposobu szukania sasiadow

    // Tworzenie klasyfikatora
    ibk.buildClassifier(trainData);

    Evaluation eval = new Evaluation(trainData);
    eval.evaluateModel(ibk, testData);
    spr.setVisible(true);
    labelKnn.setVisible(true);
    labelOdleglosc.setVisible(true);
    labelKnn.setText(textFieldKnn.getText());
    labelOdleglosc.setText(comboboxOdleglosc.getSelectionModel().getSelectedItem().toString());
    spr.setText(eval.toSummaryString("Wynik:", true));
}

From source file:jjj.asap.sas.models1.job.BuildCosineModels.java

License:Open Source License

@Override
protected void run() throws Exception {

    // validate args
    if (!Bucket.isBucket("datasets", inputBucket)) {
        throw new FileNotFoundException(inputBucket);
    }//  ww  w  . j  a  v  a2 s.  c  o m
    if (!Bucket.isBucket("models", outputBucket)) {
        throw new FileNotFoundException(outputBucket);
    }

    // init multi-threading
    Job.startService();
    final Queue<Future<Object>> queue = new LinkedList<Future<Object>>();

    // get the input from the bucket
    List<String> names = Bucket.getBucketItems("datasets", this.inputBucket);
    for (String dsn : names) {

        int essaySet = Contest.getEssaySet(dsn);

        int k = -1;
        switch (essaySet) {

        case 3:
            k = 13;
            break;
        case 5:
        case 7:
            k = 55;
            break;
        case 2:
        case 6:
        case 10:
            k = 21;
            break;
        case 1:
        case 4:
        case 8:
        case 9:
            k = 34;
            break;
        }

        if (k == -1) {
            throw new IllegalArgumentException("not k defined for " + essaySet);
        }

        LinearNNSearch search = new LinearNNSearch();
        search.setDistanceFunction(new CosineDistance());
        search.setSkipIdentical(false);

        IBk knn = new IBk();
        knn.setKNN(k);
        knn.setDistanceWeighting(INVERSE);
        knn.setNearestNeighbourSearchAlgorithm(search);

        queue.add(Job.submit(new ModelBuilder(dsn, "KNN-" + k, knn, this.outputBucket)));
    }

    // wait on complete
    Progress progress = new Progress(queue.size(), this.getClass().getSimpleName());
    while (!queue.isEmpty()) {
        try {
            queue.remove().get();
        } catch (Exception e) {
            Job.log("ERROR", e.toString());
            e.printStackTrace(System.err);
        }
        progress.tick();
    }
    progress.done();
    Job.stopService();

}