Example usage for weka.core Instances setClassIndex

List of usage examples for weka.core Instances setClassIndex

Introduction

In this page you can find the example usage for weka.core Instances setClassIndex.

Prototype

public void setClassIndex(int classIndex) 

Source Link

Document

Sets the class index of the set.

Usage

From source file:com.guidefreitas.locator.services.PredictionService.java

public Evaluation train() {
    try {/*  www.ja  v a  2s.c om*/
        String arffData = this.generateTrainData();
        InputStream stream = new ByteArrayInputStream(arffData.getBytes(StandardCharsets.UTF_8));
        DataSource source = new DataSource(stream);
        Instances data = source.getDataSet();
        data.setClassIndex(data.numAttributes() - 1);
        this.classifier = new LibSVM();
        this.classifier.setKernelType(new SelectedTag(LibSVM.KERNELTYPE_POLYNOMIAL, LibSVM.TAGS_KERNELTYPE));
        this.classifier.setSVMType(new SelectedTag(LibSVM.SVMTYPE_C_SVC, LibSVM.TAGS_SVMTYPE));

        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(this.classifier, data, 10, new Random(1));

        this.classifier.buildClassifier(data);
        return eval;
    } catch (Exception ex) {
        Logger.getLogger(PredictionService.class.getName()).log(Level.SEVERE, null, ex);
    }

    return null;
}

From source file:com.guidefreitas.locator.services.PredictionService.java

public Room predict(PredictionRequest request) {
    try {/*from www . ja v a  2  s.co m*/

        String arffData = this.generateTestData(request);
        StringReader reader = new StringReader(arffData);
        Instances unlabeled = new Instances(reader);
        System.out.println("Test data size: " + unlabeled.size());
        unlabeled.setClassIndex(unlabeled.numAttributes() - 1);
        Instances labeled = new Instances(unlabeled);
        Double clsLabel = this.classifier.classifyInstance(unlabeled.get(0));
        labeled.instance(0).setClassValue(clsLabel);
        String roomIdString = unlabeled.classAttribute().value(clsLabel.intValue());

        Long roomId = Long.parseLong(roomIdString);
        Room predictedRoom = RoomService.getInstance().getById(roomId);
        System.out.println(clsLabel + " -> " + roomIdString + " -> " + predictedRoom.getName());
        return predictedRoom;

    } catch (Exception ex) {
        Logger.getLogger(PredictionService.class.getName()).log(Level.SEVERE, null, ex);
    }
    return null;
}

From source file:com.ivanrf.smsspam.SpamClassifier.java

License:Apache License

public static void train(int wordsToKeep, String tokenizerOp, boolean useAttributeSelection,
        String classifierOp, boolean boosting, JTextArea log) {
    try {// ww w.  j  a v  a2  s .c  om
        long start = System.currentTimeMillis();

        String modelName = getModelName(wordsToKeep, tokenizerOp, useAttributeSelection, classifierOp,
                boosting);
        showEstimatedTime(true, modelName, log);

        Instances trainData = loadDataset("SMSSpamCollection.arff", log);
        trainData.setClassIndex(0);

        FilteredClassifier classifier = initFilterClassifier(wordsToKeep, tokenizerOp, useAttributeSelection,
                classifierOp, boosting);

        publishEstado("=== Building the classifier on the filtered data ===", log);
        classifier.buildClassifier(trainData);

        publishEstado(classifier.toString(), log);
        publishEstado("=== Training done ===", log);

        saveModel(classifier, modelName, log);

        publishEstado("Elapsed time: " + Utils.getDateHsMinSegString(System.currentTimeMillis() - start), log);
    } catch (Exception e) {
        e.printStackTrace();
        publishEstado("Error found when training", log);
    }
}

From source file:com.ivanrf.smsspam.SpamClassifier.java

License:Apache License

public static void evaluate(int wordsToKeep, String tokenizerOp, boolean useAttributeSelection,
        String classifierOp, boolean boosting, JTextArea log) {
    try {// w w w .j a va2 s . c  om
        long start = System.currentTimeMillis();

        String modelName = getModelName(wordsToKeep, tokenizerOp, useAttributeSelection, classifierOp,
                boosting);
        showEstimatedTime(false, modelName, log);

        Instances trainData = loadDataset("SMSSpamCollection.arff", log);
        trainData.setClassIndex(0);
        FilteredClassifier classifier = initFilterClassifier(wordsToKeep, tokenizerOp, useAttributeSelection,
                classifierOp, boosting);

        publishEstado("=== Performing cross-validation ===", log);
        Evaluation eval = new Evaluation(trainData);
        //         eval.evaluateModel(classifier, trainData);
        eval.crossValidateModel(classifier, trainData, 10, new Random(1));

        publishEstado(eval.toSummaryString(), log);
        publishEstado(eval.toClassDetailsString(), log);
        publishEstado(eval.toMatrixString(), log);
        publishEstado("=== Evaluation finished ===", log);

        publishEstado("Elapsed time: " + Utils.getDateHsMinSegString(System.currentTimeMillis() - start), log);
    } catch (Exception e) {
        e.printStackTrace();
        publishEstado("Error found when evaluating", log);
    }
}

From source file:com.ivanrf.smsspam.SpamClassifier.java

License:Apache License

public static String classify(String model, String text, JTextArea log) {
    FilteredClassifier classifier = loadModel(model, log);

    //Create the instance
    ArrayList<String> fvNominalVal = new ArrayList<String>();
    fvNominalVal.add("ham");
    fvNominalVal.add("spam");

    Attribute attribute1 = new Attribute("spam_class", fvNominalVal);
    Attribute attribute2 = new Attribute("text", (List<String>) null);
    ArrayList<Attribute> fvWekaAttributes = new ArrayList<Attribute>();
    fvWekaAttributes.add(attribute1);// www.  ja v  a 2 s .  co  m
    fvWekaAttributes.add(attribute2);

    Instances instances = new Instances("Test relation", fvWekaAttributes, 1);
    instances.setClassIndex(0);

    DenseInstance instance = new DenseInstance(2);
    instance.setValue(attribute2, text);
    instances.add(instance);

    publishEstado("=== Instance created ===", log);
    publishEstado(instances.toString(), log);

    //Classify the instance
    try {
        publishEstado("=== Classifying instance ===", log);

        double pred = classifier.classifyInstance(instances.instance(0));

        publishEstado("=== Instance classified  ===", log);

        String classPredicted = instances.classAttribute().value((int) pred);
        publishEstado("Class predicted: " + classPredicted, log);

        return classPredicted;
    } catch (Exception e) {
        publishEstado("Error found when classifying the text", log);
        return null;
    }
}

From source file:com.jgaap.util.Instance.java

License:Open Source License

/**
 * Main method for testing this class.//from ww w  . j  ava  2  s.co  m
 * 
 * @param options the commandline options - ignored
 */
//@ requires options != null;
public static void main(String[] options) {

    try {

        // Create numeric attributes "length" and "weight"
        Attribute length = new Attribute("length");
        Attribute weight = new Attribute("weight");

        // Create vector to hold nominal values "first", "second", "third" 
        FastVector my_nominal_values = new FastVector(3);
        my_nominal_values.addElement("first");
        my_nominal_values.addElement("second");
        my_nominal_values.addElement("third");

        // Create nominal attribute "position" 
        Attribute position = new Attribute("position", my_nominal_values);

        // Create vector of the above attributes 
        FastVector attributes = new FastVector(3);
        attributes.addElement(length);
        attributes.addElement(weight);
        attributes.addElement(position);

        // Create the empty dataset "race" with above attributes
        Instances race = new Instances("race", attributes, 0);

        // Make position the class attribute
        race.setClassIndex(position.index());

        // Create empty instance with three attribute values
        Instance inst = new Instance(3);

        // Set instance's values for the attributes "length", "weight", and "position"
        inst.setValue(length, 5.3);
        inst.setValue(weight, 300);
        inst.setValue(position, "first");

        // Set instance's dataset to be the dataset "race"
        inst.setDataset(race);

        // Print the instance
        System.out.println("The instance: " + inst);

        // Print the first attribute
        System.out.println("First attribute: " + inst.attribute(0));

        // Print the class attribute
        System.out.println("Class attribute: " + inst.classAttribute());

        // Print the class index
        System.out.println("Class index: " + inst.classIndex());

        // Say if class is missing
        System.out.println("Class is missing: " + inst.classIsMissing());

        // Print the instance's class value in internal format
        System.out.println("Class value (internal format): " + inst.classValue());

        // Print a shallow copy of this instance
        Instance copy = (Instance) inst.copy();
        System.out.println("Shallow copy: " + copy);

        // Set dataset for shallow copy
        copy.setDataset(inst.dataset());
        System.out.println("Shallow copy with dataset set: " + copy);

        // Unset dataset for copy, delete first attribute, and insert it again
        copy.setDataset(null);
        copy.deleteAttributeAt(0);
        copy.insertAttributeAt(0);
        copy.setDataset(inst.dataset());
        System.out.println("Copy with first attribute deleted and inserted: " + copy);

        // Enumerate attributes (leaving out the class attribute)
        System.out.println("Enumerating attributes (leaving out class):");
        Enumeration enu = inst.enumerateAttributes();
        while (enu.hasMoreElements()) {
            Attribute att = (Attribute) enu.nextElement();
            System.out.println(att);
        }

        // Headers are equivalent?
        System.out.println("Header of original and copy equivalent: " + inst.equalHeaders(copy));

        // Test for missing values
        System.out.println("Length of copy missing: " + copy.isMissing(length));
        System.out.println("Weight of copy missing: " + copy.isMissing(weight.index()));
        System.out.println("Length of copy missing: " + Instance.isMissingValue(copy.value(length)));
        System.out.println("Missing value coded as: " + Instance.missingValue());

        // Prints number of attributes and classes
        System.out.println("Number of attributes: " + copy.numAttributes());
        System.out.println("Number of classes: " + copy.numClasses());

        // Replace missing values
        double[] meansAndModes = { 2, 3, 0 };
        copy.replaceMissingValues(meansAndModes);
        System.out.println("Copy with missing value replaced: " + copy);

        // Setting and getting values and weights
        copy.setClassMissing();
        System.out.println("Copy with missing class: " + copy);
        copy.setClassValue(0);
        System.out.println("Copy with class value set to first value: " + copy);
        copy.setClassValue("third");
        System.out.println("Copy with class value set to \"third\": " + copy);
        copy.setMissing(1);
        System.out.println("Copy with second attribute set to be missing: " + copy);
        copy.setMissing(length);
        System.out.println("Copy with length set to be missing: " + copy);
        copy.setValue(0, 0);
        System.out.println("Copy with first attribute set to 0: " + copy);
        copy.setValue(weight, 1);
        System.out.println("Copy with weight attribute set to 1: " + copy);
        copy.setValue(position, "second");
        System.out.println("Copy with position set to \"second\": " + copy);
        copy.setValue(2, "first");
        System.out.println("Copy with last attribute set to \"first\": " + copy);
        System.out.println("Current weight of instance copy: " + copy.weight());
        copy.setWeight(2);
        System.out.println("Current weight of instance copy (set to 2): " + copy.weight());
        System.out.println("Last value of copy: " + copy.toString(2));
        System.out.println("Value of position for copy: " + copy.toString(position));
        System.out.println("Last value of copy (internal format): " + copy.value(2));
        System.out.println("Value of position for copy (internal format): " + copy.value(position));
    } catch (Exception e) {
        e.printStackTrace();
    }
}

From source file:com.mechaglot_Alpha2.controller.Calculate.java

License:Creative Commons License

/**
 * /*from  w  w  w  .ja va  2 s.  com*/
 * @param in
 *            String representing the calculated String-metric distances,
 *            comma separated.
 * @return Instance The inputted series of numbers (comma separated) as
 *         Instance.
 */

private Instance instanceMaker(String in) {

    String[] s = in.split(",");
    double[] r = new double[s.length];
    for (int t = 0; t < r.length; t++) {
        r[t] = Double.parseDouble(s[t]);
    }

    int sz = r.length - 1;

    ArrayList<Attribute> atts = new ArrayList<Attribute>(sz);

    for (int t = 0; t < sz + 1; t++) {
        atts.add(new Attribute("number" + t, t));
    }

    Instances dataRaw = new Instances("TestInstances", atts, sz);
    dataRaw.add(new DenseInstance(1.0, r));
    Instance first = dataRaw.firstInstance(); //
    int cIdx = dataRaw.numAttributes() - 1;
    dataRaw.setClassIndex(cIdx);

    return first;

}

From source file:com.mycompany.id3classifier.ID3Shell.java

public static void main(String[] args) throws Exception {
    ConverterUtils.DataSource source = new ConverterUtils.DataSource("lensesData.csv");
    Instances dataSet = source.getDataSet();

    Discretize filter = new Discretize();
    filter.setInputFormat(dataSet);//from   w w w  .  ja  va 2  s . c  o m
    dataSet = Filter.useFilter(dataSet, filter);

    Standardize standardize = new Standardize();
    standardize.setInputFormat(dataSet);
    dataSet = Filter.useFilter(dataSet, standardize);

    dataSet.setClassIndex(dataSet.numAttributes() - 1);
    dataSet.randomize(new Random(9001)); //It's over 9000!!

    int folds = 10;
    //Perform crossvalidation
    Evaluation eval = new Evaluation(dataSet);
    for (int n = 0; n < folds; n++) {
        int trainingSize = (int) Math.round(dataSet.numInstances() * .7);
        int testSize = dataSet.numInstances() - trainingSize;

        Instances trainingData = dataSet.trainCV(folds, n);
        Instances testData = dataSet.testCV(folds, n);

        ID3Classifier classifier = new ID3Classifier();
        // Id3 classifier = new Id3();
        classifier.buildClassifier(trainingData);

        eval.evaluateModel(classifier, testData);
    }
    System.out.println(eval.toSummaryString("\nResults:\n", false));
}

From source file:com.mycompany.knnclassifier.kNNShell.java

public static void main(String[] args) throws Exception {
    ConverterUtils.DataSource source = new ConverterUtils.DataSource("carData.csv");
    Instances dataSet = source.getDataSet();

    Standardize standardize = new Standardize();
    standardize.setInputFormat(dataSet);
    dataSet = Filter.useFilter(dataSet, standardize);

    dataSet.setClassIndex(dataSet.numAttributes() - 1);
    dataSet.randomize(new Random(9001)); //It's over 9000!!

    int trainingSize = (int) Math.round(dataSet.numInstances() * .7);
    int testSize = dataSet.numInstances() - trainingSize;

    Instances trainingData = new Instances(dataSet, 0, trainingSize);
    Instances testData = new Instances(dataSet, trainingSize, testSize);

    kNNClassifier classifier = new kNNClassifier(3);
    classifier.buildClassifier(trainingData);

    //Used to compare to Weka's built in KNN algorithm
    //Classifier classifier = new IBk(1);
    //classifier.buildClassifier(trainingData);

    Evaluation eval = new Evaluation(trainingData);
    eval.evaluateModel(classifier, testData);

    System.out.println(eval.toSummaryString("\nResults:\n", false));
}

From source file:com.mycompany.neuralnetwork.NeuralNetworkShell.java

public static void main(String[] args) throws Exception {
    ConverterUtils.DataSource source = new ConverterUtils.DataSource("irisData.csv");
    Instances dataSet = source.getDataSet();

    Standardize standardize = new Standardize();
    standardize.setInputFormat(dataSet);
    dataSet = Filter.useFilter(dataSet, standardize);
    dataSet.setClassIndex(dataSet.numAttributes() - 1);
    dataSet.randomize(new Random(9001)); //It's over 9000!!

    int trainingSize = (int) Math.round(dataSet.numInstances() * .7);
    int testSize = dataSet.numInstances() - trainingSize;

    Instances trainingData = new Instances(dataSet, 0, trainingSize);
    Instances testData = new Instances(dataSet, trainingSize, testSize);

    //MultilayerPerceptron classifier = new MultilayerPerceptron();
    NeuralNetworkClassifier classifier = new NeuralNetworkClassifier(3, 20000, 0.1);
    classifier.buildClassifier(trainingData);

    Evaluation eval = new Evaluation(trainingData);
    eval.evaluateModel(classifier, testData);

    System.out.println(eval.toSummaryString("\nResults:\n", false));
}