Java tutorial
/* * * JAQPOT Quattro * * JAQPOT Quattro and the components shipped with it (web applications and beans) * are licensed by GPL v3 as specified hereafter. Additional components may ship * with some other licence as will be specified therein. * * Copyright (C) 2014-2015 KinkyDesign (Charalampos Chomenidis, Pantelis Sopasakis) * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. * * Source code: * The source code of JAQPOT Quattro is available on github at: * https://github.com/KinkyDesign/JaqpotQuattro * All source files of JAQPOT Quattro that are stored on github are licensed * with the aforementioned licence. */ package org.jaqpot.algorithm.resource; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.ObjectInput; import java.io.ObjectInputStream; import java.io.ObjectOutput; import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; import javax.ws.rs.Consumes; import javax.ws.rs.POST; import javax.ws.rs.Path; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.jaqpot.algorithm.model.WekaModel; import org.jaqpot.algorithm.pmml.PmmlUtils; import org.jaqpot.algorithm.weka.InstanceUtils; import org.jaqpot.core.model.dto.jpdi.PredictionRequest; import org.jaqpot.core.model.dto.jpdi.PredictionResponse; import org.jaqpot.core.model.dto.jpdi.TrainingRequest; import org.jaqpot.core.model.dto.jpdi.TrainingResponse; import org.jaqpot.core.model.factory.ErrorReportFactory; import weka.classifiers.Classifier; import weka.classifiers.functions.LinearRegression; import weka.core.Attribute; import weka.core.Instance; import weka.core.Instances; /** * * @author Charalampos Chomenidis * @author Pantelis Sopasakis */ @Path("mlr") @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public class WekaMLR { @POST @Path("training") public Response training(TrainingRequest request) { try { if (request.getDataset().getDataEntry().isEmpty() || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) { return Response.status(Response.Status.BAD_REQUEST).entity( ErrorReportFactory.badRequest("Dataset is empty", "Cannot train model on empty dataset")) .build(); } List<String> features = request.getDataset().getDataEntry().stream().findFirst().get().getValues() .keySet().stream().collect(Collectors.toList()); Instances data = InstanceUtils.createFromDataset(request.getDataset(), request.getPredictionFeature()); LinearRegression linreg = new LinearRegression(); String[] linRegOptions = { "-S", "1", "-C" }; linreg.setOptions(linRegOptions); linreg.buildClassifier(data); WekaModel model = new WekaModel(); model.setClassifier(linreg); String pmml = PmmlUtils.createRegressionModel(features, request.getPredictionFeature(), linreg.coefficients(), "MLR"); TrainingResponse response = new TrainingResponse(); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutput out = new ObjectOutputStream(baos); out.writeObject(model); String base64Model = Base64.getEncoder().encodeToString(baos.toByteArray()); response.setRawModel(base64Model); List<String> independentFeatures = features.stream() .filter(feature -> !feature.equals(request.getPredictionFeature())) .collect(Collectors.toList()); response.setIndependentFeatures(independentFeatures); response.setPmmlModel(pmml); String predictionFeatureName = request.getDataset().getFeatures().stream() .filter(f -> f.getURI().equals(request.getPredictionFeature())).findFirst().get().getName(); response.setAdditionalInfo(Arrays.asList(request.getPredictionFeature(), predictionFeatureName)); response.setPredictedFeatures(Arrays.asList("Weka MLR prediction of " + predictionFeatureName)); return Response.ok(response).build(); } catch (Exception ex) { Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build(); } } @POST @Path("prediction") public Response prediction(PredictionRequest request) { try { if (request.getDataset().getDataEntry().isEmpty() || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) { return Response.status(Response.Status.BAD_REQUEST).entity( ErrorReportFactory.badRequest("Dataset is empty", "Cannot train model on empty dataset")) .build(); } String base64Model = (String) request.getRawModel(); byte[] modelBytes = Base64.getDecoder().decode(base64Model); ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes); ObjectInput in = new ObjectInputStream(bais); WekaModel model = (WekaModel) in.readObject(); Classifier classifier = model.getClassifier(); Instances data = InstanceUtils.createFromDataset(request.getDataset()); List<String> additionalInfo = (List) request.getAdditionalInfo(); String dependentFeature = additionalInfo.get(0); String dependentFeatureName = additionalInfo.get(1); data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes()); List<LinkedHashMap<String, Object>> predictions = new ArrayList<>(); // data.stream().forEach(instance -> { // try { // double prediction = classifier.classifyInstance(instance); // Map<String, Object> predictionMap = new HashMap<>(); // predictionMap.put("Weka MLR prediction of " + dependentFeature, prediction); // predictions.add(predictionMap); // } catch (Exception ex) { // Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); // } // }); for (int i = 0; i < data.numInstances(); i++) { Instance instance = data.instance(i); try { double prediction = classifier.classifyInstance(instance); LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>(); predictionMap.put("Weka MLR prediction of " + dependentFeatureName, prediction); predictions.add(predictionMap); } catch (Exception ex) { Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.BAD_REQUEST).entity( ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage())) .build(); } } PredictionResponse response = new PredictionResponse(); response.setPredictions(predictions); return Response.ok(response).build(); } catch (Exception ex) { Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build(); } } }