List of usage examples for weka.core Instances setClassIndex
public void setClassIndex(int classIndex)
From source file:de.fub.maps.project.detector.model.inference.processhandler.CrossValidationProcessHandler.java
License:Open Source License
@Override protected void handle() { Collection<Attribute> attributeList = getInferenceModel().getAttributes(); Instances trainingSet = new Instances("Classes", new ArrayList<Attribute>(attributeList), 9); trainingSet.setClassIndex(0); HashMap<String, HashSet<TrackSegment>> dataset = getInferenceModel().getInput().getTrainingsSet(); for (Entry<String, HashSet<TrackSegment>> entry : dataset.entrySet()) { for (TrackSegment trackSegment : entry.getValue()) { Instance instance = getInstance(entry.getKey(), trackSegment); trainingSet.add(instance);/*from ww w . jav a 2 s . c om*/ } } assert trainingSet.numInstances() > 0 : "Training set is empty and has no instances"; //NO18N evaluate(trainingSet); }
From source file:de.fub.maps.project.detector.model.inference.processhandler.InferenceDataProcessHandler.java
License:Open Source License
@Override protected void handle() { clearResults();//from w w w. j a va2 s .c o m Classifier classifier = getInferenceModel().getClassifier(); HashSet<TrackSegment> inferenceDataSet = getInferenceDataSet(); Collection<Attribute> attributeList = getInferenceModel().getAttributes(); if (!attributeList.isEmpty()) { Set<String> keySet = getInferenceModel().getInput().getTrainingsSet().keySet(); setClassesToView(keySet); Instances unlabeledInstances = new Instances("Unlabeld Tracks", new ArrayList<Attribute>(attributeList), 0); //NO18N unlabeledInstances.setClassIndex(0); ArrayList<TrackSegment> segmentList = new ArrayList<TrackSegment>(); for (TrackSegment segment : inferenceDataSet) { Instance instance = getInstance(segment); unlabeledInstances.add(instance); segmentList.add(segment); } // create copy Instances labeledInstances = new Instances(unlabeledInstances); for (int index = 0; index < labeledInstances.numInstances(); index++) { try { Instance instance = labeledInstances.instance(index); // classify instance double classifyed = classifier.classifyInstance(instance); instance.setClassValue(classifyed); // get class label String value = unlabeledInstances.classAttribute().value((int) classifyed); if (index < segmentList.size()) { instanceToTrackSegmentMap.put(instance, segmentList.get(index)); } // put label and instance to result map put(value, instance); } catch (Exception ex) { Exceptions.printStackTrace(ex); } } // update visw updateVisualRepresentation(); // update result set of the inferenceModel for (Entry<String, List<Instance>> entry : resultMap.entrySet()) { HashSet<TrackSegment> trackSegmentList = new HashSet<TrackSegment>(); for (Instance instance : entry.getValue()) { TrackSegment trackSegment = instanceToTrackSegmentMap.get(instance); if (trackSegment != null) { trackSegmentList.add(trackSegment); } } // only those classes are put into the result data set, which are not empty if (!trackSegmentList.isEmpty()) { getInferenceModel().getResult().put(entry.getKey(), trackSegmentList); } } } else { throw new InferenceModelClassifyException(MessageFormat .format("No attributes available. Attribute list lengeth == {0}", attributeList.size())); } resultMap.clear(); instanceToTrackSegmentMap.clear(); }
From source file:de.fub.maps.project.detector.model.inference.processhandler.SpecialInferenceDataProcessHandler.java
License:Open Source License
@Override protected void handle() { clearResults();// w ww. j ava 2 s. c o m Classifier classifier = getInferenceModel().getClassifier(); Collection<Attribute> attributeList = getInferenceModel().getAttributes(); if (!attributeList.isEmpty()) { Set<String> keySet = getInferenceModel().getInput().getTrainingsSet().keySet(); setClassesToView(keySet); Instances unlabeledInstances = new Instances("Unlabeld Tracks", new ArrayList<Attribute>(attributeList), 0); //NO18N unlabeledInstances.setClassIndex(0); ArrayList<TrackSegment> segmentList = new ArrayList<TrackSegment>(); for (Entry<String, HashSet<TrackSegment>> entry : getInferenceModel().getInput().getTrainingsSet() .entrySet()) { for (TrackSegment segment : entry.getValue()) { segment.setLabel(entry.getKey()); Instance instance = getInstance(segment); unlabeledInstances.add(instance); segmentList.add(segment); } } // create copy Instances labeledInstances = new Instances(unlabeledInstances); for (int index = 0; index < labeledInstances.numInstances(); index++) { try { Instance instance = labeledInstances.instance(index); // classify instance double classifyed = classifier.classifyInstance(instance); instance.setClassValue(classifyed); // get class label String value = unlabeledInstances.classAttribute().value((int) classifyed); if (index < segmentList.size()) { instanceToTrackSegmentMap.put(instance, segmentList.get(index)); } // put label and instance to result map put(value, instance); } catch (Exception ex) { Exceptions.printStackTrace(ex); } } // update visw updateVisualRepresentation(); // update result set of the inferenceModel for (Map.Entry<String, List<Instance>> entry : resultMap.entrySet()) { HashSet<TrackSegment> trackSegmentList = new HashSet<TrackSegment>(); for (Instance instance : entry.getValue()) { TrackSegment trackSegment = instanceToTrackSegmentMap.get(instance); if (trackSegment != null) { trackSegmentList.add(trackSegment); } } // only those classes are put into the result data set, which are not empty if (!trackSegmentList.isEmpty()) { getInferenceModel().getResult().put(entry.getKey(), trackSegmentList); } } } else { throw new InferenceModelClassifyException(MessageFormat .format("No attributes available. Attribute list lengeth == {0}", attributeList.size())); } resultMap.clear(); instanceToTrackSegmentMap.clear(); }
From source file:de.fub.maps.project.detector.model.inference.processhandler.TrainingsDataProcessHandler.java
License:Open Source License
@Override protected void handle() { final ProgressHandle handle = ProgressHandleFactory.createHandle("Trainings"); try {/*from www . j a v a 2s.co m*/ handle.start(); Collection<Attribute> attributeCollection = getInferenceModel().getAttributes(); ArrayList<Attribute> arrayList = new ArrayList<Attribute>(attributeCollection); Instances trainingSet = new Instances("Classes", arrayList, 0); trainingSet.setClassIndex(0); Instances testingSet = new Instances("Classes", arrayList, 0); testingSet.setClassIndex(0); HashMap<String, HashSet<TrackSegment>> dataset = getInferenceModel().getInput().getTrainingsSet(); int datasetCount = 0; for (HashSet<TrackSegment> list : dataset.values()) { for (TrackSegment trackSegment : list) { datasetCount += trackSegment.getWayPointList().size(); } } handle.switchToDeterminate(datasetCount); int trackCount = 0; for (Entry<String, HashSet<TrackSegment>> entry : dataset.entrySet()) { int trainingsSetSize = (int) Math.ceil(entry.getValue().size() * getTrainingsSetRatioParameter()); int index = 0; for (TrackSegment trackSegment : entry.getValue()) { Instance instance = getInstance(entry.getKey(), trackSegment); if (index < trainingsSetSize) { trainingSet.add(instance); } else { testingSet.add(instance); } handle.progress(trackCount++); index++; } } assert trainingSet.numInstances() > 0 : "Training set is empty and has no instances"; //NO18N assert testingSet.numInstances() > 0 : "Testing set is empty and has no instances"; //NO18N handle.switchToIndeterminate(); evaluate(trainingSet, testingSet); } finally { handle.finish(); } }
From source file:de.tudarmstadt.ukp.alignment.framework.combined.WekaMachineLearning.java
License:Apache License
/** * * This method creates a serialized WEKA model file from an .arff file containing the annotated gold standard * * * @param gs_arff the annotated gold standard in an .arff file * @param model output file for the model * @param output_eval if true, the evaluation of the trained classifier is printed (10-fold cross validation) * @throws Exception//from w w w .ja v a 2 s .c o m */ public static void createModelFromGoldstandard(String gs_arff, String model, boolean output_eval) throws Exception { DataSource source = new DataSource(gs_arff); Instances data = source.getDataSet(); if (data.classIndex() == -1) { data.setClassIndex(data.numAttributes() - 1); } Remove rm = new Remove(); rm.setAttributeIndices("1"); // remove ID attribute BayesNet bn = new BayesNet(); //Standard classifier; BNs proved most robust, but of course other classifiers are possible // meta-classifier FilteredClassifier fc = new FilteredClassifier(); fc.setFilter(rm); fc.setClassifier(bn); fc.buildClassifier(data); // build classifier SerializationHelper.write(model, fc); if (output_eval) { Evaluation eval = new Evaluation(data); eval.crossValidateModel(fc, data, 10, new Random(1)); System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); System.out.println(eval.toClassDetailsString()); } }
From source file:de.tudarmstadt.ukp.alignment.framework.combined.WekaMachineLearning.java
License:Apache License
/** * * This method applies a serialized WEKA model file to an unlabeld .arff file for classification * * * @param input_arff the annotated gold standard in an .arff file * @param model output file for the model * @param output output file for evaluation of trained classifier (10-fold cross validation) * @throws Exception/* w w w .jav a 2 s . c o m*/ */ public static void applyModelToUnlabeledArff(String input_arff, String model, String output) throws Exception { DataSource source = new DataSource(input_arff); Instances unlabeled = source.getDataSet(); if (unlabeled.classIndex() == -1) { unlabeled.setClassIndex(unlabeled.numAttributes() - 1); } Remove rm = new Remove(); rm.setAttributeIndices("1"); // remove ID attribute ObjectInputStream ois = new ObjectInputStream(new FileInputStream(model)); Classifier cls = (Classifier) ois.readObject(); ois.close(); // create copy Instances labeled = new Instances(unlabeled); // label instances for (int i = 0; i < unlabeled.numInstances(); i++) { double clsLabel = cls.classifyInstance(unlabeled.instance(i)); labeled.instance(i).setClassValue(clsLabel); } // save labeled data BufferedWriter writer = new BufferedWriter(new FileWriter(output)); writer.write(labeled.toString()); writer.newLine(); writer.flush(); writer.close(); }
From source file:de.tudarmstadt.ukp.dkpro.spelling.experiments.hoo2012.featureextraction.AllFeaturesExtractor.java
License:Apache License
private Instances getInstances(File instancesFile) throws FileNotFoundException, IOException { Instances trainData = null; Reader reader;// www . ja v a2 s. c o m if (instancesFile.getAbsolutePath().endsWith(".gz")) { reader = new BufferedReader( new InputStreamReader(new GZIPInputStream(new FileInputStream(instancesFile)))); } else { reader = new BufferedReader(new FileReader(instancesFile)); } try { trainData = new Instances(reader); trainData.setClassIndex(trainData.numAttributes() - 1); } finally { reader.close(); } return trainData; }
From source file:de.tudarmstadt.ukp.similarity.experiments.coling2012.util.Evaluator.java
License:Open Source License
public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception { // Set parameters int folds = 10; Classifier baseClassifier = getClassifier(wekaClassifier); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o", MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" }); Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff"); data.setClassIndex(data.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data data.randomize(random);//from w w w .j a v a 2s.c om // Perform cross-validation Instances predictedData = null; Evaluation eval = new Evaluation(data); for (int n = 0; n < folds; n++) { Instances train = data.trainCV(folds, n, random); Instances test = data.testCV(folds, n); // Apply log filter // Filter logFilter = new LogFilter(); // logFilter.setInputFormat(train); // train = Filter.useFilter(train, logFilter); // logFilter.setInputFormat(test); // test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Evaluate eval.evaluateModel(filteredClassifier, test); // Add predictions AddClassification filter = new AddClassification(); filter.setClassifier(filteredClassifier); filter.setOutputClassification(true); filter.setOutputDistribution(false); filter.setOutputErrorFlag(true); filter.setInputFormat(train); Filter.useFilter(train, filter); // trains the classifier Instances pred = Filter.useFilter(test, filter); // performs predictions on test set if (predictedData == null) predictedData = new Instances(pred, 0); for (int j = 0; j < pred.numInstances(); j++) predictedData.add(pred.instance(j)); } // Prepare output classification String[] scores = new String[predictedData.numInstances()]; for (Instance predInst : predictedData) { int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1; int valueIdx = predictedData.numAttributes() - 2; String value = predInst.stringValue(predInst.attribute(valueIdx)); scores[id] = value; } // Output StringBuilder sb = new StringBuilder(); for (String score : scores) sb.append(score.toString() + LF); FileUtils.writeStringToFile( new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/output.csv"), sb.toString()); }
From source file:de.ugoe.cs.cpdp.dataselection.DecisionTreeSelection.java
License:Apache License
@Override public void apply(Instances testdata, SetUniqueList<Instances> traindataSet) { final Instances data = characteristicInstances(testdata, traindataSet); final ArrayList<String> attVals = new ArrayList<String>(); attVals.add("same"); attVals.add("more"); attVals.add("less"); final ArrayList<Attribute> atts = new ArrayList<Attribute>(); for (int j = 0; j < data.numAttributes(); j++) { atts.add(new Attribute(data.attribute(j).name(), attVals)); }/*from www . ja v a 2 s. c o m*/ atts.add(new Attribute("score")); Instances similarityData = new Instances("similarity", atts, 0); similarityData.setClassIndex(similarityData.numAttributes() - 1); try { Classifier classifier = new J48(); for (int i = 0; i < traindataSet.size(); i++) { classifier.buildClassifier(traindataSet.get(i)); for (int j = 0; j < traindataSet.size(); j++) { if (i != j) { double[] similarity = new double[data.numAttributes() + 1]; for (int k = 0; k < data.numAttributes(); k++) { if (0.9 * data.get(i + 1).value(k) > data.get(j + 1).value(k)) { similarity[k] = 2.0; } else if (1.1 * data.get(i + 1).value(k) < data.get(j + 1).value(k)) { similarity[k] = 1.0; } else { similarity[k] = 0.0; } } Evaluation eval = new Evaluation(traindataSet.get(j)); eval.evaluateModel(classifier, traindataSet.get(j)); similarity[data.numAttributes()] = eval.fMeasure(1); similarityData.add(new DenseInstance(1.0, similarity)); } } } REPTree repTree = new REPTree(); if (repTree.getNumFolds() > similarityData.size()) { repTree.setNumFolds(similarityData.size()); } repTree.setNumFolds(2); repTree.buildClassifier(similarityData); Instances testTrainSimilarity = new Instances(similarityData); testTrainSimilarity.clear(); for (int i = 0; i < traindataSet.size(); i++) { double[] similarity = new double[data.numAttributes() + 1]; for (int k = 0; k < data.numAttributes(); k++) { if (0.9 * data.get(0).value(k) > data.get(i + 1).value(k)) { similarity[k] = 2.0; } else if (1.1 * data.get(0).value(k) < data.get(i + 1).value(k)) { similarity[k] = 1.0; } else { similarity[k] = 0.0; } } testTrainSimilarity.add(new DenseInstance(1.0, similarity)); } int bestScoringProductIndex = -1; double maxScore = Double.MIN_VALUE; for (int i = 0; i < traindataSet.size(); i++) { double score = repTree.classifyInstance(testTrainSimilarity.get(i)); if (score > maxScore) { maxScore = score; bestScoringProductIndex = i; } } Instances bestScoringProduct = traindataSet.get(bestScoringProductIndex); traindataSet.clear(); traindataSet.add(bestScoringProduct); } catch (Exception e) { Console.printerr("failure during DecisionTreeSelection: " + e.getMessage()); throw new RuntimeException(e); } }
From source file:de.ugoe.cs.cpdp.loader.NetgeneLoader.java
License:Apache License
@Override public Instances load(File fileMetricsFile) { // first determine all files String path = fileMetricsFile.getParentFile().getAbsolutePath(); String project = fileMetricsFile.getName().split("_")[0]; File bugsFile = new File(path + "/" + project + "_bugs_per_file.csv"); File networkMetrics = new File(path + "/" + project + "_network_metrics.csv"); Instances metricsData = null; try {/*from ww w . j av a 2 s .c om*/ CSVLoader wekaCsvLoader = new CSVLoader(); wekaCsvLoader.setSource(fileMetricsFile); metricsData = wekaCsvLoader.getDataSet(); wekaCsvLoader.setSource(bugsFile); Instances bugsData = wekaCsvLoader.getDataSet(); wekaCsvLoader.setSource(networkMetrics); Instances networkData = wekaCsvLoader.getDataSet(); metricsData.setRelationName(project); // fix nominal attributes (i.e., NA values) for (int j = 2; j < networkData.numAttributes(); j++) { if (networkData.attribute(j).isNominal()) { String attributeName = networkData.attribute(j).name(); double[] tmpVals = new double[networkData.size()]; // get temporary values for (int i = 0; i < networkData.size(); i++) { Instance inst = networkData.instance(i); if (!inst.isMissing(j)) { String val = networkData.instance(i).stringValue(j); try { tmpVals[i] = Double.parseDouble(val); } catch (NumberFormatException e) { // not a number, using 0.0; tmpVals[i] = 0.0; } } else { tmpVals[i] = 0.0; } } // replace attribute networkData.deleteAttributeAt(j); networkData.insertAttributeAt(new Attribute(attributeName), j); for (int i = 0; i < networkData.size(); i++) { networkData.instance(i).setValue(j, tmpVals[i]); } } } // fix string attributes for (int j = 2; j < networkData.numAttributes(); j++) { if (networkData.attribute(j).isString()) { String attributeName = networkData.attribute(j).name(); double[] tmpVals = new double[networkData.size()]; // get temporary values for (int i = 0; i < networkData.size(); i++) { Instance inst = networkData.instance(i); if (!inst.isMissing(j)) { String val = networkData.instance(i).stringValue(j); try { tmpVals[i] = Double.parseDouble(val); } catch (NumberFormatException e) { // not a number, using 0.0; tmpVals[i] = 0.0; } } else { tmpVals[i] = 0.0; } } // replace attribute networkData.deleteAttributeAt(j); networkData.insertAttributeAt(new Attribute(attributeName), j); for (int i = 0; i < networkData.size(); i++) { networkData.instance(i).setValue(j, tmpVals[i]); } } } Map<String, Integer> filenames = new HashMap<>(); for (int j = 0; j < metricsData.size(); j++) { filenames.put(metricsData.instance(j).stringValue(0), j); } // merge with network data int attributeIndex; for (int j = 2; j < networkData.numAttributes(); j++) { attributeIndex = metricsData.numAttributes(); metricsData.insertAttributeAt(networkData.attribute(j), attributeIndex); for (int i = 0; i < networkData.size(); i++) { Integer instanceIndex = filenames.get(networkData.instance(i).stringValue(1)); if (instanceIndex != null) { metricsData.instance(instanceIndex).setValue(attributeIndex, networkData.instance(i).value(j)); } } } // add bug information attributeIndex = metricsData.numAttributes(); final ArrayList<String> classAttVals = new ArrayList<String>(); classAttVals.add("0"); classAttVals.add("1"); final Attribute classAtt = new Attribute("bug", classAttVals); metricsData.insertAttributeAt(classAtt, attributeIndex); for (int i = 0; i < bugsData.size(); i++) { if (bugsData.instance(i).value(2) > 0.0d) { Integer instanceIndex = filenames.get(bugsData.instance(i).stringValue(1)); if (instanceIndex != null) { metricsData.instance(instanceIndex).setValue(attributeIndex, 1.0); } } } // remove filenames metricsData.deleteAttributeAt(0); Attribute eigenvector = metricsData.attribute("eigenvector"); if (eigenvector != null) { for (int j = 0; j < metricsData.numAttributes(); j++) { if (metricsData.attribute(j) == eigenvector) { metricsData.deleteAttributeAt(j); } } } metricsData.setClassIndex(metricsData.numAttributes() - 1); // set all missing values to 0 for (int i = 0; i < metricsData.size(); i++) { for (int j = 0; j < metricsData.numAttributes(); j++) { if (metricsData.instance(i).isMissing(j)) { metricsData.instance(i).setValue(j, 0.0d); } } } } catch (IOException e) { Console.traceln(Level.SEVERE, "failure reading file: " + e.getMessage()); metricsData = null; } return metricsData; }