List of usage examples for weka.core Instances attribute
publicAttribute attribute(String name)
From source file:asap.CrossValidation.java
/** * * @param dataInput//from w ww. ja va 2 s.c o m * @param classIndex * @param removeIndices * @param cls * @param seed * @param folds * @param modelOutputFile * @return * @throws Exception */ public static String performCrossValidation(String dataInput, String classIndex, String removeIndices, AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception { PerformanceCounters.startTimer("cross-validation ST"); PerformanceCounters.startTimer("cross-validation init ST"); // loads data and set class index Instances data = DataSource.read(dataInput); String clsIndex = classIndex; switch (clsIndex) { case "first": data.setClassIndex(0); break; case "last": data.setClassIndex(data.numAttributes() - 1); break; default: try { data.setClassIndex(Integer.parseInt(clsIndex) - 1); } catch (NumberFormatException e) { data.setClassIndex(data.attribute(clsIndex).index()); } break; } Remove removeFilter = new Remove(); removeFilter.setAttributeIndices(removeIndices); removeFilter.setInputFormat(data); data = Filter.useFilter(data, removeFilter); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval = new Evaluation(randData); Instances trainSets[] = new Instances[folds]; Instances testSets[] = new Instances[folds]; Classifier foldCls[] = new Classifier[folds]; for (int n = 0; n < folds; n++) { trainSets[n] = randData.trainCV(folds, n); testSets[n] = randData.testCV(folds, n); foldCls[n] = AbstractClassifier.makeCopy(cls); } PerformanceCounters.stopTimer("cross-validation init ST"); PerformanceCounters.startTimer("cross-validation folds+train ST"); //paralelize!!:-------------------------------------------------------------- for (int n = 0; n < folds; n++) { Instances train = trainSets[n]; Instances test = testSets[n]; // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier Classifier clsCopy = foldCls[n]; clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); } cls.buildClassifier(data); //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train ST"); PerformanceCounters.startTimer("cross-validation post ST"); // output evaluation String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (!modelOutputFile.isEmpty()) { SerializationHelper.write(modelOutputFile, cls); } PerformanceCounters.stopTimer("cross-validation post ST"); PerformanceCounters.stopTimer("cross-validation ST"); return out; }
From source file:asap.CrossValidation.java
/** * * @param dataInput/*from w w w . jav a 2 s. c o m*/ * @param classIndex * @param removeIndices * @param cls * @param seed * @param folds * @param modelOutputFile * @return * @throws Exception */ public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices, AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception { PerformanceCounters.startTimer("cross-validation MT"); PerformanceCounters.startTimer("cross-validation init MT"); // loads data and set class index Instances data = DataSource.read(dataInput); String clsIndex = classIndex; switch (clsIndex) { case "first": data.setClassIndex(0); break; case "last": data.setClassIndex(data.numAttributes() - 1); break; default: try { data.setClassIndex(Integer.parseInt(clsIndex) - 1); } catch (NumberFormatException e) { data.setClassIndex(data.attribute(clsIndex).index()); } break; } Remove removeFilter = new Remove(); removeFilter.setAttributeIndices(removeIndices); removeFilter.setInputFormat(data); data = Filter.useFilter(data, removeFilter); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval = new Evaluation(randData); List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>()); List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>()); for (int n = 0; n < folds; n++) { foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n), AbstractClassifier.makeCopy(cls))); if (n < Config.getNumThreads() - 1) { Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval)); foldThreads.add(foldThread); } } PerformanceCounters.stopTimer("cross-validation init MT"); PerformanceCounters.startTimer("cross-validation folds+train MT"); //paralelize!!:-------------------------------------------------------------- if (Config.getNumThreads() > 1) { for (Thread foldThread : foldThreads) { foldThread.start(); } } else { //use the current thread to run the cross-validation instead of using the Thread instance created here: new CrossValidationFoldThread(0, foldSets, eval).run(); } cls.buildClassifier(data); for (Thread foldThread : foldThreads) { foldThread.join(); } //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train MT"); PerformanceCounters.startTimer("cross-validation post MT"); // evaluation for output: String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (!modelOutputFile.isEmpty()) { SerializationHelper.write(modelOutputFile, cls); } PerformanceCounters.stopTimer("cross-validation post MT"); PerformanceCounters.stopTimer("cross-validation MT"); return out; }
From source file:asap.PostProcess.java
private void formatPredictions(Instances instances, double[] predictions, String[] columnNames, int predictionsColumnIndex, String predictionsColumnName, String columnSeparator, String outputFilename, boolean writeColumnsHeaderLine) { PerformanceCounters.startTimer("formatPredictions"); System.out.println("Formatting predictions to file " + outputFilename + "..."); File outputFile = new File(outputFilename); PrintWriter writer;// w w w . ja v a2s . c om try { outputFile.getParentFile().mkdirs(); outputFile.createNewFile(); writer = new PrintWriter(outputFile, "UTF-8"); } catch (IOException ex) { Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex); return; } StringBuilder sb = new StringBuilder(); DecimalFormat df = new DecimalFormat("#.#", new DecimalFormatSymbols(Locale.US)); df.setMaximumFractionDigits(3); int i = -1; if (!writeColumnsHeaderLine) { i = 0; } for (; i < instances.numInstances(); i++) { sb.delete(0, sb.length()); for (int j = 0; j < columnNames.length; j++) { if (j > 0) { sb.append(columnSeparator); } if (j == predictionsColumnIndex) { if (i < 0) { sb.append(predictionsColumnName); } else { sb.append(df.format(predictions[i])); } sb.append(columnSeparator); } if (i < 0) { sb.append(columnNames[j]); } else { if (columnNames[j].toLowerCase().contains("id")) { Attribute attribute = instances.attribute(columnNames[j]); if (attribute != null) { sb.append((int) instances.instance(i).value(attribute.index())); } else { sb.append(0); } } else { Attribute attribute = instances.attribute(columnNames[j]); if (attribute != null) { sb.append(instances.instance(i).value(attribute.index())); } else { sb.append(df.format(0d)); } } } } if (columnNames.length == predictionsColumnIndex) { sb.append(columnSeparator); if (i < 0) { sb.append(predictionsColumnName); } else { sb.append(df.format(predictions[i])); } } writer.println(sb); } writer.flush(); writer.close(); System.out.println("\tdone."); PerformanceCounters.stopTimer("formatPredictions"); }
From source file:asap.PostProcess.java
private void writePredictionErrors(Instances instances, double[] predictions, String errorsFilename) { TreeSet<PredictionError> errors = new TreeSet<>(); for (int i = 0; i < predictions.length; i++) { double prediction = predictions[i]; double expected = instances.get(i).classValue(); int pairId = (int) instances.get(i).value(instances.attribute("pair_ID")); String sourceFile = instances.get(i).stringValue(instances.attribute("source_file")); PredictionError pe = new PredictionError(prediction, expected, pairId, sourceFile, instances.get(i)); //if (pe.getError()>=0.5d) errors.add(pe);/* w ww . j a va2 s. c o m*/ } StringBuilder sb = new StringBuilder(); for (PredictionError error : errors) { sb.append(error.toString()).append("\n"); } File f = new File(errorsFilename); try (FileOutputStream fos = new FileOutputStream(f)) { fos.write(sb.toString().getBytes()); } catch (IOException ex) { Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:asap.PostProcess.java
public void loadTrainingDataStream(PreProcessOutputStream pposTrainingData) { Instances instancesTrainingSet; DataSource source = new DataSource(pposTrainingData); try {/*from w ww . j a v a 2 s . com*/ instancesTrainingSet = source.getDataSet(); } catch (Exception ex) { Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex); return; } // setting class attribute if the data format does not provide this information if (instancesTrainingSet.classIndex() == -1) { instancesTrainingSet.setClass(instancesTrainingSet.attribute("gold_standard")); } for (String wekaModelsCmd : Config.getWekaModelsCmd()) { String[] classifierCmd; try { classifierCmd = Utils.splitOptions(wekaModelsCmd); } catch (Exception ex) { Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex); continue; } String classname = classifierCmd[0]; classifierCmd[0] = ""; try { AbstractClassifier cl = (AbstractClassifier) Utils.forName(Classifier.class, classname, classifierCmd); // String modelName = String.format("%s%s%s%s.model", modelDirectory, File.separatorChar, i, classname); // System.out.println(String.format("\tBuilding model %s (%s) and doing cross-validation...", i++, modelName)); // System.out.println(CrossValidation.performCrossValidationMT(trainSet, cl, Config.getCrossValidationSeed(), Config.getCrossValidationFolds(), modelName)); systems.add(new NLPSystem(cl, instancesTrainingSet, null)); System.out.println("\tAdded system " + systems.get(systems.size() - 1).shortName()); } catch (Exception ex) { Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex); } } }
From source file:asap.PostProcess.java
public void loadEvaluationDataStream(PreProcessOutputStream pposEvaluationData) { Instances instancesEvaluationSet; DataSource source = new DataSource(pposEvaluationData); try {/*from w w w. ja va2 s. c o m*/ instancesEvaluationSet = source.getDataSet(); } catch (Exception ex) { Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex); return; } // setting class attribute if the data format does not provide this information if (instancesEvaluationSet.classIndex() == -1) { instancesEvaluationSet.setClass(instancesEvaluationSet.attribute("gold_standard")); } for (NLPSystem system : systems) { system.setEvaluationSet(instancesEvaluationSet); } }
From source file:at.aictopic1.sentimentanalysis.machinelearning.impl.TwitterClassifer.java
public Instances loadTrainingData() { try {//ww w .j a v a2 s. c o m //DataSource source = new DataSource("C:\\Users\\David\\Documents\\Datalogi\\TU Wien\\2014W_Advanced Internet Computing\\Labs\\aic_group2_topic1\\Other Stuff\\training_dataset.arff"); DataSource source = new DataSource( "C:\\Users\\David\\Documents\\Datalogi\\TU Wien\\2014W_Advanced Internet Computing\\Labs\\Data sets\\labelled.arff"); // System.out.println("Data Structure pre processing: " + source.getStructure()); Instances data = source.getDataSet(); // Get and save the dataStructure of the dataset dataStructure = source.getStructure(); try { // Save the datastructure to file // serialize dataStructure weka.core.SerializationHelper.write(modelDir + algorithm + ".dataStruct", dataStructure); } catch (Exception ex) { Logger.getLogger(TwitterClassifer.class.getName()).log(Level.SEVERE, null, ex); } // Set class index data.setClassIndex(2); // Giving attributes unique names before converting strings data.renameAttribute(2, "class_attr"); data.renameAttribute(0, "twitter_id"); // Convert String attribute to Words using filter StringToWordVector filter = new StringToWordVector(); filter.setInputFormat(data); Instances filteredData = Filter.useFilter(data, filter); System.out.println("filteredData struct: " + filteredData.attribute(0)); System.out.println("filteredData struct: " + filteredData.attribute(1)); System.out.println("filteredData struct: " + filteredData.attribute(2)); return filteredData; } catch (Exception ex) { System.out.println("Error loading training set: " + ex.toString()); return null; //Logger.getLogger(Trainer.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:at.aictopic1.sentimentanalysis.machinelearning.impl.TwitterClassifer.java
public Integer classify(Tweet[] tweets) { // TEST/* w w w .j a v a2s . c om*/ // Generate two tweet examples Tweet exOne = new Tweet("This is good and fantastic"); exOne.setPreprocessedText("This is good and fantastic"); Tweet exTwo = new Tweet("Horribly, terribly bad and more"); exTwo.setPreprocessedText("Horribly, terribly bad and more"); Tweet exThree = new Tweet( "I want to update lj and read my friends list, but I\\'m groggy and sick and blargh."); exThree.setPreprocessedText( "I want to update lj and read my friends list, but I\\'m groggy and sick and blargh."); Tweet exFour = new Tweet("bad hate worst sick"); exFour.setPreprocessedText("bad hate worst sick"); tweets = new Tweet[] { exOne, exTwo, exThree, exFour }; // TEST // Load model // loadModel(); // Convert Tweet to Instance type // Get String Data // Create attributes for the Instances set Attribute twitter_id = new Attribute("twitter_id"); // Attribute body = new Attribute("body"); FastVector classVal = new FastVector(2); classVal.addElement("pos"); classVal.addElement("neg"); Attribute class_attr = new Attribute("class_attr", classVal); // Add them to a list FastVector attrVector = new FastVector(3); // attrVector.addElement(twitter_id); // attrVector.addElement(new Attribute("body", (FastVector) null)); // attrVector.addElement(class_attr); // Get the number of tweets and then create predictSet int numTweets = tweets.length; Enumeration structAttrs = dataStructure.enumerateAttributes(); // ArrayList<Attribute> attrList = new ArrayList<Attribute>(dataStructure.numAttributes()); while (structAttrs.hasMoreElements()) { attrVector.addElement((Attribute) structAttrs.nextElement()); } Instances predictSet = new Instances("predictInstances", attrVector, numTweets); // Instances predictSet = new Instances(dataStructure); predictSet.setClassIndex(2); // init prediction double prediction = -1; System.out.println("PredictSet matches source structure: " + predictSet.equalHeaders(dataStructure)); System.out.println("PredSet struct: " + predictSet.attribute(0)); System.out.println("PredSet struct: " + predictSet.attribute(1)); System.out.println("PredSet struct: " + predictSet.attribute(2)); // Array to return predictions //double[] tweetsClassified = new double[2][numTweets]; //List<Integer, Double> tweetsClass = new ArrayList<Integer, Double>(numTweets); for (int i = 0; i < numTweets; i++) { String content = (String) tweets[i].getPreprocessedText(); System.out.println("Tweet content: " + content); // attrList Instance tweetInstance = new Instance(predictSet.numAttributes()); tweetInstance.setDataset(predictSet); tweetInstance.setValue(predictSet.attribute(0), i); tweetInstance.setValue(predictSet.attribute(1), content); tweetInstance.setClassMissing(); predictSet.add(tweetInstance); try { // Apply string filter StringToWordVector filter = new StringToWordVector(); filter.setInputFormat(predictSet); Instances filteredPredictSet = Filter.useFilter(predictSet, filter); // Apply model prediction = trainedModel.classifyInstance(filteredPredictSet.instance(i)); filteredPredictSet.instance(i).setClassValue(prediction); System.out.println("Classification: " + filteredPredictSet.instance(i).toString()); System.out.println("Prediction: " + prediction); } catch (Exception ex) { Logger.getLogger(TwitterClassifer.class.getName()).log(Level.SEVERE, null, ex); } } return 0; }
From source file:bme.mace.logicdomain.Evaluation.java
License:Open Source License
/** * Returns the area under ROC for those predictions that have been collected * in the evaluateClassifier(Classifier, Instances) method. Returns * Instance.missingValue() if the area is not available. * /* w w w .j a v a 2 s . c o m*/ * @param classIndex the index of the class to consider as "positive" * @return the area under the ROC curve or not a number */ public double areaUnderROC(int classIndex) { // Check if any predictions have been collected if (m_Predictions == null) { return Instance.missingValue(); } else { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_Predictions, classIndex); double rocArea = ThresholdCurve.getROCArea(result); if (rocArea < 0.5) { rocArea = 1 - rocArea; } int tpIndex = result.attribute(ThresholdCurve.TP_RATE_NAME).index(); int fpIndex = result.attribute(ThresholdCurve.FP_RATE_NAME).index(); double[] tpRate = result.attributeToDoubleArray(tpIndex); double[] fpRate = result.attributeToDoubleArray(fpIndex); try { FileWriter fw; if (classIndex == 0) fw = new FileWriter("C://1.csv", true); else fw = new FileWriter("C://1.csv", true); BufferedWriter bw = new BufferedWriter(fw); int length = fpRate.length; for (int i = 255; i >= 0; i--) { int index = i * (length - 1) / 255; bw.write(fpRate[index] + ","); } bw.write("\n"); for (int i = 255; i >= 0; i--) { int index = i * (length - 1) / 255; bw.write(tpRate[index] + ","); } bw.write("\n"); bw.close(); fw.close(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } return rocArea; } }
From source file:bme.mace.logicdomain.Evaluation.java
License:Open Source License
/** * Prints the header for the predictions output into a supplied StringBuffer * // w w w .j a v a2s . co m * @param test structure of the test set to print predictions for * @param attributesToOutput indices of the attributes to output * @param printDistribution prints the complete distribution for nominal * attributes, not just the predicted value * @param text the StringBuffer to print to */ protected static void printClassificationsHeader(Instances test, Range attributesToOutput, boolean printDistribution, StringBuffer text) { // print header if (test.classAttribute().isNominal()) { if (printDistribution) { text.append(" inst# actual predicted error distribution"); } else { text.append(" inst# actual predicted error prediction"); } } else { text.append(" inst# actual predicted error"); } if (attributesToOutput != null) { attributesToOutput.setUpper(test.numAttributes() - 1); text.append(" ("); boolean first = true; for (int i = 0; i < test.numAttributes(); i++) { if (i == test.classIndex()) { continue; } if (attributesToOutput.isInRange(i)) { if (!first) { text.append(","); } text.append(test.attribute(i).name()); first = false; } } text.append(")"); } text.append("\n"); }