Example usage for weka.core Instances testCV

List of usage examples for weka.core Instances testCV

Introduction

In this page you can find the example usage for weka.core Instances testCV.

Prototype



public Instances testCV(int numFolds, int numFold) 

Source Link

Document

Creates the test set for one fold of a cross-validation on the dataset.

Usage

From source file:au.edu.usyd.it.yangpy.sampling.BPSO.java

License:Open Source License

/**
 * this method starts the under sampling procedure
 *///from  ww w. j  av a2s. c o m
public void underSampling() {
    // create a copy of original data set for cross validation
    Instances randData = new Instances(dataset);

    // dividing the data set to 3 folds
    randData.stratify(3);

    for (int fold = 0; fold < 3; fold++) {
        // using the first 2 folds as internal training set. the last fold as the internal test set.
        internalTrain = randData.trainCV(3, fold);
        internalTest = randData.testCV(3, fold);

        // calculate the number of the major class samples in the internal training set
        majorSize = 0;

        for (int i = 0; i < internalTrain.numInstances(); i++) {
            if (internalTrain.instance(i).classValue() == majorLabel) {
                majorSize++;
            }
        }

        // class variable initialization
        dec = new DecimalFormat("##.####");
        localBest = new double[popSize];
        localBestParticles = new int[popSize][majorSize];
        globalBest = Double.MIN_VALUE;
        globalBestParticle = new int[majorSize];
        velocity = new double[popSize][majorSize];
        particles = new int[popSize][majorSize];
        searchSpace = new double[popSize][majorSize];

        System.out.println("-------------------- parameters ----------------------");
        System.out.println("CV fold = " + fold);
        System.out.println("inertia weight = " + w);
        System.out.println("c1,c2 = " + c1);
        System.out.println("iteration time = " + iteration);
        System.out.println("population size = " + popSize);

        // initialize BPSO
        initialization();

        // perform optimization process
        findMaxFit();

        // save optimization results to array list
        saveResults();
    }

    // rank the selected samples and build the balanced dataset
    try {
        createBalanceData();
    } catch (IOException ioe) {
        ioe.printStackTrace();
    }

}

From source file:au.edu.usyd.it.yangpy.snp.ParallelGenetic.java

License:Open Source License

public void crossValidate() {
    // create a copy of original training set for CV
    Instances randData = new Instances(data);

    // divide the data set with x-fold stratify measure
    randData.stratify(foldSize);//from   w  w w .ja v a 2s .c om

    try {

        cvTrain = randData.trainCV(foldSize, foldIndex);
        cvTest = randData.testCV(foldSize, foldIndex);

        foldIndex++;

        if (foldIndex >= foldSize) {
            foldIndex = 0;
        }

    } catch (Exception e) {
        System.out.println(cvTest.toString());
    }
}

From source file:br.ufrn.ia.core.clustering.EMIaProject.java

License:Open Source License

private void CVClusters() throws Exception {
    double CVLogLikely = -Double.MAX_VALUE;
    double templl, tll;
    boolean CVincreased = true;
    m_num_clusters = 1;/*from   w ww  .  j a  v a  2 s.com*/
    int num_clusters = m_num_clusters;
    int i;
    Random cvr;
    Instances trainCopy;
    int numFolds = (m_theInstances.numInstances() < 10) ? m_theInstances.numInstances() : 10;

    boolean ok = true;
    int seed = getSeed();
    int restartCount = 0;
    CLUSTER_SEARCH: while (CVincreased) {
        // theInstances.stratify(10);

        CVincreased = false;
        cvr = new Random(getSeed());
        trainCopy = new Instances(m_theInstances);
        trainCopy.randomize(cvr);
        templl = 0.0;
        for (i = 0; i < numFolds; i++) {
            Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr);
            if (num_clusters > cvTrain.numInstances()) {
                break CLUSTER_SEARCH;
            }
            Instances cvTest = trainCopy.testCV(numFolds, i);
            m_rr = new Random(seed);
            for (int z = 0; z < 10; z++)
                m_rr.nextDouble();
            m_num_clusters = num_clusters;
            EM_Init(cvTrain);
            try {
                iterate(cvTrain, false);
            } catch (Exception ex) {
                // catch any problems - i.e. empty clusters occuring
                ex.printStackTrace();
                // System.err.println("Restarting after CV training failure
                // ("+num_clusters+" clusters");
                seed++;
                restartCount++;
                ok = false;
                if (restartCount > 5) {
                    break CLUSTER_SEARCH;
                }
                break;
            }
            try {
                tll = E(cvTest, false);
            } catch (Exception ex) {
                // catch any problems - i.e. empty clusters occuring
                // ex.printStackTrace();
                ex.printStackTrace();
                // System.err.println("Restarting after CV testing failure
                // ("+num_clusters+" clusters");
                // throw new Exception(ex);
                seed++;
                restartCount++;
                ok = false;
                if (restartCount > 5) {
                    break CLUSTER_SEARCH;
                }
                break;
            }

            if (m_verbose) {
                System.out.println("# clust: " + num_clusters + " Fold: " + i + " Loglikely: " + tll);
            }
            templl += tll;
        }

        if (ok) {
            restartCount = 0;
            seed = getSeed();
            templl /= (double) numFolds;

            if (m_verbose) {
                System.out.println("===================================" + "==============\n# clust: "
                        + num_clusters + " Mean Loglikely: " + templl + "\n================================"
                        + "=================");
            }

            if (templl > CVLogLikely) {
                CVLogLikely = templl;
                CVincreased = true;
                num_clusters++;
            }
        }
    }

    if (m_verbose) {
        System.out.println("Number of clusters: " + (num_clusters - 1));
    }

    m_num_clusters = num_clusters - 1;
}

From source file:br.unicamp.ic.recod.gpsi.gp.gpsiJGAPRoiFitnessFunction.java

@Override
protected double evaluate(IGPProgram igpp) {

    double mean_accuracy = 0.0;
    Object[] noargs = new Object[0];

    gpsiRoiBandCombiner roiBandCombinator = new gpsiRoiBandCombiner(new gpsiJGAPVoxelCombiner(super.b, igpp));
    // TODO: The ROI descriptors must combine the images first
    //roiBandCombinator.combineEntity(this.dataset.getTrainingEntities());

    gpsiMLDataset mlDataset = new gpsiMLDataset(this.descriptor);
    try {/*w  ww . j  a va2  s  .c  o  m*/
        mlDataset.loadWholeDataset(this.dataset, true);
    } catch (Exception ex) {
        Logger.getLogger(gpsiJGAPRoiFitnessFunction.class.getName()).log(Level.SEVERE, null, ex);
    }

    int dimensionality = mlDataset.getDimensionality();
    int n_classes = mlDataset.getTrainingEntities().keySet().size();
    int n_entities = mlDataset.getNumberOfTrainingEntities();
    ArrayList<Byte> listOfClasses = new ArrayList<>(mlDataset.getTrainingEntities().keySet());

    Attribute[] attributes = new Attribute[dimensionality];
    FastVector fvClassVal = new FastVector(n_classes);

    int i, j;
    for (i = 0; i < dimensionality; i++)
        attributes[i] = new Attribute("f" + Integer.toString(i));
    for (i = 0; i < n_classes; i++)
        fvClassVal.addElement(Integer.toString(listOfClasses.get(i)));

    Attribute classes = new Attribute("class", fvClassVal);

    FastVector fvWekaAttributes = new FastVector(dimensionality + 1);

    for (i = 0; i < dimensionality; i++)
        fvWekaAttributes.addElement(attributes[i]);
    fvWekaAttributes.addElement(classes);

    Instances instances = new Instances("Rel", fvWekaAttributes, n_entities);
    instances.setClassIndex(dimensionality);

    Instance iExample;
    for (byte label : mlDataset.getTrainingEntities().keySet()) {
        for (double[] featureVector : mlDataset.getTrainingEntities().get(label)) {
            iExample = new Instance(dimensionality + 1);
            for (j = 0; j < dimensionality; j++)
                iExample.setValue(i, featureVector[i]);
            iExample.setValue(dimensionality, label);
            instances.add(iExample);
        }
    }

    int folds = 5;
    Random rand = new Random();
    Instances randData = new Instances(instances);
    randData.randomize(rand);

    Instances trainingSet, testingSet;
    Classifier cModel;
    Evaluation eTest;
    try {
        for (i = 0; i < folds; i++) {
            cModel = (Classifier) new SimpleLogistic();
            trainingSet = randData.trainCV(folds, i);
            testingSet = randData.testCV(folds, i);

            cModel.buildClassifier(trainingSet);

            eTest = new Evaluation(trainingSet);
            eTest.evaluateModel(cModel, testingSet);

            mean_accuracy += eTest.pctCorrect();

        }
    } catch (Exception ex) {
        Logger.getLogger(gpsiJGAPRoiFitnessFunction.class.getName()).log(Level.SEVERE, null, ex);
    }

    mean_accuracy /= (folds * 100);

    return mean_accuracy;

}

From source file:cezeri.evaluater.FactoryEvaluation.java

public static Evaluation performCrossValidate(Classifier model, Instances datax, int folds, boolean show_text,
        boolean show_plot, TFigureAttribute attr) {
    Random rand = new Random(1);
    Instances randData = new Instances(datax);
    randData.randomize(rand);/*ww  w  .j  a v  a  2s  .  c o m*/
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }
    Evaluation eval = null;
    try {
        // perform cross-validation
        eval = new Evaluation(randData);
        //            double[] simulated = new double[0];
        //            double[] observed = new double[0];
        //            double[] sim = new double[0];
        //            double[] obs = new double[0];
        for (int n = 0; n < folds; n++) {
            Instances train = randData.trainCV(folds, n, rand);
            Instances validation = randData.testCV(folds, n);
            // build and evaluate classifier
            Classifier clsCopy = Classifier.makeCopy(model);
            clsCopy.buildClassifier(train);

            //                sim = eval.evaluateModel(clsCopy, validation);
            //                obs = validation.attributeToDoubleArray(validation.classIndex());
            //                if (show_plot) {
            //                    double[][] d = new double[2][sim.length];
            //                    d[0] = obs;
            //                    d[1] = sim;
            //                    CMatrix f1 = CMatrix.getInstance(d);
            //                    f1.transpose().plot(attr);
            //                }
            //                if (show_text) {
            //                    // output evaluation
            //                    System.out.println();
            //                    System.out.println("=== Setup for each Cross Validation fold===");
            //                    System.out.println("Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
            //                    System.out.println("Dataset: " + randData.relationName());
            //                    System.out.println("Folds: " + folds);
            //                    System.out.println("Seed: " + 1);
            //                    System.out.println();
            //                    System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
            //                }
            simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation));
            observed = FactoryUtils.concatenate(observed,
                    validation.attributeToDoubleArray(validation.classIndex()));
            //                simulated = FactoryUtils.mean(simulated,eval.evaluateModel(clsCopy, validation));
            //                observed = FactoryUtils.mean(observed,validation.attributeToDoubleArray(validation.classIndex()));
        }

        if (show_plot) {
            double[][] d = new double[2][simulated.length];
            d[0] = observed;
            d[1] = simulated;
            CMatrix f1 = CMatrix.getInstance(d);
            attr.figureCaption = "overall performance";
            f1.transpose().plot(attr);
        }
        if (show_text) {
            // output evaluation
            System.out.println();
            System.out.println("=== Setup for Overall Cross Validation===");
            System.out.println(
                    "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
            System.out.println("Dataset: " + randData.relationName());
            System.out.println("Folds: " + folds);
            System.out.println("Seed: " + 1);
            System.out.println();
            System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
        }
    } catch (Exception ex) {
        Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex);
    }
    return eval;
}

From source file:cezeri.evaluater.FactoryEvaluation.java

public static Evaluation performCrossValidateTestAlso(Classifier model, Instances datax, Instances test,
        boolean show_text, boolean show_plot) {
    TFigureAttribute attr = new TFigureAttribute();
    Random rand = new Random(1);
    Instances randData = new Instances(datax);
    randData.randomize(rand);//from   w  ww  .j  a v  a  2  s.  co m

    Evaluation eval = null;
    int folds = randData.numInstances();
    try {
        eval = new Evaluation(randData);
        for (int n = 0; n < folds; n++) {
            //                randData.randomize(rand);
            //                Instances train = randData;                
            Instances train = randData.trainCV(folds, n);
            //                Instances train = randData.trainCV(folds, n, rand);
            Classifier clsCopy = Classifier.makeCopy(model);
            clsCopy.buildClassifier(train);
            Instances validation = randData.testCV(folds, n);
            //                Instances validation = test.testCV(test.numInstances(), n%test.numInstances());
            //                CMatrix.fromInstances(train).showDataGrid();
            //                CMatrix.fromInstances(validation).showDataGrid();

            simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation));
            observed = FactoryUtils.concatenate(observed,
                    validation.attributeToDoubleArray(validation.classIndex()));
        }

        if (show_plot) {
            double[][] d = new double[2][simulated.length];
            d[0] = observed;
            d[1] = simulated;
            CMatrix f1 = CMatrix.getInstance(d);
            attr.figureCaption = "overall performance";
            f1.transpose().plot(attr);
        }
        if (show_text) {
            // output evaluation
            System.out.println();
            System.out.println("=== Setup for Overall Cross Validation===");
            System.out.println(
                    "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
            System.out.println("Dataset: " + randData.relationName());
            System.out.println("Folds: " + folds);
            System.out.println("Seed: " + 1);
            System.out.println();
            System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
        }
    } catch (Exception ex) {
        Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex);
    }
    return eval;
}

From source file:cezeri.feature.selection.FeatureSelectionInfluence.java

public static Evaluation getEvaluation(Instances randData, Classifier model, int folds) {
    Evaluation eval = null;/*  w  ww .ja  v  a  2 s. c o m*/
    try {
        eval = new Evaluation(randData);
        for (int n = 0; n < folds; n++) {
            Instances train = randData.trainCV(folds, n);
            Instances test = randData.testCV(folds, n);
            // build and evaluate classifier
            Classifier clsCopy = Classifier.makeCopy(model);
            clsCopy.buildClassifier(train);
            eval.evaluateModel(clsCopy, test);
            //                double[] prediction = eval.evaluateModel(clsCopy, test);
            //                double[] original = getAttributeValues(test);
            //                double[][] d = new double[2][prediction.length];
            //                d[0] = prediction;
            //                d[1] = original;
            //                CMatrix f1 = new CMatrix(d);
        }

        // output evaluation
        System.out.println();
        System.out.println("=== Setup ===");
        System.out.println(
                "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
        System.out.println("Dataset: " + randData.relationName());
        System.out.println("Folds: " + folds);
        System.out.println();
        System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
        System.out.println(eval.toClassDetailsString("=== Detailed Accuracy By Class ==="));
        System.out.println(eval.toMatrixString("Confusion Matrix"));

        double acc = eval.correct() / eval.numInstances() * 100;
        System.out.println("correct:" + eval.correct() + "  " + acc + "%");
    } catch (Exception ex) {

        Logger.getLogger(FeatureSelectionInfluence.class.getName()).log(Level.SEVERE, null, ex);
    }
    return eval;
}

From source file:com.mycompany.id3classifier.ID3Shell.java

public static void main(String[] args) throws Exception {
    ConverterUtils.DataSource source = new ConverterUtils.DataSource("lensesData.csv");
    Instances dataSet = source.getDataSet();

    Discretize filter = new Discretize();
    filter.setInputFormat(dataSet);/*from w  ww  .j a v a 2 s. c  om*/
    dataSet = Filter.useFilter(dataSet, filter);

    Standardize standardize = new Standardize();
    standardize.setInputFormat(dataSet);
    dataSet = Filter.useFilter(dataSet, standardize);

    dataSet.setClassIndex(dataSet.numAttributes() - 1);
    dataSet.randomize(new Random(9001)); //It's over 9000!!

    int folds = 10;
    //Perform crossvalidation
    Evaluation eval = new Evaluation(dataSet);
    for (int n = 0; n < folds; n++) {
        int trainingSize = (int) Math.round(dataSet.numInstances() * .7);
        int testSize = dataSet.numInstances() - trainingSize;

        Instances trainingData = dataSet.trainCV(folds, n);
        Instances testData = dataSet.testCV(folds, n);

        ID3Classifier classifier = new ID3Classifier();
        // Id3 classifier = new Id3();
        classifier.buildClassifier(trainingData);

        eval.evaluateModel(classifier, testData);
    }
    System.out.println(eval.toSummaryString("\nResults:\n", false));
}

From source file:com.reactivetechnologies.analytics.core.eval.StackingWithBuiltClassifiers.java

License:Open Source License

/**
 * Generates the meta data/*from  ww w . j a va2s. com*/
 *
 * @param newData the data to work on
 * @param random the random number generator to use for cross-validation
 * @throws Exception if generation fails
 */
@Override
protected void generateMetaLevel(Instances newData, Random random) throws Exception {

    Instances metaData = metaFormat(newData);
    m_MetaFormat = new Instances(metaData, 0);
    for (int j = 0; j < m_NumFolds; j++) {

        /** Changed here */
        //Instances train = newData.trainCV(m_NumFolds, j, random);
        // DO NOT Build base classifiers
        /*for (int i = 0; i < m_Classifiers.length; i++) {
            getClassifier(i).buildClassifier(train);
        }*/
        /** End change */

        // Classify test instances and add to meta data
        Instances test = newData.testCV(m_NumFolds, j);
        for (int i = 0; i < test.numInstances(); i++) {
            metaData.add(metaInstance(test.instance(i)));
        }
    }

    m_MetaClassifier.buildClassifier(metaData);
}

From source file:Control.Classificador.java

public ArrayList<Resultado> classificar(Plano plano, Arquivo arq) {
    try {// ww  w.  j a v  a 2 s .c om
        FileReader leitor = new FileReader(arq.arquivo);
        Instances conjunto = new Instances(leitor);
        conjunto.setClassIndex(conjunto.numAttributes() - 1);
        Evaluation avaliacao = new Evaluation(conjunto);
        conjunto = conjunto.resample(new Random());

        Instances baseTreino = null, baseTeste = null;
        Random rand = new Random(1);

        if (plano.eHoldOut) {
            baseTeste = conjunto.testCV(3, 0);
            baseTreino = conjunto.trainCV(3, 0);
        } else {
            baseTeste = baseTreino = conjunto;
        }

        if (plano.IBK) {
            try {
                IB1 vizinho = new IB1();
                vizinho.buildClassifier(baseTeste);
                avaliacao.crossValidateModel(vizinho, baseTeste,
                        (plano.eHoldOut) ? 4 : baseTeste.numInstances(), rand);
                Resultado resultado = new Resultado("NN",
                        avaliacao.toMatrixString("Algortmo Vizinho Mais Prximo - Matriz de Confuso"),
                        avaliacao.toClassDetailsString("kNN"));
                resultado.setTaxaErro(avaliacao.errorRate());
                resultado.setTaxaAcerto(1 - avaliacao.errorRate());
                resultado.setRevocacao(recallToDouble(avaliacao, baseTeste));
                resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste));
                this.resultados.add(resultado);
            } catch (UnsupportedAttributeTypeException ex) {
                Mensagem.erro("Algortmo IB1 no suporta atributos numricos!", "MTCS - ERRO");
            }
        }
        if (plano.J48) {
            try {
                J48 j48 = new J48();
                j48.buildClassifier(baseTeste);
                avaliacao.crossValidateModel(j48, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(),
                        rand);
                Resultado resultado = new Resultado("J48",
                        avaliacao.toMatrixString("Algortmo J48 - Matriz de Confuso"),
                        avaliacao.toClassDetailsString("J48"));
                resultado.setTaxaErro(avaliacao.errorRate());
                resultado.setTaxaAcerto(1 - avaliacao.errorRate());
                resultado.setRevocacao(recallToDouble(avaliacao, baseTeste));
                resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste));
                this.resultados.add(resultado);
            } catch (UnsupportedAttributeTypeException ex) {
                Mensagem.erro("Algortmo J48 no suporta atributos nominais!", "MTCS - ERRO");
            }
        }
        if (plano.KNN) {
            try {
                IBk knn = new IBk(3);
                knn.buildClassifier(baseTeste);
                avaliacao.crossValidateModel(knn, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(),
                        rand);
                Resultado resultado = new Resultado("KNN",
                        avaliacao.toMatrixString("Algortmo KNN - Matriz de Confuso"),
                        avaliacao.toClassDetailsString("kNN"));
                resultado.setTaxaErro(avaliacao.errorRate());
                resultado.setTaxaAcerto(1 - avaliacao.errorRate());
                resultado.setRevocacao(recallToDouble(avaliacao, baseTeste));
                resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste));
                this.resultados.add(resultado);
            } catch (UnsupportedAttributeTypeException ex) {
                Mensagem.erro("Algortmo KNN no suporta atributos numricos!", "MTCS - ERRO");
            }

        }
        if (plano.Naive) {
            NaiveBayes naive = new NaiveBayes();
            naive.buildClassifier(baseTeste);
            avaliacao.crossValidateModel(naive, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(),
                    rand);
            Resultado resultado = new Resultado("Naive",
                    avaliacao.toMatrixString("Algortmo NaiveBayes - Matriz de Confuso"),
                    avaliacao.toClassDetailsString("kNN"));
            resultado.setTaxaErro(avaliacao.errorRate());
            resultado.setTaxaAcerto(1 - avaliacao.errorRate());
            resultado.setRevocacao(recallToDouble(avaliacao, baseTeste));
            resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste));
            this.resultados.add(resultado);
        }
        if (plano.Tree) {
            try {
                Id3 id3 = new Id3();
                id3.buildClassifier(baseTeste);
                avaliacao.crossValidateModel(id3, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(),
                        rand);
                Resultado resultado = new Resultado("ID3",
                        avaliacao.toMatrixString("Algortmo ID3 - Matriz de Confuso"),
                        avaliacao.toClassDetailsString("kNN"));
                resultado.setTaxaErro(avaliacao.errorRate());
                resultado.setTaxaAcerto(1 - avaliacao.errorRate());
                resultado.setRevocacao(recallToDouble(avaliacao, baseTeste));
                resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste));
                this.resultados.add(resultado);
            } catch (UnsupportedAttributeTypeException ex) {
                Mensagem.erro("Algortmo Arvore de Deciso no suporta atributos numricos!",
                        "MTCS - ERRO");

            }
        }

    } catch (FileNotFoundException ex) {

        Logger.getLogger(Classificador.class.getName()).log(Level.SEVERE, null, ex);
    } catch (IOException ex) {
        Logger.getLogger(Classificador.class.getName()).log(Level.SEVERE, null, ex);
    } catch (NullPointerException ex) {
        Mensagem.erro("Selecione um arquivo para comear!", "MTCS - ERRO");
        Logger.getLogger(Classificador.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(Classificador.class.getName()).log(Level.SEVERE, null, ex);
    }

    return this.resultados;
}