Example usage for weka.core Instances attribute

List of usage examples for weka.core Instances attribute

Introduction

In this page you can find the example usage for weka.core Instances attribute.

Prototype

publicAttribute attribute(String name) 

Source Link

Document

Returns an attribute given its name.

Usage

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput//from   w ww.  ja va  2  s.c o m
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidation(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation ST");

    PerformanceCounters.startTimer("cross-validation init ST");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    Instances trainSets[] = new Instances[folds];
    Instances testSets[] = new Instances[folds];
    Classifier foldCls[] = new Classifier[folds];

    for (int n = 0; n < folds; n++) {
        trainSets[n] = randData.trainCV(folds, n);
        testSets[n] = randData.testCV(folds, n);
        foldCls[n] = AbstractClassifier.makeCopy(cls);
    }

    PerformanceCounters.stopTimer("cross-validation init ST");
    PerformanceCounters.startTimer("cross-validation folds+train ST");
    //paralelize!!:--------------------------------------------------------------
    for (int n = 0; n < folds; n++) {
        Instances train = trainSets[n];
        Instances test = testSets[n];

        // the above code is used by the StratifiedRemoveFolds filter, the
        // code below by the Explorer/Experimenter:
        // Instances train = randData.trainCV(folds, n, rand);
        // build and evaluate classifier
        Classifier clsCopy = foldCls[n];
        clsCopy.buildClassifier(train);
        eval.evaluateModel(clsCopy, test);
    }

    cls.buildClassifier(data);
    //until here!-----------------------------------------------------------------

    PerformanceCounters.stopTimer("cross-validation folds+train ST");
    PerformanceCounters.startTimer("cross-validation post ST");
    // output evaluation
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post ST");
    PerformanceCounters.stopTimer("cross-validation ST");

    return out;
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput/*from w w  w  . jav  a  2  s.  c o  m*/
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                AbstractClassifier.makeCopy(cls)));

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        //use the current thread to run the cross-validation instead of using the Thread instance created here:
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    cls.buildClassifier(data);

    for (Thread foldThread : foldThreads) {
        foldThread.join();
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.PostProcess.java

private void formatPredictions(Instances instances, double[] predictions, String[] columnNames,
        int predictionsColumnIndex, String predictionsColumnName, String columnSeparator, String outputFilename,
        boolean writeColumnsHeaderLine) {
    PerformanceCounters.startTimer("formatPredictions");

    System.out.println("Formatting predictions to file " + outputFilename + "...");
    File outputFile = new File(outputFilename);
    PrintWriter writer;//  w  w  w  . ja  v  a2s  .  c  om

    try {
        outputFile.getParentFile().mkdirs();
        outputFile.createNewFile();
        writer = new PrintWriter(outputFile, "UTF-8");
    } catch (IOException ex) {
        Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex);
        return;
    }

    StringBuilder sb = new StringBuilder();
    DecimalFormat df = new DecimalFormat("#.#", new DecimalFormatSymbols(Locale.US));
    df.setMaximumFractionDigits(3);

    int i = -1;
    if (!writeColumnsHeaderLine) {
        i = 0;
    }
    for (; i < instances.numInstances(); i++) {
        sb.delete(0, sb.length());

        for (int j = 0; j < columnNames.length; j++) {
            if (j > 0) {
                sb.append(columnSeparator);
            }

            if (j == predictionsColumnIndex) {
                if (i < 0) {
                    sb.append(predictionsColumnName);
                } else {
                    sb.append(df.format(predictions[i]));
                }
                sb.append(columnSeparator);
            }
            if (i < 0) {
                sb.append(columnNames[j]);
            } else {
                if (columnNames[j].toLowerCase().contains("id")) {
                    Attribute attribute = instances.attribute(columnNames[j]);
                    if (attribute != null) {
                        sb.append((int) instances.instance(i).value(attribute.index()));
                    } else {
                        sb.append(0);
                    }
                } else {
                    Attribute attribute = instances.attribute(columnNames[j]);
                    if (attribute != null) {
                        sb.append(instances.instance(i).value(attribute.index()));
                    } else {
                        sb.append(df.format(0d));
                    }
                }
            }
        }

        if (columnNames.length == predictionsColumnIndex) {
            sb.append(columnSeparator);
            if (i < 0) {
                sb.append(predictionsColumnName);
            } else {
                sb.append(df.format(predictions[i]));
            }
        }

        writer.println(sb);
    }
    writer.flush();
    writer.close();
    System.out.println("\tdone.");
    PerformanceCounters.stopTimer("formatPredictions");
}

From source file:asap.PostProcess.java

private void writePredictionErrors(Instances instances, double[] predictions, String errorsFilename) {

    TreeSet<PredictionError> errors = new TreeSet<>();

    for (int i = 0; i < predictions.length; i++) {
        double prediction = predictions[i];
        double expected = instances.get(i).classValue();
        int pairId = (int) instances.get(i).value(instances.attribute("pair_ID"));
        String sourceFile = instances.get(i).stringValue(instances.attribute("source_file"));
        PredictionError pe = new PredictionError(prediction, expected, pairId, sourceFile, instances.get(i));

        //if (pe.getError()>=0.5d)
        errors.add(pe);/* w  ww  .  j a  va2 s. c  o m*/
    }

    StringBuilder sb = new StringBuilder();

    for (PredictionError error : errors) {
        sb.append(error.toString()).append("\n");
    }

    File f = new File(errorsFilename);
    try (FileOutputStream fos = new FileOutputStream(f)) {
        fos.write(sb.toString().getBytes());
    } catch (IOException ex) {
        Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:asap.PostProcess.java

public void loadTrainingDataStream(PreProcessOutputStream pposTrainingData) {
    Instances instancesTrainingSet;

    DataSource source = new DataSource(pposTrainingData);
    try {/*from   w  ww  .  j a v  a 2  s . com*/
        instancesTrainingSet = source.getDataSet();

    } catch (Exception ex) {
        Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex);
        return;
    }
    // setting class attribute if the data format does not provide this information
    if (instancesTrainingSet.classIndex() == -1) {
        instancesTrainingSet.setClass(instancesTrainingSet.attribute("gold_standard"));
    }

    for (String wekaModelsCmd : Config.getWekaModelsCmd()) {
        String[] classifierCmd;
        try {
            classifierCmd = Utils.splitOptions(wekaModelsCmd);
        } catch (Exception ex) {
            Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex);
            continue;
        }
        String classname = classifierCmd[0];
        classifierCmd[0] = "";
        try {
            AbstractClassifier cl = (AbstractClassifier) Utils.forName(Classifier.class, classname,
                    classifierCmd);
            //                String modelName = String.format("%s%s%s%s.model", modelDirectory, File.separatorChar, i, classname);
            //                System.out.println(String.format("\tBuilding model %s (%s) and doing cross-validation...", i++, modelName));
            //                System.out.println(CrossValidation.performCrossValidationMT(trainSet, cl, Config.getCrossValidationSeed(), Config.getCrossValidationFolds(), modelName));
            systems.add(new NLPSystem(cl, instancesTrainingSet, null));
            System.out.println("\tAdded system " + systems.get(systems.size() - 1).shortName());
        } catch (Exception ex) {
            Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

}

From source file:asap.PostProcess.java

public void loadEvaluationDataStream(PreProcessOutputStream pposEvaluationData) {

    Instances instancesEvaluationSet;

    DataSource source = new DataSource(pposEvaluationData);

    try {/*from w  w  w.  ja  va2  s.  c o  m*/
        instancesEvaluationSet = source.getDataSet();
    } catch (Exception ex) {
        Logger.getLogger(PostProcess.class.getName()).log(Level.SEVERE, null, ex);
        return;
    }
    // setting class attribute if the data format does not provide this information
    if (instancesEvaluationSet.classIndex() == -1) {
        instancesEvaluationSet.setClass(instancesEvaluationSet.attribute("gold_standard"));
    }

    for (NLPSystem system : systems) {
        system.setEvaluationSet(instancesEvaluationSet);
    }
}

From source file:at.aictopic1.sentimentanalysis.machinelearning.impl.TwitterClassifer.java

public Instances loadTrainingData() {

    try {//ww  w .j a v  a2  s.  c o m
        //DataSource source = new DataSource("C:\\Users\\David\\Documents\\Datalogi\\TU Wien\\2014W_Advanced Internet Computing\\Labs\\aic_group2_topic1\\Other Stuff\\training_dataset.arff");
        DataSource source = new DataSource(
                "C:\\Users\\David\\Documents\\Datalogi\\TU Wien\\2014W_Advanced Internet Computing\\Labs\\Data sets\\labelled.arff");

        //            System.out.println("Data Structure pre processing: " + source.getStructure());
        Instances data = source.getDataSet();

        // Get and save the dataStructure of the dataset
        dataStructure = source.getStructure();
        try {
            // Save the datastructure to file
            // serialize dataStructure
            weka.core.SerializationHelper.write(modelDir + algorithm + ".dataStruct", dataStructure);
        } catch (Exception ex) {
            Logger.getLogger(TwitterClassifer.class.getName()).log(Level.SEVERE, null, ex);
        }
        // Set class index
        data.setClassIndex(2);

        // Giving attributes unique names before converting strings
        data.renameAttribute(2, "class_attr");
        data.renameAttribute(0, "twitter_id");

        // Convert String attribute to Words using filter
        StringToWordVector filter = new StringToWordVector();

        filter.setInputFormat(data);

        Instances filteredData = Filter.useFilter(data, filter);

        System.out.println("filteredData struct: " + filteredData.attribute(0));
        System.out.println("filteredData struct: " + filteredData.attribute(1));
        System.out.println("filteredData struct: " + filteredData.attribute(2));

        return filteredData;

    } catch (Exception ex) {
        System.out.println("Error loading training set: " + ex.toString());
        return null;
        //Logger.getLogger(Trainer.class.getName()).log(Level.SEVERE, null, ex);
    }

}

From source file:at.aictopic1.sentimentanalysis.machinelearning.impl.TwitterClassifer.java

public Integer classify(Tweet[] tweets) {
    // TEST/*  w  w w .j a v  a2s  . c  om*/

    // Generate two tweet examples
    Tweet exOne = new Tweet("This is good and fantastic");
    exOne.setPreprocessedText("This is good and fantastic");
    Tweet exTwo = new Tweet("Horribly, terribly bad and more");
    exTwo.setPreprocessedText("Horribly, terribly bad and more");
    Tweet exThree = new Tweet(
            "I want to update lj and read my friends list, but I\\'m groggy and sick and blargh.");
    exThree.setPreprocessedText(
            "I want to update lj and read my friends list, but I\\'m groggy and sick and blargh.");
    Tweet exFour = new Tweet("bad hate worst sick");
    exFour.setPreprocessedText("bad hate worst sick");
    tweets = new Tweet[] { exOne, exTwo, exThree, exFour };
    // TEST

    // Load model
    //        loadModel();
    // Convert Tweet to Instance type 
    // Get String Data
    // Create attributes for the Instances set
    Attribute twitter_id = new Attribute("twitter_id");
    //        Attribute body = new Attribute("body");

    FastVector classVal = new FastVector(2);
    classVal.addElement("pos");
    classVal.addElement("neg");

    Attribute class_attr = new Attribute("class_attr", classVal);

    // Add them to a list
    FastVector attrVector = new FastVector(3);
    //        attrVector.addElement(twitter_id);
    //        attrVector.addElement(new Attribute("body", (FastVector) null));
    //        attrVector.addElement(class_attr);

    // Get the number of tweets and then create predictSet
    int numTweets = tweets.length;
    Enumeration structAttrs = dataStructure.enumerateAttributes();

    //        ArrayList<Attribute> attrList = new ArrayList<Attribute>(dataStructure.numAttributes());
    while (structAttrs.hasMoreElements()) {
        attrVector.addElement((Attribute) structAttrs.nextElement());
    }
    Instances predictSet = new Instances("predictInstances", attrVector, numTweets);
    //        Instances predictSet = new Instances(dataStructure);
    predictSet.setClassIndex(2);

    // init prediction
    double prediction = -1;

    System.out.println("PredictSet matches source structure: " + predictSet.equalHeaders(dataStructure));

    System.out.println("PredSet struct: " + predictSet.attribute(0));
    System.out.println("PredSet struct: " + predictSet.attribute(1));
    System.out.println("PredSet struct: " + predictSet.attribute(2));
    // Array to return predictions 
    //double[] tweetsClassified = new double[2][numTweets];
    //List<Integer, Double> tweetsClass = new ArrayList<Integer, Double>(numTweets);
    for (int i = 0; i < numTweets; i++) {
        String content = (String) tweets[i].getPreprocessedText();

        System.out.println("Tweet content: " + content);

        //            attrList
        Instance tweetInstance = new Instance(predictSet.numAttributes());

        tweetInstance.setDataset(predictSet);

        tweetInstance.setValue(predictSet.attribute(0), i);
        tweetInstance.setValue(predictSet.attribute(1), content);
        tweetInstance.setClassMissing();

        predictSet.add(tweetInstance);

        try {
            // Apply string filter
            StringToWordVector filter = new StringToWordVector();

            filter.setInputFormat(predictSet);
            Instances filteredPredictSet = Filter.useFilter(predictSet, filter);

            // Apply model
            prediction = trainedModel.classifyInstance(filteredPredictSet.instance(i));
            filteredPredictSet.instance(i).setClassValue(prediction);
            System.out.println("Classification: " + filteredPredictSet.instance(i).toString());
            System.out.println("Prediction: " + prediction);

        } catch (Exception ex) {
            Logger.getLogger(TwitterClassifer.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    return 0;
}

From source file:bme.mace.logicdomain.Evaluation.java

License:Open Source License

/**
 * Returns the area under ROC for those predictions that have been collected
 * in the evaluateClassifier(Classifier, Instances) method. Returns
 * Instance.missingValue() if the area is not available.
 * /* w w  w  .j a  v a  2  s . c  o m*/
 * @param classIndex the index of the class to consider as "positive"
 * @return the area under the ROC curve or not a number
 */
public double areaUnderROC(int classIndex) {

    // Check if any predictions have been collected
    if (m_Predictions == null) {
        return Instance.missingValue();
    } else {
        ThresholdCurve tc = new ThresholdCurve();
        Instances result = tc.getCurve(m_Predictions, classIndex);
        double rocArea = ThresholdCurve.getROCArea(result);
        if (rocArea < 0.5) {
            rocArea = 1 - rocArea;
        }

        int tpIndex = result.attribute(ThresholdCurve.TP_RATE_NAME).index();
        int fpIndex = result.attribute(ThresholdCurve.FP_RATE_NAME).index();
        double[] tpRate = result.attributeToDoubleArray(tpIndex);
        double[] fpRate = result.attributeToDoubleArray(fpIndex);

        try {
            FileWriter fw;
            if (classIndex == 0)
                fw = new FileWriter("C://1.csv", true);
            else
                fw = new FileWriter("C://1.csv", true);

            BufferedWriter bw = new BufferedWriter(fw);
            int length = fpRate.length;
            for (int i = 255; i >= 0; i--) {

                int index = i * (length - 1) / 255;
                bw.write(fpRate[index] + ",");
            }
            bw.write("\n");
            for (int i = 255; i >= 0; i--) {
                int index = i * (length - 1) / 255;
                bw.write(tpRate[index] + ",");
            }
            bw.write("\n");

            bw.close();
            fw.close();
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        return rocArea;
    }
}

From source file:bme.mace.logicdomain.Evaluation.java

License:Open Source License

/**
 * Prints the header for the predictions output into a supplied StringBuffer
 * // w  w  w .j  a  v a2s . co m
 * @param test structure of the test set to print predictions for
 * @param attributesToOutput indices of the attributes to output
 * @param printDistribution prints the complete distribution for nominal
 *          attributes, not just the predicted value
 * @param text the StringBuffer to print to
 */
protected static void printClassificationsHeader(Instances test, Range attributesToOutput,
        boolean printDistribution, StringBuffer text) {
    // print header
    if (test.classAttribute().isNominal()) {
        if (printDistribution) {
            text.append(" inst#     actual  predicted error distribution");
        } else {
            text.append(" inst#     actual  predicted error prediction");
        }
    } else {
        text.append(" inst#     actual  predicted      error");
    }
    if (attributesToOutput != null) {
        attributesToOutput.setUpper(test.numAttributes() - 1);
        text.append(" (");
        boolean first = true;
        for (int i = 0; i < test.numAttributes(); i++) {
            if (i == test.classIndex()) {
                continue;
            }

            if (attributesToOutput.isInRange(i)) {
                if (!first) {
                    text.append(",");
                }
                text.append(test.attribute(i).name());
                first = false;
            }
        }
        text.append(")");
    }
    text.append("\n");
}