List of usage examples for weka.classifiers AbstractClassifier makeCopy
public static Classifier makeCopy(Classifier model) throws Exception
From source file:asap.CrossValidation.java
/** * * @param dataInput/* w ww . j a v a2 s . c o m*/ * @param classIndex * @param removeIndices * @param cls * @param seed * @param folds * @param modelOutputFile * @return * @throws Exception */ public static String performCrossValidation(String dataInput, String classIndex, String removeIndices, AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception { PerformanceCounters.startTimer("cross-validation ST"); PerformanceCounters.startTimer("cross-validation init ST"); // loads data and set class index Instances data = DataSource.read(dataInput); String clsIndex = classIndex; switch (clsIndex) { case "first": data.setClassIndex(0); break; case "last": data.setClassIndex(data.numAttributes() - 1); break; default: try { data.setClassIndex(Integer.parseInt(clsIndex) - 1); } catch (NumberFormatException e) { data.setClassIndex(data.attribute(clsIndex).index()); } break; } Remove removeFilter = new Remove(); removeFilter.setAttributeIndices(removeIndices); removeFilter.setInputFormat(data); data = Filter.useFilter(data, removeFilter); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval = new Evaluation(randData); Instances trainSets[] = new Instances[folds]; Instances testSets[] = new Instances[folds]; Classifier foldCls[] = new Classifier[folds]; for (int n = 0; n < folds; n++) { trainSets[n] = randData.trainCV(folds, n); testSets[n] = randData.testCV(folds, n); foldCls[n] = AbstractClassifier.makeCopy(cls); } PerformanceCounters.stopTimer("cross-validation init ST"); PerformanceCounters.startTimer("cross-validation folds+train ST"); //paralelize!!:-------------------------------------------------------------- for (int n = 0; n < folds; n++) { Instances train = trainSets[n]; Instances test = testSets[n]; // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier Classifier clsCopy = foldCls[n]; clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); } cls.buildClassifier(data); //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train ST"); PerformanceCounters.startTimer("cross-validation post ST"); // output evaluation String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (!modelOutputFile.isEmpty()) { SerializationHelper.write(modelOutputFile, cls); } PerformanceCounters.stopTimer("cross-validation post ST"); PerformanceCounters.stopTimer("cross-validation ST"); return out; }
From source file:asap.CrossValidation.java
/** * * @param dataInput/*from ww w.ja v a2 s . co m*/ * @param classIndex * @param removeIndices * @param cls * @param seed * @param folds * @param modelOutputFile * @return * @throws Exception */ public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices, AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception { PerformanceCounters.startTimer("cross-validation MT"); PerformanceCounters.startTimer("cross-validation init MT"); // loads data and set class index Instances data = DataSource.read(dataInput); String clsIndex = classIndex; switch (clsIndex) { case "first": data.setClassIndex(0); break; case "last": data.setClassIndex(data.numAttributes() - 1); break; default: try { data.setClassIndex(Integer.parseInt(clsIndex) - 1); } catch (NumberFormatException e) { data.setClassIndex(data.attribute(clsIndex).index()); } break; } Remove removeFilter = new Remove(); removeFilter.setAttributeIndices(removeIndices); removeFilter.setInputFormat(data); data = Filter.useFilter(data, removeFilter); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval = new Evaluation(randData); List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(cls))); if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init MT"); PerformanceCounters.startTimer("cross-validation folds+train MT"); //paralelize!!:-------------------------------------------------------------- if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { //use the current thread to run the cross-validation instead of using the Thread instance created here: new CrossValidationFoldThread(0, foldSets, eval).run(); } cls.buildClassifier(data); for (Thread foldThread : foldThreads) { foldThread.join(); } //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train MT"); PerformanceCounters.startTimer("cross-validation post MT"); // evaluation for output: String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (!modelOutputFile.isEmpty()) { SerializationHelper.write(modelOutputFile, cls); } PerformanceCounters.stopTimer("cross-validation post MT"); PerformanceCounters.stopTimer("cross-validation MT"); return out; }
From source file:asap.CrossValidation.java
static String performCrossValidationMT(Instances data, AbstractClassifier cls, int seed, int folds, String modelOutputFile) { PerformanceCounters.startTimer("cross-validation MT"); PerformanceCounters.startTimer("cross-validation init MT"); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand);/* w ww . j a v a2s.c o m*/ if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval; try { eval = new Evaluation(randData); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); return "Error creating evaluation instance for given data!"; } List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { try { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(cls))); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } //TODO: use Config.getNumThreads() for limiting these:: if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init MT"); PerformanceCounters.startTimer("cross-validation folds+train MT"); //paralelize!!:-------------------------------------------------------------- if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { new CrossValidationFoldThread(0, foldSets, eval).run(); } try { cls.buildClassifier(data); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } for (Thread foldThread : foldThreads) { try { foldThread.join(); } catch (InterruptedException ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } } //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train MT"); PerformanceCounters.startTimer("cross-validation post MT"); // evaluation for output: String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (modelOutputFile != null) { if (!modelOutputFile.isEmpty()) { try { SerializationHelper.write(modelOutputFile, cls); } catch (Exception ex) { Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex); } } } PerformanceCounters.stopTimer("cross-validation post MT"); PerformanceCounters.stopTimer("cross-validation MT"); return out; }
From source file:asap.NLPSystem.java
private String crossValidate(int seed, int folds, String modelOutputFile) { PerformanceCounters.startTimer("cross-validation"); PerformanceCounters.startTimer("cross-validation init"); AbstractClassifier abstractClassifier = (AbstractClassifier) classifier; // randomize data Random rand = new Random(seed); Instances randData = new Instances(trainingSet); randData.randomize(rand);/*from ww w. j av a2s. co m*/ if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval; try { eval = new Evaluation(randData); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); return "Error creating evaluation instance for given data!"; } List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { try { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(abstractClassifier))); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init"); PerformanceCounters.startTimer("cross-validation folds+train"); if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { new CrossValidationFoldThread(0, foldSets, eval).run(); } for (Thread foldThread : foldThreads) { while (foldThread.isAlive()) { try { foldThread.join(); } catch (InterruptedException ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } } } PerformanceCounters.stopTimer("cross-validation folds+train"); PerformanceCounters.startTimer("cross-validation post"); // evaluation for output: String out = String.format( "\n=== Setup ===\nClassifier: %s %s\n" + "Dataset: %s\nFolds: %s\nSeed: %s\n\n%s\n", abstractClassifier.getClass().getName(), Utils.joinOptions(abstractClassifier.getOptions()), trainingSet.relationName(), folds, seed, eval.toSummaryString(String.format("=== %s-fold Cross-validation ===", folds), false)); try { crossValidationPearsonsCorrelation = eval.correlationCoefficient(); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } if (modelOutputFile != null) { if (!modelOutputFile.isEmpty()) { try { SerializationHelper.write(modelOutputFile, abstractClassifier); } catch (Exception ex) { Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex); } } } classifierBuiltWithCrossValidation = true; PerformanceCounters.stopTimer("cross-validation post"); PerformanceCounters.stopTimer("cross-validation"); return out; }
From source file:com.actelion.research.orbit.imageAnalysis.models.ClassifierWrapper.java
License:Open Source License
public static ClassifierWrapper makeCopy(ClassifierWrapper classifierWrapper) throws Exception { if (classifierWrapper == null) return null; if (classifierWrapper.getWrapperType() == WRAPPERTYPE_CLASSIFIER) { //return new ClassifierWrapper(Classifier.makeCopy(classifierWrapper.getClassifier())); ClassifierWrapper cw = new ClassifierWrapper( (AbstractClassifier.makeCopy(classifierWrapper.getClassifier()))); // weka 3.7.1 cw.setBinaryClassification(classifierWrapper.getBinaryClassification()); cw.setBuild(classifierWrapper.isBuild); return cw; } else {//from w w w . java 2 s . c o m ClassifierWrapper cw = new ClassifierWrapper( AbstractClusterer.makeCopy(classifierWrapper.getClusterer())); cw.setBinaryClassification(classifierWrapper.getBinaryClassification()); cw.setBuild(classifierWrapper.isBuild); return cw; } }
From source file:com.deafgoat.ml.prognosticator.AppClassifier.java
License:Apache License
/** * Perform cross-validation on data set/builds model * /*from ww w .j a va 2 s. c o m*/ * @throws Exception */ public void crossValidate() throws Exception { // stratify nominal target class if (_trainInstances.classAttribute().isNominal()) { _trainInstances.stratify(_folds); } _eval = new Evaluation(_trainInstances); for (int n = 0; n < _folds; n++) { if (_logger.isDebugEnabled()) { _logger.debug("Cross validation fold: " + (n + 1)); } _train = _trainInstances.trainCV(_folds, n); _test = _trainInstances.testCV(_folds, n); _clsCopy = AbstractClassifier.makeCopy(_cls); try { _clsCopy.buildClassifier(_train); } catch (Exception e) { _logger.debug(_config._classifier + " can not handle " + getAttributeType(_test.classAttribute()) + " class attributes"); } try { _eval.evaluateModel(_clsCopy, _test); } catch (Exception e) { _logger.debug("Can not evaluate model"); } } if (_config._writeToMongoDB) { _logger.info("Writing model to mongoDB"); // save the trained model saveModel(); // save CV performance of trained model writeToMongoDB(_eval); } if (_config._writeToFile) { _logger.info("Writing model to file"); SerializationHelper.write(_config._modelFile, _clsCopy); } }
From source file:de.tudarmstadt.ukp.similarity.experiments.coling2012.util.Evaluator.java
License:Open Source License
public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception { // Set parameters int folds = 10; Classifier baseClassifier = getClassifier(wekaClassifier); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o", MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" }); Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff"); data.setClassIndex(data.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data data.randomize(random);/*www .j a v a2s . co m*/ // Perform cross-validation Instances predictedData = null; Evaluation eval = new Evaluation(data); for (int n = 0; n < folds; n++) { Instances train = data.trainCV(folds, n, random); Instances test = data.testCV(folds, n); // Apply log filter // Filter logFilter = new LogFilter(); // logFilter.setInputFormat(train); // train = Filter.useFilter(train, logFilter); // logFilter.setInputFormat(test); // test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Evaluate eval.evaluateModel(filteredClassifier, test); // Add predictions AddClassification filter = new AddClassification(); filter.setClassifier(filteredClassifier); filter.setOutputClassification(true); filter.setOutputDistribution(false); filter.setOutputErrorFlag(true); filter.setInputFormat(train); Filter.useFilter(train, filter); // trains the classifier Instances pred = Filter.useFilter(test, filter); // performs predictions on test set if (predictedData == null) predictedData = new Instances(pred, 0); for (int j = 0; j < pred.numInstances(); j++) predictedData.add(pred.instance(j)); } // Prepare output classification String[] scores = new String[predictedData.numInstances()]; for (Instance predInst : predictedData) { int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1; int valueIdx = predictedData.numAttributes() - 2; String value = predInst.stringValue(predInst.attribute(valueIdx)); scores[id] = value; } // Output StringBuilder sb = new StringBuilder(); for (String score : scores) sb.append(score.toString() + LF); FileUtils.writeStringToFile( new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/output.csv"), sb.toString()); }
From source file:dkpro.similarity.experiments.rte.util.Evaluator.java
License:Open Source License
public static void runClassifier(WekaClassifier wekaClassifier, Dataset trainDataset, Dataset testDataset) throws Exception { Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the train instances and get the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + trainDataset.toString() + ".arff", "-o", MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff" }); Instances train = DataSource.read(MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff"); train.setClassIndex(train.numAttributes() - 1); // Add IDs to the test instances and get the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + testDataset.toString() + ".arff", "-o", MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff" }); Instances test = DataSource.read(MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff"); test.setClassIndex(test.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data test.randomize(random);//from ww w . ja va 2 s . co m // Apply log filter // Filter logFilter = new LogFilter(); // logFilter.setInputFormat(train); // train = Filter.useFilter(train, logFilter); // logFilter.setInputFormat(test); // test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Prepare the output buffer AbstractOutput output = new PlainText(); output.setBuffer(new StringBuffer()); output.setHeader(test); output.setAttributes("first"); Evaluation eval = new Evaluation(train); eval.evaluateModel(filteredClassifier, test, output); // Convert predictions to CSV // Format: inst#, actual, predicted, error, probability, (ID) String[] scores = new String[new Double(eval.numInstances()).intValue()]; double[] probabilities = new double[new Double(eval.numInstances()).intValue()]; for (String line : output.getBuffer().toString().split("\n")) { String[] linesplit = line.split("\\s+"); // If there's been an error, the length of linesplit is 6, otherwise 5, // due to the error flag "+" int id; String expectedValue, classification; double probability; if (line.contains("+")) { id = Integer.parseInt(linesplit[6].substring(1, linesplit[6].length() - 1)); expectedValue = linesplit[2].substring(2); classification = linesplit[3].substring(2); probability = Double.parseDouble(linesplit[5]); } else { id = Integer.parseInt(linesplit[5].substring(1, linesplit[5].length() - 1)); expectedValue = linesplit[2].substring(2); classification = linesplit[3].substring(2); probability = Double.parseDouble(linesplit[4]); } scores[id - 1] = classification; probabilities[id - 1] = probability; } System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); // Output classifications StringBuilder sb = new StringBuilder(); for (String score : scores) sb.append(score.toString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".csv"), sb.toString()); // Output probabilities sb = new StringBuilder(); for (Double probability : probabilities) sb.append(probability.toString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".probabilities.csv"), sb.toString()); // Output predictions FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".predictions.txt"), output.getBuffer().toString()); // Output meta information sb = new StringBuilder(); sb.append(classifier.toString() + LF); sb.append(eval.toSummaryString() + LF); sb.append(eval.toMatrixString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".meta.txt"), sb.toString()); }
From source file:dkpro.similarity.experiments.rte.util.Evaluator.java
License:Open Source License
public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception { // Set parameters int folds = 10; Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o", MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" }); Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff"); data.setClassIndex(data.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data data.randomize(random);/*from www. ja va 2s .c om*/ // Perform cross-validation Instances predictedData = null; Evaluation eval = new Evaluation(data); for (int n = 0; n < folds; n++) { Instances train = data.trainCV(folds, n, random); Instances test = data.testCV(folds, n); // Apply log filter // Filter logFilter = new LogFilter(); // logFilter.setInputFormat(train); // train = Filter.useFilter(train, logFilter); // logFilter.setInputFormat(test); // test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Evaluate eval.evaluateModel(filteredClassifier, test); // Add predictions AddClassification filter = new AddClassification(); filter.setClassifier(classifier); filter.setOutputClassification(true); filter.setOutputDistribution(false); filter.setOutputErrorFlag(true); filter.setInputFormat(train); Filter.useFilter(train, filter); // trains the classifier Instances pred = Filter.useFilter(test, filter); // performs predictions on test set if (predictedData == null) predictedData = new Instances(pred, 0); for (int j = 0; j < pred.numInstances(); j++) predictedData.add(pred.instance(j)); } System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); // Prepare output scores String[] scores = new String[predictedData.numInstances()]; for (Instance predInst : predictedData) { int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1; int valueIdx = predictedData.numAttributes() - 2; String value = predInst.stringValue(predInst.attribute(valueIdx)); scores[id] = value; } // Output classifications StringBuilder sb = new StringBuilder(); for (String score : scores) sb.append(score.toString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".csv"), sb.toString()); // Output prediction arff DataSink.write(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".predicted.arff", predictedData); // Output meta information sb = new StringBuilder(); sb.append(baseClassifier.toString() + LF); sb.append(eval.toSummaryString() + LF); sb.append(eval.toMatrixString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".meta.txt"), sb.toString()); }
From source file:dkpro.similarity.experiments.sts2013.util.Evaluator.java
License:Open Source License
public static void runLinearRegressionCV(Mode mode, Dataset... datasets) throws Exception { for (Dataset dataset : datasets) { // Set parameters int folds = 10; Classifier baseClassifier = new LinearRegression(); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff", "-o", MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff" }); Instances data = DataSource.read( MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff"); data.setClassIndex(data.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data data.randomize(random);// ww w.j a va 2s .co m // Perform cross-validation Instances predictedData = null; Evaluation eval = new Evaluation(data); for (int n = 0; n < folds; n++) { Instances train = data.trainCV(folds, n, random); Instances test = data.testCV(folds, n); // Apply log filter Filter logFilter = new LogFilter(); logFilter.setInputFormat(train); train = Filter.useFilter(train, logFilter); logFilter.setInputFormat(test); test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Evaluate eval.evaluateModel(classifier, test); // Add predictions AddClassification filter = new AddClassification(); filter.setClassifier(classifier); filter.setOutputClassification(true); filter.setOutputDistribution(false); filter.setOutputErrorFlag(true); filter.setInputFormat(train); Filter.useFilter(train, filter); // trains the classifier Instances pred = Filter.useFilter(test, filter); // performs predictions on test set if (predictedData == null) { predictedData = new Instances(pred, 0); } for (int j = 0; j < pred.numInstances(); j++) { predictedData.add(pred.instance(j)); } } // Prepare output scores double[] scores = new double[predictedData.numInstances()]; for (Instance predInst : predictedData) { int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1; int valueIdx = predictedData.numAttributes() - 2; double value = predInst.value(predInst.attribute(valueIdx)); scores[id] = value; // Limit to interval [0;5] if (scores[id] > 5.0) { scores[id] = 5.0; } if (scores[id] < 0.0) { scores[id] = 0.0; } } // Output StringBuilder sb = new StringBuilder(); for (Double score : scores) { sb.append(score.toString() + LF); } FileUtils.writeStringToFile( new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv"), sb.toString()); } }