List of usage examples for weka.core Instances testCV
public Instances testCV(int numFolds, int numFold)
From source file:CrossValidationMultipleRuns.java
License:Open Source License
/** * Performs the cross-validation. See Javadoc of class for information * on command-line parameters.//from w w w.jav a 2s . c o m * * @param args the command-line parameters * @throws Exception if something goes wrong */ public static void main(String[] args) throws Exception { // loads data and set class index Instances data = DataSource.read(Utils.getOption("t", args)); String clsIndex = Utils.getOption("c", args); if (clsIndex.length() == 0) clsIndex = "last"; if (clsIndex.equals("first")) data.setClassIndex(0); else if (clsIndex.equals("last")) data.setClassIndex(data.numAttributes() - 1); else data.setClassIndex(Integer.parseInt(clsIndex) - 1); // classifier String[] tmpOptions; String classname; tmpOptions = Utils.splitOptions(Utils.getOption("W", args)); classname = tmpOptions[0]; tmpOptions[0] = ""; Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions); // other options int runs = Integer.parseInt(Utils.getOption("r", args)); int folds = Integer.parseInt(Utils.getOption("x", args)); // perform cross-validation for (int i = 0; i < runs; i++) { // randomize data int seed = i + 1; Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); //if (randData.classAttribute().isNominal()) // randData.stratify(folds); Evaluation eval = new Evaluation(randData); StringBuilder optionsString = new StringBuilder(); for (String s : cls.getOptions()) { optionsString.append(s); optionsString.append(" "); } // output evaluation System.out.println(); System.out.println("=== Setup run " + (i + 1) + " ==="); System.out.println("Classifier: " + optionsString.toString()); System.out.println("Dataset: " + data.relationName()); System.out.println("Folds: " + folds); System.out.println("Seed: " + seed); System.out.println(); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // build and evaluate classifier Classifier clsCopy = Classifier.makeCopy(cls); clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); System.out.println(eval.toClassDetailsString()); } System.out.println( eval.toSummaryString("=== " + folds + "-fold Cross-validation run " + (i + 1) + " ===", false)); } }
From source file:REPTree.java
License:Open Source License
/** * Builds classifier.// w w w . j av a 2 s . c om * * @param data the data to train with * @throws Exception if building fails */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); Random random = new Random(m_Seed); m_zeroR = null; if (data.numAttributes() == 1) { m_zeroR = new ZeroR(); m_zeroR.buildClassifier(data); return; } // Randomize and stratify data.randomize(random); if (data.classAttribute().isNominal()) { data.stratify(m_NumFolds); } // Split data into training and pruning set Instances train = null; Instances prune = null; if (!m_NoPruning) { train = data.trainCV(m_NumFolds, 0, random); prune = data.testCV(m_NumFolds, 0); } else { train = data; } // Create array of sorted indices and weights int[][][] sortedIndices = new int[1][train.numAttributes()][0]; double[][][] weights = new double[1][train.numAttributes()][0]; double[] vals = new double[train.numInstances()]; for (int j = 0; j < train.numAttributes(); j++) { if (j != train.classIndex()) { weights[0][j] = new double[train.numInstances()]; if (train.attribute(j).isNominal()) { // Handling nominal attributes. Putting indices of // instances with missing values at the end. sortedIndices[0][j] = new int[train.numInstances()]; int count = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (!inst.isMissing(j)) { sortedIndices[0][j][count] = i; weights[0][j][count] = inst.weight(); count++; } } for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (inst.isMissing(j)) { sortedIndices[0][j][count] = i; weights[0][j][count] = inst.weight(); count++; } } } else { // Sorted indices are computed for numeric attributes for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); vals[i] = inst.value(j); } sortedIndices[0][j] = Utils.sort(vals); for (int i = 0; i < train.numInstances(); i++) { weights[0][j][i] = train.instance(sortedIndices[0][j][i]).weight(); } } } } // Compute initial class counts double[] classProbs = new double[train.numClasses()]; double totalWeight = 0, totalSumSquared = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (data.classAttribute().isNominal()) { classProbs[(int) inst.classValue()] += inst.weight(); totalWeight += inst.weight(); } else { classProbs[0] += inst.classValue() * inst.weight(); totalSumSquared += inst.classValue() * inst.classValue() * inst.weight(); totalWeight += inst.weight(); } } m_Tree = new Tree(); double trainVariance = 0; if (data.classAttribute().isNumeric()) { trainVariance = m_Tree.singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight; classProbs[0] /= totalWeight; } // Build tree m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum, m_MinVarianceProp * trainVariance, 0, m_MaxDepth); // Insert pruning data and perform reduced error pruning if (!m_NoPruning) { m_Tree.insertHoldOutSet(prune); m_Tree.reducedErrorPrune(); m_Tree.backfitHoldOutSet(); } }
From source file:CopiaSeg3.java
public static Instances[] split(Instances data, int numberOfFolds) { Instances[] split = new Instances[2]; Random semilla = new Random(); int seed = semilla.nextInt(20); // Genera una semilla aleatorio entre 0 y 20 Random rand = new Random(seed); // Create seeded number generator Instances randData = new Instances(data); // Crea una copia de los datos originales randData.randomize(rand); // Ordena los datos de forma aleatoria split[0] = randData.trainCV(numberOfFolds, 0); split[1] = randData.testCV(numberOfFolds, 0); return split; }
From source file:REPRandomTree.java
License:Open Source License
/** * Builds classifier.//from w w w. java 2 s . c o m * * @param data the data to train with * @throws Exception if building fails */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); Random random = new Random(m_Seed); m_zeroR = null; if (data.numAttributes() == 1) { m_zeroR = new ZeroR(); m_zeroR.buildClassifier(data); return; } // Randomize and stratify data.randomize(random); if (data.classAttribute().isNominal()) { data.stratify(m_NumFolds); } // Split data into training and pruning set Instances train = null; Instances prune = null; if (!m_NoPruning) { train = data.trainCV(m_NumFolds, 0, random); prune = data.testCV(m_NumFolds, 0); } else { train = data; } // Create array of sorted indices and weights int[][][] sortedIndices = new int[1][train.numAttributes()][0]; double[][][] weights = new double[1][train.numAttributes()][0]; double[] vals = new double[train.numInstances()]; for (int j = 0; j < train.numAttributes(); j++) { if (j != train.classIndex()) { weights[0][j] = new double[train.numInstances()]; if (train.attribute(j).isNominal()) { // Handling nominal attributes. Putting indices of // instances with missing values at the end. sortedIndices[0][j] = new int[train.numInstances()]; int count = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (!inst.isMissing(j)) { sortedIndices[0][j][count] = i; weights[0][j][count] = inst.weight(); count++; } } for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (inst.isMissing(j)) { sortedIndices[0][j][count] = i; weights[0][j][count] = inst.weight(); count++; } } } else { // Sorted indices are computed for numeric attributes for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); vals[i] = inst.value(j); } sortedIndices[0][j] = Utils.sort(vals); for (int i = 0; i < train.numInstances(); i++) { weights[0][j][i] = train.instance(sortedIndices[0][j][i]).weight(); } } } } // Compute initial class counts double[] classProbs = new double[train.numClasses()]; double totalWeight = 0, totalSumSquared = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (data.classAttribute().isNominal()) { classProbs[(int) inst.classValue()] += inst.weight(); totalWeight += inst.weight(); } else { classProbs[0] += inst.classValue() * inst.weight(); totalSumSquared += inst.classValue() * inst.classValue() * inst.weight(); totalWeight += inst.weight(); } } m_Tree = new Tree(); double trainVariance = 0; if (data.classAttribute().isNumeric()) { trainVariance = m_Tree.singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight; classProbs[0] /= totalWeight; } // Build tree m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum, m_MinVarianceProp * trainVariance, 0, m_MaxDepth, m_FeatureFrac, random); // Insert pruning data and perform reduced error pruning if (!m_NoPruning) { m_Tree.insertHoldOutSet(prune); m_Tree.reducedErrorPrune(); m_Tree.backfitHoldOutSet(); } }
From source file:algoritmogeneticocluster.NewClass.java
public static Instances[][] crossValidationSplit(Instances data, int numberOfFolds) { Instances[][] split = new Instances[2][numberOfFolds]; for (int i = 0; i < numberOfFolds; i++) { split[0][i] = data.trainCV(numberOfFolds, i); split[1][i] = data.testCV(numberOfFolds, i); }/* w w w .j ava 2 s . c o m*/ return split; }
From source file:ann.MyANN.java
/** * Mengevaluasi model dengan membagi instances menjadi trainSet dan testSet sebanyak numFold * @param instances data yang akan diuji * @param numFold/*from ww w.ja v a 2 s . c o m*/ * @param rand * @return confusion matrix */ public int[][] crossValidation(Instances instances, int numFold, Random rand) { int[][] totalResult = null; instances = new Instances(instances); instances.randomize(rand); if (instances.classAttribute().isNominal()) { instances.stratify(numFold); } for (int i = 0; i < numFold; i++) { try { // membagi instance berdasarkan jumlah fold Instances train = instances.trainCV(numFold, i, rand); Instances test = instances.testCV(numFold, i); MyANN cc = new MyANN(this); cc.buildClassifier(train); int[][] result = cc.evaluate(test); if (i == 0) { totalResult = cc.evaluate(test); } else { result = cc.evaluate(test); for (int j = 0; j < totalResult.length; j++) { for (int k = 0; k < totalResult[0].length; k++) { totalResult[j][k] += result[j][k]; } } } } catch (Exception ex) { Logger.getLogger(MyANN.class.getName()).log(Level.SEVERE, null, ex); } } return totalResult; }
From source file:asap.CrossValidation.java
/** * * @param dataInput/*from w w w . jav a2s .c o m*/ * @param classIndex * @param removeIndices * @param cls * @param seed * @param folds * @param modelOutputFile * @return * @throws Exception */ public static String performCrossValidation(String dataInput, String classIndex, String removeIndices, AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception { PerformanceCounters.startTimer("cross-validation ST"); PerformanceCounters.startTimer("cross-validation init ST"); // loads data and set class index Instances data = DataSource.read(dataInput); String clsIndex = classIndex; switch (clsIndex) { case "first": data.setClassIndex(0); break; case "last": data.setClassIndex(data.numAttributes() - 1); break; default: try { data.setClassIndex(Integer.parseInt(clsIndex) - 1); } catch (NumberFormatException e) { data.setClassIndex(data.attribute(clsIndex).index()); } break; } Remove removeFilter = new Remove(); removeFilter.setAttributeIndices(removeIndices); removeFilter.setInputFormat(data); data = Filter.useFilter(data, removeFilter); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval = new Evaluation(randData); Instances trainSets[] = new Instances[folds]; Instances testSets[] = new Instances[folds]; Classifier foldCls[] = new Classifier[folds]; for (int n = 0; n < folds; n++) { trainSets[n] = randData.trainCV(folds, n); testSets[n] = randData.testCV(folds, n); foldCls[n] = AbstractClassifier.makeCopy(cls); } PerformanceCounters.stopTimer("cross-validation init ST"); PerformanceCounters.startTimer("cross-validation folds+train ST"); //paralelize!!:-------------------------------------------------------------- for (int n = 0; n < folds; n++) { Instances train = trainSets[n]; Instances test = testSets[n]; // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier Classifier clsCopy = foldCls[n]; clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); } cls.buildClassifier(data); //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train ST"); PerformanceCounters.startTimer("cross-validation post ST"); // output evaluation String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (!modelOutputFile.isEmpty()) { SerializationHelper.write(modelOutputFile, cls); } PerformanceCounters.stopTimer("cross-validation post ST"); PerformanceCounters.stopTimer("cross-validation ST"); return out; }
From source file:asap.CrossValidation.java
/** * * @param dataInput// www .jav a2 s . c o m * @param classIndex * @param removeIndices * @param cls * @param seed * @param folds * @param modelOutputFile * @return * @throws Exception */ public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices, AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception { PerformanceCounters.startTimer("cross-validation MT"); PerformanceCounters.startTimer("cross-validation init MT"); // loads data and set class index Instances data = DataSource.read(dataInput); String clsIndex = classIndex; switch (clsIndex) { case "first": data.setClassIndex(0); break; case "last": data.setClassIndex(data.numAttributes() - 1); break; default: try { data.setClassIndex(Integer.parseInt(clsIndex) - 1); } catch (NumberFormatException e) { data.setClassIndex(data.attribute(clsIndex).index()); } break; } Remove removeFilter = new Remove(); removeFilter.setAttributeIndices(removeIndices); removeFilter.setInputFormat(data); data = Filter.useFilter(data, removeFilter); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval = new Evaluation(randData); List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(cls))); if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init MT"); PerformanceCounters.startTimer("cross-validation folds+train MT"); //paralelize!!:-------------------------------------------------------------- if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { //use the current thread to run the cross-validation instead of using the Thread instance created here: new CrossValidationFoldThread(0, foldSets, eval).run(); } cls.buildClassifier(data); for (Thread foldThread : foldThreads) { foldThread.join(); } //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train MT"); PerformanceCounters.startTimer("cross-validation post MT"); // evaluation for output: String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (!modelOutputFile.isEmpty()) { SerializationHelper.write(modelOutputFile, cls); } PerformanceCounters.stopTimer("cross-validation post MT"); PerformanceCounters.stopTimer("cross-validation MT"); return out; }
From source file:asap.CrossValidation.java
static String performCrossValidationMT(Instances data, AbstractClassifier cls, int seed, int folds, String modelOutputFile) { PerformanceCounters.startTimer("cross-validation MT"); PerformanceCounters.startTimer("cross-validation init MT"); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand);//from ww w .j av a 2s .com if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval; try { eval = new Evaluation(randData); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); return "Error creating evaluation instance for given data!"; } List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { try { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(cls))); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } //TODO: use Config.getNumThreads() for limiting these:: if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init MT"); PerformanceCounters.startTimer("cross-validation folds+train MT"); //paralelize!!:-------------------------------------------------------------- if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { new CrossValidationFoldThread(0, foldSets, eval).run(); } try { cls.buildClassifier(data); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } for (Thread foldThread : foldThreads) { try { foldThread.join(); } catch (InterruptedException ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } } //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train MT"); PerformanceCounters.startTimer("cross-validation post MT"); // evaluation for output: String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (modelOutputFile != null) { if (!modelOutputFile.isEmpty()) { try { SerializationHelper.write(modelOutputFile, cls); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } } } PerformanceCounters.stopTimer("cross-validation post MT"); PerformanceCounters.stopTimer("cross-validation MT"); return out; }
From source file:asap.NLPSystem.java
private String crossValidate(int seed, int folds, String modelOutputFile) { PerformanceCounters.startTimer("cross-validation"); PerformanceCounters.startTimer("cross-validation init"); AbstractClassifier abstractClassifier = (AbstractClassifier) classifier; // randomize data Random rand = new Random(seed); Instances randData = new Instances(trainingSet); randData.randomize(rand);/* w w w .j a va 2 s . c o m*/ if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval; try { eval = new Evaluation(randData); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); return "Error creating evaluation instance for given data!"; } List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { try { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(abstractClassifier))); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init"); PerformanceCounters.startTimer("cross-validation folds+train"); if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { new CrossValidationFoldThread(0, foldSets, eval).run(); } for (Thread foldThread : foldThreads) { while (foldThread.isAlive()) { try { foldThread.join(); } catch (InterruptedException ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } } } PerformanceCounters.stopTimer("cross-validation folds+train"); PerformanceCounters.startTimer("cross-validation post"); // evaluation for output: String out = String.format( "\n=== Setup ===\nClassifier: %s %s\n" + "Dataset: %s\nFolds: %s\nSeed: %s\n\n%s\n", abstractClassifier.getClass().getName(), Utils.joinOptions(abstractClassifier.getOptions()), trainingSet.relationName(), folds, seed, eval.toSummaryString(String.format("=== %s-fold Cross-validation ===", folds), false)); try { crossValidationPearsonsCorrelation = eval.correlationCoefficient(); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } if (modelOutputFile != null) { if (!modelOutputFile.isEmpty()) { try { SerializationHelper.write(modelOutputFile, abstractClassifier); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } } } classifierBuiltWithCrossValidation = true; PerformanceCounters.stopTimer("cross-validation post"); PerformanceCounters.stopTimer("cross-validation"); return out; }