List of usage examples for weka.core Instances trainCV
public Instances trainCV(int numFolds, int numFold)
From source file:cezeri.evaluater.FactoryEvaluation.java
public static Evaluation performCrossValidateTestAlso(Classifier model, Instances datax, Instances test, boolean show_text, boolean show_plot) { TFigureAttribute attr = new TFigureAttribute(); Random rand = new Random(1); Instances randData = new Instances(datax); randData.randomize(rand);/* w w w. j a va2s .c o m*/ Evaluation eval = null; int folds = randData.numInstances(); try { eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { // randData.randomize(rand); // Instances train = randData; Instances train = randData.trainCV(folds, n); // Instances train = randData.trainCV(folds, n, rand); Classifier clsCopy = Classifier.makeCopy(model); clsCopy.buildClassifier(train); Instances validation = randData.testCV(folds, n); // Instances validation = test.testCV(test.numInstances(), n%test.numInstances()); // CMatrix.fromInstances(train).showDataGrid(); // CMatrix.fromInstances(validation).showDataGrid(); simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation)); observed = FactoryUtils.concatenate(observed, validation.attributeToDoubleArray(validation.classIndex())); } if (show_plot) { double[][] d = new double[2][simulated.length]; d[0] = observed; d[1] = simulated; CMatrix f1 = CMatrix.getInstance(d); attr.figureCaption = "overall performance"; f1.transpose().plot(attr); } if (show_text) { // output evaluation System.out.println(); System.out.println("=== Setup for Overall Cross Validation==="); System.out.println( "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); System.out.println("Dataset: " + randData.relationName()); System.out.println("Folds: " + folds); System.out.println("Seed: " + 1); System.out.println(); System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); } } catch (Exception ex) { Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex); } return eval; }
From source file:cezeri.feature.selection.FeatureSelectionInfluence.java
public static Evaluation getEvaluation(Instances randData, Classifier model, int folds) { Evaluation eval = null;/* w w w . j a va2 s . c o m*/ try { eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // build and evaluate classifier Classifier clsCopy = Classifier.makeCopy(model); clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); // double[] prediction = eval.evaluateModel(clsCopy, test); // double[] original = getAttributeValues(test); // double[][] d = new double[2][prediction.length]; // d[0] = prediction; // d[1] = original; // CMatrix f1 = new CMatrix(d); } // output evaluation System.out.println(); System.out.println("=== Setup ==="); System.out.println( "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions())); System.out.println("Dataset: " + randData.relationName()); System.out.println("Folds: " + folds); System.out.println(); System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); System.out.println(eval.toClassDetailsString("=== Detailed Accuracy By Class ===")); System.out.println(eval.toMatrixString("Confusion Matrix")); double acc = eval.correct() / eval.numInstances() * 100; System.out.println("correct:" + eval.correct() + " " + acc + "%"); } catch (Exception ex) { Logger.getLogger(FeatureSelectionInfluence.class.getName()).log(Level.SEVERE, null, ex); } return eval; }
From source file:com.mycompany.id3classifier.ID3Shell.java
public static void main(String[] args) throws Exception { ConverterUtils.DataSource source = new ConverterUtils.DataSource("lensesData.csv"); Instances dataSet = source.getDataSet(); Discretize filter = new Discretize(); filter.setInputFormat(dataSet);//ww w . java 2 s. c om dataSet = Filter.useFilter(dataSet, filter); Standardize standardize = new Standardize(); standardize.setInputFormat(dataSet); dataSet = Filter.useFilter(dataSet, standardize); dataSet.setClassIndex(dataSet.numAttributes() - 1); dataSet.randomize(new Random(9001)); //It's over 9000!! int folds = 10; //Perform crossvalidation Evaluation eval = new Evaluation(dataSet); for (int n = 0; n < folds; n++) { int trainingSize = (int) Math.round(dataSet.numInstances() * .7); int testSize = dataSet.numInstances() - trainingSize; Instances trainingData = dataSet.trainCV(folds, n); Instances testData = dataSet.testCV(folds, n); ID3Classifier classifier = new ID3Classifier(); // Id3 classifier = new Id3(); classifier.buildClassifier(trainingData); eval.evaluateModel(classifier, testData); } System.out.println(eval.toSummaryString("\nResults:\n", false)); }
From source file:Control.Classificador.java
public ArrayList<Resultado> classificar(Plano plano, Arquivo arq) { try {/*from w w w .ja v a 2 s. c o m*/ FileReader leitor = new FileReader(arq.arquivo); Instances conjunto = new Instances(leitor); conjunto.setClassIndex(conjunto.numAttributes() - 1); Evaluation avaliacao = new Evaluation(conjunto); conjunto = conjunto.resample(new Random()); Instances baseTreino = null, baseTeste = null; Random rand = new Random(1); if (plano.eHoldOut) { baseTeste = conjunto.testCV(3, 0); baseTreino = conjunto.trainCV(3, 0); } else { baseTeste = baseTreino = conjunto; } if (plano.IBK) { try { IB1 vizinho = new IB1(); vizinho.buildClassifier(baseTeste); avaliacao.crossValidateModel(vizinho, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(), rand); Resultado resultado = new Resultado("NN", avaliacao.toMatrixString("Algortmo Vizinho Mais Prximo - Matriz de Confuso"), avaliacao.toClassDetailsString("kNN")); resultado.setTaxaErro(avaliacao.errorRate()); resultado.setTaxaAcerto(1 - avaliacao.errorRate()); resultado.setRevocacao(recallToDouble(avaliacao, baseTeste)); resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste)); this.resultados.add(resultado); } catch (UnsupportedAttributeTypeException ex) { Mensagem.erro("Algortmo IB1 no suporta atributos numricos!", "MTCS - ERRO"); } } if (plano.J48) { try { J48 j48 = new J48(); j48.buildClassifier(baseTeste); avaliacao.crossValidateModel(j48, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(), rand); Resultado resultado = new Resultado("J48", avaliacao.toMatrixString("Algortmo J48 - Matriz de Confuso"), avaliacao.toClassDetailsString("J48")); resultado.setTaxaErro(avaliacao.errorRate()); resultado.setTaxaAcerto(1 - avaliacao.errorRate()); resultado.setRevocacao(recallToDouble(avaliacao, baseTeste)); resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste)); this.resultados.add(resultado); } catch (UnsupportedAttributeTypeException ex) { Mensagem.erro("Algortmo J48 no suporta atributos nominais!", "MTCS - ERRO"); } } if (plano.KNN) { try { IBk knn = new IBk(3); knn.buildClassifier(baseTeste); avaliacao.crossValidateModel(knn, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(), rand); Resultado resultado = new Resultado("KNN", avaliacao.toMatrixString("Algortmo KNN - Matriz de Confuso"), avaliacao.toClassDetailsString("kNN")); resultado.setTaxaErro(avaliacao.errorRate()); resultado.setTaxaAcerto(1 - avaliacao.errorRate()); resultado.setRevocacao(recallToDouble(avaliacao, baseTeste)); resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste)); this.resultados.add(resultado); } catch (UnsupportedAttributeTypeException ex) { Mensagem.erro("Algortmo KNN no suporta atributos numricos!", "MTCS - ERRO"); } } if (plano.Naive) { NaiveBayes naive = new NaiveBayes(); naive.buildClassifier(baseTeste); avaliacao.crossValidateModel(naive, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(), rand); Resultado resultado = new Resultado("Naive", avaliacao.toMatrixString("Algortmo NaiveBayes - Matriz de Confuso"), avaliacao.toClassDetailsString("kNN")); resultado.setTaxaErro(avaliacao.errorRate()); resultado.setTaxaAcerto(1 - avaliacao.errorRate()); resultado.setRevocacao(recallToDouble(avaliacao, baseTeste)); resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste)); this.resultados.add(resultado); } if (plano.Tree) { try { Id3 id3 = new Id3(); id3.buildClassifier(baseTeste); avaliacao.crossValidateModel(id3, baseTeste, (plano.eHoldOut) ? 4 : baseTeste.numInstances(), rand); Resultado resultado = new Resultado("ID3", avaliacao.toMatrixString("Algortmo ID3 - Matriz de Confuso"), avaliacao.toClassDetailsString("kNN")); resultado.setTaxaErro(avaliacao.errorRate()); resultado.setTaxaAcerto(1 - avaliacao.errorRate()); resultado.setRevocacao(recallToDouble(avaliacao, baseTeste)); resultado.setPrecisao(precisionToDouble(avaliacao, baseTeste)); this.resultados.add(resultado); } catch (UnsupportedAttributeTypeException ex) { Mensagem.erro("Algortmo Arvore de Deciso no suporta atributos numricos!", "MTCS - ERRO"); } } } catch (FileNotFoundException ex) { Logger.getLogger(Classificador.class.getName()).log(Level.SEVERE, null, ex); } catch (IOException ex) { Logger.getLogger(Classificador.class.getName()).log(Level.SEVERE, null, ex); } catch (NullPointerException ex) { Mensagem.erro("Selecione um arquivo para comear!", "MTCS - ERRO"); Logger.getLogger(Classificador.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(Classificador.class.getName()).log(Level.SEVERE, null, ex); } return this.resultados; }
From source file:cs.man.ac.uk.classifiers.GetAUC.java
License:Open Source License
/** * Computes the AUC for the supplied stream learner. * @return the AUC as a double value.//from w ww . j av a 2s .co m */ private static double validate5x2CVStream() { try { // Other options int runs = 5; int folds = 2; double AUC_SUM = 0; // perform cross-validation for (int i = 0; i < runs; i++) { // randomize data int seed = i + 1; Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { System.out.println("Stratifying..."); randData.stratify(folds); } for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); Distribution testDistribution = new Distribution(test); ArffSaver trainSaver = new ArffSaver(); trainSaver.setInstances(train); trainSaver.setFile(new File(trainPath)); trainSaver.writeBatch(); ArffSaver testSaver = new ArffSaver(); testSaver.setInstances(test); double[][] dist = testDistribution.matrix(); int negativeClassSize = (int) dist[0][0]; int positiveClassSize = (int) dist[0][1]; double balance = (double) positiveClassSize / (double) negativeClassSize; String tempTestPath = testPath.replace(".arff", "_" + positiveClassSize + "_" + negativeClassSize + "_" + balance + "_1.0.arff");// [Test-n-Set-n]_[+]_[-]_[K]_[L]; testSaver.setFile(new File(tempTestPath)); testSaver.writeBatch(); ARFFFile file = new ARFFFile(tempTestPath, CLASS_INDEX, new DebugLogger(false)); file.createMetaData(); HoeffdingTreeTester streamClassifier = new HoeffdingTreeTester(trainPath, tempTestPath, CLASS_INDEX, new String[] { "0", "1" }, new DebugLogger(true)); streamClassifier.train(); System.in.read(); //AUC_SUM += streamClassifier.getROCExternalData("",(int)testDistribution.perClass(1),(int)testDistribution.perClass(0)); streamClassifier.testStatic(homeDirectory + "/FuckSakeTest.txt"); String[] files = Common.getFilePaths(scratch); for (int j = 0; j < files.length; j++) Common.fileDelete(files[j]); } } return AUC_SUM / ((double) runs * (double) folds); } catch (Exception e) { System.out.println("Exception validating data!"); e.printStackTrace(); return 0; } }
From source file:cs.man.ac.uk.classifiers.GetAUC.java
License:Open Source License
/** * Computes the AUC for the supplied learner. * @return the AUC as a double value.// www. ja v a 2s. c o m */ @SuppressWarnings("unused") private static double validate5x2CV() { try { // other options int runs = 5; int folds = 2; double AUC_SUM = 0; // perform cross-validation for (int i = 0; i < runs; i++) { // randomize data int seed = i + 1; Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { System.out.println("Stratifying..."); randData.stratify(folds); } Evaluation eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier String[] options = { "-U", "-A" }; J48 classifier = new J48(); //HTree classifier = new HTree(); classifier.setOptions(options); classifier.buildClassifier(train); eval.evaluateModel(classifier, test); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve vmc = new ThresholdVisualizePanel(); AUC_SUM += ThresholdCurve.getROCArea(result); System.out.println("AUC: " + ThresholdCurve.getROCArea(result) + " \tAUC SUM: " + AUC_SUM); } } return AUC_SUM / ((double) runs * (double) folds); } catch (Exception e) { System.out.println("Exception validating data!"); return 0; } }
From source file:edu.utexas.cs.tactex.utils.RegressionUtils.java
License:Open Source License
public static Double leaveOneOutErrorLinRegLambda(double lambda, Instances data) { // MANUAL /* w w w. j a v a 2s. c om*/ // create a linear regression classifier with Xy_polynorm data LinearRegression linreg = createLinearRegression(); linreg.setRidge(lambda); double mse = 0; for (int i = 0; i < data.numInstances(); ++i) { log.info("fold " + i); Instances train = data.trainCV(data.numInstances(), i); log.info("train"); Instances test = data.testCV(data.numInstances(), i); log.info("test"); double actualY = data.instance(i).classValue(); log.info("actualY"); try { linreg.buildClassifier(train); log.info("buildClassifier"); } catch (Exception e) { log.error("failed to build classifier in cross validation", e); return null; } double predictedY = 0; try { predictedY = linreg.classifyInstance(test.instance(0)); log.info("predictedY"); } catch (Exception e) { log.error("failed to classify in cross validation", e); return null; } double error = predictedY - actualY; log.info("error " + error); mse += error * error; log.info("mse " + mse); } if (data.numInstances() == 0) { log.error("no instances in leave-one-out data"); return null; } mse /= data.numInstances(); log.info("mse " + mse); return mse; // // USING WEKA // // // create evaluation object // Evaluation eval = null; // try { // eval = new Evaluation(data); // } catch (Exception e) { // log.error("weka Evaluation() creation threw exception", e); // //e.printStackTrace(); // return null; // } // // // create a linear regression classifier with Xy_polynorm data // LinearRegression linreg = createLinearRegression(); // linreg.setRidge(lambda); // // try { // // linreg.buildClassifier(data); // // } catch (Exception e) { // // log.error("FAILED: linear regression threw exception", e); // // //e.printStackTrace(); // // return null; // // } // // // initialize the evaluation object // Classifier classifier = linreg; // int numFolds = data.numInstances(); // Random random = new Random(0); // try { // eval.crossValidateModel(classifier , data , numFolds , random); // } catch (Exception e) { // log.error("crossvalidation threw exception", e); // //e.printStackTrace(); // return null; // } // // double mse = eval.errorRate(); // return mse; }
From source file:es.upm.dit.gsi.barmas.dataset.utils.DatasetSplitter.java
License:Open Source License
/** * @param folds/*from w ww .j a v a2 s .co m*/ * @param minAgents * @param maxAgents * @param originalDatasetPath * @param outputDir * @param scenario * @param logger */ public void splitDataset(int folds, int minAgents, int maxAgents, String originalDatasetPath, String outputDir, String scenario, Logger logger) { int ratioint = (int) ((1 / (double) folds) * 100); double roundedratio = ((double) ratioint) / 100; // Look for essentials List<String[]> essentials = this.getEssentials(originalDatasetPath, logger); for (int fold = 0; fold < folds; fold++) { String outputDirWithRatio = outputDir + "/" + roundedratio + "testRatio/iteration-" + fold; File dir = new File(outputDirWithRatio); if (!dir.exists() || !dir.isDirectory()) { dir.mkdirs(); } logger.finer("--> splitDataset()"); logger.fine("Creating experiment.info..."); try { Instances originalData = this.getDataFromCSV(originalDatasetPath); originalData.randomize(new Random()); originalData.stratify(folds); // TestDataSet Instances testData = originalData.testCV(folds, fold); CSVSaver saver = new CSVSaver(); ArffSaver arffsaver = new ArffSaver(); File file = new File(outputDirWithRatio + File.separator + "test-dataset.csv"); if (!file.exists()) { saver.resetOptions(); saver.setInstances(testData); saver.setFile(file); saver.writeBatch(); } file = new File(outputDirWithRatio + File.separator + "test-dataset.arff"); if (!file.exists()) { arffsaver.resetOptions(); arffsaver.setInstances(testData); arffsaver.setFile(file); arffsaver.writeBatch(); } // BayesCentralDataset Instances trainData = originalData.trainCV(folds, fold); file = new File(outputDirWithRatio + File.separator + "bayes-central-dataset.csv"); if (!file.exists()) { saver.resetOptions(); saver.setInstances(trainData); saver.setFile(file); saver.writeBatch(); this.copyFileUsingApacheCommonsIO(file, new File( outputDirWithRatio + File.separator + "bayes-central-dataset-noEssentials.csv"), logger); CsvWriter w = new CsvWriter(new FileWriter(file, true), ','); for (String[] essential : essentials) { w.writeRecord(essential); } w.close(); } file = new File(outputDirWithRatio + File.separator + "bayes-central-dataset.arff"); if (!file.exists()) { arffsaver.resetOptions(); arffsaver.setInstances(trainData); arffsaver.setFile(file); arffsaver.writeBatch(); this.copyFileUsingApacheCommonsIO(file, new File( outputDirWithRatio + File.separator + "bayes-central-dataset-noEssentials.arff"), logger); CsvWriter w = new CsvWriter(new FileWriter(file, true), ','); for (String[] essential : essentials) { w.writeRecord(essential); } w.close(); } // Agent datasets CsvReader csvreader = new CsvReader(new FileReader(new File(originalDatasetPath))); csvreader.readHeaders(); String[] headers = csvreader.getHeaders(); csvreader.close(); for (int agents = minAgents; agents <= maxAgents; agents++) { this.createExperimentInfoFile(folds, agents, originalDatasetPath, outputDirWithRatio, scenario, logger); HashMap<String, CsvWriter> writers = new HashMap<String, CsvWriter>(); String agentsDatasetsDir = outputDirWithRatio + File.separator + agents + "agents"; HashMap<String, CsvWriter> arffWriters = new HashMap<String, CsvWriter>(); File f = new File(agentsDatasetsDir); if (!f.isDirectory()) { f.mkdirs(); } Instances copy = new Instances(trainData); copy.delete(); for (int i = 0; i < agents; i++) { String fileName = agentsDatasetsDir + File.separator + "agent-" + i + "-dataset.csv"; file = new File(fileName); if (!file.exists()) { CsvWriter writer = new CsvWriter(new FileWriter(fileName), ','); writer.writeRecord(headers); writers.put("AGENT" + i, writer); } fileName = agentsDatasetsDir + File.separator + "agent-" + i + "-dataset.arff"; file = new File(fileName); if (!file.exists()) { arffsaver.resetOptions(); arffsaver.setInstances(copy); arffsaver.setFile(new File(fileName)); arffsaver.writeBatch(); CsvWriter arffwriter = new CsvWriter(new FileWriter(fileName, true), ','); arffWriters.put("AGENT" + i, arffwriter); } logger.fine("AGENT" + i + " dataset created in csv and arff formats."); } // Append essentials to all for (String[] essential : essentials) { for (CsvWriter wr : writers.values()) { wr.writeRecord(essential); } for (CsvWriter arffwr : arffWriters.values()) { arffwr.writeRecord(essential); } } int agentCounter = 0; for (int j = 0; j < trainData.numInstances(); j++) { Instance instance = trainData.instance(j); CsvWriter writer = writers.get("AGENT" + agentCounter); CsvWriter arffwriter = arffWriters.get("AGENT" + agentCounter); String[] row = new String[instance.numAttributes()]; for (int a = 0; a < instance.numAttributes(); a++) { row[a] = instance.stringValue(a); } if (writer != null) { writer.writeRecord(row); } if (arffwriter != null) { arffwriter.writeRecord(row); } agentCounter++; if (agentCounter == agents) { agentCounter = 0; } } for (CsvWriter wr : writers.values()) { wr.close(); } for (CsvWriter arffwr : arffWriters.values()) { arffwr.close(); } } } catch (Exception e) { logger.severe("Exception while splitting dataset. ->"); logger.severe(e.getMessage()); System.exit(1); } logger.finest("Dataset for fold " + fold + " created."); } logger.finer("<-- splitDataset()"); }
From source file:fr.unice.i3s.rockflows.experiments.main.IntermediateExecutor.java
private boolean checkMinInstances(Instances data, int min) { for (int iii = 0; iii < 4; iii++) { Instances train4 = data.trainCV(4, iii); if (train4.numInstances() < min) { return false; }// ww w.j ava 2s . c o m } for (int iii = 0; iii < 10; iii++) { Instances train10 = data.trainCV(10, iii); if (train10.numInstances() < min) { return false; } } return true; }
From source file:gr.auth.ee.lcs.ArffTrainTestLoader.java
License:Open Source License
/** * Load instances into the global train store and create test set. * /*w w w . j a va2s .com*/ * @param filename * the .arff filename to be used * @param generateTestSet * true if a test set is going to be generated * @throws IOException * if the input file is not found */ public final void loadInstances(final String filename, final boolean generateTestSet) throws IOException { // Open .arff final Instances set = InstancesUtility.openInstance(filename); if (set.classIndex() < 0) { set.setClassIndex(set.numAttributes() - 1); } set.randomize(new Random()); if (generateTestSet) { final int numOfFolds = (int) SettingsLoader.getNumericSetting("NumberOfFolds", 10); final int fold = (int) Math.floor(Math.random() * numOfFolds); trainSet = set.trainCV(numOfFolds, fold); testSet = set.testCV(numOfFolds, fold); } else { trainSet = set; } myLcs.instances = InstancesUtility.convertIntancesToDouble(trainSet); myLcs.labelCardinality = InstancesUtility.getLabelCardinality(trainSet); }