List of usage examples for weka.classifiers Classifier classifyInstance
public double classifyInstance(Instance instance) throws Exception;
From source file:myclassifier.wekaCode.java
public static void classifyUnseenData(String[] attributes, Classifier classifiers, Instances data) throws Exception { Instance newInstance = new Instance(data.numAttributes()); newInstance.setDataset(data);/*w ww . j a v a 2 s .co m*/ for (int i = 0; i < data.numAttributes() - 1; i++) { if (Attribute.NUMERIC == data.attribute(i).type()) { Double value = Double.valueOf(attributes[i]); newInstance.setValue(i, value); } else { newInstance.setValue(i, attributes[i]); } } double clsLabel = classifiers.classifyInstance(newInstance); newInstance.setClassValue(clsLabel); String result = data.classAttribute().value((int) clsLabel); System.out.println("Hasil Classify Unseen Data Adalah: " + result); }
From source file:net.sf.bddbddb.FindBestDomainOrder.java
License:LGPL
public TrialGuess tryNewGoodOrder(EpisodeCollection ec, List allVars, InferenceRule ir, int opNum, Order chosenOne, boolean returnBest) { out.println("Variables: " + allVars); TrialDataGroup vDataGroup = this.dataRepository.getVariableDataGroup(ir, allVars); TrialDataGroup aDataGroup = dataRepository.getAttribDataGroup(ir, allVars); TrialDataGroup dDataGroup = dataRepository.getDomainDataGroup(ir, allVars); // Build instances based on the experimental data. TrialInstances vData, aData, dData;/*from w w w .ja v a2s . co m*/ vData = vDataGroup.getTrialInstances(); aData = aDataGroup.getTrialInstances(); dData = dDataGroup.getTrialInstances(); /* TrialInstances vTest = dataRepository.buildVarInstances(ir, allVars); Assert._assert(vData.numInstances() == vTest.numInstances(),"vGot " + vData.numInstances() + " Wanted: " + vTest.numInstances()); TrialInstances aTest = dataRepository.buildAttribInstances(ir, allVars); Assert._assert(aData.numInstances() == aTest.numInstances(), "aGot: " + aData.numInstances() + " Wanted: " + aTest.numInstances()); TrialInstances dTest =dataRepository.buildDomainInstances(ir, allVars); Assert._assert(dData.numInstances() == dTest.numInstances(), "dGot: " + dData.numInstances() + " Wanted: " + dTest.numInstances()); out.println(aData); out.println(vData); out.println(dData); */ // Readjust the weights using an exponential decay factor. adjustWeights(vData, aData, dData); Discretization vDis = null, aDis = null, dDis = null; /* // Discretize the experimental data. null if there is no data. if (DISCRETIZE1) vDis = vData.discretize(.5); if (DISCRETIZE2) aDis = aData.discretize(.25); if (DISCRETIZE3) dDis = dData.threshold(DOMAIN_THRESHOLD); */ vDis = vDataGroup.discretize(.5); aDis = aDataGroup.discretize(.25); dDis = dDataGroup.threshold(DOMAIN_THRESHOLD); // Calculate the accuracy of each classifier using cv folds. long vCTime = System.currentTimeMillis(); double vConstCV = -1;//constFoldCV(vData, CLASSIFIER1); vCTime = System.currentTimeMillis() - vCTime; long aCTime = System.currentTimeMillis(); double aConstCV = -1;//constFoldCV(aData, CLASSIFIER2); aCTime = System.currentTimeMillis() - aCTime; long dCTime = System.currentTimeMillis(); double dConstCV = -1;//constFoldCV(dData, CLASSIFIER3); dCTime = System.currentTimeMillis() - dCTime; long vLTime = System.currentTimeMillis(); double vLeaveCV = -1; //leaveOneOutCV(vData, CLASSIFIER1); vLTime = System.currentTimeMillis() - vLTime; long aLTime = System.currentTimeMillis(); double aLeaveCV = -1; //leaveOneOutCV(aData, CLASSIFIER2); aLTime = System.currentTimeMillis() - aLTime; long dLTime = System.currentTimeMillis(); double dLeaveCV = -1; //leaveOneOutCV(dData, CLASSIFIER3); dLTime = System.currentTimeMillis() - dLTime; if (TRACE > 1) { out.println(" Var data points: " + vData.numInstances()); //out.println(" Var Classifier " + NUM_CV_FOLDS + " fold CV Score: " + vConstCV + " took " + vCTime + " ms"); // out.println(" Var Classifier leave one out CV Score: " + vLeaveCV + " took " + vLTime + " ms"); out.println(" Var Classifier Weight: " + varClassWeight); //out.println(" Var data points: "+vData); out.println(" Attrib data points: " + aData.numInstances()); // out.println(" Attrib Classifier " + NUM_CV_FOLDS + " fold CV Score : " + aConstCV + " took " + aCTime + " ms"); //out.println(" Attrib Classifier leave one out CV Score: " + aLeaveCV + " took " + aLTime + " ms"); out.println(" Attrib Classifier Weight: " + attrClassWeight); //out.println(" Attrib data points: "+aData); out.println(" Domain data points: " + dData.numInstances()); //out.println(" Domain Classifier " + NUM_CV_FOLDS + " fold CV Score: " + dConstCV + " took " + dCTime + " ms"); //out.println(" Attrib Classifier leave one out CV Score: " + dLeaveCV + " took " + dLTime + " ms"); out.println(" Domain Classifier Weight: " + domClassWeight); //out.println(" Domain data points: "+dData); } Classifier vClassifier = null, aClassifier = null, dClassifier = null; // Build the classifiers. /* if (vData.numInstances() > 0) vClassifier = WekaInterface.buildClassifier(CLASSIFIER1, vData); if (aData.numInstances() > 0) aClassifier = WekaInterface.buildClassifier(CLASSIFIER2, aData); if (dData.numInstances() > 0) dClassifier = WekaInterface.buildClassifier(CLASSIFIER3, dData); */ vClassifier = vDataGroup.classify(); aClassifier = aDataGroup.classify(); dClassifier = dDataGroup.classify(); if (DUMP_CLASSIFIER_INFO) { String baseName = solver.getBaseName() + "_rule" + ir.id; if (vClassifier != null) dumpClassifierInfo(baseName + "_vclassifier", vClassifier, vData); if (aClassifier != null) dumpClassifierInfo(baseName + "_aclassifier", aClassifier, aData); if (dClassifier != null) dumpClassifierInfo(baseName + "_dclassifier", dClassifier, dData); try { out_t = new PrintStream(new FileOutputStream(baseName + "_trials")); } catch (IOException x) { solver.err.println("Error while opening file: " + x); } } else { out_t = null; } if (TRACE > 2) { out.println("Var classifier: " + vClassifier); out.println("Attrib classifier: " + aClassifier); out.println("Domain classifier: " + dClassifier); } double[][] bucketmeans = getBucketMeans(vDis, aDis, dDis); Collection sel = null; Collection candidates = null; if (chosenOne == null) { Collection triedOrders = returnBest ? new LinkedList() : getTriedOrders((BDDInferenceRule) ir, opNum); if (ec != null) { triedOrders.addAll(ec.trials.keySet()); } Object object = generateCandidateSet(ir, allVars, bucketmeans, vDataGroup, aDataGroup, dDataGroup, triedOrders, returnBest); /*vClassifier, aClassifier, dClassifier, vData, aData, dData, vDis, aDis, dDis,*/ if (object == null) return null; else if (object instanceof Collection) candidates = (Collection) object; else if (object instanceof TrialGuess) return (TrialGuess) object; } else { sel = Collections.singleton(chosenOne); } boolean force = (ec != null && ec.getNumTrials() < 2) || vData.numInstances() < INITIAL_VAR_SET || aData.numInstances() < INITIAL_ATTRIB_SET || dData.numInstances() < INITIAL_DOM_SET; if (!returnBest) sel = selectOrder(candidates, vData, aData, dData, ir, force); if (sel == null || sel.isEmpty()) return null; Order o_v = (Order) sel.iterator().next(); try { OrderTranslator v2a = new VarToAttribTranslator(ir); OrderTranslator a2d = AttribToDomainTranslator.INSTANCE; double vClass = 0, aClass = 0, dClass = 0; if (vClassifier != null) { OrderInstance vInst = OrderInstance.construct(o_v, vData); vClass = vClassifier.classifyInstance(vInst); } Order o_a = v2a.translate(o_v); if (aClassifier != null) { OrderInstance aInst = OrderInstance.construct(o_a, aData); aClass = aClassifier.classifyInstance(aInst); } Order o_d = a2d.translate(o_a); if (dClassifier != null) { OrderInstance dInst = OrderInstance.construct(o_d, dData); dClass = dClassifier.classifyInstance(dInst); } int vi = (int) vClass, ai = (int) aClass, di = (int) dClass; double vScore = 0, aScore = 0, dScore = 0; if (vi < bucketmeans[VMEAN_INDEX].length) vScore = bucketmeans[VMEAN_INDEX][vi]; if (ai < bucketmeans[AMEAN_INDEX].length) aScore = bucketmeans[AMEAN_INDEX][ai]; if (di < bucketmeans[DMEAN_INDEX].length) dScore = bucketmeans[DMEAN_INDEX][di]; double score = varClassWeight * vScore; score += attrClassWeight * aScore; score += domClassWeight * dScore; return genGuess(o_v, score, vClass, aClass, dClass, vDis, aDis, dDis); } catch (Exception x) { x.printStackTrace(); Assert.UNREACHABLE(x.toString()); return null; } }
From source file:net.sf.bddbddb.order.WekaInterface.java
License:LGPL
public static double cvError(int numFolds, Instances data0, String cClassName) { if (data0.numInstances() < numFolds) return Double.NaN; //more folds than elements if (numFolds == 0) return Double.NaN; // no folds if (data0.numInstances() == 0) return 0; //no instances Instances data = new Instances(data0); //data.randomize(new Random(System.currentTimeMillis())); data.stratify(numFolds);//from w w w . j a v a2s. c o m Assert._assert(data.classAttribute() != null); double[] estimates = new double[numFolds]; for (int i = 0; i < numFolds; ++i) { Instances trainData = data.trainCV(numFolds, i); Assert._assert(trainData.classAttribute() != null); Assert._assert(trainData.numInstances() != 0, "Cannot train classifier on 0 instances."); Instances testData = data.testCV(numFolds, i); Assert._assert(testData.classAttribute() != null); Assert._assert(testData.numInstances() != 0, "Cannot test classifier on 0 instances."); int temp = FindBestDomainOrder.TRACE; FindBestDomainOrder.TRACE = 0; Classifier classifier = buildClassifier(cClassName, trainData); FindBestDomainOrder.TRACE = temp; int count = testData.numInstances(); double loss = 0; double sum = 0; for (Enumeration e = testData.enumerateInstances(); e.hasMoreElements();) { Instance instance = (Instance) e.nextElement(); Assert._assert(instance != null); Assert._assert(instance.classAttribute() != null && instance.classAttribute() == trainData.classAttribute()); try { double testClass = classifier.classifyInstance(instance); double weight = instance.weight(); if (testClass != instance.classValue()) loss += weight; sum += weight; } catch (Exception ex) { FindBestDomainOrder.out.println("Exception while classifying: " + instance + "\n" + ex); } } estimates[i] = 1 - loss / sum; } double average = 0; for (int i = 0; i < numFolds; ++i) average += estimates[i]; return average / numFolds; }
From source file:nl.bioinf.roelen.thema11.classifier_tools.ClassifierUser.java
License:Open Source License
/** * use the classifier to test the sequences in a genbank or fasta file for boundaries * @param fileLocation the location of the genbank of fasta file * @param classifier the classifier to use * @return // ww w. j av a 2 s . c om */ public static ArrayList<ClassifiedNucleotide> getPossibleBoundaries(String fileLocation, Classifier classifier) { ArrayList<Gene> genesFromFile = new ArrayList<>(); ArrayList<ClassifiedNucleotide> classifiedNucleotides = new ArrayList<>(); //read from fasta if (fileLocation.toUpperCase().endsWith(".FASTA") || fileLocation.toUpperCase().endsWith(".FA") || fileLocation.toUpperCase().endsWith(".FAN")) { genesFromFile.addAll(readFasta(fileLocation)); } //read from genbank else if (fileLocation.toUpperCase().endsWith(".GENBANK") || fileLocation.toUpperCase().endsWith(".GB")) { GenBankReader gbr = new GenBankReader(); gbr.readFile(fileLocation); GenbankResult gbresult = gbr.getResult(); genesFromFile = gbresult.getGenes(); } //get the test data HashMap<String, ArrayList<IntronExonBoundaryTesterResult>> geneTestResults; geneTestResults = TestGenes.testForIntronExonBoundaries(genesFromFile, 1); ArrayList<InstanceToClassify> instanceNucs = new ArrayList<>(); try { //write our results to a temporary file File tempArrf = File.createTempFile("realSet", ".arff"); ArffWriter.write(tempArrf.getAbsolutePath(), geneTestResults, null); //get data ConverterUtils.DataSource source = new ConverterUtils.DataSource(tempArrf.getAbsolutePath()); //SET DATA AND OPTIONS Instances data = source.getDataSet(); for (int i = 0; i < data.numInstances(); i++) { Instance in = data.instance(i); //get the name of the gene or sequence tested String nameOfInstance = in.stringValue(in.numAttributes() - 3); //get the tested position int testedPosition = (int) in.value(in.numAttributes() - 2); //set the class as missing, because we want to find it in.setMissing((in.numAttributes() - 1)); Instance instanceNoExtras = new Instance(in); //delete the name and position, they are irrelevant for classifying instanceNoExtras.deleteAttributeAt(instanceNoExtras.numAttributes() - 2); instanceNoExtras.deleteAttributeAt(instanceNoExtras.numAttributes() - 2); InstanceToClassify ic = new InstanceToClassify(instanceNoExtras, testedPosition, nameOfInstance); instanceNucs.add(ic); } for (InstanceToClassify ic : instanceNucs) { Instance in = ic.getInstance(); in.setDataset(data); data.setClassIndex(data.numAttributes() - 1); //classify our instance classifier.classifyInstance(in); //save the likelyhood something is part of something double likelyhoodBoundary = classifier.distributionForInstance(in)[0]; double likelyhoodNotBoundary = classifier.distributionForInstance(in)[1]; //create a classified nucleotide and give it the added data ClassifiedNucleotide cn = new ClassifiedNucleotide(likelyhoodBoundary, likelyhoodNotBoundary, ic.getName(), ic.getPosition()); classifiedNucleotides.add(cn); } } catch (IOException ex) { Logger.getLogger(ClassifierUser.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(ClassifierUser.class.getName()).log(Level.SEVERE, null, ex); } return classifiedNucleotides; }
From source file:nl.uva.sne.commons.ClusterUtils.java
public static Map<String, String> classify(String testDataPath, Classifier classifier) throws Exception { Instances testData = createInstancesWithClasses(testDataPath); testData.setClassIndex(testData.numAttributes() - 1); Map<String, String> classes = new HashMap(); for (int j = 0; j < testData.numInstances(); j++) { // System.err.println(m); Instance inst = testData.get(j); String id = inst.toString().split(",")[0]; // System.err.println(inst); // System.out.print("ID: " + UNdata.instance(j).value(0) + " "); int clsLabel = (int) classifier.classifyInstance(inst); // String theClass = testData.classAttribute().value(clsLabel); // System.err.println(id + " " + clsLabel); classes.put(testDataPath + File.separator + id, String.valueOf(clsLabel)); }/*from w w w . j av a 2 s. c o m*/ return classes; }
From source file:nlp.NLP.java
public void calculateRate(String review) throws IOException, Exception { double positiveSentences = 0, allSentences = 0; String predictedClass = ""; File writeFile = new File("test.arff"); PrintWriter pw = new PrintWriter(writeFile); pw.println("@relation movie_review"); pw.println("@attribute 'positive_words' numeric"); pw.println("@attribute 'negative_words' numeric"); pw.println("@attribute 'positive_score' numeric"); pw.println("@attribute 'negative_score' numeric"); pw.println("@attribute 'strongPositive' numeric"); pw.println("@attribute 'strongNegative' numeric"); pw.println("@attribute 'subjective_words' numeric"); pw.println("@attribute 'neutral_words' numeric"); pw.println("@attribute 'adj_words' numeric"); pw.println("@attribute 'adv_words' numeric"); pw.println("@attribute 'class' {negative, positive}"); pw.println("@data"); String[] splitByPoint = review.split("\\."); for (int j = 0; j < splitByPoint.length; j++) { // String normalized = normalization(splitByPoint[j]); if (splitByPoint[j] == null || splitByPoint[j].isEmpty()) { continue; }/*ww w. j av a2 s . c o m*/ System.out.println("your review : " + splitByPoint[j]); WekaFileGenerator wk = new WekaFileGenerator(splitByPoint[j], pipeline, ra); pw.print(wk.getSentence().getPositiveWords() + "," + wk.getSentence().getNegativeWords() + "," + +wk.getSentence().getSumOfPositiveScore() + "," + wk.getSentence().getSumOfNegativeScore() + "," + wk.getSentence().getStrongPositive() + "," + wk.getSentence().getStrongNegative() + "," + wk.getSentence().getSubjectiveWords() + "," + wk.getSentence().getNeutralWords() + "," + wk.getSentence().getNumOfAdjective() + "," + wk.getSentence().getNumOfAdverb() + ", ? \n"); // System.out.println("here"); // } } pw.close(); DataSource test = new DataSource("test.arff"); Instances testData = test.getDataSet(); testData.setClassIndex(testData.numAttributes() - 1); Classifier j = (Classifier) weka.core.SerializationHelper.read("movieReview.model"); for (int i = 0; i < testData.numInstances(); i++) { Instance inst = testData.instance(i); double predictNum = j.classifyInstance(inst); predictedClass = testData.classAttribute().value((int) predictNum); System.out.println("Class Predicted: " + predictedClass); if (predictedClass.equals("positive")) { positiveSentences++; System.out.println("positiveSentences = " + positiveSentences); if (splitByPoint[i].contains("story")) { story = 1; } if (splitByPoint[i].contains("direction")) { direction = 1; } } else { // positiveSentences--; if (splitByPoint[i].contains("story")) { story = 0; } if (splitByPoint[i].contains("direction")) { direction = 0; } } allSentences++; System.out.println("allSentences = " + allSentences); } DecimalFormat format = new DecimalFormat("#0.000"); rate = (positiveSentences / allSentences) * 100; if (rate != NaN) { rate = Double.parseDouble(format.format(rate)); if (rate > 0 && rate <= 10) { rate = 0.5; } else if (rate > 10 && rate <= 20) { rate = 1.0; } else if (rate > 20 && rate <= 30) { rate = 1.5; } else if (rate > 30 && rate <= 40) { rate = 2; } else if (rate > 40 && rate <= 50) { rate = 2.5; } else if (rate > 50 && rate <= 60) { rate = 3; } else if (rate > 60 && rate <= 70) { rate = 3.5; } else if (rate > 70 && rate <= 80) { rate = 4; } else if (rate > 80 && rate <= 90) { rate = 4.5; } else if (rate > 90 && rate <= 100) { rate = 5; } } System.out.println("rate: " + rate); }
From source file:org.jaqpot.algorithm.resource.WekaMLR.java
License:Open Source License
@POST @Path("prediction") public Response prediction(PredictionRequest request) { try {//ww w . j a va 2s. co m if (request.getDataset().getDataEntry().isEmpty() || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) { return Response.status(Response.Status.BAD_REQUEST).entity( ErrorReportFactory.badRequest("Dataset is empty", "Cannot train model on empty dataset")) .build(); } String base64Model = (String) request.getRawModel(); byte[] modelBytes = Base64.getDecoder().decode(base64Model); ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes); ObjectInput in = new ObjectInputStream(bais); WekaModel model = (WekaModel) in.readObject(); Classifier classifier = model.getClassifier(); Instances data = InstanceUtils.createFromDataset(request.getDataset()); List<String> additionalInfo = (List) request.getAdditionalInfo(); String dependentFeature = additionalInfo.get(0); String dependentFeatureName = additionalInfo.get(1); data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes()); List<LinkedHashMap<String, Object>> predictions = new ArrayList<>(); // data.stream().forEach(instance -> { // try { // double prediction = classifier.classifyInstance(instance); // Map<String, Object> predictionMap = new HashMap<>(); // predictionMap.put("Weka MLR prediction of " + dependentFeature, prediction); // predictions.add(predictionMap); // } catch (Exception ex) { // Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); // } // }); for (int i = 0; i < data.numInstances(); i++) { Instance instance = data.instance(i); try { double prediction = classifier.classifyInstance(instance); LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>(); predictionMap.put("Weka MLR prediction of " + dependentFeatureName, prediction); predictions.add(predictionMap); } catch (Exception ex) { Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.BAD_REQUEST).entity( ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage())) .build(); } } PredictionResponse response = new PredictionResponse(); response.setPredictions(predictions); return Response.ok(response).build(); } catch (Exception ex) { Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build(); } }
From source file:org.jaqpot.algorithm.resource.WekaPLS.java
License:Open Source License
@POST @Path("prediction") public Response prediction(PredictionRequest request) { try {/*from w w w.java2s . c o m*/ if (request.getDataset().getDataEntry().isEmpty() || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) { return Response .status(Response.Status.BAD_REQUEST).entity(ErrorReportFactory .badRequest("Dataset is empty", "Cannot make predictions on empty dataset")) .build(); } String base64Model = (String) request.getRawModel(); byte[] modelBytes = Base64.getDecoder().decode(base64Model); ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes); ObjectInput in = new ObjectInputStream(bais); WekaModel model = (WekaModel) in.readObject(); Classifier classifier = model.getClassifier(); Instances data = InstanceUtils.createFromDataset(request.getDataset()); String dependentFeature = (String) request.getAdditionalInfo(); data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes()); data.setClass(data.attribute(dependentFeature)); List<LinkedHashMap<String, Object>> predictions = new ArrayList<>(); // data.stream().forEach(instance -> { // try { // double prediction = classifier.classifyInstance(instance); // Map<String, Object> predictionMap = new HashMap<>(); // predictionMap.put("Weka PLS prediction of " + dependentFeature, prediction); // predictions.add(predictionMap); // } catch (Exception ex) { // Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex); // } // }); for (int i = 0; i < data.numInstances(); i++) { Instance instance = data.instance(i); try { double prediction = classifier.classifyInstance(instance); LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>(); predictionMap.put("Weka PLS prediction of " + dependentFeature, prediction); predictions.add(predictionMap); } catch (Exception ex) { Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.BAD_REQUEST).entity( ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage())) .build(); } } PredictionResponse response = new PredictionResponse(); response.setPredictions(predictions); return Response.ok(response).build(); } catch (IOException | ClassNotFoundException ex) { Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build(); } }
From source file:org.jaqpot.algorithm.resource.WekaRBF.java
License:Open Source License
@POST @Path("prediction") public Response prediction(PredictionRequest request) { try {/*from w ww. j a v a 2 s.com*/ if (request.getDataset().getDataEntry().isEmpty() || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) { return Response .status(Response.Status.BAD_REQUEST).entity(ErrorReportFactory .badRequest("Dataset is empty", "Cannot make predictions on empty dataset")) .build(); } String base64Model = (String) request.getRawModel(); byte[] modelBytes = Base64.getDecoder().decode(base64Model); ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes); ObjectInput in = new ObjectInputStream(bais); WekaModel model = (WekaModel) in.readObject(); Classifier classifier = model.getClassifier(); Instances data = InstanceUtils.createFromDataset(request.getDataset()); String dependentFeature = (String) request.getAdditionalInfo(); data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes()); data.setClass(data.attribute(dependentFeature)); List<LinkedHashMap<String, Object>> predictions = new ArrayList<>(); // data.stream().forEach(instance -> { // try { // double prediction = classifier.classifyInstance(instance); // Map<String, Object> predictionMap = new HashMap<>(); // predictionMap.put("Weka MLR prediction of " + dependentFeature, prediction); // predictions.add(predictionMap); // } catch (Exception ex) { // Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); // } // }); for (int i = 0; i < data.numInstances(); i++) { Instance instance = data.instance(i); try { double prediction = classifier.classifyInstance(instance); LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>(); predictionMap.put("Weka RBF prediction of " + dependentFeature, prediction); predictions.add(predictionMap); } catch (Exception ex) { Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.BAD_REQUEST).entity( ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage())) .build(); } } PredictionResponse response = new PredictionResponse(); response.setPredictions(predictions); return Response.ok(response).build(); } catch (IOException | ClassNotFoundException ex) { Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build(); } }
From source file:org.jaqpot.algorithm.resource.WekaSVM.java
License:Open Source License
@POST @Path("prediction") public Response prediction(PredictionRequest request) { try {/*from w ww. jav a 2 s . c o m*/ if (request.getDataset().getDataEntry().isEmpty() || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) { return Response .status(Response.Status.BAD_REQUEST).entity(ErrorReportFactory .badRequest("Dataset is empty", "Cannot make predictions on empty dataset")) .build(); } String base64Model = (String) request.getRawModel(); byte[] modelBytes = Base64.getDecoder().decode(base64Model); ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes); ObjectInput in = new ObjectInputStream(bais); WekaModel model = (WekaModel) in.readObject(); Classifier classifier = model.getClassifier(); Instances data = InstanceUtils.createFromDataset(request.getDataset()); String dependentFeature = (String) request.getAdditionalInfo(); data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes()); data.setClass(data.attribute(dependentFeature)); List<LinkedHashMap<String, Object>> predictions = new ArrayList<>(); // data.stream().forEach(instance -> { // try { // double prediction = classifier.classifyInstance(instance); // Map<String, Object> predictionMap = new HashMap<>(); // predictionMap.put("Weka SVM prediction of " + dependentFeature, prediction); // predictions.add(predictionMap); // } catch (Exception ex) { // Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex); // } // }); for (int i = 0; i < data.numInstances(); i++) { Instance instance = data.instance(i); try { double prediction = classifier.classifyInstance(instance); LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>(); predictionMap.put("Weka SVM prediction of " + dependentFeature, prediction); predictions.add(predictionMap); } catch (Exception ex) { Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.BAD_REQUEST).entity( ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage())) .build(); } } PredictionResponse response = new PredictionResponse(); response.setPredictions(predictions); return Response.ok(response).build(); } catch (Exception ex) { Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build(); } }