Example usage for weka.core Instances setClassIndex

List of usage examples for weka.core Instances setClassIndex

Introduction

In this page you can find the example usage for weka.core Instances setClassIndex.

Prototype

public void setClassIndex(int classIndex) 

Source Link

Document

Sets the class index of the set.

Usage

From source file:de.fub.maps.project.detector.model.inference.processhandler.CrossValidationProcessHandler.java

License:Open Source License

@Override
protected void handle() {
    Collection<Attribute> attributeList = getInferenceModel().getAttributes();
    Instances trainingSet = new Instances("Classes", new ArrayList<Attribute>(attributeList), 9);
    trainingSet.setClassIndex(0);

    HashMap<String, HashSet<TrackSegment>> dataset = getInferenceModel().getInput().getTrainingsSet();

    for (Entry<String, HashSet<TrackSegment>> entry : dataset.entrySet()) {
        for (TrackSegment trackSegment : entry.getValue()) {
            Instance instance = getInstance(entry.getKey(), trackSegment);
            trainingSet.add(instance);/*from ww w  . jav  a 2  s . c  om*/
        }
    }

    assert trainingSet.numInstances() > 0 : "Training set is empty and has no instances"; //NO18N

    evaluate(trainingSet);
}

From source file:de.fub.maps.project.detector.model.inference.processhandler.InferenceDataProcessHandler.java

License:Open Source License

@Override
protected void handle() {
    clearResults();//from w w w. j  a  va2  s .c o m

    Classifier classifier = getInferenceModel().getClassifier();
    HashSet<TrackSegment> inferenceDataSet = getInferenceDataSet();
    Collection<Attribute> attributeList = getInferenceModel().getAttributes();

    if (!attributeList.isEmpty()) {
        Set<String> keySet = getInferenceModel().getInput().getTrainingsSet().keySet();
        setClassesToView(keySet);

        Instances unlabeledInstances = new Instances("Unlabeld Tracks", new ArrayList<Attribute>(attributeList),
                0); //NO18N
        unlabeledInstances.setClassIndex(0);

        ArrayList<TrackSegment> segmentList = new ArrayList<TrackSegment>();
        for (TrackSegment segment : inferenceDataSet) {
            Instance instance = getInstance(segment);
            unlabeledInstances.add(instance);
            segmentList.add(segment);
        }

        // create copy
        Instances labeledInstances = new Instances(unlabeledInstances);

        for (int index = 0; index < labeledInstances.numInstances(); index++) {
            try {
                Instance instance = labeledInstances.instance(index);

                // classify instance
                double classifyed = classifier.classifyInstance(instance);
                instance.setClassValue(classifyed);

                // get class label
                String value = unlabeledInstances.classAttribute().value((int) classifyed);

                if (index < segmentList.size()) {
                    instanceToTrackSegmentMap.put(instance, segmentList.get(index));
                }

                // put label and instance to result map
                put(value, instance);

            } catch (Exception ex) {
                Exceptions.printStackTrace(ex);
            }
        }

        // update visw
        updateVisualRepresentation();

        // update result set of the inferenceModel
        for (Entry<String, List<Instance>> entry : resultMap.entrySet()) {
            HashSet<TrackSegment> trackSegmentList = new HashSet<TrackSegment>();
            for (Instance instance : entry.getValue()) {
                TrackSegment trackSegment = instanceToTrackSegmentMap.get(instance);
                if (trackSegment != null) {
                    trackSegmentList.add(trackSegment);
                }
            }

            // only those classes are put into  the result data set, which are not empty
            if (!trackSegmentList.isEmpty()) {
                getInferenceModel().getResult().put(entry.getKey(), trackSegmentList);
            }
        }
    } else {
        throw new InferenceModelClassifyException(MessageFormat
                .format("No attributes available. Attribute list lengeth == {0}", attributeList.size()));
    }
    resultMap.clear();
    instanceToTrackSegmentMap.clear();
}

From source file:de.fub.maps.project.detector.model.inference.processhandler.SpecialInferenceDataProcessHandler.java

License:Open Source License

@Override
protected void handle() {
    clearResults();//  w  ww. j ava  2  s. c o  m

    Classifier classifier = getInferenceModel().getClassifier();
    Collection<Attribute> attributeList = getInferenceModel().getAttributes();

    if (!attributeList.isEmpty()) {
        Set<String> keySet = getInferenceModel().getInput().getTrainingsSet().keySet();
        setClassesToView(keySet);

        Instances unlabeledInstances = new Instances("Unlabeld Tracks", new ArrayList<Attribute>(attributeList),
                0); //NO18N
        unlabeledInstances.setClassIndex(0);

        ArrayList<TrackSegment> segmentList = new ArrayList<TrackSegment>();
        for (Entry<String, HashSet<TrackSegment>> entry : getInferenceModel().getInput().getTrainingsSet()
                .entrySet()) {
            for (TrackSegment segment : entry.getValue()) {
                segment.setLabel(entry.getKey());
                Instance instance = getInstance(segment);
                unlabeledInstances.add(instance);
                segmentList.add(segment);
            }
        }

        // create copy
        Instances labeledInstances = new Instances(unlabeledInstances);

        for (int index = 0; index < labeledInstances.numInstances(); index++) {
            try {
                Instance instance = labeledInstances.instance(index);

                // classify instance
                double classifyed = classifier.classifyInstance(instance);
                instance.setClassValue(classifyed);

                // get class label
                String value = unlabeledInstances.classAttribute().value((int) classifyed);

                if (index < segmentList.size()) {
                    instanceToTrackSegmentMap.put(instance, segmentList.get(index));
                }

                // put label and instance to result map
                put(value, instance);

            } catch (Exception ex) {
                Exceptions.printStackTrace(ex);
            }
        }

        // update visw
        updateVisualRepresentation();

        // update result set of the inferenceModel
        for (Map.Entry<String, List<Instance>> entry : resultMap.entrySet()) {
            HashSet<TrackSegment> trackSegmentList = new HashSet<TrackSegment>();
            for (Instance instance : entry.getValue()) {
                TrackSegment trackSegment = instanceToTrackSegmentMap.get(instance);
                if (trackSegment != null) {
                    trackSegmentList.add(trackSegment);
                }
            }

            // only those classes are put into  the result data set, which are not empty
            if (!trackSegmentList.isEmpty()) {
                getInferenceModel().getResult().put(entry.getKey(), trackSegmentList);
            }
        }
    } else {
        throw new InferenceModelClassifyException(MessageFormat
                .format("No attributes available. Attribute list lengeth == {0}", attributeList.size()));
    }
    resultMap.clear();
    instanceToTrackSegmentMap.clear();
}

From source file:de.fub.maps.project.detector.model.inference.processhandler.TrainingsDataProcessHandler.java

License:Open Source License

@Override
protected void handle() {
    final ProgressHandle handle = ProgressHandleFactory.createHandle("Trainings");
    try {/*from   www  .  j a v a 2s.co  m*/
        handle.start();
        Collection<Attribute> attributeCollection = getInferenceModel().getAttributes();
        ArrayList<Attribute> arrayList = new ArrayList<Attribute>(attributeCollection);
        Instances trainingSet = new Instances("Classes", arrayList, 0);
        trainingSet.setClassIndex(0);

        Instances testingSet = new Instances("Classes", arrayList, 0);
        testingSet.setClassIndex(0);

        HashMap<String, HashSet<TrackSegment>> dataset = getInferenceModel().getInput().getTrainingsSet();

        int datasetCount = 0;
        for (HashSet<TrackSegment> list : dataset.values()) {
            for (TrackSegment trackSegment : list) {
                datasetCount += trackSegment.getWayPointList().size();
            }
        }
        handle.switchToDeterminate(datasetCount);
        int trackCount = 0;
        for (Entry<String, HashSet<TrackSegment>> entry : dataset.entrySet()) {

            int trainingsSetSize = (int) Math.ceil(entry.getValue().size() * getTrainingsSetRatioParameter());
            int index = 0;
            for (TrackSegment trackSegment : entry.getValue()) {
                Instance instance = getInstance(entry.getKey(), trackSegment);

                if (index < trainingsSetSize) {
                    trainingSet.add(instance);
                } else {
                    testingSet.add(instance);
                }
                handle.progress(trackCount++);
                index++;
            }
        }

        assert trainingSet.numInstances() > 0 : "Training set is empty and has no instances"; //NO18N
        assert testingSet.numInstances() > 0 : "Testing set is empty and has no instances"; //NO18N
        handle.switchToIndeterminate();
        evaluate(trainingSet, testingSet);
    } finally {
        handle.finish();
    }
}

From source file:de.tudarmstadt.ukp.alignment.framework.combined.WekaMachineLearning.java

License:Apache License

/**
 *
 * This method creates a serialized WEKA model file from an .arff file containing the annotated gold standard
 *
 *
 * @param gs_arff the annotated gold standard in an .arff file
 * @param model output file for the model
 * @param output_eval if true, the evaluation of the trained classifier is printed (10-fold cross validation)
 * @throws Exception//from   w  w w  .ja  v a  2 s  .c  o  m
 */

public static void createModelFromGoldstandard(String gs_arff, String model, boolean output_eval)
        throws Exception {
    DataSource source = new DataSource(gs_arff);
    Instances data = source.getDataSet();
    if (data.classIndex() == -1) {
        data.setClassIndex(data.numAttributes() - 1);
    }

    Remove rm = new Remove();
    rm.setAttributeIndices("1"); // remove ID  attribute

    BayesNet bn = new BayesNet(); //Standard classifier; BNs proved most robust, but of course other classifiers are possible
    // meta-classifier
    FilteredClassifier fc = new FilteredClassifier();
    fc.setFilter(rm);
    fc.setClassifier(bn);
    fc.buildClassifier(data); // build classifier
    SerializationHelper.write(model, fc);
    if (output_eval) {
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(fc, data, 10, new Random(1));
        System.out.println(eval.toSummaryString());
        System.out.println(eval.toMatrixString());
        System.out.println(eval.toClassDetailsString());
    }

}

From source file:de.tudarmstadt.ukp.alignment.framework.combined.WekaMachineLearning.java

License:Apache License

/**
 *
 * This method applies a serialized WEKA model file to an unlabeld .arff file for classification
 *
 *
 * @param input_arff the annotated gold standard in an .arff file
 * @param model output file for the model
 * @param output output file for evaluation of trained classifier (10-fold cross validation)
 * @throws Exception/*  w w  w  .jav a 2 s  . c  o  m*/
 */

public static void applyModelToUnlabeledArff(String input_arff, String model, String output) throws Exception {
    DataSource source = new DataSource(input_arff);
    Instances unlabeled = source.getDataSet();
    if (unlabeled.classIndex() == -1) {
        unlabeled.setClassIndex(unlabeled.numAttributes() - 1);
    }

    Remove rm = new Remove();
    rm.setAttributeIndices("1"); // remove ID  attribute

    ObjectInputStream ois = new ObjectInputStream(new FileInputStream(model));
    Classifier cls = (Classifier) ois.readObject();
    ois.close();
    // create copy
    Instances labeled = new Instances(unlabeled);

    // label instances
    for (int i = 0; i < unlabeled.numInstances(); i++) {
        double clsLabel = cls.classifyInstance(unlabeled.instance(i));
        labeled.instance(i).setClassValue(clsLabel);
    }
    // save labeled data
    BufferedWriter writer = new BufferedWriter(new FileWriter(output));
    writer.write(labeled.toString());
    writer.newLine();
    writer.flush();
    writer.close();

}

From source file:de.tudarmstadt.ukp.dkpro.spelling.experiments.hoo2012.featureextraction.AllFeaturesExtractor.java

License:Apache License

private Instances getInstances(File instancesFile) throws FileNotFoundException, IOException {
    Instances trainData = null;
    Reader reader;//  www .  ja  v  a2  s.  c  o m
    if (instancesFile.getAbsolutePath().endsWith(".gz")) {
        reader = new BufferedReader(
                new InputStreamReader(new GZIPInputStream(new FileInputStream(instancesFile))));
    } else {
        reader = new BufferedReader(new FileReader(instancesFile));
    }

    try {
        trainData = new Instances(reader);
        trainData.setClassIndex(trainData.numAttributes() - 1);
    } finally {
        reader.close();
    }

    return trainData;
}

From source file:de.tudarmstadt.ukp.similarity.experiments.coling2012.util.Evaluator.java

License:Open Source License

public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception {
    // Set parameters
    int folds = 10;
    Classifier baseClassifier = getClassifier(wekaClassifier);

    // Set up the random number generator
    long seed = new Date().getTime();
    Random random = new Random(seed);

    // Add IDs to the instances
    AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o",
            MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" });
    Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff");
    data.setClassIndex(data.numAttributes() - 1);

    // Instantiate the Remove filter
    Remove removeIDFilter = new Remove();
    removeIDFilter.setAttributeIndices("first");

    // Randomize the data
    data.randomize(random);//from  w w w .j  a  v a  2s.c om

    // Perform cross-validation
    Instances predictedData = null;
    Evaluation eval = new Evaluation(data);

    for (int n = 0; n < folds; n++) {
        Instances train = data.trainCV(folds, n, random);
        Instances test = data.testCV(folds, n);

        // Apply log filter
        //          Filter logFilter = new LogFilter();
        //           logFilter.setInputFormat(train);
        //           train = Filter.useFilter(train, logFilter);        
        //           logFilter.setInputFormat(test);
        //           test = Filter.useFilter(test, logFilter);

        // Copy the classifier
        Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

        // Instantiate the FilteredClassifier
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(removeIDFilter);
        filteredClassifier.setClassifier(classifier);

        // Build the classifier
        filteredClassifier.buildClassifier(train);

        // Evaluate
        eval.evaluateModel(filteredClassifier, test);

        // Add predictions
        AddClassification filter = new AddClassification();
        filter.setClassifier(filteredClassifier);
        filter.setOutputClassification(true);
        filter.setOutputDistribution(false);
        filter.setOutputErrorFlag(true);
        filter.setInputFormat(train);
        Filter.useFilter(train, filter); // trains the classifier

        Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
        if (predictedData == null)
            predictedData = new Instances(pred, 0);
        for (int j = 0; j < pred.numInstances(); j++)
            predictedData.add(pred.instance(j));
    }

    // Prepare output classification
    String[] scores = new String[predictedData.numInstances()];

    for (Instance predInst : predictedData) {
        int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

        int valueIdx = predictedData.numAttributes() - 2;

        String value = predInst.stringValue(predInst.attribute(valueIdx));

        scores[id] = value;
    }

    // Output
    StringBuilder sb = new StringBuilder();
    for (String score : scores)
        sb.append(score.toString() + LF);

    FileUtils.writeStringToFile(
            new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/output.csv"),
            sb.toString());
}

From source file:de.ugoe.cs.cpdp.dataselection.DecisionTreeSelection.java

License:Apache License

@Override
public void apply(Instances testdata, SetUniqueList<Instances> traindataSet) {
    final Instances data = characteristicInstances(testdata, traindataSet);

    final ArrayList<String> attVals = new ArrayList<String>();
    attVals.add("same");
    attVals.add("more");
    attVals.add("less");
    final ArrayList<Attribute> atts = new ArrayList<Attribute>();
    for (int j = 0; j < data.numAttributes(); j++) {
        atts.add(new Attribute(data.attribute(j).name(), attVals));
    }/*from  www  .  ja v a  2  s. c  o m*/
    atts.add(new Attribute("score"));
    Instances similarityData = new Instances("similarity", atts, 0);
    similarityData.setClassIndex(similarityData.numAttributes() - 1);

    try {
        Classifier classifier = new J48();
        for (int i = 0; i < traindataSet.size(); i++) {
            classifier.buildClassifier(traindataSet.get(i));
            for (int j = 0; j < traindataSet.size(); j++) {
                if (i != j) {
                    double[] similarity = new double[data.numAttributes() + 1];
                    for (int k = 0; k < data.numAttributes(); k++) {
                        if (0.9 * data.get(i + 1).value(k) > data.get(j + 1).value(k)) {
                            similarity[k] = 2.0;
                        } else if (1.1 * data.get(i + 1).value(k) < data.get(j + 1).value(k)) {
                            similarity[k] = 1.0;
                        } else {
                            similarity[k] = 0.0;
                        }
                    }

                    Evaluation eval = new Evaluation(traindataSet.get(j));
                    eval.evaluateModel(classifier, traindataSet.get(j));
                    similarity[data.numAttributes()] = eval.fMeasure(1);
                    similarityData.add(new DenseInstance(1.0, similarity));
                }
            }
        }
        REPTree repTree = new REPTree();
        if (repTree.getNumFolds() > similarityData.size()) {
            repTree.setNumFolds(similarityData.size());
        }
        repTree.setNumFolds(2);
        repTree.buildClassifier(similarityData);

        Instances testTrainSimilarity = new Instances(similarityData);
        testTrainSimilarity.clear();
        for (int i = 0; i < traindataSet.size(); i++) {
            double[] similarity = new double[data.numAttributes() + 1];
            for (int k = 0; k < data.numAttributes(); k++) {
                if (0.9 * data.get(0).value(k) > data.get(i + 1).value(k)) {
                    similarity[k] = 2.0;
                } else if (1.1 * data.get(0).value(k) < data.get(i + 1).value(k)) {
                    similarity[k] = 1.0;
                } else {
                    similarity[k] = 0.0;
                }
            }
            testTrainSimilarity.add(new DenseInstance(1.0, similarity));
        }

        int bestScoringProductIndex = -1;
        double maxScore = Double.MIN_VALUE;
        for (int i = 0; i < traindataSet.size(); i++) {
            double score = repTree.classifyInstance(testTrainSimilarity.get(i));
            if (score > maxScore) {
                maxScore = score;
                bestScoringProductIndex = i;
            }
        }
        Instances bestScoringProduct = traindataSet.get(bestScoringProductIndex);
        traindataSet.clear();
        traindataSet.add(bestScoringProduct);
    } catch (Exception e) {
        Console.printerr("failure during DecisionTreeSelection: " + e.getMessage());
        throw new RuntimeException(e);
    }
}

From source file:de.ugoe.cs.cpdp.loader.NetgeneLoader.java

License:Apache License

@Override
public Instances load(File fileMetricsFile) {
    // first determine all files
    String path = fileMetricsFile.getParentFile().getAbsolutePath();
    String project = fileMetricsFile.getName().split("_")[0];
    File bugsFile = new File(path + "/" + project + "_bugs_per_file.csv");
    File networkMetrics = new File(path + "/" + project + "_network_metrics.csv");
    Instances metricsData = null;

    try {/*from  ww w . j av  a 2  s .c om*/
        CSVLoader wekaCsvLoader = new CSVLoader();
        wekaCsvLoader.setSource(fileMetricsFile);
        metricsData = wekaCsvLoader.getDataSet();
        wekaCsvLoader.setSource(bugsFile);
        Instances bugsData = wekaCsvLoader.getDataSet();
        wekaCsvLoader.setSource(networkMetrics);
        Instances networkData = wekaCsvLoader.getDataSet();

        metricsData.setRelationName(project);

        // fix nominal attributes (i.e., NA values)
        for (int j = 2; j < networkData.numAttributes(); j++) {
            if (networkData.attribute(j).isNominal()) {
                String attributeName = networkData.attribute(j).name();
                double[] tmpVals = new double[networkData.size()];
                // get temporary values
                for (int i = 0; i < networkData.size(); i++) {
                    Instance inst = networkData.instance(i);
                    if (!inst.isMissing(j)) {
                        String val = networkData.instance(i).stringValue(j);
                        try {
                            tmpVals[i] = Double.parseDouble(val);
                        } catch (NumberFormatException e) {
                            // not a number, using 0.0;
                            tmpVals[i] = 0.0;
                        }
                    } else {
                        tmpVals[i] = 0.0;
                    }
                }
                // replace attribute
                networkData.deleteAttributeAt(j);
                networkData.insertAttributeAt(new Attribute(attributeName), j);
                for (int i = 0; i < networkData.size(); i++) {
                    networkData.instance(i).setValue(j, tmpVals[i]);
                }
            }
        }
        // fix string attributes
        for (int j = 2; j < networkData.numAttributes(); j++) {
            if (networkData.attribute(j).isString()) {
                String attributeName = networkData.attribute(j).name();
                double[] tmpVals = new double[networkData.size()];
                // get temporary values
                for (int i = 0; i < networkData.size(); i++) {
                    Instance inst = networkData.instance(i);
                    if (!inst.isMissing(j)) {
                        String val = networkData.instance(i).stringValue(j);
                        try {
                            tmpVals[i] = Double.parseDouble(val);
                        } catch (NumberFormatException e) {
                            // not a number, using 0.0;
                            tmpVals[i] = 0.0;
                        }
                    } else {
                        tmpVals[i] = 0.0;
                    }
                }
                // replace attribute
                networkData.deleteAttributeAt(j);
                networkData.insertAttributeAt(new Attribute(attributeName), j);
                for (int i = 0; i < networkData.size(); i++) {
                    networkData.instance(i).setValue(j, tmpVals[i]);
                }
            }
        }

        Map<String, Integer> filenames = new HashMap<>();
        for (int j = 0; j < metricsData.size(); j++) {
            filenames.put(metricsData.instance(j).stringValue(0), j);
        }
        // merge with network data
        int attributeIndex;
        for (int j = 2; j < networkData.numAttributes(); j++) {
            attributeIndex = metricsData.numAttributes();
            metricsData.insertAttributeAt(networkData.attribute(j), attributeIndex);
            for (int i = 0; i < networkData.size(); i++) {
                Integer instanceIndex = filenames.get(networkData.instance(i).stringValue(1));
                if (instanceIndex != null) {
                    metricsData.instance(instanceIndex).setValue(attributeIndex,
                            networkData.instance(i).value(j));
                }
            }
        }

        // add bug information
        attributeIndex = metricsData.numAttributes();
        final ArrayList<String> classAttVals = new ArrayList<String>();
        classAttVals.add("0");
        classAttVals.add("1");
        final Attribute classAtt = new Attribute("bug", classAttVals);
        metricsData.insertAttributeAt(classAtt, attributeIndex);
        for (int i = 0; i < bugsData.size(); i++) {
            if (bugsData.instance(i).value(2) > 0.0d) {
                Integer instanceIndex = filenames.get(bugsData.instance(i).stringValue(1));
                if (instanceIndex != null) {
                    metricsData.instance(instanceIndex).setValue(attributeIndex, 1.0);
                }
            }
        }

        // remove filenames
        metricsData.deleteAttributeAt(0);
        Attribute eigenvector = metricsData.attribute("eigenvector");
        if (eigenvector != null) {
            for (int j = 0; j < metricsData.numAttributes(); j++) {
                if (metricsData.attribute(j) == eigenvector) {
                    metricsData.deleteAttributeAt(j);
                }
            }
        }

        metricsData.setClassIndex(metricsData.numAttributes() - 1);

        // set all missing values to 0
        for (int i = 0; i < metricsData.size(); i++) {
            for (int j = 0; j < metricsData.numAttributes(); j++) {
                if (metricsData.instance(i).isMissing(j)) {
                    metricsData.instance(i).setValue(j, 0.0d);
                }
            }
        }
    } catch (IOException e) {
        Console.traceln(Level.SEVERE, "failure reading file: " + e.getMessage());
        metricsData = null;
    }
    return metricsData;
}