Java tutorial
/* * * Jaqpot - version 3 * * The JAQPOT-3 web services are OpenTox API-1.2 compliant web services. Jaqpot * is a web application that supports model training and data preprocessing algorithms * such as multiple linear regression, support vector machines, neural networks * (an in-house implementation based on an efficient algorithm), an implementation * of the leverage algorithm for domain of applicability estimation and various * data preprocessing algorithms like PLS and data cleanup. * * Copyright (C) 2009-2012 Pantelis Sopasakis & Charalampos Chomenides * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. * * Contact: * Pantelis Sopasakis * chvng@mail.ntua.gr * Address: Iroon Politechniou St. 9, Zografou, Athens Greece * tel. +30 210 7723236 * */ package org.opentox.jaqpot3.qsar.trainer; import java.io.NotSerializableException; import java.net.URISyntaxException; import java.util.ArrayList; import static java.util.Arrays.asList; import java.util.List; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; import org.opentox.jaqpot3.exception.JaqpotException; import org.opentox.jaqpot3.qsar.AbstractTrainer; import org.opentox.jaqpot3.qsar.IClientInput; import org.opentox.jaqpot3.qsar.ITrainer; import org.opentox.jaqpot3.qsar.InstancesUtil; import org.opentox.jaqpot3.qsar.exceptions.BadParameterException; import org.opentox.jaqpot3.qsar.exceptions.QSARException; import org.opentox.jaqpot3.resources.collections.Algorithms; import org.opentox.jaqpot3.util.Configuration; import org.opentox.toxotis.client.VRI; import org.opentox.toxotis.client.collection.Services; import org.opentox.toxotis.core.component.ActualModel; import org.opentox.toxotis.core.component.Algorithm; import org.opentox.toxotis.core.component.Feature; import org.opentox.toxotis.core.component.Model; import org.opentox.toxotis.database.engine.task.UpdateTask; import org.opentox.toxotis.database.exception.DbException; import org.opentox.toxotis.exceptions.impl.ServiceInvocationException; import org.opentox.toxotis.ontology.LiteralValue; import weka.classifiers.Evaluation; import weka.core.Attribute; import weka.core.Instances; import weka.classifiers.functions.LinearRegression; /** * * @author Pantelis Sopasakis * @author Charalampos Chomenides */ public class MlrRegression extends AbstractTrainer { private VRI targetUri; private VRI datasetUri; private VRI featureService; private org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MlrRegression.class); @Override protected boolean keepNumeric() { return true; } @Override protected boolean keepNominal() { return true; } @Override protected boolean keepString() { return false; } @Override protected boolean pmmlSupported() { return true; } @Override protected boolean scalingSupported() { return true; } @Override protected boolean normalizationSupported() { return true; } @Override protected boolean DoASupported() { return true; } @Override protected boolean performMVH() { return true; } public MlrRegression() { } @Override public ITrainer doParametrize(IClientInput clientParameters) throws BadParameterException { String targetString = clientParameters.getFirstValue("prediction_feature"); if (targetString == null) { throw new BadParameterException("The parameter 'prediction_feature' is mandatory for this algorithm."); } try { targetUri = new VRI(targetString); } catch (URISyntaxException ex) { throw new BadParameterException("The parameter 'prediction_feature' you provided is not a valid URI.", ex); } String datasetUriString = clientParameters.getFirstValue("dataset_uri"); if (datasetUriString == null) { throw new BadParameterException("The parameter 'dataset_uri' is mandatory for this algorithm."); } try { datasetUri = new VRI(datasetUriString); } catch (URISyntaxException ex) { throw new BadParameterException("The parameter 'dataset_uri' you provided is not a valid URI.", ex); } String featureServiceString = clientParameters.getFirstValue("feature_service"); if (featureServiceString != null) { try { featureService = new VRI(featureServiceString); } catch (URISyntaxException ex) { throw new BadParameterException("The parameter 'feature_service' you provided is not a valid URI.", ex); } } else { featureService = Services.ideaconsult().augment("feature"); } return this; } @Override public Algorithm getAlgorithm() { return Algorithms.mlr(); } @Override public Model train(Instances data) throws JaqpotException { try { getTask().getMeta().addComment( "Dataset successfully retrieved and converted " + "into a weka.core.Instances object"); UpdateTask firstTaskUpdater = new UpdateTask(getTask()); firstTaskUpdater.setUpdateMeta(true); firstTaskUpdater.setUpdateTaskStatus(true);//TODO: Is this necessary? try { firstTaskUpdater.update(); } catch (DbException ex) { throw new JaqpotException(ex); } finally { try { firstTaskUpdater.close(); } catch (DbException ex) { throw new JaqpotException(ex); } } Instances trainingSet = data; getTask().getMeta().addComment("The downloaded dataset is now preprocessed"); firstTaskUpdater = new UpdateTask(getTask()); firstTaskUpdater.setUpdateMeta(true); firstTaskUpdater.setUpdateTaskStatus(true);//TODO: Is this necessary? try { firstTaskUpdater.update(); } catch (DbException ex) { throw new JaqpotException(ex); } finally { try { firstTaskUpdater.close(); } catch (DbException ex) { throw new JaqpotException(ex); } } /* SET CLASS ATTRIBUTE */ Attribute target = trainingSet.attribute(targetUri.toString()); if (target == null) { throw new BadParameterException("The prediction feature you provided was not found in the dataset"); } else { if (!target.isNumeric()) { throw new QSARException("The prediction feature you provided is not numeric."); } } trainingSet.setClass(target); /* Very important: place the target feature at the end! (target = last)*/ int numAttributes = trainingSet.numAttributes(); int classIndex = trainingSet.classIndex(); Instances orderedTrainingSet = null; List<String> properOrder = new ArrayList<String>(numAttributes); for (int j = 0; j < numAttributes; j++) { if (j != classIndex) { properOrder.add(trainingSet.attribute(j).name()); } } properOrder.add(trainingSet.attribute(classIndex).name()); try { orderedTrainingSet = InstancesUtil.sortByFeatureAttrList(properOrder, trainingSet, -1); } catch (JaqpotException ex) { logger.error("Improper dataset - training will stop", ex); throw ex; } orderedTrainingSet.setClass(orderedTrainingSet.attribute(targetUri.toString())); /* START CONSTRUCTION OF MODEL */ Model m = new Model(Configuration.getBaseUri().augment("model", getUuid().toString())); m.setAlgorithm(getAlgorithm()); m.setCreatedBy(getTask().getCreatedBy()); m.setDataset(datasetUri); m.addDependentFeatures(dependentFeature); try { dependentFeature.loadFromRemote(); } catch (ServiceInvocationException ex) { Logger.getLogger(MlrRegression.class.getName()).log(Level.SEVERE, null, ex); } Set<LiteralValue> depFeatTitles = null; if (dependentFeature.getMeta() != null) { depFeatTitles = dependentFeature.getMeta().getTitles(); } String depFeatTitle = dependentFeature.getUri().toString(); if (depFeatTitles != null) { depFeatTitle = depFeatTitles.iterator().next().getValueAsString(); m.getMeta().addTitle("MLR model for " + depFeatTitle) .addDescription("MLR model for the prediction of " + depFeatTitle + " (uri: " + dependentFeature.getUri() + " )."); } else { m.getMeta().addTitle("MLR model for the prediction of the feature with URI " + depFeatTitle) .addComment("No name was found for the feature " + depFeatTitle); } /* * COMPILE THE LIST OF INDEPENDENT FEATURES with the exact order in which * these appear in the Instances object (training set). */ m.setIndependentFeatures(independentFeatures); /* CREATE PREDICTED FEATURE AND POST IT TO REMOTE SERVER */ String predictionFeatureUri = null; Feature predictedFeature = publishFeature(m, dependentFeature.getUnits(), "Predicted " + depFeatTitle + " by MLR model", datasetUri, featureService); m.addPredictedFeatures(predictedFeature); predictionFeatureUri = predictedFeature.getUri().toString(); getTask().getMeta().addComment("Prediction feature " + predictionFeatureUri + " was created."); firstTaskUpdater = new UpdateTask(getTask()); firstTaskUpdater.setUpdateMeta(true); firstTaskUpdater.setUpdateTaskStatus(true);//TODO: Is this necessary? try { firstTaskUpdater.update(); } catch (DbException ex) { throw new JaqpotException(ex); } finally { try { firstTaskUpdater.close(); } catch (DbException ex) { throw new JaqpotException(ex); } } /* ACTUAL TRAINING OF THE MODEL USING WEKA */ LinearRegression linreg = new LinearRegression(); String[] linRegOptions = { "-S", "1", "-C" }; try { linreg.setOptions(linRegOptions); linreg.buildClassifier(orderedTrainingSet); } catch (final Exception ex) {// illegal options or could not build the classifier! String message = "MLR Model could not be trained"; logger.error(message, ex); throw new JaqpotException(message, ex); } try { // evaluate classifier and print some statistics Evaluation eval = new Evaluation(orderedTrainingSet); eval.evaluateModel(linreg, orderedTrainingSet); String stats = eval.toSummaryString("\nResults\n======\n", false); ActualModel am = new ActualModel(linreg); am.setStatistics(stats); m.setActualModel(am); } catch (NotSerializableException ex) { String message = "Model is not serializable"; logger.error(message, ex); throw new JaqpotException(message, ex); } catch (final Exception ex) {// illegal options or could not build the classifier! String message = "MLR Model could not be trained"; logger.error(message, ex); throw new JaqpotException(message, ex); } m.getMeta().addPublisher("OpenTox").addComment("This is a Multiple Linear Regression Model"); //save the instances being predicted to abstract trainer for calculating DoA predictedInstances = orderedTrainingSet; excludeAttributesDoA.add(dependentFeature.getUri().toString()); return m; } catch (QSARException ex) { String message = "QSAR Exception: cannot train MLR model"; logger.error(message, ex); throw new JaqpotException(message, ex); } } }