org.jaqpot.algorithm.resource.WekaMLR.java Source code

Java tutorial

Introduction

Here is the source code for org.jaqpot.algorithm.resource.WekaMLR.java

Source

/*
 *
 * JAQPOT Quattro
 *
 * JAQPOT Quattro and the components shipped with it (web applications and beans)
 * are licensed by GPL v3 as specified hereafter. Additional components may ship
 * with some other licence as will be specified therein.
 *
 * Copyright (C) 2014-2015 KinkyDesign (Charalampos Chomenidis, Pantelis Sopasakis)
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * Source code:
 * The source code of JAQPOT Quattro is available on github at:
 * https://github.com/KinkyDesign/JaqpotQuattro
 * All source files of JAQPOT Quattro that are stored on github are licensed
 * with the aforementioned licence. 
 */
package org.jaqpot.algorithm.resource;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.ObjectInput;
import java.io.ObjectInputStream;
import java.io.ObjectOutput;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import javax.ws.rs.Consumes;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.jaqpot.algorithm.model.WekaModel;
import org.jaqpot.algorithm.pmml.PmmlUtils;
import org.jaqpot.algorithm.weka.InstanceUtils;
import org.jaqpot.core.model.dto.jpdi.PredictionRequest;
import org.jaqpot.core.model.dto.jpdi.PredictionResponse;
import org.jaqpot.core.model.dto.jpdi.TrainingRequest;
import org.jaqpot.core.model.dto.jpdi.TrainingResponse;
import org.jaqpot.core.model.factory.ErrorReportFactory;
import weka.classifiers.Classifier;
import weka.classifiers.functions.LinearRegression;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;

/**
 *
 * @author Charalampos Chomenidis
 * @author Pantelis Sopasakis
 */
@Path("mlr")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public class WekaMLR {

    @POST
    @Path("training")
    public Response training(TrainingRequest request) {

        try {
            if (request.getDataset().getDataEntry().isEmpty()
                    || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
                return Response.status(Response.Status.BAD_REQUEST).entity(
                        ErrorReportFactory.badRequest("Dataset is empty", "Cannot train model on empty dataset"))
                        .build();
            }
            List<String> features = request.getDataset().getDataEntry().stream().findFirst().get().getValues()
                    .keySet().stream().collect(Collectors.toList());

            Instances data = InstanceUtils.createFromDataset(request.getDataset(), request.getPredictionFeature());

            LinearRegression linreg = new LinearRegression();
            String[] linRegOptions = { "-S", "1", "-C" };
            linreg.setOptions(linRegOptions);
            linreg.buildClassifier(data);

            WekaModel model = new WekaModel();
            model.setClassifier(linreg);

            String pmml = PmmlUtils.createRegressionModel(features, request.getPredictionFeature(),
                    linreg.coefficients(), "MLR");

            TrainingResponse response = new TrainingResponse();
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ObjectOutput out = new ObjectOutputStream(baos);
            out.writeObject(model);
            String base64Model = Base64.getEncoder().encodeToString(baos.toByteArray());
            response.setRawModel(base64Model);
            List<String> independentFeatures = features.stream()
                    .filter(feature -> !feature.equals(request.getPredictionFeature()))
                    .collect(Collectors.toList());
            response.setIndependentFeatures(independentFeatures);
            response.setPmmlModel(pmml);
            String predictionFeatureName = request.getDataset().getFeatures().stream()
                    .filter(f -> f.getURI().equals(request.getPredictionFeature())).findFirst().get().getName();
            response.setAdditionalInfo(Arrays.asList(request.getPredictionFeature(), predictionFeatureName));

            response.setPredictedFeatures(Arrays.asList("Weka MLR prediction of " + predictionFeatureName));

            return Response.ok(response).build();
        } catch (Exception ex) {
            Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
            return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
        }
    }

    @POST
    @Path("prediction")
    public Response prediction(PredictionRequest request) {

        try {
            if (request.getDataset().getDataEntry().isEmpty()
                    || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
                return Response.status(Response.Status.BAD_REQUEST).entity(
                        ErrorReportFactory.badRequest("Dataset is empty", "Cannot train model on empty dataset"))
                        .build();
            }

            String base64Model = (String) request.getRawModel();
            byte[] modelBytes = Base64.getDecoder().decode(base64Model);
            ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes);
            ObjectInput in = new ObjectInputStream(bais);
            WekaModel model = (WekaModel) in.readObject();

            Classifier classifier = model.getClassifier();
            Instances data = InstanceUtils.createFromDataset(request.getDataset());
            List<String> additionalInfo = (List) request.getAdditionalInfo();
            String dependentFeature = additionalInfo.get(0);
            String dependentFeatureName = additionalInfo.get(1);
            data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes());
            List<LinkedHashMap<String, Object>> predictions = new ArrayList<>();
            //            data.stream().forEach(instance -> {
            //                try {
            //                    double prediction = classifier.classifyInstance(instance);
            //                    Map<String, Object> predictionMap = new HashMap<>();
            //                    predictionMap.put("Weka MLR prediction of " + dependentFeature, prediction);
            //                    predictions.add(predictionMap);
            //                } catch (Exception ex) {
            //                    Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
            //                }
            //            });

            for (int i = 0; i < data.numInstances(); i++) {
                Instance instance = data.instance(i);
                try {
                    double prediction = classifier.classifyInstance(instance);
                    LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>();
                    predictionMap.put("Weka MLR prediction of " + dependentFeatureName, prediction);
                    predictions.add(predictionMap);
                } catch (Exception ex) {
                    Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
                    return Response.status(Response.Status.BAD_REQUEST).entity(
                            ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage()))
                            .build();
                }
            }

            PredictionResponse response = new PredictionResponse();
            response.setPredictions(predictions);
            return Response.ok(response).build();
        } catch (Exception ex) {
            Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
            return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
        }
    }

}