Example usage for weka.classifiers Classifier classifyInstance

List of usage examples for weka.classifiers Classifier classifyInstance

Introduction

In this page you can find the example usage for weka.classifiers Classifier classifyInstance.

Prototype

public double classifyInstance(Instance instance) throws Exception;

Source Link

Document

Classifies the given test instance.

Usage

From source file:myclassifier.wekaCode.java

public static void classifyUnseenData(String[] attributes, Classifier classifiers, Instances data)
        throws Exception {
    Instance newInstance = new Instance(data.numAttributes());
    newInstance.setDataset(data);/*w ww . j a  v  a  2  s  .co m*/
    for (int i = 0; i < data.numAttributes() - 1; i++) {
        if (Attribute.NUMERIC == data.attribute(i).type()) {
            Double value = Double.valueOf(attributes[i]);
            newInstance.setValue(i, value);
        } else {
            newInstance.setValue(i, attributes[i]);
        }
    }

    double clsLabel = classifiers.classifyInstance(newInstance);
    newInstance.setClassValue(clsLabel);

    String result = data.classAttribute().value((int) clsLabel);

    System.out.println("Hasil Classify Unseen Data Adalah: " + result);
}

From source file:net.sf.bddbddb.FindBestDomainOrder.java

License:LGPL

public TrialGuess tryNewGoodOrder(EpisodeCollection ec, List allVars, InferenceRule ir, int opNum,
        Order chosenOne, boolean returnBest) {

    out.println("Variables: " + allVars);
    TrialDataGroup vDataGroup = this.dataRepository.getVariableDataGroup(ir, allVars);
    TrialDataGroup aDataGroup = dataRepository.getAttribDataGroup(ir, allVars);
    TrialDataGroup dDataGroup = dataRepository.getDomainDataGroup(ir, allVars);

    // Build instances based on the experimental data.
    TrialInstances vData, aData, dData;/*from  w  w w .ja  v a2s . co  m*/
    vData = vDataGroup.getTrialInstances();
    aData = aDataGroup.getTrialInstances();
    dData = dDataGroup.getTrialInstances();
    /* 
            TrialInstances vTest = dataRepository.buildVarInstances(ir, allVars);
            
            Assert._assert(vData.numInstances() == vTest.numInstances(),"vGot " + vData.numInstances() + " Wanted: " + vTest.numInstances());
            TrialInstances aTest = dataRepository.buildAttribInstances(ir, allVars);
              
            Assert._assert(aData.numInstances() == aTest.numInstances(), "aGot: " + aData.numInstances() + " Wanted: " + aTest.numInstances());
                
            TrialInstances dTest =dataRepository.buildDomainInstances(ir, allVars);
                  
            Assert._assert(dData.numInstances() == dTest.numInstances(), "dGot: " + dData.numInstances() + " Wanted: " + dTest.numInstances());
            out.println(aData);
            out.println(vData);
            out.println(dData);
    */
    // Readjust the weights using an exponential decay factor.
    adjustWeights(vData, aData, dData);
    Discretization vDis = null, aDis = null, dDis = null;

    /*
    // Discretize the experimental data.  null if there is no data.
    if (DISCRETIZE1) vDis = vData.discretize(.5);
    if (DISCRETIZE2) aDis = aData.discretize(.25);
    if (DISCRETIZE3) dDis = dData.threshold(DOMAIN_THRESHOLD);
    */
    vDis = vDataGroup.discretize(.5);
    aDis = aDataGroup.discretize(.25);
    dDis = dDataGroup.threshold(DOMAIN_THRESHOLD);
    // Calculate the accuracy of each classifier using cv folds.
    long vCTime = System.currentTimeMillis();
    double vConstCV = -1;//constFoldCV(vData, CLASSIFIER1);
    vCTime = System.currentTimeMillis() - vCTime;

    long aCTime = System.currentTimeMillis();
    double aConstCV = -1;//constFoldCV(aData, CLASSIFIER2);
    aCTime = System.currentTimeMillis() - aCTime;

    long dCTime = System.currentTimeMillis();
    double dConstCV = -1;//constFoldCV(dData, CLASSIFIER3);
    dCTime = System.currentTimeMillis() - dCTime;

    long vLTime = System.currentTimeMillis();
    double vLeaveCV = -1; //leaveOneOutCV(vData, CLASSIFIER1);
    vLTime = System.currentTimeMillis() - vLTime;

    long aLTime = System.currentTimeMillis();
    double aLeaveCV = -1; //leaveOneOutCV(aData, CLASSIFIER2);
    aLTime = System.currentTimeMillis() - aLTime;

    long dLTime = System.currentTimeMillis();
    double dLeaveCV = -1; //leaveOneOutCV(dData, CLASSIFIER3);
    dLTime = System.currentTimeMillis() - dLTime;

    if (TRACE > 1) {
        out.println(" Var data points: " + vData.numInstances());
        //out.println(" Var Classifier " + NUM_CV_FOLDS + " fold CV Score: " + vConstCV + " took " + vCTime + " ms");
        // out.println(" Var Classifier leave one out CV Score: " + vLeaveCV + " took " + vLTime + " ms");
        out.println(" Var Classifier Weight: " + varClassWeight);
        //out.println(" Var data points: "+vData);
        out.println(" Attrib data points: " + aData.numInstances());
        // out.println(" Attrib Classifier " + NUM_CV_FOLDS + " fold CV Score : " + aConstCV + " took " + aCTime + " ms");
        //out.println(" Attrib Classifier leave one out CV Score: " + aLeaveCV + " took " + aLTime + " ms");
        out.println(" Attrib Classifier Weight: " + attrClassWeight);
        //out.println(" Attrib data points: "+aData);
        out.println(" Domain data points: " + dData.numInstances());
        //out.println(" Domain Classifier " + NUM_CV_FOLDS + " fold CV Score: " + dConstCV + " took " + dCTime + " ms");
        //out.println(" Attrib Classifier leave one out CV Score: " + dLeaveCV + " took " + dLTime + " ms");
        out.println(" Domain Classifier Weight: " + domClassWeight);
        //out.println(" Domain data points: "+dData);

    }

    Classifier vClassifier = null, aClassifier = null, dClassifier = null;
    // Build the classifiers.
    /*    
         if (vData.numInstances() > 0)
    vClassifier = WekaInterface.buildClassifier(CLASSIFIER1, vData);
         if (aData.numInstances() > 0)
    aClassifier = WekaInterface.buildClassifier(CLASSIFIER2, aData);
         if (dData.numInstances() > 0)
    dClassifier = WekaInterface.buildClassifier(CLASSIFIER3, dData);
    */
    vClassifier = vDataGroup.classify();
    aClassifier = aDataGroup.classify();
    dClassifier = dDataGroup.classify();

    if (DUMP_CLASSIFIER_INFO) {
        String baseName = solver.getBaseName() + "_rule" + ir.id;
        if (vClassifier != null)
            dumpClassifierInfo(baseName + "_vclassifier", vClassifier, vData);
        if (aClassifier != null)
            dumpClassifierInfo(baseName + "_aclassifier", aClassifier, aData);
        if (dClassifier != null)
            dumpClassifierInfo(baseName + "_dclassifier", dClassifier, dData);
        try {
            out_t = new PrintStream(new FileOutputStream(baseName + "_trials"));
        } catch (IOException x) {
            solver.err.println("Error while opening file: " + x);
        }
    } else {
        out_t = null;
    }

    if (TRACE > 2) {
        out.println("Var classifier: " + vClassifier);
        out.println("Attrib classifier: " + aClassifier);
        out.println("Domain classifier: " + dClassifier);
    }

    double[][] bucketmeans = getBucketMeans(vDis, aDis, dDis);

    Collection sel = null;
    Collection candidates = null;
    if (chosenOne == null) {
        Collection triedOrders = returnBest ? new LinkedList() : getTriedOrders((BDDInferenceRule) ir, opNum);
        if (ec != null) {
            triedOrders.addAll(ec.trials.keySet());

        }
        Object object = generateCandidateSet(ir, allVars, bucketmeans, vDataGroup, aDataGroup, dDataGroup,
                triedOrders, returnBest);
        /*vClassifier,
        aClassifier, dClassifier, vData,
        aData, dData, vDis, aDis,
        dDis,*/
        if (object == null)
            return null;
        else if (object instanceof Collection)
            candidates = (Collection) object;
        else if (object instanceof TrialGuess)
            return (TrialGuess) object;
    } else {
        sel = Collections.singleton(chosenOne);
    }
    boolean force = (ec != null && ec.getNumTrials() < 2) || vData.numInstances() < INITIAL_VAR_SET
            || aData.numInstances() < INITIAL_ATTRIB_SET || dData.numInstances() < INITIAL_DOM_SET;

    if (!returnBest)
        sel = selectOrder(candidates, vData, aData, dData, ir, force);

    if (sel == null || sel.isEmpty())
        return null;
    Order o_v = (Order) sel.iterator().next();
    try {
        OrderTranslator v2a = new VarToAttribTranslator(ir);
        OrderTranslator a2d = AttribToDomainTranslator.INSTANCE;
        double vClass = 0, aClass = 0, dClass = 0;
        if (vClassifier != null) {
            OrderInstance vInst = OrderInstance.construct(o_v, vData);
            vClass = vClassifier.classifyInstance(vInst);
        }
        Order o_a = v2a.translate(o_v);
        if (aClassifier != null) {
            OrderInstance aInst = OrderInstance.construct(o_a, aData);
            aClass = aClassifier.classifyInstance(aInst);
        }
        Order o_d = a2d.translate(o_a);
        if (dClassifier != null) {
            OrderInstance dInst = OrderInstance.construct(o_d, dData);
            dClass = dClassifier.classifyInstance(dInst);
        }
        int vi = (int) vClass, ai = (int) aClass, di = (int) dClass;
        double vScore = 0, aScore = 0, dScore = 0;
        if (vi < bucketmeans[VMEAN_INDEX].length)
            vScore = bucketmeans[VMEAN_INDEX][vi];
        if (ai < bucketmeans[AMEAN_INDEX].length)
            aScore = bucketmeans[AMEAN_INDEX][ai];
        if (di < bucketmeans[DMEAN_INDEX].length)
            dScore = bucketmeans[DMEAN_INDEX][di];
        double score = varClassWeight * vScore;
        score += attrClassWeight * aScore;
        score += domClassWeight * dScore;
        return genGuess(o_v, score, vClass, aClass, dClass, vDis, aDis, dDis);
    } catch (Exception x) {
        x.printStackTrace();
        Assert.UNREACHABLE(x.toString());
        return null;
    }
}

From source file:net.sf.bddbddb.order.WekaInterface.java

License:LGPL

public static double cvError(int numFolds, Instances data0, String cClassName) {
    if (data0.numInstances() < numFolds)
        return Double.NaN; //more folds than elements
    if (numFolds == 0)
        return Double.NaN; // no folds
    if (data0.numInstances() == 0)
        return 0; //no instances

    Instances data = new Instances(data0);
    //data.randomize(new Random(System.currentTimeMillis()));
    data.stratify(numFolds);//from   w  w w .  j  a v  a2s.  c o  m
    Assert._assert(data.classAttribute() != null);
    double[] estimates = new double[numFolds];
    for (int i = 0; i < numFolds; ++i) {
        Instances trainData = data.trainCV(numFolds, i);
        Assert._assert(trainData.classAttribute() != null);
        Assert._assert(trainData.numInstances() != 0, "Cannot train classifier on 0 instances.");

        Instances testData = data.testCV(numFolds, i);
        Assert._assert(testData.classAttribute() != null);
        Assert._assert(testData.numInstances() != 0, "Cannot test classifier on 0 instances.");

        int temp = FindBestDomainOrder.TRACE;
        FindBestDomainOrder.TRACE = 0;
        Classifier classifier = buildClassifier(cClassName, trainData);
        FindBestDomainOrder.TRACE = temp;
        int count = testData.numInstances();
        double loss = 0;
        double sum = 0;
        for (Enumeration e = testData.enumerateInstances(); e.hasMoreElements();) {
            Instance instance = (Instance) e.nextElement();
            Assert._assert(instance != null);
            Assert._assert(instance.classAttribute() != null
                    && instance.classAttribute() == trainData.classAttribute());
            try {
                double testClass = classifier.classifyInstance(instance);
                double weight = instance.weight();
                if (testClass != instance.classValue())
                    loss += weight;
                sum += weight;
            } catch (Exception ex) {
                FindBestDomainOrder.out.println("Exception while classifying: " + instance + "\n" + ex);
            }
        }
        estimates[i] = 1 - loss / sum;
    }
    double average = 0;
    for (int i = 0; i < numFolds; ++i)
        average += estimates[i];

    return average / numFolds;
}

From source file:nl.bioinf.roelen.thema11.classifier_tools.ClassifierUser.java

License:Open Source License

/**
 * use the classifier to test the sequences in a genbank or fasta file for boundaries
 * @param fileLocation the location of the genbank of fasta file
 * @param classifier the classifier to use
 * @return // ww  w.  j av a  2 s . c  om
 */
public static ArrayList<ClassifiedNucleotide> getPossibleBoundaries(String fileLocation,
        Classifier classifier) {
    ArrayList<Gene> genesFromFile = new ArrayList<>();
    ArrayList<ClassifiedNucleotide> classifiedNucleotides = new ArrayList<>();
    //read from fasta
    if (fileLocation.toUpperCase().endsWith(".FASTA") || fileLocation.toUpperCase().endsWith(".FA")
            || fileLocation.toUpperCase().endsWith(".FAN")) {
        genesFromFile.addAll(readFasta(fileLocation));
    }
    //read from genbank
    else if (fileLocation.toUpperCase().endsWith(".GENBANK") || fileLocation.toUpperCase().endsWith(".GB")) {
        GenBankReader gbr = new GenBankReader();
        gbr.readFile(fileLocation);
        GenbankResult gbresult = gbr.getResult();
        genesFromFile = gbresult.getGenes();
    }
    //get the test data
    HashMap<String, ArrayList<IntronExonBoundaryTesterResult>> geneTestResults;
    geneTestResults = TestGenes.testForIntronExonBoundaries(genesFromFile, 1);
    ArrayList<InstanceToClassify> instanceNucs = new ArrayList<>();
    try {
        //write our results to a temporary file
        File tempArrf = File.createTempFile("realSet", ".arff");
        ArffWriter.write(tempArrf.getAbsolutePath(), geneTestResults, null);
        //get data
        ConverterUtils.DataSource source = new ConverterUtils.DataSource(tempArrf.getAbsolutePath());
        //SET DATA AND OPTIONS
        Instances data = source.getDataSet();
        for (int i = 0; i < data.numInstances(); i++) {
            Instance in = data.instance(i);
            //get the name of the gene or sequence tested
            String nameOfInstance = in.stringValue(in.numAttributes() - 3);
            //get the tested position
            int testedPosition = (int) in.value(in.numAttributes() - 2);
            //set the class as missing, because we want to find it
            in.setMissing((in.numAttributes() - 1));

            Instance instanceNoExtras = new Instance(in);

            //delete the name and position, they are irrelevant for classifying
            instanceNoExtras.deleteAttributeAt(instanceNoExtras.numAttributes() - 2);
            instanceNoExtras.deleteAttributeAt(instanceNoExtras.numAttributes() - 2);
            InstanceToClassify ic = new InstanceToClassify(instanceNoExtras, testedPosition, nameOfInstance);
            instanceNucs.add(ic);
        }
        for (InstanceToClassify ic : instanceNucs) {
            Instance in = ic.getInstance();
            in.setDataset(data);
            data.setClassIndex(data.numAttributes() - 1);
            //classify our instance
            classifier.classifyInstance(in);
            //save the likelyhood something is part of something
            double likelyhoodBoundary = classifier.distributionForInstance(in)[0];
            double likelyhoodNotBoundary = classifier.distributionForInstance(in)[1];

            //create a classified nucleotide and give it the added data
            ClassifiedNucleotide cn = new ClassifiedNucleotide(likelyhoodBoundary, likelyhoodNotBoundary,
                    ic.getName(), ic.getPosition());
            classifiedNucleotides.add(cn);
        }

    } catch (IOException ex) {
        Logger.getLogger(ClassifierUser.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(ClassifierUser.class.getName()).log(Level.SEVERE, null, ex);
    }
    return classifiedNucleotides;
}

From source file:nl.uva.sne.commons.ClusterUtils.java

public static Map<String, String> classify(String testDataPath, Classifier classifier) throws Exception {

    Instances testData = createInstancesWithClasses(testDataPath);
    testData.setClassIndex(testData.numAttributes() - 1);

    Map<String, String> classes = new HashMap();
    for (int j = 0; j < testData.numInstances(); j++) {
        //                System.err.println(m);
        Instance inst = testData.get(j);
        String id = inst.toString().split(",")[0];
        //                System.err.println(inst);
        //            System.out.print("ID: " + UNdata.instance(j).value(0) + " ");
        int clsLabel = (int) classifier.classifyInstance(inst);
        //            String theClass = testData.classAttribute().value(clsLabel);
        //            System.err.println(id + " " + clsLabel);
        classes.put(testDataPath + File.separator + id, String.valueOf(clsLabel));
    }/*from w  w w  . j av  a  2  s.  c o  m*/
    return classes;
}

From source file:nlp.NLP.java

public void calculateRate(String review) throws IOException, Exception {
    double positiveSentences = 0, allSentences = 0;
    String predictedClass = "";
    File writeFile = new File("test.arff");
    PrintWriter pw = new PrintWriter(writeFile);
    pw.println("@relation movie_review");
    pw.println("@attribute 'positive_words' numeric");
    pw.println("@attribute 'negative_words' numeric");
    pw.println("@attribute 'positive_score' numeric");
    pw.println("@attribute 'negative_score' numeric");
    pw.println("@attribute 'strongPositive' numeric");
    pw.println("@attribute 'strongNegative' numeric");
    pw.println("@attribute 'subjective_words' numeric");
    pw.println("@attribute 'neutral_words' numeric");
    pw.println("@attribute 'adj_words' numeric");
    pw.println("@attribute 'adv_words' numeric");
    pw.println("@attribute 'class' {negative, positive}");
    pw.println("@data");

    String[] splitByPoint = review.split("\\.");
    for (int j = 0; j < splitByPoint.length; j++) {
        // String normalized = normalization(splitByPoint[j]);

        if (splitByPoint[j] == null || splitByPoint[j].isEmpty()) {
            continue;
        }/*ww  w. j  av a2 s .  c o m*/
        System.out.println("your review : " + splitByPoint[j]);
        WekaFileGenerator wk = new WekaFileGenerator(splitByPoint[j], pipeline, ra);
        pw.print(wk.getSentence().getPositiveWords() + "," + wk.getSentence().getNegativeWords() + ","
                + +wk.getSentence().getSumOfPositiveScore() + "," + wk.getSentence().getSumOfNegativeScore()
                + "," + wk.getSentence().getStrongPositive() + "," + wk.getSentence().getStrongNegative() + ","
                + wk.getSentence().getSubjectiveWords() + "," + wk.getSentence().getNeutralWords() + ","
                + wk.getSentence().getNumOfAdjective() + "," + wk.getSentence().getNumOfAdverb() + ", ? \n");
        //   System.out.println("here");
        //    }

    }

    pw.close();
    DataSource test = new DataSource("test.arff");
    Instances testData = test.getDataSet();
    testData.setClassIndex(testData.numAttributes() - 1);

    Classifier j = (Classifier) weka.core.SerializationHelper.read("movieReview.model");

    for (int i = 0; i < testData.numInstances(); i++) {
        Instance inst = testData.instance(i);
        double predictNum = j.classifyInstance(inst);
        predictedClass = testData.classAttribute().value((int) predictNum);
        System.out.println("Class Predicted: " + predictedClass);
        if (predictedClass.equals("positive")) {
            positiveSentences++;
            System.out.println("positiveSentences = " + positiveSentences);
            if (splitByPoint[i].contains("story")) {
                story = 1;
            }
            if (splitByPoint[i].contains("direction")) {
                direction = 1;
            }
        } else {
            //   positiveSentences--;
            if (splitByPoint[i].contains("story")) {
                story = 0;
            }
            if (splitByPoint[i].contains("direction")) {
                direction = 0;
            }
        }
        allSentences++;
        System.out.println("allSentences = " + allSentences);
    }
    DecimalFormat format = new DecimalFormat("#0.000");
    rate = (positiveSentences / allSentences) * 100;
    if (rate != NaN) {
        rate = Double.parseDouble(format.format(rate));
        if (rate > 0 && rate <= 10) {
            rate = 0.5;
        } else if (rate > 10 && rate <= 20) {
            rate = 1.0;
        } else if (rate > 20 && rate <= 30) {
            rate = 1.5;
        } else if (rate > 30 && rate <= 40) {
            rate = 2;
        } else if (rate > 40 && rate <= 50) {
            rate = 2.5;
        } else if (rate > 50 && rate <= 60) {
            rate = 3;
        } else if (rate > 60 && rate <= 70) {
            rate = 3.5;
        } else if (rate > 70 && rate <= 80) {
            rate = 4;
        } else if (rate > 80 && rate <= 90) {
            rate = 4.5;
        } else if (rate > 90 && rate <= 100) {
            rate = 5;
        }
    }
    System.out.println("rate: " + rate);

}

From source file:org.jaqpot.algorithm.resource.WekaMLR.java

License:Open Source License

@POST
@Path("prediction")
public Response prediction(PredictionRequest request) {

    try {//ww  w . j a  va  2s. co  m
        if (request.getDataset().getDataEntry().isEmpty()
                || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
            return Response.status(Response.Status.BAD_REQUEST).entity(
                    ErrorReportFactory.badRequest("Dataset is empty", "Cannot train model on empty dataset"))
                    .build();
        }

        String base64Model = (String) request.getRawModel();
        byte[] modelBytes = Base64.getDecoder().decode(base64Model);
        ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes);
        ObjectInput in = new ObjectInputStream(bais);
        WekaModel model = (WekaModel) in.readObject();

        Classifier classifier = model.getClassifier();
        Instances data = InstanceUtils.createFromDataset(request.getDataset());
        List<String> additionalInfo = (List) request.getAdditionalInfo();
        String dependentFeature = additionalInfo.get(0);
        String dependentFeatureName = additionalInfo.get(1);
        data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes());
        List<LinkedHashMap<String, Object>> predictions = new ArrayList<>();
        //            data.stream().forEach(instance -> {
        //                try {
        //                    double prediction = classifier.classifyInstance(instance);
        //                    Map<String, Object> predictionMap = new HashMap<>();
        //                    predictionMap.put("Weka MLR prediction of " + dependentFeature, prediction);
        //                    predictions.add(predictionMap);
        //                } catch (Exception ex) {
        //                    Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
        //                }
        //            });

        for (int i = 0; i < data.numInstances(); i++) {
            Instance instance = data.instance(i);
            try {
                double prediction = classifier.classifyInstance(instance);
                LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>();
                predictionMap.put("Weka MLR prediction of " + dependentFeatureName, prediction);
                predictions.add(predictionMap);
            } catch (Exception ex) {
                Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
                return Response.status(Response.Status.BAD_REQUEST).entity(
                        ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage()))
                        .build();
            }
        }

        PredictionResponse response = new PredictionResponse();
        response.setPredictions(predictions);
        return Response.ok(response).build();
    } catch (Exception ex) {
        Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
        return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
    }
}

From source file:org.jaqpot.algorithm.resource.WekaPLS.java

License:Open Source License

@POST
@Path("prediction")
public Response prediction(PredictionRequest request) {
    try {/*from   w w w.java2s .  c o m*/
        if (request.getDataset().getDataEntry().isEmpty()
                || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
            return Response
                    .status(Response.Status.BAD_REQUEST).entity(ErrorReportFactory
                            .badRequest("Dataset is empty", "Cannot make predictions on empty dataset"))
                    .build();
        }

        String base64Model = (String) request.getRawModel();
        byte[] modelBytes = Base64.getDecoder().decode(base64Model);
        ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes);
        ObjectInput in = new ObjectInputStream(bais);
        WekaModel model = (WekaModel) in.readObject();

        Classifier classifier = model.getClassifier();
        Instances data = InstanceUtils.createFromDataset(request.getDataset());
        String dependentFeature = (String) request.getAdditionalInfo();
        data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes());
        data.setClass(data.attribute(dependentFeature));
        List<LinkedHashMap<String, Object>> predictions = new ArrayList<>();
        //            data.stream().forEach(instance -> {
        //                try {
        //                    double prediction = classifier.classifyInstance(instance);
        //                    Map<String, Object> predictionMap = new HashMap<>();
        //                    predictionMap.put("Weka PLS prediction of " + dependentFeature, prediction);
        //                    predictions.add(predictionMap);
        //                } catch (Exception ex) {
        //                    Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex);
        //                }
        //            });

        for (int i = 0; i < data.numInstances(); i++) {
            Instance instance = data.instance(i);
            try {
                double prediction = classifier.classifyInstance(instance);
                LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>();
                predictionMap.put("Weka PLS prediction of " + dependentFeature, prediction);
                predictions.add(predictionMap);
            } catch (Exception ex) {
                Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
                return Response.status(Response.Status.BAD_REQUEST).entity(
                        ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage()))
                        .build();
            }
        }

        PredictionResponse response = new PredictionResponse();
        response.setPredictions(predictions);
        return Response.ok(response).build();
    } catch (IOException | ClassNotFoundException ex) {
        Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex);
        return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
    }
}

From source file:org.jaqpot.algorithm.resource.WekaRBF.java

License:Open Source License

@POST
@Path("prediction")
public Response prediction(PredictionRequest request) {

    try {/*from   w ww. j a v  a 2 s.com*/
        if (request.getDataset().getDataEntry().isEmpty()
                || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
            return Response
                    .status(Response.Status.BAD_REQUEST).entity(ErrorReportFactory
                            .badRequest("Dataset is empty", "Cannot make predictions on empty dataset"))
                    .build();
        }

        String base64Model = (String) request.getRawModel();
        byte[] modelBytes = Base64.getDecoder().decode(base64Model);
        ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes);
        ObjectInput in = new ObjectInputStream(bais);
        WekaModel model = (WekaModel) in.readObject();

        Classifier classifier = model.getClassifier();
        Instances data = InstanceUtils.createFromDataset(request.getDataset());
        String dependentFeature = (String) request.getAdditionalInfo();
        data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes());
        data.setClass(data.attribute(dependentFeature));

        List<LinkedHashMap<String, Object>> predictions = new ArrayList<>();
        //            data.stream().forEach(instance -> {
        //                try {
        //                    double prediction = classifier.classifyInstance(instance);
        //                    Map<String, Object> predictionMap = new HashMap<>();
        //                    predictionMap.put("Weka MLR prediction of " + dependentFeature, prediction);
        //                    predictions.add(predictionMap);
        //                } catch (Exception ex) {
        //                    Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
        //                }
        //            });

        for (int i = 0; i < data.numInstances(); i++) {
            Instance instance = data.instance(i);
            try {
                double prediction = classifier.classifyInstance(instance);
                LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>();
                predictionMap.put("Weka RBF prediction of " + dependentFeature, prediction);
                predictions.add(predictionMap);
            } catch (Exception ex) {
                Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
                return Response.status(Response.Status.BAD_REQUEST).entity(
                        ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage()))
                        .build();
            }
        }

        PredictionResponse response = new PredictionResponse();
        response.setPredictions(predictions);
        return Response.ok(response).build();
    } catch (IOException | ClassNotFoundException ex) {
        Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
        return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
    }
}

From source file:org.jaqpot.algorithm.resource.WekaSVM.java

License:Open Source License

@POST
@Path("prediction")
public Response prediction(PredictionRequest request) {
    try {/*from   w ww. jav  a 2  s  . c  o  m*/
        if (request.getDataset().getDataEntry().isEmpty()
                || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
            return Response
                    .status(Response.Status.BAD_REQUEST).entity(ErrorReportFactory
                            .badRequest("Dataset is empty", "Cannot make predictions on empty dataset"))
                    .build();
        }

        String base64Model = (String) request.getRawModel();
        byte[] modelBytes = Base64.getDecoder().decode(base64Model);
        ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes);
        ObjectInput in = new ObjectInputStream(bais);
        WekaModel model = (WekaModel) in.readObject();

        Classifier classifier = model.getClassifier();
        Instances data = InstanceUtils.createFromDataset(request.getDataset());
        String dependentFeature = (String) request.getAdditionalInfo();
        data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes());
        data.setClass(data.attribute(dependentFeature));
        List<LinkedHashMap<String, Object>> predictions = new ArrayList<>();
        //            data.stream().forEach(instance -> {
        //                try {
        //                    double prediction = classifier.classifyInstance(instance);
        //                    Map<String, Object> predictionMap = new HashMap<>();
        //                    predictionMap.put("Weka SVM prediction of " + dependentFeature, prediction);
        //                    predictions.add(predictionMap);
        //                } catch (Exception ex) {
        //                    Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex);
        //                }
        //            });

        for (int i = 0; i < data.numInstances(); i++) {
            Instance instance = data.instance(i);
            try {
                double prediction = classifier.classifyInstance(instance);
                LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>();
                predictionMap.put("Weka SVM prediction of " + dependentFeature, prediction);
                predictions.add(predictionMap);
            } catch (Exception ex) {
                Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
                return Response.status(Response.Status.BAD_REQUEST).entity(
                        ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage()))
                        .build();
            }
        }

        PredictionResponse response = new PredictionResponse();
        response.setPredictions(predictions);
        return Response.ok(response).build();
    } catch (Exception ex) {
        Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex);
        return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
    }
}