List of usage examples for weka.core Instances Instances
public Instances(Instances dataset)
From source file:ann.MyANN.java
/** * Melakukan training dengan data yang diberikan * @param instances training data//from ww w . ja va 2 s . co m * @throws Exception Exception apapun yang menyebabkan training gagal */ @Override public void buildClassifier(Instances instances) throws Exception { // cek apakah sesuai dengan data input getCapabilities().testWithFail(instances); // copy data dan buang semua missing class instances = new Instances(instances); instances.deleteWithMissingClass(); // filter NumericToBinary ntb = new NumericToBinary(); ntb.setInputFormat(instances); instances = Filter.useFilter(instances, ntb); // ubah instances ke data instancesToDatas(instances); // membangun ANN berdasarkan nbLayers // membuat layer ArrayList<ArrayList<Node>> layers = new ArrayList<>(); for (int i = 0; i < nbLayers.length; i++) { layers.add(new ArrayList<>()); } // inisialisasi bagian input layer for (int i = 0; i < nbLayers[0]; i++) { // set id, prevLayer = null, nextLayer = layers[1] layers.get(0).add(new Node("node-0" + "-" + i, null, layers.get(1))); } // inisialisasi bagian hidden layer for (int i = 1; i < nbLayers.length - 1; i++) { for (int j = 0; j < nbLayers[i]; j++) { // set id, prevLayer = layers[i-1], nextLayer = layers[i+1] layers.get(i).add(new Node("node-" + i + "-" + j, layers.get(i - 1), layers.get(i + 1))); } } // inisialisasi bagian output layer for (int i = 0; i < nbLayers[nbLayers.length - 1]; i++) { // set id, prevLayer = layers[n-1], nextLayer = null layers.get(nbLayers.length - 1).add( new Node("node-" + (nbLayers.length - 1) + "-" + i, layers.get(nbLayers.length - 2), null)); } // tambah weight tiap neuron // siapin bobot bias, jumlah layer bias adalah nbLayers - 1 ArrayList<Double> bias = new ArrayList<>(); for (int i = 0; i < nbLayers.length - 1; i++) { bias.add(1.0); } // jumlah bobot setiap layer sama dengan jumlah node setiap layer double[][] biasWeight = new double[nbLayers.length - 1][]; for (int i = 1; i < biasWeight.length; i++) { biasWeight[i] = new double[nbLayers[i]]; } // masukin setiap bobot dengan angka random //Random rand = new Random(System.currentTimeMillis()); Random rand = new Random(1); // masukin bobot bias int j = 0; Map<Integer, Map<Node, Double>> biasesWeight = new HashMap<>(); for (int i = 0; i < nbLayers.length - 1; i++) { ArrayList<Node> arrNode = layers.get(i + 1); Map<Node, Double> map = new HashMap<>(); for (Node node : arrNode) { if (isInitialWeightSet) { map.put(node, weights[1][j]); } else { map.put(node, rand.nextDouble()); } j++; } biasesWeight.put(i, map); } j = 0; // masukin bobot tiap neuron Map<Node, Map<Node, Double>> mapWeight = new HashMap<>(); for (int i = 0; i < nbLayers.length - 1; i++) { ArrayList<Node> arrNode = layers.get(i); for (Node node : arrNode) { Map<Node, Double> map = new HashMap<>(); for (Node nextNode : node.getNextNodes()) { if (isInitialWeightSet) { map.put(nextNode, weights[0][j]); } else { map.put(nextNode, rand.nextDouble()); } j++; } mapWeight.put(node, map); } } // buat model ANN berdasarkan nilai di atas annModel = new ANNModel(layers, mapWeight, bias, biasesWeight); // set konfigurasi awal model // debug // System.out.println("debug"); // for (Data d : datas) { // for (Double dd : d.input) { // System.out.print(dd+" "); // } // System.out.print(" | "); // for (Double dd : d.target) { // System.out.print(dd+" "); // } // System.out.println(""); // } // System.out.println("debug"); annModel.setDataSet(datas); annModel.setLearningRate(learningRate); annModel.setMomentum(momentum); switch (activationFunction) { case SIGMOID_FUNCTION: annModel.setActivationFunction(ANNModel.SIGMOID); break; case SIGN_FUNCTION: // ubah target jadi -1 dan 1 for (Data d : datas) { for (Double dd : d.target) { if (dd == 0.0) { dd = -1.0; } } } annModel.setActivationFunction(ANNModel.SIGN); break; case STEP_FUNCTION: annModel.setActivationFunction(ANNModel.STEP); break; default: break; } if (learningRule == BATCH_GRADIENT_DESCENT || learningRule == DELTA_RULE) annModel.setActivationFunction(ANNModel.NO_FUNC); if (topology == MULTILAYER_PERCEPTRON) { annModel.setActivationFunction(ANNModel.SIGMOID); } annModel.setThreshold(threshold); // jalankan algoritma boolean stop = false; iteration = 0; //annModel.print(); annModel.resetDeltaWeight(); do { if (topology == ONE_PERCEPTRON) { switch (learningRule) { case PERCEPTRON_TRAINING_RULE: annModel.perceptronTrainingRule(); break; case BATCH_GRADIENT_DESCENT: annModel.batchGradienDescent(); break; case DELTA_RULE: annModel.deltaRule(); break; default: break; } } else if (topology == MULTILAYER_PERCEPTRON) { annModel.backProp(); } iteration++; // berhenti jika terminateCondition terpenuhi switch (terminationCondition) { case TERMINATE_MAX_ITERATION: if (iteration >= maxIteration) stop = true; break; case TERMINATE_MSE: if (annModel.error < deltaMSE) stop = true; break; case TERMINATE_BOTH: if (iteration > maxIteration || annModel.error < deltaMSE) stop = true; break; default: break; } // System.out.println(annModel.error); } while (!stop); // annModel.print(); }
From source file:ann.MyANN.java
/** * Mengevaluasi model dengan membagi instances menjadi trainSet dan testSet sebanyak numFold * @param instances data yang akan diuji * @param numFold//from w w w. jav a 2 s. c om * @param rand * @return confusion matrix */ public int[][] crossValidation(Instances instances, int numFold, Random rand) { int[][] totalResult = null; instances = new Instances(instances); instances.randomize(rand); if (instances.classAttribute().isNominal()) { instances.stratify(numFold); } for (int i = 0; i < numFold; i++) { try { // membagi instance berdasarkan jumlah fold Instances train = instances.trainCV(numFold, i, rand); Instances test = instances.testCV(numFold, i); MyANN cc = new MyANN(this); cc.buildClassifier(train); int[][] result = cc.evaluate(test); if (i == 0) { totalResult = cc.evaluate(test); } else { result = cc.evaluate(test); for (int j = 0; j < totalResult.length; j++) { for (int k = 0; k < totalResult[0].length; k++) { totalResult[j][k] += result[j][k]; } } } } catch (Exception ex) { Logger.getLogger(MyANN.class.getName()).log(Level.SEVERE, null, ex); } } return totalResult; }
From source file:ann.SingleLayerPerceptron.java
@Override public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); annOptions = new ANNOptions(); annOptions = annOptions.loadConfiguration(); output = new ArrayList<Neuron>(); normalize = new Normalize(); ntb = new NominalToBinary(); output = annOptions.output;// w w w . j a v a 2s . c o m // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); //nominal to binary filter ntb.setInputFormat(data); data = new Instances(Filter.useFilter(data, ntb)); //normalize filter normalize.setInputFormat(data); data = new Instances(Filter.useFilter(data, normalize)); // do main function doPerceptron(data); }
From source file:ann.SingleLayerPerceptron.java
public int[] classifyInstances(Instances data) throws Exception { int[] classValue = new int[data.numInstances()]; // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass();// www.ja va2 s . c om //nominal to binary filter ntb.setInputFormat(data); data = new Instances(Filter.useFilter(data, ntb)); int right = 0; for (int i = 0; i < data.numInstances(); i++) { int outputSize = output.size(); double[] result = new double[outputSize]; for (int j = 0; j < outputSize; j++) { result[j] = 0.0; for (int k = 0; k < data.numAttributes(); k++) { double input = 1; if (k < data.numAttributes() - 1) { input = data.instance(i).value(k); } result[j] += output.get(j).weights.get(k) * input; } result[j] = Util.activationFunction(result[j], annOptions); } if (outputSize >= 2) { for (int j = 0; j < outputSize; j++) { if (result[j] > result[classValue[i]]) { classValue[i] = j; } } } else { classValue[i] = (int) result[0]; } double target = data.instance(i).classValue(); double output = classValue[i]; System.out.println("Intance-" + i + " target: " + target + " output: " + output); if (target == output) { right = right + 1; } } System.out.println("Percentage: " + ((double) right / (double) data.numInstances())); return classValue; }
From source file:ant.Game.java
public Game(int dim, boolean first, int fov) throws FileNotFoundException, IOException { m_dim = dim;//from w w w. j av a 2 s .c o m m_score = 0; m_moves = m_dim * 2; initGrid(); initFood(); initAnt(fov); if (first) { initFile(); } BufferedReader reader = new BufferedReader(new FileReader(m_fileName)); m_data = new Instances(reader); m_data.setClassIndex(m_data.numAttributes() - 1); m_wrapper1x1 = new WekaWrapper1x1(); m_wrapper5x1 = new WekaWrapper5x1(); m_wrapper10x1 = new WekaWrapper10x1(); m_wrapper1x2 = new WekaWrapper1x2(); m_wrapper5x2 = new WekaWrapper5x2(); m_wrapper10x2 = new WekaWrapper10x2(); m_wrapper50x2 = new WekaWrapper50x2(); }
From source file:app.RunApp.java
License:Open Source License
/** * Preprocess dataset// w w w. j a v a 2 s .c om * * @return Positive number if successfull and negative otherwise */ private int preprocess() { trainDatasets = new ArrayList(); testDatasets = new ArrayList(); Instances train, test; if (dataset == null) { JOptionPane.showMessageDialog(null, "You must load a dataset.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } MultiLabelInstances preprocessDataset = dataset.clone(); if (!radioNoIS.isSelected()) { //Do Instance Selection if (radioRandomIS.isSelected()) { int nInstances = Integer.parseInt(textRandomIS.getText()); if (nInstances < 1) { JOptionPane.showMessageDialog(null, "The number of instances must be a positive natural number.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } else if (nInstances > dataset.getNumInstances()) { JOptionPane.showMessageDialog(null, "The number of instances to select must be less than the original.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } Instances dataIS; try { Randomize randomize = new Randomize(); dataIS = dataset.getDataSet(); randomize.setInputFormat(dataIS); dataIS = Filter.useFilter(dataIS, randomize); randomize.batchFinished(); RemoveRange removeRange = new RemoveRange(); removeRange.setInputFormat(dataIS); removeRange.setInstancesIndices((nInstances + 1) + "-last"); dataIS = Filter.useFilter(dataIS, removeRange); removeRange.batchFinished(); preprocessDataset = dataset.reintegrateModifiedDataSet(dataIS); } catch (Exception ex) { Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex); } if (preprocessDataset == null) { JOptionPane.showMessageDialog(null, "Error when selecting instances.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } preprocessedDataset = preprocessDataset; } } if (!radioNoFS.isSelected()) { //FS_BR if (radioBRFS.isSelected()) { int nFeatures = Integer.parseInt(textBRFS.getText()); if (nFeatures < 1) { JOptionPane.showMessageDialog(null, "The number of features must be a positive natural number.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } else if (nFeatures > dataset.getFeatureIndices().length) { JOptionPane.showMessageDialog(null, "The number of features to select must be less than the original.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } String combination = jComboBoxBRFSComb.getSelectedItem().toString(); String normalization = jComboBoxBRFSNorm.getSelectedItem().toString(); String output = jComboBoxBRFSOut.getSelectedItem().toString(); FeatureSelector fs; if (radioNoIS.isSelected()) { fs = new FeatureSelector(dataset, nFeatures); } else { //If IS have been done fs = new FeatureSelector(preprocessDataset, nFeatures); } preprocessedDataset = fs.select(combination, normalization, output); if (preprocessedDataset == null) { JOptionPane.showMessageDialog(null, "Error when selecting features.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } preprocessDataset = preprocessedDataset; } else if (radioRandomFS.isSelected()) { int nFeatures = Integer.parseInt(textRandomFS.getText()); if (nFeatures < 1) { JOptionPane.showMessageDialog(null, "The number of features must be a positive natural number.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } else if (nFeatures > dataset.getFeatureIndices().length) { JOptionPane.showMessageDialog(null, "The number of features to select must be less than the original.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } FeatureSelector fs; if (radioNoIS.isSelected()) { fs = new FeatureSelector(dataset, nFeatures); } else { //If IS have been done fs = new FeatureSelector(preprocessDataset, nFeatures); } preprocessedDataset = fs.randomSelect(); if (preprocessedDataset == null) { JOptionPane.showMessageDialog(null, "Error when selecting features.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } preprocessDataset = preprocessedDataset; } } if (!radioNoSplit.isSelected()) { //Random Holdout if (radioRandomHoldout.isSelected()) { String split = textRandomHoldout.getText(); double percentage = Double.parseDouble(split); if ((percentage <= 0) || (percentage >= 100)) { JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).", "alert", JOptionPane.ERROR_MESSAGE); return -1; } try { RandomTrainTest pre = new RandomTrainTest(); MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage); trainDataset = partitions[0]; testDataset = partitions[1]; } catch (InvalidDataFormatException ex) { Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex); } } //Random CV else if (radioRandomCV.isSelected()) { String split = textRandomCV.getText(); if (split.equals("")) { JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } int nFolds; try { nFolds = Integer.parseInt(split); } catch (Exception e) { JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } if (nFolds < 2) { JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } else if (nFolds > preprocessDataset.getNumInstances()) { JOptionPane.showMessageDialog(null, "The number of folds can not be greater than the number of instances.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } try { MultiLabelInstances temp = preprocessDataset.clone(); Instances dataTemp = temp.getDataSet(); int seed = (int) (Math.random() * 100) + 100; Random rand = new Random(seed); dataTemp.randomize(rand); Instances[] foldsCV = new Instances[nFolds]; for (int i = 0; i < nFolds; i++) { foldsCV[i] = new Instances(dataTemp); foldsCV[i].clear(); } for (int i = 0; i < dataTemp.numInstances(); i++) { foldsCV[i % nFolds].add(dataTemp.get(i)); } train = new Instances(dataTemp); test = new Instances(dataTemp); for (int i = 0; i < nFolds; i++) { train.clear(); test.clear(); for (int j = 0; j < nFolds; j++) { if (i != j) { System.out.println("Add fold " + j + " to train"); train.addAll(foldsCV[j]); } } System.out.println("Add fold " + i + " to test"); test.addAll(foldsCV[i]); System.out.println(train.get(0).toString()); System.out.println(test.get(0).toString()); trainDatasets.add(new MultiLabelInstances(new Instances(train), preprocessDataset.getLabelsMetaData())); testDatasets.add(new MultiLabelInstances(new Instances(test), preprocessDataset.getLabelsMetaData())); System.out.println(trainDatasets.get(i).getDataSet().get(0).toString()); System.out.println(testDatasets.get(i).getDataSet().get(0).toString()); System.out.println("---"); } } catch (Exception ex) { Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex); } } //Iterative stratified holdout else if (radioIterativeStratifiedHoldout.isSelected()) { String split = textIterativeStratifiedHoldout.getText(); double percentage = Double.parseDouble(split); if ((percentage <= 0) || (percentage >= 100)) { JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).", "alert", JOptionPane.ERROR_MESSAGE); return -1; } try { IterativeTrainTest pre = new IterativeTrainTest(); MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage); trainDataset = partitions[0]; testDataset = partitions[1]; } catch (Exception ex) { Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex); } } //Iterative stratified CV else if (radioIterativeStratifiedCV.isSelected()) { String split = textIterativeStratifiedCV.getText(); if (split.equals("")) { JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } int nFolds = 0; try { nFolds = Integer.parseInt(split); } catch (Exception e) { JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } if (nFolds < 2) { JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } else if (nFolds > preprocessDataset.getNumInstances()) { JOptionPane.showMessageDialog(null, "The number of folds can not be greater than the number of instances.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } IterativeStratification strat = new IterativeStratification(); MultiLabelInstances folds[] = strat.stratify(preprocessDataset, nFolds); for (int i = 0; i < nFolds; i++) { try { int trainSize = 0, testSize = 0; for (int j = 0; j < nFolds; j++) { if (i != j) { trainSize += folds[j].getNumInstances(); } } testSize += folds[i].getNumInstances(); train = new Instances(preprocessDataset.getDataSet(), trainSize); test = new Instances(preprocessDataset.getDataSet(), testSize); for (int j = 0; j < nFolds; j++) { if (i != j) { train.addAll(folds[j].getDataSet()); } } test.addAll(folds[i].getDataSet()); trainDatasets.add(new MultiLabelInstances(train, preprocessDataset.getLabelsMetaData())); testDatasets.add(new MultiLabelInstances(test, preprocessDataset.getLabelsMetaData())); } catch (InvalidDataFormatException ex) { Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex); } } } //LP stratified holdout else if (radioLPStratifiedHoldout.isSelected()) { String split = textLPStratifiedHoldout.getText(); double percentage = Double.parseDouble(split); if ((percentage <= 0) || (percentage >= 100)) { JOptionPane.showMessageDialog(null, "The percentage must be a number in the range (0, 100).", "alert", JOptionPane.ERROR_MESSAGE); return -1; } try { IterativeTrainTest pre = new IterativeTrainTest(); MultiLabelInstances[] partitions = pre.split(preprocessDataset, percentage); trainDataset = partitions[0]; testDataset = partitions[1]; } catch (Exception ex) { Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex); } } //LP stratified CV else if (radioLPStratifiedCV.isSelected()) { String split = textLPStratifiedCV.getText(); if (split.equals("")) { JOptionPane.showMessageDialog(null, "You must enter the number of folds.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } int nFolds = 0; try { nFolds = Integer.parseInt(split); } catch (Exception e) { JOptionPane.showMessageDialog(null, "Introduce a correct number of folds.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } if (nFolds < 2) { JOptionPane.showMessageDialog(null, "The number of folds must be greater or equal to 2.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } else if (nFolds > preprocessDataset.getNumInstances()) { JOptionPane.showMessageDialog(null, "The number of folds can not be greater than the number of instances.", "alert", JOptionPane.ERROR_MESSAGE); return -1; } LabelPowersetTrainTest strat = new LabelPowersetTrainTest(); MultiLabelInstances folds[] = strat.stratify(preprocessDataset, nFolds); for (int i = 0; i < nFolds; i++) { try { train = new Instances(preprocessDataset.getDataSet(), 0); test = new Instances(preprocessDataset.getDataSet(), 0); for (int j = 0; j < nFolds; j++) { if (i != j) { train.addAll(folds[j].getDataSet()); } } test.addAll(folds[i].getDataSet()); trainDatasets.add(new MultiLabelInstances(train, preprocessDataset.getLabelsMetaData())); testDatasets.add(new MultiLabelInstances(test, preprocessDataset.getLabelsMetaData())); } catch (InvalidDataFormatException ex) { Logger.getLogger(RunApp.class.getName()).log(Level.SEVERE, null, ex); } } } } jButtonSaveDatasets.setEnabled(true); jComboBoxSaveFormat.setEnabled(true); return 1; }
From source file:asap.CrossValidation.java
/** * * @param dataInput/*www. ja v a 2 s.c om*/ * @param classIndex * @param removeIndices * @param cls * @param seed * @param folds * @param modelOutputFile * @return * @throws Exception */ public static String performCrossValidation(String dataInput, String classIndex, String removeIndices, AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception { PerformanceCounters.startTimer("cross-validation ST"); PerformanceCounters.startTimer("cross-validation init ST"); // loads data and set class index Instances data = DataSource.read(dataInput); String clsIndex = classIndex; switch (clsIndex) { case "first": data.setClassIndex(0); break; case "last": data.setClassIndex(data.numAttributes() - 1); break; default: try { data.setClassIndex(Integer.parseInt(clsIndex) - 1); } catch (NumberFormatException e) { data.setClassIndex(data.attribute(clsIndex).index()); } break; } Remove removeFilter = new Remove(); removeFilter.setAttributeIndices(removeIndices); removeFilter.setInputFormat(data); data = Filter.useFilter(data, removeFilter); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval = new Evaluation(randData); Instances trainSets[] = new Instances[folds]; Instances testSets[] = new Instances[folds]; Classifier foldCls[] = new Classifier[folds]; for (int n = 0; n < folds; n++) { trainSets[n] = randData.trainCV(folds, n); testSets[n] = randData.testCV(folds, n); foldCls[n] = AbstractClassifier.makeCopy(cls); } PerformanceCounters.stopTimer("cross-validation init ST"); PerformanceCounters.startTimer("cross-validation folds+train ST"); //paralelize!!:-------------------------------------------------------------- for (int n = 0; n < folds; n++) { Instances train = trainSets[n]; Instances test = testSets[n]; // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier Classifier clsCopy = foldCls[n]; clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); } cls.buildClassifier(data); //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train ST"); PerformanceCounters.startTimer("cross-validation post ST"); // output evaluation String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (!modelOutputFile.isEmpty()) { SerializationHelper.write(modelOutputFile, cls); } PerformanceCounters.stopTimer("cross-validation post ST"); PerformanceCounters.stopTimer("cross-validation ST"); return out; }
From source file:asap.CrossValidation.java
/** * * @param dataInput//from w w w .java 2 s .c om * @param classIndex * @param removeIndices * @param cls * @param seed * @param folds * @param modelOutputFile * @return * @throws Exception */ public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices, AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception { PerformanceCounters.startTimer("cross-validation MT"); PerformanceCounters.startTimer("cross-validation init MT"); // loads data and set class index Instances data = DataSource.read(dataInput); String clsIndex = classIndex; switch (clsIndex) { case "first": data.setClassIndex(0); break; case "last": data.setClassIndex(data.numAttributes() - 1); break; default: try { data.setClassIndex(Integer.parseInt(clsIndex) - 1); } catch (NumberFormatException e) { data.setClassIndex(data.attribute(clsIndex).index()); } break; } Remove removeFilter = new Remove(); removeFilter.setAttributeIndices(removeIndices); removeFilter.setInputFormat(data); data = Filter.useFilter(data, removeFilter); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval = new Evaluation(randData); List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(cls))); if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init MT"); PerformanceCounters.startTimer("cross-validation folds+train MT"); //paralelize!!:-------------------------------------------------------------- if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { //use the current thread to run the cross-validation instead of using the Thread instance created here: new CrossValidationFoldThread(0, foldSets, eval).run(); } cls.buildClassifier(data); for (Thread foldThread : foldThreads) { foldThread.join(); } //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train MT"); PerformanceCounters.startTimer("cross-validation post MT"); // evaluation for output: String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (!modelOutputFile.isEmpty()) { SerializationHelper.write(modelOutputFile, cls); } PerformanceCounters.stopTimer("cross-validation post MT"); PerformanceCounters.stopTimer("cross-validation MT"); return out; }
From source file:asap.CrossValidation.java
static String performCrossValidationMT(Instances data, AbstractClassifier cls, int seed, int folds, String modelOutputFile) { PerformanceCounters.startTimer("cross-validation MT"); PerformanceCounters.startTimer("cross-validation init MT"); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand);/* w w w . j a v a 2 s . c om*/ if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval; try { eval = new Evaluation(randData); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); return "Error creating evaluation instance for given data!"; } List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { try { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(cls))); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } //TODO: use Config.getNumThreads() for limiting these:: if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init MT"); PerformanceCounters.startTimer("cross-validation folds+train MT"); //paralelize!!:-------------------------------------------------------------- if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { new CrossValidationFoldThread(0, foldSets, eval).run(); } try { cls.buildClassifier(data); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } for (Thread foldThread : foldThreads) { try { foldThread.join(); } catch (InterruptedException ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } } //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train MT"); PerformanceCounters.startTimer("cross-validation post MT"); // evaluation for output: String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (modelOutputFile != null) { if (!modelOutputFile.isEmpty()) { try { SerializationHelper.write(modelOutputFile, cls); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } } } PerformanceCounters.stopTimer("cross-validation post MT"); PerformanceCounters.stopTimer("cross-validation MT"); return out; }
From source file:asap.NLPSystem.java
private String crossValidate(int seed, int folds, String modelOutputFile) { PerformanceCounters.startTimer("cross-validation"); PerformanceCounters.startTimer("cross-validation init"); AbstractClassifier abstractClassifier = (AbstractClassifier) classifier; // randomize data Random rand = new Random(seed); Instances randData = new Instances(trainingSet); randData.randomize(rand);/*w w w . j ava 2s . c o m*/ if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval; try { eval = new Evaluation(randData); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); return "Error creating evaluation instance for given data!"; } List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { try { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(abstractClassifier))); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init"); PerformanceCounters.startTimer("cross-validation folds+train"); if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { new CrossValidationFoldThread(0, foldSets, eval).run(); } for (Thread foldThread : foldThreads) { while (foldThread.isAlive()) { try { foldThread.join(); } catch (InterruptedException ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } } } PerformanceCounters.stopTimer("cross-validation folds+train"); PerformanceCounters.startTimer("cross-validation post"); // evaluation for output: String out = String.format( "\n=== Setup ===\nClassifier: %s %s\n" + "Dataset: %s\nFolds: %s\nSeed: %s\n\n%s\n", abstractClassifier.getClass().getName(), Utils.joinOptions(abstractClassifier.getOptions()), trainingSet.relationName(), folds, seed, eval.toSummaryString(String.format("=== %s-fold Cross-validation ===", folds), false)); try { crossValidationPearsonsCorrelation = eval.correlationCoefficient(); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } if (modelOutputFile != null) { if (!modelOutputFile.isEmpty()) { try { SerializationHelper.write(modelOutputFile, abstractClassifier); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } } } classifierBuiltWithCrossValidation = true; PerformanceCounters.stopTimer("cross-validation post"); PerformanceCounters.stopTimer("cross-validation"); return out; }