Example usage for weka.classifiers Classifier distributionForInstance

List of usage examples for weka.classifiers Classifier distributionForInstance

Introduction

In this page you can find the example usage for weka.classifiers Classifier distributionForInstance.

Prototype

public double[] distributionForInstance(Instance instance) throws Exception;

Source Link

Document

Predicts the class memberships for a given instance.

Usage

From source file:csav2.Weka_additive.java

public void classifyTestSet6(String input) throws Exception {
    String ids = "";
    ReaderWriter rw = new ReaderWriter();

    //ATTRIBUTES// w  w w . jav a 2  s  .co m
    Attribute attr[] = new Attribute[50];

    //numeric
    attr[0] = new Attribute("Autosentiment");
    attr[1] = new Attribute("PositiveMatch");
    attr[2] = new Attribute("NegativeMatch");
    attr[3] = new Attribute("FW");
    attr[4] = new Attribute("JJ");
    attr[5] = new Attribute("RB");
    attr[6] = new Attribute("RB_JJ");
    attr[7] = new Attribute("amod");
    attr[8] = new Attribute("acomp");
    attr[9] = new Attribute("advmod");
    attr[10] = new Attribute("BLPos");
    attr[11] = new Attribute("BLNeg");
    attr[12] = new Attribute("VSPos");
    attr[13] = new Attribute("VSNeg");

    //class
    FastVector classValue = new FastVector(3);
    classValue.addElement("p");
    classValue.addElement("n");
    classValue.addElement("o");
    attr[14] = new Attribute("answer", classValue);

    FastVector attrs = new FastVector();
    attrs.addElement(attr[0]);
    attrs.addElement(attr[1]);
    attrs.addElement(attr[2]);
    attrs.addElement(attr[3]);
    attrs.addElement(attr[4]);
    attrs.addElement(attr[5]);
    attrs.addElement(attr[6]);
    attrs.addElement(attr[7]);
    attrs.addElement(attr[8]);
    attrs.addElement(attr[9]);
    attrs.addElement(attr[10]);
    attrs.addElement(attr[11]);
    attrs.addElement(attr[12]);
    attrs.addElement(attr[13]);
    attrs.addElement(attr[14]);

    // Add Instances
    Instances dataset = new Instances("my_dataset", attrs, 0);

    StringTokenizer tokenizer = new StringTokenizer(input);

    while (tokenizer.hasMoreTokens()) {
        Instance example = new Instance(15);
        for (int j = 0; j < 15; j++) {
            String st = tokenizer.nextToken();
            System.out.println(j + " " + st);
            if (j == 0)
                example.setValue(attr[j], Float.parseFloat(st));
            else if (j == 14)
                example.setValue(attr[j], st);
            else
                example.setValue(attr[j], Integer.parseInt(st));
        }
        ids += tokenizer.nextToken() + "\t";
        dataset.add(example);
    }

    //Save dataset
    String file = "Classifier\\featurefile_additive_test6.arff";
    ArffSaver saver = new ArffSaver();
    saver.setInstances(dataset);
    saver.setFile(new File(file));
    saver.writeBatch();

    //Read dataset
    ArffLoader loader = new ArffLoader();
    loader.setFile(new File(file));
    dataset = loader.getDataSet();

    //Build classifier
    dataset.setClassIndex(14);

    //Read classifier back
    String file1 = "Classifier\\classifier_asAndpolarwordsAndposAnddepAndblAndvs.model";
    InputStream is = new FileInputStream(file1);
    Classifier classifier;
    ObjectInputStream objectInputStream = new ObjectInputStream(is);
    classifier = (Classifier) objectInputStream.readObject();

    //Evaluate
    Instances test = new Instances(dataset, 0, dataset.numInstances());
    test.setClassIndex(14);

    //Do eval
    Evaluation eval = new Evaluation(test); //trainset
    eval.evaluateModel(classifier, test); //testset
    System.out.println(eval.toSummaryString());
    System.out.println("WEIGHTED F-MEASURE:" + eval.weightedFMeasure());
    System.out.println("WEIGHTED PRECISION:" + eval.weightedPrecision());
    System.out.println("WEIGHTED RECALL:" + eval.weightedRecall());

    //output predictions
    String optest = "", val = "";
    StringTokenizer op = new StringTokenizer(ids);
    int count = 0;
    while (op.hasMoreTokens()) {
        double[] prediction = classifier.distributionForInstance(test.instance(count));
        count += 1;
        if (prediction[0] > prediction[1]) {
            if (prediction[0] > prediction[2]) {
                val = "p: " + Double.toString((double) Math.round((prediction[0]) * 1000) / 1000);
            } else {
                val = "o: " + Double.toString((double) Math.round((prediction[2]) * 1000) / 1000);
            }
        } else {
            if (prediction[1] > prediction[2]) {
                val = "n: " + Double.toString((double) Math.round((prediction[1]) * 1000) / 1000);
            } else {
                val = "o: " + Double.toString((double) Math.round((prediction[2]) * 1000) / 1000);
            }
        }
        optest += op.nextToken() + "\t" + val + "\n";
    }
    rw.writeToFile(optest, "Answers_additive_Test6", "txt");
}

From source file:edu.oregonstate.eecs.mcplan.abstraction.WekaUtil.java

License:Open Source License

/**
 * Classify a feature vector that is not part of an Instances object.
 * @param classifier/*  w  w  w .j  ava  2s.  c om*/
 * @param attributes
 * @param features
 * @return
 */
public static double[] distribution(final Classifier classifier, final List<Attribute> attributes,
        final double[] features) {
    final Instances x = createSingletonInstances(attributes, features);
    try {
        return classifier.distributionForInstance(x.get(0));
    } catch (final Exception ex) {
        throw new RuntimeException(ex);
    }
}

From source file:fk.stardust.localizer.machinelearn.WekaFaultLocalizer.java

License:Open Source License

@Override
public Ranking<T> localize(final ISpectra<T> spectra) {

    // == 1. Create Weka training instance

    final List<INode<T>> nodes = new ArrayList<>(spectra.getNodes());

    // nominal true/false values
    final List<String> tf = new ArrayList<String>();
    tf.add("t");//from  www . ja v a2 s .  c  om
    tf.add("f");

    // create an attribute for each component
    final Map<INode<T>, Attribute> attributeMap = new HashMap<INode<T>, Attribute>();
    final ArrayList<Attribute> attributeList = new ArrayList<Attribute>(); // NOCS: Weka needs ArrayList..
    for (final INode<T> node : nodes) {
        final Attribute attribute = new Attribute(node.toString(), tf);
        attributeList.add(attribute);
        attributeMap.put(node, attribute);
    }

    // create class attribute (trace success)
    final Attribute successAttribute = new Attribute("success", tf);
    attributeList.add(successAttribute);

    // create weka training instance
    final Instances trainingSet = new Instances("TraceInfoInstances", attributeList, 1);
    trainingSet.setClassIndex(attributeList.size() - 1);

    // == 2. add traces to training set

    // add an instance for each trace
    for (final ITrace<T> trace : spectra.getTraces()) {
        final Instance instance = new DenseInstance(nodes.size() + 1);
        instance.setDataset(trainingSet);
        for (final INode<T> node : nodes) {
            instance.setValue(attributeMap.get(node), trace.isInvolved(node) ? "t" : "f");
        }
        instance.setValue(successAttribute, trace.isSuccessful() ? "t" : "f");
        trainingSet.add(instance);
    }

    // == 3. use prediction to localize faults

    // build classifier
    try {
        final Classifier classifier = this.buildClassifier(this.classifierName, this.classifierOptions,
                trainingSet);
        final Ranking<T> ranking = new Ranking<>();

        System.out.println("begin classifying");
        int classified = 0;

        final Instance instance = new DenseInstance(nodes.size() + 1);
        instance.setDataset(trainingSet);
        for (final INode<T> node : nodes) {
            instance.setValue(attributeMap.get(node), "f");
        }
        instance.setValue(successAttribute, "f");

        for (final INode<T> node : nodes) {
            classified++;
            if (classified % 1000 == 0) {
                System.out.println(String.format("Classified %d nodes.", classified));
            }

            // contain only the current node in the network
            instance.setValue(attributeMap.get(node), "t");

            // predict with which probability this setup leads to a failing network
            final double[] distribution = classifier.distributionForInstance(instance);
            ranking.rank(node, distribution[1]);

            // reset involvment for node
            instance.setValue(attributeMap.get(node), "f");
        }
        return ranking;
    } catch (final Exception e) { // NOCS: Weka throws only raw exceptions
        throw new RuntimeException(e);
    }
}

From source file:gate.plugin.learningframework.engines.EngineWeka.java

@Override
public List<GateClassification> classify(AnnotationSet instanceAS, AnnotationSet inputAS,
        AnnotationSet sequenceAS, String parms) {

    Instances instances = crWeka.getRepresentationWeka();
    CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget) corpusRepresentationMallet;
    data.stopGrowth();/*from  w ww  .  j  a v  a 2  s  .  c o m*/
    List<GateClassification> gcs = new ArrayList<GateClassification>();
    LFPipe pipe = (LFPipe) data.getRepresentationMallet().getPipe();
    Classifier wekaClassifier = (Classifier) model;
    // iterate over the instance annotations and create mallet instances 
    for (Annotation instAnn : instanceAS.inDocumentOrder()) {
        Instance inst = data.extractIndependentFeatures(instAnn, inputAS);
        inst = pipe.instanceFrom(inst);
        // Convert to weka Instance
        weka.core.Instance wekaInstance = CorpusRepresentationWeka.wekaInstanceFromMalletInstance(instances,
                inst);
        // classify with the weka classifier or predict the numeric value: if the mallet pipe does have
        // a target alphabet we assume classification, otherwise we assume regression
        GateClassification gc = null;
        if (pipe.getTargetAlphabet() == null) {
            // regression
            double result = Double.NaN;
            try {
                result = wekaClassifier.classifyInstance(wekaInstance);
            } catch (Exception ex) {
                // Hmm, for now we just log the error and continue, not sure if we should stop here!
                ex.printStackTrace(System.err);
                Logger.getLogger(EngineWeka.class.getName()).log(Level.SEVERE, null, ex);
            }
            //gc = new GateClassification(instAnn, (result==Double.NaN ? null : String.valueOf(result)), 1.0);
            gc = new GateClassification(instAnn, result);
        } else {
            // classification

            // Weka AbstractClassifier already handles the situation correctly when 
            // distributionForInstance is not implemented by the classifier: in that case
            // is calls classifyInstance and returns an array of size numClasses where
            // the entry of the target class is set to 1.0 except when the classification is a missing
            // value, then all class probabilities will be 0.0
            // If distributionForInstance is implemented for the algorithm, we should get
            // the probabilities or all zeros for missing class from the algorithm.
            double[] predictionDistribution = new double[0];
            try {
                //System.err.println("classifying instance "+wekaInstance.toString());
                predictionDistribution = wekaClassifier.distributionForInstance(wekaInstance);
            } catch (Exception ex) {
                throw new RuntimeException(
                        "Weka classifier error in document " + instanceAS.getDocument().getName(), ex);
            }
            // This is classification, we should always get a distribution list > 1
            if (predictionDistribution.length < 2) {
                throw new RuntimeException("Classifier returned less than 2 probabilities: "
                        + predictionDistribution.length + "for instance" + wekaInstance);
            }
            double bestprob = 0.0;
            int bestlabel = 0;
            /*
            System.err.print("DEBUG: got classes from pipe: ");
              Object[] cls = pipe.getTargetAlphabet().toArray();
            boolean first = true;
            for(Object cl : cls) {
              if(first) { first = false; } else { System.err.print(", "); }
              System.err.print(">"+cl+"<");
            }
            System.err.println();
             */
            List<String> classList = new ArrayList<String>();
            List<Double> confidenceList = new ArrayList<Double>();
            for (int i = 0; i < predictionDistribution.length; i++) {
                int thislabel = i;
                double thisprob = predictionDistribution[i];
                String labelstr = (String) pipe.getTargetAlphabet().lookupObject(thislabel);
                classList.add(labelstr);
                confidenceList.add(thisprob);
                if (thisprob > bestprob) {
                    bestlabel = thislabel;
                    bestprob = thisprob;
                }
            } // end for i < predictionDistribution.length

            String cl = (String) pipe.getTargetAlphabet().lookupObject(bestlabel);

            gc = new GateClassification(instAnn, cl, bestprob, classList, confidenceList);
        }
        gcs.add(gc);
    }
    data.startGrowth();
    return gcs;
}

From source file:GClass.EvaluationInternal.java

License:Open Source License

/**
 * Prints the predictions for the given dataset into a String variable.
 *//* w  ww  .j  a va  2 s.c om*/
protected static String printClassifications(Classifier classifier, Instances train, String testFileName,
        int classIndex, Range attributesToOutput) throws Exception {

    StringBuffer text = new StringBuffer();
    if (testFileName.length() != 0) {
        BufferedReader testReader = null;
        try {
            testReader = new BufferedReader(new FileReader(testFileName));
        } catch (Exception e) {
            throw new Exception("Can't open file " + e.getMessage() + '.');
        }
        Instances test = new Instances(testReader, 1);
        if (classIndex != -1) {
            test.setClassIndex(classIndex - 1);
        } else {
            test.setClassIndex(test.numAttributes() - 1);
        }
        int i = 0;
        while (test.readInstance(testReader)) {
            Instance instance = test.instance(0);
            Instance withMissing = (Instance) instance.copy();
            withMissing.setDataset(test);
            double predValue = ((Classifier) classifier).classifyInstance(withMissing);
            if (test.classAttribute().isNumeric()) {
                if (Instance.isMissingValue(predValue)) {
                    text.append(i + " missing ");
                } else {
                    text.append(i + " " + predValue + " ");
                }
                if (instance.classIsMissing()) {
                    text.append("missing");
                } else {
                    text.append(instance.classValue());
                }
                text.append(" " + attributeValuesString(withMissing, attributesToOutput) + "\n");
            } else {
                if (Instance.isMissingValue(predValue)) {
                    text.append(i + " missing ");
                } else {
                    text.append(i + " " + test.classAttribute().value((int) predValue) + " ");
                }
                if (Instance.isMissingValue(predValue)) {
                    text.append("missing ");
                } else {
                    text.append(classifier.distributionForInstance(withMissing)[(int) predValue] + " ");
                }
                text.append(instance.toString(instance.classIndex()) + " "
                        + attributeValuesString(withMissing, attributesToOutput) + "\n");
            }
            test.delete(0);
            i++;
        }
        testReader.close();
    }
    return text.toString();
}

From source file:gnusmail.learning.ClassifierManager.java

License:Open Source License

public void classifyDocument(Document document) throws Exception {
    Instance inst = document.toWekaInstance(filterManager);
    Classifier model;

    System.out.println(inst);//from www.  j a  va  2s .c o m
    if (!ConfigManager.MODEL_FILE.exists()) {
        trainModel();
    }

    FileInputStream fe = new FileInputStream(ConfigManager.MODEL_FILE);
    ObjectInputStream fie = new ObjectInputStream(fe);
    model = (Classifier) fie.readObject();

    System.out.println("\nClassifying...\n");
    double[] res = model.distributionForInstance(inst);
    Attribute att = dataSet.attribute("Label");
    double biggest = 0;
    int biggest_index = 0;
    for (int i = 0; i < res.length; i++) {
        System.out.println("\nDestination folder will be " + att.value(i) + " with probability: " + res[i]);
        if (res[i] > biggest) {
            biggest_index = i;
            biggest = res[i];
        }

    }
    System.out.println("------------------------------");
    System.out.println("\nThe most probable folder is: " + att.value(biggest_index));
}

From source file:GroupProject.DMChartUI.java

/**
* Action for the generate button/*from   w w w.ja  v a  2 s  . c  o  m*/
* It reads the user input from the table and the selected options and performs
* a classifiecation of the user input
* the user can choose linear regression, naive bayes classifier, or j48 trees to classify 
*
*/
private void generateButtonActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_generateButtonActionPerformed
    // TODO add your handling code here:                                              
    // TODO add your handling code here:
    //File file = new File("studentTemp.csv");
    CSVtoArff converter = new CSVtoArff();
    Instances students = null;
    Instances students2 = null;
    try {
        converter.convert("studentTemp.csv", "studentTemp.arff");
    } catch (IOException ex) {
        Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
    }

    try {
        students = new Instances(new BufferedReader(new FileReader("studentTemp.arff")));
        students2 = new Instances(new BufferedReader(new FileReader("studentTemp.arff")));
    } catch (IOException ex) {
        Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
    }

    //get column to predict values for 
    //int target=students.numAttributes()-1; 
    int target = dataSelector.getSelectedIndex() + 1;
    System.out.printf("this is the target: %d\n", target);
    //set target 
    students.setClassIndex(target);
    students2.setClassIndex(target);

    //case on which radio button is selected 
    //Linear Regressions
    if (LRB.isSelected()) {

        LinearRegression model = null;
        if (Lmodel != null) {
            model = Lmodel;
        } else {
            buildLinearModel();
            model = Lmodel;
        }

        System.out.println("im doing linear regression");

        equationDisplayArea.setText(model.toString());

        System.out.println("im going to get the instance");

        Instance prediction2 = getInstance(true);

        Remove remove = new Remove();
        int[] toremove = { 0, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17 };
        remove.setAttributeIndicesArray(toremove);

        try {
            remove.setInputFormat(students);
        } catch (Exception ex) {
            Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
        }

        Instances instNew = null;
        try {
            instNew = Filter.useFilter(students, remove);
        } catch (Exception ex) {
            Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
        }

        prediction2.setDataset(instNew);
        System.err.print("i got the instance");
        double result = 0;
        try {
            result = model.classifyInstance(prediction2);
        } catch (Exception ex) {
            Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
        }

        System.out.printf("the result : %f \n ", result);
        predictValue.setText(Double.toString(result));
        System.out.println("I'm done with Linear Regression");
    }

    //Naive Bayes
    else if (NBB.isSelected()) {
        Classifier cModel = null;

        if (NBmodel != null) {
            cModel = NBmodel;
        } else {
            buildNBClassifier();
            cModel = NBmodel;
        }

        System.out.println("im doing NB");

        //build test 
        Evaluation eTest = null;
        try {
            eTest = new Evaluation(students);
        } catch (Exception ex) {
            Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
        }
        System.out.println("Using NB");

        try {
            eTest.evaluateModel(cModel, students);
        } catch (Exception ex) {
            Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
        }

        //display the test results to console 
        String strSummary = eTest.toSummaryString();
        System.out.println(strSummary);

        //build instance to predict 
        System.out.println("im going to get the instance");

        Instance prediction2 = getInstance(false);

        prediction2.setDataset(students);
        System.err.print("i got the instance");

        //replace with loop stating the class names 
        //fit text based on name of categories 
        double pred = 0;
        try {
            pred = cModel.classifyInstance(prediction2);
            prediction2.setClassValue(pred);
        } catch (Exception ex) {
            Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
        }
        //get the predicted value and set predictValue to it 
        predictValue.setText(prediction2.classAttribute().value((int) pred));

        System.out.println("I'm done with Naive Bayes");

        double[] fDistribution2 = null;
        try {
            fDistribution2 = cModel.distributionForInstance(prediction2);
        } catch (Exception ex) {
            Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
        }

        double max = 0;
        int maxindex = 0;
        max = fDistribution2[0];
        for (int i = 0; i < fDistribution2.length; i++) {
            if (fDistribution2[i] > max) {
                maxindex = i;
                max = fDistribution2[i];
            }
            System.out.println("the value at " + i + " : " + fDistribution2[i]);
            System.out.println("the label at " + i + prediction2.classAttribute().value(i));
        }
        prediction2.setClassValue(maxindex);
        predictValue.setText(prediction2.classAttribute().value(maxindex));

    }
    //J48 Tree
    else if (JB.isSelected()) {

        System.out.println("im doing j48 ");

        Classifier jModel = null;
        if (Jmodel != null) {
            jModel = Jmodel;
        } else {
            buildJClassifier();
            jModel = Jmodel;
        }
        //test model 
        Evaluation eTest2 = null;
        try {
            eTest2 = new Evaluation(students);
        } catch (Exception ex) {
            Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
        }
        System.out.println("Using J48 test");
        try {
            eTest2.evaluateModel(jModel, students);
        } catch (Exception ex) {
            Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
        }
        String strSummary2 = eTest2.toSummaryString();
        System.out.println(strSummary2);

        System.out.println("im going to get the instance");

        Instance prediction2 = getInstance(false);

        prediction2.setDataset(students);
        System.err.print("i got the instance\n");

        double pred = 0;
        try {
            pred = jModel.classifyInstance(prediction2);
            prediction2.setClassValue(pred);
            System.out.println("i did a prediction");
        } catch (Exception ex) {
            Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
        }

        //get the predicted value and set predictValue to it 
        System.out.println("this was pred:" + pred);
        predictValue.setText(prediction2.classAttribute().value((int) pred));

        System.out.println("I'm done with J48");
        //replace with loop stating the class names 
        //fit text based on name of categories 

        double[] fDistribution2 = null;
        try {
            fDistribution2 = jModel.distributionForInstance(prediction2);
        } catch (Exception ex) {
            Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
        }

        double max = 0;
        int maxindex = 0;
        max = fDistribution2[0];
        for (int i = 0; i < fDistribution2.length; i++) {
            if (fDistribution2[i] > max) {
                maxindex = i;
                max = fDistribution2[i];
            }
            System.out.println("the value at " + i + " : " + fDistribution2[i]);
            System.out.println("the label at " + i + " " + prediction2.classAttribute().value(i));
        }
        prediction2.setClassValue(maxindex);
        predictValue.setText(prediction2.classAttribute().value(maxindex));

    }

}

From source file:mlflex.WekaInMemoryLearner.java

License:Open Source License

@Override
protected ModelPredictions TrainTest(ArrayList<String> classificationParameters,
        DataInstanceCollection trainData, DataInstanceCollection testData,
        DataInstanceCollection dependentVariableInstances) throws Exception {
    ArrayList<String> dataPointNames = Lists.SortStringList(trainData.GetDataPointNames());
    FastVector attVector = GetAttributeVector(dependentVariableInstances, dataPointNames, trainData, testData);

    Instances wekaTrainingInstances = GetInstances(dependentVariableInstances, attVector, trainData);
    Instances wekaTestInstances = GetInstances(dependentVariableInstances, attVector, testData);

    ArrayList<String> dependentVariableClasses = Utilities.ProcessorVault.DependentVariableDataProcessor
            .GetUniqueDependentVariableValues();

    Classifier classifier = GetClassifier(classificationParameters);
    classifier.buildClassifier(wekaTrainingInstances);

    Predictions predictions = new Predictions();

    for (DataValues testInstance : testData) {
        String dependentVariableValue = dependentVariableInstances.Get(testInstance.GetID())
                .GetDataPointValue(0);/*  w  ww . jav a  2  s  .co  m*/

        // This is the default before the prediction is made
        Prediction prediction = new Prediction(testInstance.GetID(), dependentVariableValue,
                Lists.PickRandomValue(dependentVariableClasses),
                Lists.CreateDoubleList(0.5, dependentVariableClasses.size()));

        if (!testInstance.HasOnlyMissingValues()) {
            Instance wekaTestInstance = GetInstance(wekaTestInstances, attVector, testInstance, null);

            double clsLabel = classifier.classifyInstance(wekaTestInstance);
            String predictedClass = wekaTestInstance.classAttribute().value((int) clsLabel);

            double[] probabilities = classifier.distributionForInstance(wekaTestInstance);
            ArrayList<Double> classProbabilities = Lists.CreateDoubleList(probabilities);

            prediction = new Prediction(testInstance.GetID(), dependentVariableValue, predictedClass,
                    classProbabilities);
        }

        predictions.Add(prediction);
    }

    classifier = null;

    return new ModelPredictions("", predictions);
}

From source file:nl.bioinf.roelen.thema11.classifier_tools.ClassifierUser.java

License:Open Source License

/**
 * use the classifier to test the sequences in a genbank or fasta file for boundaries
 * @param fileLocation the location of the genbank of fasta file
 * @param classifier the classifier to use
 * @return //from   w ww . ja  va  2  s.co  m
 */
public static ArrayList<ClassifiedNucleotide> getPossibleBoundaries(String fileLocation,
        Classifier classifier) {
    ArrayList<Gene> genesFromFile = new ArrayList<>();
    ArrayList<ClassifiedNucleotide> classifiedNucleotides = new ArrayList<>();
    //read from fasta
    if (fileLocation.toUpperCase().endsWith(".FASTA") || fileLocation.toUpperCase().endsWith(".FA")
            || fileLocation.toUpperCase().endsWith(".FAN")) {
        genesFromFile.addAll(readFasta(fileLocation));
    }
    //read from genbank
    else if (fileLocation.toUpperCase().endsWith(".GENBANK") || fileLocation.toUpperCase().endsWith(".GB")) {
        GenBankReader gbr = new GenBankReader();
        gbr.readFile(fileLocation);
        GenbankResult gbresult = gbr.getResult();
        genesFromFile = gbresult.getGenes();
    }
    //get the test data
    HashMap<String, ArrayList<IntronExonBoundaryTesterResult>> geneTestResults;
    geneTestResults = TestGenes.testForIntronExonBoundaries(genesFromFile, 1);
    ArrayList<InstanceToClassify> instanceNucs = new ArrayList<>();
    try {
        //write our results to a temporary file
        File tempArrf = File.createTempFile("realSet", ".arff");
        ArffWriter.write(tempArrf.getAbsolutePath(), geneTestResults, null);
        //get data
        ConverterUtils.DataSource source = new ConverterUtils.DataSource(tempArrf.getAbsolutePath());
        //SET DATA AND OPTIONS
        Instances data = source.getDataSet();
        for (int i = 0; i < data.numInstances(); i++) {
            Instance in = data.instance(i);
            //get the name of the gene or sequence tested
            String nameOfInstance = in.stringValue(in.numAttributes() - 3);
            //get the tested position
            int testedPosition = (int) in.value(in.numAttributes() - 2);
            //set the class as missing, because we want to find it
            in.setMissing((in.numAttributes() - 1));

            Instance instanceNoExtras = new Instance(in);

            //delete the name and position, they are irrelevant for classifying
            instanceNoExtras.deleteAttributeAt(instanceNoExtras.numAttributes() - 2);
            instanceNoExtras.deleteAttributeAt(instanceNoExtras.numAttributes() - 2);
            InstanceToClassify ic = new InstanceToClassify(instanceNoExtras, testedPosition, nameOfInstance);
            instanceNucs.add(ic);
        }
        for (InstanceToClassify ic : instanceNucs) {
            Instance in = ic.getInstance();
            in.setDataset(data);
            data.setClassIndex(data.numAttributes() - 1);
            //classify our instance
            classifier.classifyInstance(in);
            //save the likelyhood something is part of something
            double likelyhoodBoundary = classifier.distributionForInstance(in)[0];
            double likelyhoodNotBoundary = classifier.distributionForInstance(in)[1];

            //create a classified nucleotide and give it the added data
            ClassifiedNucleotide cn = new ClassifiedNucleotide(likelyhoodBoundary, likelyhoodNotBoundary,
                    ic.getName(), ic.getPosition());
            classifiedNucleotides.add(cn);
        }

    } catch (IOException ex) {
        Logger.getLogger(ClassifierUser.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(ClassifierUser.class.getName()).log(Level.SEVERE, null, ex);
    }
    return classifiedNucleotides;
}

From source file:org.opentox.jaqpot3.qsar.predictor.WekaPredictor.java

License:Open Source License

@Override
public Instances predict(Instances inputSet) throws JaqpotException {

    /* THE OBJECT newData WILL HOST THE PREDICTIONS... */
    Instances newData = InstancesUtil.sortForPMMLModel(model.getIndependentFeatures(), trFieldsAttrIndex,
            inputSet, -1);// w w w.  j ava  2  s. c om
    /* ADD TO THE NEW DATA THE PREDICTION FEATURE*/
    Add attributeAdder = new Add();
    attributeAdder.setAttributeIndex("last");
    attributeAdder.setAttributeName(model.getPredictedFeatures().iterator().next().getUri().toString());
    Instances predictions = null;
    try {
        attributeAdder.setInputFormat(newData);
        predictions = Filter.useFilter(newData, attributeAdder);
        predictions.setClass(
                predictions.attribute(model.getPredictedFeatures().iterator().next().getUri().toString()));
    } catch (Exception ex) {
        String message = "Exception while trying to add prediction feature to Instances";
        logger.debug(message, ex);
        throw new JaqpotException(message, ex);
    }

    if (predictions != null) {
        Classifier classifier = (Classifier) model.getActualModel().getSerializableActualModel();

        int numInstances = predictions.numInstances();
        for (int i = 0; i < numInstances; i++) {
            try {
                double predictionValue = classifier.distributionForInstance(predictions.instance(i))[0];
                predictions.instance(i).setClassValue(predictionValue);
            } catch (Exception ex) {
                logger.warn("Prediction failed :-(", ex);
            }
        }
    }

    List<Integer> trFieldsIndex = WekaInstancesProcess.getTransformationFieldsAttrIndex(predictions,
            pmmlObject);
    predictions = WekaInstancesProcess.removeInstancesAttributes(predictions, trFieldsIndex);
    Instances result = Instances.mergeInstances(justCompounds, predictions);

    return result;
}