List of usage examples for weka.core Instances testCV
public Instances testCV(int numFolds, int numFold)
From source file:au.edu.usyd.it.yangpy.sampling.BPSO.java
License:Open Source License
/** * this method starts the under sampling procedure *///from ww w. j av a2s. c o m public void underSampling() { // create a copy of original data set for cross validation Instances randData = new Instances(dataset); // dividing the data set to 3 folds randData.stratify(3); for (int fold = 0; fold < 3; fold++) { // using the first 2 folds as internal training set. the last fold as the internal test set. internalTrain = randData.trainCV(3, fold); internalTest = randData.testCV(3, fold); // calculate the number of the major class samples in the internal training set majorSize = 0; for (int i = 0; i < internalTrain.numInstances(); i++) { if (internalTrain.instance(i).classValue() == majorLabel) { majorSize++; } } // class variable initialization dec = new DecimalFormat("##.####"); localBest = new double[popSize]; localBestParticles = new int[popSize][majorSize]; globalBest = Double.MIN_VALUE; globalBestParticle = new int[majorSize]; velocity = new double[popSize][majorSize]; particles = new int[popSize][majorSize]; searchSpace = new double[popSize][majorSize]; System.out.println("-------------------- parameters ----------------------"); System.out.println("CV fold = " + fold); System.out.println("inertia weight = " + w); System.out.println("c1,c2 = " + c1); System.out.println("iteration time = " + iteration); System.out.println("population size = " + popSize); // initialize BPSO initialization(); // perform optimization process findMaxFit(); // save optimization results to array list saveResults(); } // rank the selected samples and build the balanced dataset try { createBalanceData(); } catch (IOException ioe) { ioe.printStackTrace(); } }
From source file:au.edu.usyd.it.yangpy.snp.ParallelGenetic.java
License:Open Source License
public void crossValidate() { // create a copy of original training set for CV Instances randData = new Instances(data); // divide the data set with x-fold stratify measure randData.stratify(foldSize);//from w w w .ja v a 2s .c om try { cvTrain = randData.trainCV(foldSize, foldIndex); cvTest = randData.testCV(foldSize, foldIndex); foldIndex++; if (foldIndex >= foldSize) { foldIndex = 0; } } catch (Exception e) { System.out.println(cvTest.toString()); } }
From source file:br.ufrn.ia.core.clustering.EMIaProject.java
License:Open Source License
private void CVClusters() throws Exception { double CVLogLikely = -Double.MAX_VALUE; double templl, tll; boolean CVincreased = true; m_num_clusters = 1;/*from w ww . j a v a 2 s.com*/ int num_clusters = m_num_clusters; int i; Random cvr; Instances trainCopy; int numFolds = (m_theInstances.numInstances() < 10) ? m_theInstances.numInstances() : 10; boolean ok = true; int seed = getSeed(); int restartCount = 0; CLUSTER_SEARCH: while (CVincreased) { // theInstances.stratify(10); CVincreased = false; cvr = new Random(getSeed()); trainCopy = new Instances(m_theInstances); trainCopy.randomize(cvr); templl = 0.0; for (i = 0; i < numFolds; i++) { Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr); if (num_clusters > cvTrain.numInstances()) { break CLUSTER_SEARCH; } Instances cvTest = trainCopy.testCV(numFolds, i); m_rr = new Random(seed); for (int z = 0; z < 10; z++) m_rr.nextDouble(); m_num_clusters = num_clusters; EM_Init(cvTrain); try { iterate(cvTrain, false); } catch (Exception ex) { // catch any problems - i.e. empty clusters occuring ex.printStackTrace(); // System.err.println("Restarting after CV training failure // ("+num_clusters+" clusters"); seed++; restartCount++; ok = false; if (restartCount > 5) { break CLUSTER_SEARCH; } break; } try { tll = E(cvTest, false); } catch (Exception ex) { // catch any problems - i.e. empty clusters occuring // ex.printStackTrace(); ex.printStackTrace(); // System.err.println("Restarting after CV testing failure // ("+num_clusters+" clusters"); // throw new Exception(ex); seed++; restartCount++; ok = false; if (restartCount > 5) { break CLUSTER_SEARCH; } break; } if (m_verbose) { System.out.println("# clust: " + num_clusters + " Fold: " + i + " Loglikely: " + tll); } templl += tll; } if (ok) { restartCount = 0; seed = getSeed(); templl /= (double) numFolds; if (m_verbose) { System.out.println("===================================" + "==============\n# clust: " + num_clusters + " Mean Loglikely: " + templl + "\n================================" + "================="); } if (templl > CVLogLikely) { CVLogLikely = templl; CVincreased = true; num_clusters++; } } } if (m_verbose) { System.out.println("Number of clusters: " + (num_clusters - 1)); } m_num_clusters = num_clusters - 1; }
From source file:br.unicamp.ic.recod.gpsi.gp.gpsiJGAPRoiFitnessFunction.java
@Override protected double evaluate(IGPProgram igpp) { double mean_accuracy = 0.0; Object[] noargs = new Object[0]; gpsiRoiBandCombiner roiBandCombinator = new gpsiRoiBandCombiner(new gpsiJGAPVoxelCombiner(super.b, igpp)); // TODO: The ROI descriptors must combine the images first //roiBandCombinator.combineEntity(this.dataset.getTrainingEntities()); gpsiMLDataset mlDataset = new gpsiMLDataset(this.descriptor); try {/*w ww . j a va2 s .c o m*/ mlDataset.loadWholeDataset(this.dataset, true); } catch (Exception ex) { Logger.getLogger(gpsiJGAPRoiFitnessFunction.class.getName()).log(Level.SEVERE, null, ex); } int dimensionality = mlDataset.getDimensionality(); int n_classes = mlDataset.getTrainingEntities().keySet().size(); int n_entities = mlDataset.getNumberOfTrainingEntities(); ArrayList<Byte> listOfClasses = new ArrayList<>(mlDataset.getTrainingEntities().keySet()); Attribute[] attributes = new Attribute[dimensionality]; FastVector fvClassVal = new FastVector(n_classes); int i, j; for (i = 0; i < dimensionality; i++) attributes[i] = new Attribute("f" + Integer.toString(i)); for (i = 0; i < n_classes; i++) fvClassVal.addElement(Integer.toString(listOfClasses.get(i))); Attribute classes = new Attribute("class", fvClassVal); FastVector fvWekaAttributes = new FastVector(dimensionality + 1); for (i = 0; i < dimensionality; i++) fvWekaAttributes.addElement(attributes[i]); fvWekaAttributes.addElement(classes); Instances instances = new Instances("Rel", fvWekaAttributes, n_entities); instances.setClassIndex(dimensionality); Instance iExample; for (byte label : mlDataset.getTrainingEntities().keySet()) { for (double[] featureVector : mlDataset.getTrainingEntities().get(label)) { iExample = new Instance(dimensionality + 1); for (j = 0; j < dimensionality; j++) iExample.setValue(i, featureVector[i]); iExample.setValue(dimensionality, label); instances.add(iExample); } } int folds = 5; Random rand = new Random(); Instances randData = new Instances(instances); randData.randomize(rand); Instances trainingSet, testingSet; Classifier cModel; Evaluation eTest; try { for (i = 0; i < folds; i++) { cModel = (Classifier) new SimpleLogistic(); trainingSet = randData.trainCV(folds, i); testingSet = randData.testCV(folds, i); cModel.buildClassifier(trainingSet); eTest = new Evaluation(trainingSet); eTest.evaluateModel(cModel, testingSet); mean_accuracy += eTest.pctCorrect(); } } catch (Exception ex) { Logger.getLogger(gpsiJGAPRoiFitnessFunction.class.getName()).log(Level.SEVERE, null, ex); } mean_accuracy /= (folds * 100); return mean_accuracy; }
From source file:cezeri.evaluater.FactoryEvaluation.java
public static Evaluation performCrossValidate(Classifier model, Instances datax, int folds, boolean show_text, boolean show_plot, TFigureAttribute attr) { Random rand = new Random(1); Instances randData = new Instances(datax); randData.randomize(rand);/*ww w .j a v a 2s . c o m*/ if (randData.classAttribute().isNominal()) { randData.stratify(folds); } Evaluation eval = null; try { // perform cross-validation eval = new Evaluation(randData); // double[] simulated = new double[0]; // double[] observed = new double[0]; // double[] sim = new double[0]; // double[] obs = new double[0]; for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n, rand); Instances validation = randData.testCV(folds, n); // build and evaluate classifier Classifier clsCopy = Classifier.makeCopy(model); clsCopy.buildClassifier(train); // sim = eval.evaluateModel(clsCopy, validation); // obs = validation.attributeToDoubleArray(validation.classIndex()); // if (show_plot) { // double[][] d = new double[2][sim.length]; // d[0] = obs; // d[1] = sim; // CMatrix f1 = CMatrix.getInstance(d); // f1.transpose().plot(attr); // } // if (show_text) { // // output evaluation // System.out.println(); // System.out.println("=== Setup for each Cross Validation fold==="); // System.out.println("Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); // System.out.println("Dataset: " + randData.relationName()); // System.out.println("Folds: " + folds); // System.out.println("Seed: " + 1); // System.out.println(); // System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); // } simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation)); observed = FactoryUtils.concatenate(observed, validation.attributeToDoubleArray(validation.classIndex())); // simulated = FactoryUtils.mean(simulated,eval.evaluateModel(clsCopy, validation)); // observed = FactoryUtils.mean(observed,validation.attributeToDoubleArray(validation.classIndex())); } if (show_plot) { double[][] d = new double[2][simulated.length]; d[0] = observed; d[1] = simulated; CMatrix f1 = CMatrix.getInstance(d); attr.figureCaption = "overall performance"; f1.transpose().plot(attr); } if (show_text) { // output evaluation System.out.println(); System.out.println("=== Setup for Overall Cross Validation==="); System.out.println( "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); System.out.println("Dataset: " + randData.relationName()); System.out.println("Folds: " + folds); System.out.println("Seed: " + 1); System.out.println(); System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); } } catch (Exception ex) { Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex); } return eval; }
From source file:cezeri.evaluater.FactoryEvaluation.java
public static Evaluation performCrossValidateTestAlso(Classifier model, Instances datax, Instances test, boolean show_text, boolean show_plot) { TFigureAttribute attr = new TFigureAttribute(); Random rand = new Random(1); Instances randData = new Instances(datax); randData.randomize(rand);//from w ww .j a v a 2 s. co m Evaluation eval = null; int folds = randData.numInstances(); try { eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { // randData.randomize(rand); // Instances train = randData; Instances train = randData.trainCV(folds, n); // Instances train = randData.trainCV(folds, n, rand); Classifier clsCopy = Classifier.makeCopy(model); clsCopy.buildClassifier(train); Instances validation = randData.testCV(folds, n); // Instances validation = test.testCV(test.numInstances(), n%test.numInstances()); // CMatrix.fromInstances(train).showDataGrid(); // CMatrix.fromInstances(validation).showDataGrid(); simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation)); observed = FactoryUtils.concatenate(observed, validation.attributeToDoubleArray(validation.classIndex())); } if (show_plot) { double[][] d = new double[2][simulated.length]; d[0] = observed; d[1] = simulated; CMatrix f1 = CMatrix.getInstance(d); attr.figureCaption = "overall performance"; f1.transpose().plot(attr); } if (show_text) { // output evaluation System.out.println(); System.out.println("=== Setup for Overall Cross Validation==="); System.out.println( "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); System.out.println("Dataset: " + randData.relationName()); System.out.println("Folds: " + folds); System.out.println("Seed: " + 1); System.out.println(); System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); } } catch (Exception ex) { Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex); } return eval; }
From source file:cezeri.feature.selection.FeatureSelectionInfluence.java
public static Evaluation getEvaluation(Instances randData, Classifier model, int folds) { Evaluation eval = null;/* w ww .ja v a 2 s. c o m*/ try { eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // build and evaluate classifier Classifier clsCopy = Classifier.makeCopy(model); clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); // double[] prediction = eval.evaluateModel(clsCopy, test); // double[] original = getAttributeValues(test); // double[][] d = new double[2][prediction.length]; // d[0] = prediction; // d[1] = original; // CMatrix f1 = new CMatrix(d); } // output evaluation System.out.println(); System.out.println("=== Setup ==="); System.out.println( "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); System.out.println("Dataset: " + randData.relationName()); System.out.println("Folds: " + folds); System.out.println(); System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); System.out.println(eval.toClassDetailsString("=== Detailed Accuracy By Class ===")); System.out.println(eval.toMatrixString("Confusion Matrix")); double acc = eval.correct() / eval.numInstances() * 100; System.out.println("correct:" + eval.correct() + " " + acc + "%"); } catch (Exception ex) { Logger.getLogger(FeatureSelectionInfluence.class.getName()).log(Level.SEVERE, null, ex); } return eval; }
From source file:com.mycompany.id3classifier.ID3Shell.java
public static void main(String[] args) throws Exception { ConverterUtils.DataSource source = new ConverterUtils.DataSource("lensesData.csv"); Instances dataSet = source.getDataSet(); Discretize filter = new Discretize(); filter.setInputFormat(dataSet);/*from w ww .j a v a 2 s. c om*/ dataSet = Filter.useFilter(dataSet, filter); Standardize standardize = new Standardize(); standardize.setInputFormat(dataSet); dataSet = Filter.useFilter(dataSet, standardize); dataSet.setClassIndex(dataSet.numAttributes() - 1); dataSet.randomize(new Random(9001)); //It's over 9000!! int folds = 10; //Perform crossvalidation Evaluation eval = new Evaluation(dataSet); for (int n = 0; n < folds; n++) { int trainingSize = (int) Math.round(dataSet.numInstances() * .7); int testSize = dataSet.numInstances() - trainingSize; Instances trainingData = dataSet.trainCV(folds, n); Instances testData = dataSet.testCV(folds, n); ID3Classifier classifier = new ID3Classifier(); // Id3 classifier = new Id3(); classifier.buildClassifier(trainingData); eval.evaluateModel(classifier, testData); } System.out.println(eval.toSummaryString("\nResults:\n", false)); }
From source file:com.reactivetechnologies.analytics.core.eval.StackingWithBuiltClassifiers.java
License:Open Source License
/** * Generates the meta data/*from ww w . j a va2s. com*/ * * @param newData the data to work on * @param random the random number generator to use for cross-validation * @throws Exception if generation fails */ @Override protected void generateMetaLevel(Instances newData, Random random) throws Exception { Instances metaData = metaFormat(newData); m_MetaFormat = new Instances(metaData, 0); for (int j = 0; j < m_NumFolds; j++) { /** Changed here */ //Instances train = newData.trainCV(m_NumFolds, j, random); // DO NOT Build base classifiers /*for (int i = 0; i < m_Classifiers.length; i++) { getClassifier(i).buildClassifier(train); }*/ /** End change */ // Classify test instances and add to meta data Instances test = newData.testCV(m_NumFolds, j); for (int i = 0; i < test.numInstances(); i++) { metaData.add(metaInstance(test.instance(i))); } } m_MetaClassifier.buildClassifier(metaData); }
From source file:Control.Classificador.java
public ArrayList<Resultado> classificar(Plano plano, Arquivo arq) { try {// ww w. j a v a 2 s .c om FileReader leitor = new FileReader(arq.arquivo); Instances conjunto = new Instances(leitor); conjunto.setClassIndex(conjunto.numAttributes() - 1); Evaluation avaliacao = new Evaluation(conjunto); conjunto = conjunto.resample(new Random()); Instances baseTreino = null, baseTeste = null; Random rand = new Random(1); if (plano.eHoldOut) { baseTeste = conjunto.testCV(3, 0); baseTreino = conjunto.trainCV(3, 0); } else { baseTeste = baseTreino = conjunto; } if (plano.IBK) { try { IB1 vizinho = new IB1(); vizinho.buildClassifier(baseTeste); avaliacao.crossValidateModel(vizinho, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(), rand); Resultado resultado = new Resultado("NN", avaliacao.toMatrixString("Algortmo Vizinho Mais Prximo - Matriz de Confuso"), avaliacao.toClassDetailsString("kNN")); resultado.setTaxaErro(avaliacao.errorRate()); resultado.setTaxaAcerto(1 - avaliacao.errorRate()); resultado.setRevocacao(recallToDouble(avaliacao, baseTeste)); resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste)); this.resultados.add(resultado); } catch (UnsupportedAttributeTypeException ex) { Mensagem.erro("Algortmo IB1 no suporta atributos numricos!", "MTCS - ERRO"); } } if (plano.J48) { try { J48 j48 = new J48(); j48.buildClassifier(baseTeste); avaliacao.crossValidateModel(j48, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(), rand); Resultado resultado = new Resultado("J48", avaliacao.toMatrixString("Algortmo J48 - Matriz de Confuso"), avaliacao.toClassDetailsString("J48")); resultado.setTaxaErro(avaliacao.errorRate()); resultado.setTaxaAcerto(1 - avaliacao.errorRate()); resultado.setRevocacao(recallToDouble(avaliacao, baseTeste)); resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste)); this.resultados.add(resultado); } catch (UnsupportedAttributeTypeException ex) { Mensagem.erro("Algortmo J48 no suporta atributos nominais!", "MTCS - ERRO"); } } if (plano.KNN) { try { IBk knn = new IBk(3); knn.buildClassifier(baseTeste); avaliacao.crossValidateModel(knn, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(), rand); Resultado resultado = new Resultado("KNN", avaliacao.toMatrixString("Algortmo KNN - Matriz de Confuso"), avaliacao.toClassDetailsString("kNN")); resultado.setTaxaErro(avaliacao.errorRate()); resultado.setTaxaAcerto(1 - avaliacao.errorRate()); resultado.setRevocacao(recallToDouble(avaliacao, baseTeste)); resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste)); this.resultados.add(resultado); } catch (UnsupportedAttributeTypeException ex) { Mensagem.erro("Algortmo KNN no suporta atributos numricos!", "MTCS - ERRO"); } } if (plano.Naive) { NaiveBayes naive = new NaiveBayes(); naive.buildClassifier(baseTeste); avaliacao.crossValidateModel(naive, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(), rand); Resultado resultado = new Resultado("Naive", avaliacao.toMatrixString("Algortmo NaiveBayes - Matriz de Confuso"), avaliacao.toClassDetailsString("kNN")); resultado.setTaxaErro(avaliacao.errorRate()); resultado.setTaxaAcerto(1 - avaliacao.errorRate()); resultado.setRevocacao(recallToDouble(avaliacao, baseTeste)); resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste)); this.resultados.add(resultado); } if (plano.Tree) { try { Id3 id3 = new Id3(); id3.buildClassifier(baseTeste); avaliacao.crossValidateModel(id3, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(), rand); Resultado resultado = new Resultado("ID3", avaliacao.toMatrixString("Algortmo ID3 - Matriz de Confuso"), avaliacao.toClassDetailsString("kNN")); resultado.setTaxaErro(avaliacao.errorRate()); resultado.setTaxaAcerto(1 - avaliacao.errorRate()); resultado.setRevocacao(recallToDouble(avaliacao, baseTeste)); resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste)); this.resultados.add(resultado); } catch (UnsupportedAttributeTypeException ex) { Mensagem.erro("Algortmo Arvore de Deciso no suporta atributos numricos!", "MTCS - ERRO"); } } } catch (FileNotFoundException ex) { Logger.getLogger(Classificador.class.getName()).log(Level.SEVERE, null, ex); } catch (IOException ex) { Logger.getLogger(Classificador.class.getName()).log(Level.SEVERE, null, ex); } catch (NullPointerException ex) { Mensagem.erro("Selecione um arquivo para comear!", "MTCS - ERRO"); Logger.getLogger(Classificador.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(Classificador.class.getName()).log(Level.SEVERE, null, ex); } return this.resultados; }