List of usage examples for weka.core Instances setClass
public void setClass(Attribute att)
From source file:org.jaqpot.algorithms.resource.WekaSVM.java
License:Open Source License
@POST @Path("prediction") public Response prediction(PredictionRequest request) { try {//from w w w . jav a 2s .c o m if (request.getDataset().getDataEntry().isEmpty() || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) { return Response.status(Response.Status.BAD_REQUEST) .entity("Dataset is empty. Cannot make predictions on empty dataset.").build(); } String base64Model = (String) request.getRawModel(); byte[] modelBytes = Base64.getDecoder().decode(base64Model); ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes); ObjectInput in = new ObjectInputStream(bais); WekaModel model = (WekaModel) in.readObject(); Classifier classifier = model.getClassifier(); Instances data = InstanceUtils.createFromDataset(request.getDataset()); String dependentFeature = (String) request.getAdditionalInfo(); data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes()); data.setClass(data.attribute(dependentFeature)); List<LinkedHashMap<String, Object>> predictions = new ArrayList<>(); // data.stream().forEach(instance -> { // try { // double prediction = classifier.classifyInstance(instance); // Map<String, Object> predictionMap = new HashMap<>(); // predictionMap.put("Weka SVM prediction of " + dependentFeature, prediction); // predictions.add(predictionMap); // } catch (Exception ex) { // Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex); // } // }); for (int i = 0; i < data.numInstances(); i++) { Instance instance = data.instance(i); try { double prediction = classifier.classifyInstance(instance); LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>(); predictionMap.put("Weka SVM prediction of " + dependentFeature, prediction); predictions.add(predictionMap); } catch (Exception ex) { Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.BAD_REQUEST) .entity("Error while gettting predictions. " + ex.getMessage()).build(); } } PredictionResponse response = new PredictionResponse(); response.setPredictions(predictions); return Response.ok(response).build(); } catch (Exception ex) { Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex); return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build(); } }
From source file:org.knime.knip.suise.node.boundarymodel.contourdata.WekaContourDataClassifier.java
License:Open Source License
private Instances initDataset(int numFeatures, int numClasses, int capacity) { ArrayList<Attribute> attributes = new ArrayList<Attribute>(numFeatures + 5); for (int i = 0; i < numFeatures; i++) { attributes.add(new Attribute("att" + i)); }// ww w. j a v a 2 s . c o m ArrayList<String> classNames = new ArrayList<String>(); for (int i = 0; i < numClasses; i++) { classNames.add("class" + i); } Attribute classAtt = new Attribute("class", classNames); attributes.add(classAtt); Instances res = new Instances("trainingData", attributes, capacity); res.setClass(classAtt); return res; }
From source file:org.mcennis.graphrat.algorithm.machinelearning.WekaClassifierMultiAttribute.java
License:Open Source License
@Override public void execute(Graph g) { Actor[] source = g.getActor((String) parameter[1].getValue()); if (source != null) { // create the atributes for each artist FastVector sourceTypes = new FastVector(); Actor[] dest = g.getActor((String) parameter[3].getValue()); if (dest != null) { // create the Instances set backing this object Instances masterSet = null; Instance[] trainingData = new Instance[source.length]; for (int i = 0; i < source.length; ++i) { // First, acquire the instance objects for each actor Property p = null;//from w w w . j a va 2 s . com if ((Boolean) parameter[10].getValue()) { p = source[i].getProperty((String) parameter[2].getValue() + g.getID()); } else { p = source[i].getProperty((String) parameter[2].getValue()); } if (p != null) { Object[] values = p.getValue(); if (values.length > 0) { sourceTypes.addElement(source[i].getID()); trainingData[i] = (Instance) ((Instance) values[0]).copy(); // assume that this Instance has a backing dataset // that contains all Instance objects to be tested if (masterSet == null) { masterSet = new Instances(trainingData[i].dataset(), source.length); } masterSet.add(trainingData[i]); sourceTypes.addElement(source[i].getID()); } else { trainingData[i] = null; Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.WARNING, "Actor " + source[i].getType() + ":" + source[i].getID() + " does not have an Instance value of property ID " + p.getType()); } } else { trainingData[i] = null; Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.WARNING, "Actor " + source[i].getType() + ":" + source[i].getID() + " does not have a property of ID " + p.getType()); } } Vector<Attribute> destVector = new Vector<Attribute>(); for (int i = 0; i < dest.length; ++i) { FastVector type = new FastVector(); type.addElement("false"); type.addElement("true"); Attribute tmp = new Attribute(dest[i].getID(), type); destVector.add(tmp); masterSet.insertAttributeAt(tmp, masterSet.numAttributes()); } Attribute sourceID = new Attribute("sourceID", sourceTypes); masterSet.insertAttributeAt(sourceID, masterSet.numAttributes()); //set ground truth for evaluation for (int i = 0; i < masterSet.numInstances(); ++i) { Instance inst = masterSet.instance(i); Actor user = g.getActor((String) parameter[i].getValue(), sourceID.value((int) inst.value(sourceID))); if (user != null) { for (int j = 0; j < dest.length; ++j) { if (g.getLink((String) parameter[4].getValue(), user, dest[j]) != null) { inst.setValue(sourceID, "true"); } else { if ((Boolean) parameter[9].getValue()) { inst.setValue(sourceID, "false"); } else { inst.setValue(sourceID, Double.NaN); } } } } else { Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.SEVERE, "Actor " + sourceID.value((int) inst.value(sourceID)) + " does not exist in graph"); } } // perform cross fold evaluation of each classifier in turn String[] opts = ((String) parameter[9].getValue()).split("\\s+"); Properties props = new Properties(); if ((Boolean) parameter[11].getValue()) { props.setProperty("LinkType", (String) parameter[5].getValue() + g.getID()); } else { props.setProperty("LinkType", (String) parameter[5].getValue()); } props.setProperty("LinkClass", "Basic"); try { for (int destCount = 0; destCount < dest.length; ++destCount) { masterSet.setClass(destVector.get(destCount)); for (int i = 0; i < (Integer) parameter[8].getValue(); ++i) { Instances test = masterSet.testCV((Integer) parameter[8].getValue(), i); Instances train = masterSet.testCV((Integer) parameter[8].getValue(), i); Classifier classifier = (Classifier) ((Class) parameter[7].getValue()).newInstance(); classifier.setOptions(opts); classifier.buildClassifier(train); for (int j = 0; j < test.numInstances(); ++j) { String sourceName = sourceID.value((int) test.instance(j).value(sourceID)); double result = classifier.classifyInstance(test.instance(j)); String predicted = masterSet.classAttribute().value((int) result); Link derived = LinkFactory.newInstance().create(props); derived.set(g.getActor((String) parameter[2].getValue(), sourceName), 1.0, g.getActor((String) parameter[3].getValue(), predicted)); g.add(derived); } } } } catch (InstantiationException ex) { Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.SEVERE, null, ex); } catch (IllegalAccessException ex) { Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.SEVERE, null, ex); } } else { // dest==null Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.WARNING, "Ground truth mode '" + (String) parameter[3].getValue() + "' has no actors"); } } else { // source==null Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.WARNING, "Source mode '" + (String) parameter[2].getValue() + "' has no actors"); } }
From source file:org.mcennis.graphrat.algorithm.machinelearning.WekaClassifierOneAttribute.java
License:Open Source License
@Override public void execute(Graph g) { Actor[] source = g.getActor((String) parameter[1].getValue()); if (source != null) { // create the Instance sets for each ac FastVector classTypes = new FastVector(); FastVector sourceTypes = new FastVector(); Actor[] dest = g.getActor((String) parameter[3].getValue()); if (dest != null) { for (int i = 0; i < dest.length; ++i) { classTypes.addElement(dest[i].getID()); }/* ww w .j ava 2 s.com*/ Attribute classAttribute = new Attribute((String) parameter[5].getValue(), classTypes); Instance[] trainingData = new Instance[source.length]; Instances masterSet = null; for (int i = 0; i < source.length; ++i) { // First, acquire the instance objects for each actor Property p = null; if ((Boolean) parameter[9].getValue()) { p = source[i].getProperty((String) parameter[2].getValue() + g.getID()); } else { p = source[i].getProperty((String) parameter[2].getValue()); } if (p != null) { Object[] values = p.getValue(); if (values.length > 0) { sourceTypes.addElement(source[i].getID()); trainingData[i] = (Instance) ((Instance) values[0]).copy(); // assume that this Instance has a backing dataset // that contains all Instance objects to be tested if (masterSet == null) { masterSet = new Instances(trainingData[i].dataset(), source.length); } masterSet.add(trainingData[i]); } else { trainingData[i] = null; Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.WARNING, "Actor " + source[i].getType() + ":" + source[i].getID() + " does not have an Instance value of property ID " + p.getType()); } } else { trainingData[i] = null; Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.WARNING, "Actor " + source[i].getType() + ":" + source[i].getID() + " does not have a property of ID " + p.getType()); } } // for every actor, fix the instance Attribute sourceID = new Attribute("sourceID", sourceTypes); masterSet.insertAttributeAt(sourceID, masterSet.numAttributes()); masterSet.insertAttributeAt(classAttribute, masterSet.numAttributes()); masterSet.setClass(classAttribute); for (int i = 0; i < source.length; ++i) { if (trainingData[i] != null) { trainingData[i].setValue(sourceID, source[i].getID()); Link[] link = g.getLinkBySource((String) parameter[4].getValue(), source[i]); if (link == null) { trainingData[i].setClassValue(Double.NaN); } else { trainingData[i].setClassValue(link[0].getDestination().getID()); } } } String[] opts = ((String) parameter[7].getValue()).split("\\s+"); Properties props = new Properties(); if ((Boolean) parameter[10].getValue()) { props.setProperty("LinkType", (String) parameter[5].getValue() + g.getID()); } else { props.setProperty("LinkType", (String) parameter[5].getValue()); } props.setProperty("LinkClass", "Basic"); try { for (int i = 0; i < (Integer) parameter[8].getValue(); ++i) { Instances test = masterSet.testCV((Integer) parameter[8].getValue(), i); Instances train = masterSet.testCV((Integer) parameter[8].getValue(), i); Classifier classifier = (Classifier) ((Class) parameter[6].getValue()).newInstance(); classifier.setOptions(opts); classifier.buildClassifier(train); for (int j = 0; j < test.numInstances(); ++j) { String sourceName = sourceID.value((int) test.instance(j).value(sourceID)); double result = classifier.classifyInstance(test.instance(j)); String predicted = masterSet.classAttribute().value((int) result); Link derived = LinkFactory.newInstance().create(props); derived.set(g.getActor((String) parameter[2].getValue(), sourceName), 1.0, g.getActor((String) parameter[3].getValue(), predicted)); g.add(derived); } } } catch (InstantiationException ex) { Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.SEVERE, null, ex); } catch (IllegalAccessException ex) { Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.SEVERE, null, ex); } } else { // dest==null Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.WARNING, "Ground truth mode '" + (String) parameter[3].getValue() + "' has no actors"); } } else { // source==null Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.WARNING, "Source mode '" + (String) parameter[2].getValue() + "' has no actors"); } }
From source file:org.openml.webapplication.features.ExtractFeatures.java
License:Open Source License
public static List<Feature> getFeatures(Instances dataset, String defaultClass) { if (defaultClass != null) { dataset.setClass(dataset.attribute(defaultClass)); } else {/*from www. j a va2 s . co m*/ dataset.setClassIndex(dataset.numAttributes() - 1); } final ArrayList<Feature> resultFeatures = new ArrayList<Feature>(); for (int i = 0; i < dataset.numAttributes(); i++) { Attribute att = dataset.attribute(i); int numValues = dataset.classAttribute().isNominal() ? dataset.classAttribute().numValues() : 0; AttributeStatistics attributeStats = new AttributeStatistics(dataset.attribute(i), numValues); for (int j = 0; j < dataset.numInstances(); ++j) { attributeStats.addValue(dataset.get(j).value(i), dataset.get(j).classValue()); } String data_type = null; Integer numberOfDistinctValues = null; Integer numberOfUniqueValues = null; Integer numberOfMissingValues = null; Integer numberOfIntegerValues = null; Integer numberOfRealValues = null; Integer numberOfNominalValues = null; Integer numberOfValues = null; Double maximumValue = null; Double minimumValue = null; Double meanValue = null; Double standardDeviation = null; AttributeStats as = dataset.attributeStats(i); numberOfDistinctValues = as.distinctCount; numberOfUniqueValues = as.uniqueCount; numberOfMissingValues = as.missingCount; numberOfIntegerValues = as.intCount; numberOfRealValues = as.realCount; numberOfMissingValues = as.missingCount; if (att.isNominal()) { numberOfNominalValues = att.numValues(); } numberOfValues = attributeStats.getTotalObservations(); if (att.isNumeric()) { maximumValue = attributeStats.getMaximum(); minimumValue = attributeStats.getMinimum(); meanValue = attributeStats.getMean(); standardDeviation = 0.0; try { standardDeviation = attributeStats.getStandardDeviation(); } catch (Exception e) { Conversion.log("WARNING", "StdDev", "Could not compute standard deviation of feature " + att.name() + ": " + e.getMessage()); } } if (att.type() == 0) { data_type = "numeric"; } else if (att.type() == 1) { data_type = "nominal"; } else if (att.type() == 2) { data_type = "string"; } else { data_type = "unknown"; } resultFeatures.add(new Feature(att.index(), att.name(), data_type, att.index() == dataset.classIndex(), numberOfDistinctValues, numberOfUniqueValues, numberOfMissingValues, numberOfIntegerValues, numberOfRealValues, numberOfNominalValues, numberOfValues, maximumValue, minimumValue, meanValue, standardDeviation, attributeStats.getClassDistribution())); } return resultFeatures; }
From source file:org.openml.webapplication.features.ExtractFeatures.java
License:Open Source License
public static List<Quality> getQualities(Instances dataset, String defaultClass) { if (defaultClass != null) { dataset.setClass(dataset.attribute(defaultClass)); } else {/*from www . ja v a2 s .co m*/ dataset.setClassIndex(dataset.numAttributes() - 1); } List<Quality> result = new ArrayList<Quality>(); Characterizer simpleQualities = new SimpleMetaFeatures(); Map<String, Double> qualities = simpleQualities.characterize(dataset); for (String quality : qualities.keySet()) { result.add(new Quality(quality, qualities.get(quality) + "")); } return result; }
From source file:org.openml.webapplication.features.FantailConnector.java
License:Open Source License
private boolean extractFeatures(Integer did, Integer interval_size) throws Exception { Conversion.log("OK", "Extract Features", "Start extracting features for dataset: " + did); List<String> qualitiesAvailable = Arrays.asList(apiconnector.dataQualities(did).getQualityNames()); // TODO: initialize this properly!!!!!! streamCharacterizers = new StreamCharacterizer[1]; streamCharacterizers[0] = new ChangeDetectors(interval_size); DataSetDescription dsd = apiconnector.dataGet(did); Conversion.log("OK", "Extract Features", "Start downloading dataset: " + did); Instances dataset = new Instances(new FileReader(dsd.getDataset(apiconnector.getApiKey()))); dataset.setClass(dataset.attribute(dsd.getDefault_target_attribute())); if (dsd.getRow_id_attribute() != null) { if (dataset.attribute(dsd.getRow_id_attribute()) != null) { dataset.deleteAttributeAt(dataset.attribute(dsd.getRow_id_attribute()).index()); }/*from w w w . ja v a 2 s . c o m*/ } if (dsd.getIgnore_attribute() != null) { for (String att : dsd.getIgnore_attribute()) { if (dataset.attribute(att) != null) { dataset.deleteAttributeAt(dataset.attribute(att).index()); } } } // first run stream characterizers for (StreamCharacterizer sc : streamCharacterizers) { if (qualitiesAvailable.containsAll(Arrays.asList(sc.getIDs())) == false) { Conversion.log("OK", "Extract Features", "Running Stream Characterizers (full data)"); sc.characterize(dataset); } else { Conversion.log("OK", "Extract Features", "Skipping Stream Characterizers (full data) - already in database"); } } List<Quality> qualities = new ArrayList<DataQuality.Quality>(); if (interval_size != null) { Conversion.log("OK", "Extract Features", "Running Batch Characterizers (partial data)"); for (int i = 0; i < dataset.numInstances(); i += interval_size) { if (apiconnector.getVerboselevel() >= Constants.VERBOSE_LEVEL_ARFF) { Conversion.log("OK", "FantailConnector", "Starting window [" + i + "," + (i + interval_size) + "> (did = " + did + ",total size = " + dataset.numInstances() + ")"); } qualities.addAll(datasetCharacteristics(dataset, i, interval_size, null)); for (StreamCharacterizer sc : streamCharacterizers) { qualities.addAll(hashMaptoList(sc.interval(i), i, interval_size)); } } } else { Conversion.log("OK", "Extract Features", "Running Batch Characterizers (full data, might take a while)"); qualities.addAll(datasetCharacteristics(dataset, null, null, qualitiesAvailable)); for (StreamCharacterizer sc : streamCharacterizers) { Map<String, Double> streamqualities = sc.global(); if (streamqualities != null) { qualities.addAll(hashMaptoList(streamqualities, null, null)); } } } Conversion.log("OK", "Extract Features", "Done generating features, start wrapping up"); DataQuality dq = new DataQuality(did, qualities.toArray(new Quality[qualities.size()])); String strQualities = xstream.toXML(dq); DataQualityUpload dqu = apiconnector .dataQualitiesUpload(Conversion.stringToTempFile(strQualities, "qualities_did_" + did, "xml")); Conversion.log("OK", "Extract Features", "DONE: " + dqu.getDid()); return true; }
From source file:org.opentox.jaqpot3.qsar.predictor.FastRbfNnPredictor.java
License:Open Source License
@Override public Instances predict(Instances inputSet) throws JaqpotException { FastRbfNnModel actualModel = (FastRbfNnModel) model.getActualModel().getSerializableActualModel(); Instances orderedDataset = null;//from w ww .j ava 2 s . c o m try { orderedDataset = InstancesUtil.sortForPMMLModel(model.getIndependentFeatures(), trFieldsAttrIndex, inputSet, -1); } catch (JaqpotException ex) { logger.error(null, ex); } Instances predictions = new Instances(orderedDataset); Add attributeAdder = new Add(); attributeAdder.setAttributeIndex("last"); attributeAdder.setAttributeName(model.getPredictedFeatures().iterator().next().getUri().toString()); try { attributeAdder.setInputFormat(predictions); predictions = Filter.useFilter(predictions, attributeAdder); predictions.setClass( predictions.attribute(model.getPredictedFeatures().iterator().next().getUri().toString())); } catch (Exception ex) { String message = "Exception while trying to add prediction feature to Instances"; logger.debug(message, ex); throw new JaqpotException(message, ex); } Instances nodes = actualModel.getNodes(); double[] sigma = actualModel.getSigma(); double[] coeffs = actualModel.getLrCoefficients(); double sum; for (int i = 0; i < orderedDataset.numInstances(); i++) { sum = 0; for (int j = 0; j < nodes.numInstances(); j++) { sum += rbf(sigma[j], orderedDataset.instance(i), nodes.instance(j)) * coeffs[j]; } predictions.instance(i).setClassValue(sum); } List<Integer> trFieldsIndex = WekaInstancesProcess.getTransformationFieldsAttrIndex(predictions, pmmlObject); predictions = WekaInstancesProcess.removeInstancesAttributes(predictions, trFieldsIndex); Instances resultSet = Instances.mergeInstances(justCompounds, predictions); return resultSet; }
From source file:org.opentox.jaqpot3.qsar.predictor.WekaPredictor.java
License:Open Source License
@Override public Instances predict(Instances inputSet) throws JaqpotException { /* THE OBJECT newData WILL HOST THE PREDICTIONS... */ Instances newData = InstancesUtil.sortForPMMLModel(model.getIndependentFeatures(), trFieldsAttrIndex, inputSet, -1);/*from ww w. ja va 2 s. c om*/ /* ADD TO THE NEW DATA THE PREDICTION FEATURE*/ Add attributeAdder = new Add(); attributeAdder.setAttributeIndex("last"); attributeAdder.setAttributeName(model.getPredictedFeatures().iterator().next().getUri().toString()); Instances predictions = null; try { attributeAdder.setInputFormat(newData); predictions = Filter.useFilter(newData, attributeAdder); predictions.setClass( predictions.attribute(model.getPredictedFeatures().iterator().next().getUri().toString())); } catch (Exception ex) { String message = "Exception while trying to add prediction feature to Instances"; logger.debug(message, ex); throw new JaqpotException(message, ex); } if (predictions != null) { Classifier classifier = (Classifier) model.getActualModel().getSerializableActualModel(); int numInstances = predictions.numInstances(); for (int i = 0; i < numInstances; i++) { try { double predictionValue = classifier.distributionForInstance(predictions.instance(i))[0]; predictions.instance(i).setClassValue(predictionValue); } catch (Exception ex) { logger.warn("Prediction failed :-(", ex); } } } List<Integer> trFieldsIndex = WekaInstancesProcess.getTransformationFieldsAttrIndex(predictions, pmmlObject); predictions = WekaInstancesProcess.removeInstancesAttributes(predictions, trFieldsIndex); Instances result = Instances.mergeInstances(justCompounds, predictions); return result; }
From source file:org.opentox.jaqpot3.qsar.trainer.MlrRegression.java
License:Open Source License
@Override public Model train(Instances data) throws JaqpotException { try {// w w w. j a v a2s .co m getTask().getMeta().addComment( "Dataset successfully retrieved and converted " + "into a weka.core.Instances object"); UpdateTask firstTaskUpdater = new UpdateTask(getTask()); firstTaskUpdater.setUpdateMeta(true); firstTaskUpdater.setUpdateTaskStatus(true);//TODO: Is this necessary? try { firstTaskUpdater.update(); } catch (DbException ex) { throw new JaqpotException(ex); } finally { try { firstTaskUpdater.close(); } catch (DbException ex) { throw new JaqpotException(ex); } } Instances trainingSet = data; getTask().getMeta().addComment("The downloaded dataset is now preprocessed"); firstTaskUpdater = new UpdateTask(getTask()); firstTaskUpdater.setUpdateMeta(true); firstTaskUpdater.setUpdateTaskStatus(true);//TODO: Is this necessary? try { firstTaskUpdater.update(); } catch (DbException ex) { throw new JaqpotException(ex); } finally { try { firstTaskUpdater.close(); } catch (DbException ex) { throw new JaqpotException(ex); } } /* SET CLASS ATTRIBUTE */ Attribute target = trainingSet.attribute(targetUri.toString()); if (target == null) { throw new BadParameterException("The prediction feature you provided was not found in the dataset"); } else { if (!target.isNumeric()) { throw new QSARException("The prediction feature you provided is not numeric."); } } trainingSet.setClass(target); /* Very important: place the target feature at the end! (target = last)*/ int numAttributes = trainingSet.numAttributes(); int classIndex = trainingSet.classIndex(); Instances orderedTrainingSet = null; List<String> properOrder = new ArrayList<String>(numAttributes); for (int j = 0; j < numAttributes; j++) { if (j != classIndex) { properOrder.add(trainingSet.attribute(j).name()); } } properOrder.add(trainingSet.attribute(classIndex).name()); try { orderedTrainingSet = InstancesUtil.sortByFeatureAttrList(properOrder, trainingSet, -1); } catch (JaqpotException ex) { logger.error("Improper dataset - training will stop", ex); throw ex; } orderedTrainingSet.setClass(orderedTrainingSet.attribute(targetUri.toString())); /* START CONSTRUCTION OF MODEL */ Model m = new Model(Configuration.getBaseUri().augment("model", getUuid().toString())); m.setAlgorithm(getAlgorithm()); m.setCreatedBy(getTask().getCreatedBy()); m.setDataset(datasetUri); m.addDependentFeatures(dependentFeature); try { dependentFeature.loadFromRemote(); } catch (ServiceInvocationException ex) { Logger.getLogger(MlrRegression.class.getName()).log(Level.SEVERE, null, ex); } Set<LiteralValue> depFeatTitles = null; if (dependentFeature.getMeta() != null) { depFeatTitles = dependentFeature.getMeta().getTitles(); } String depFeatTitle = dependentFeature.getUri().toString(); if (depFeatTitles != null) { depFeatTitle = depFeatTitles.iterator().next().getValueAsString(); m.getMeta().addTitle("MLR model for " + depFeatTitle) .addDescription("MLR model for the prediction of " + depFeatTitle + " (uri: " + dependentFeature.getUri() + " )."); } else { m.getMeta().addTitle("MLR model for the prediction of the feature with URI " + depFeatTitle) .addComment("No name was found for the feature " + depFeatTitle); } /* * COMPILE THE LIST OF INDEPENDENT FEATURES with the exact order in which * these appear in the Instances object (training set). */ m.setIndependentFeatures(independentFeatures); /* CREATE PREDICTED FEATURE AND POST IT TO REMOTE SERVER */ String predictionFeatureUri = null; Feature predictedFeature = publishFeature(m, dependentFeature.getUnits(), "Predicted " + depFeatTitle + " by MLR model", datasetUri, featureService); m.addPredictedFeatures(predictedFeature); predictionFeatureUri = predictedFeature.getUri().toString(); getTask().getMeta().addComment("Prediction feature " + predictionFeatureUri + " was created."); firstTaskUpdater = new UpdateTask(getTask()); firstTaskUpdater.setUpdateMeta(true); firstTaskUpdater.setUpdateTaskStatus(true);//TODO: Is this necessary? try { firstTaskUpdater.update(); } catch (DbException ex) { throw new JaqpotException(ex); } finally { try { firstTaskUpdater.close(); } catch (DbException ex) { throw new JaqpotException(ex); } } /* ACTUAL TRAINING OF THE MODEL USING WEKA */ LinearRegression linreg = new LinearRegression(); String[] linRegOptions = { "-S", "1", "-C" }; try { linreg.setOptions(linRegOptions); linreg.buildClassifier(orderedTrainingSet); } catch (final Exception ex) {// illegal options or could not build the classifier! String message = "MLR Model could not be trained"; logger.error(message, ex); throw new JaqpotException(message, ex); } try { // evaluate classifier and print some statistics Evaluation eval = new Evaluation(orderedTrainingSet); eval.evaluateModel(linreg, orderedTrainingSet); String stats = eval.toSummaryString("\nResults\n======\n", false); ActualModel am = new ActualModel(linreg); am.setStatistics(stats); m.setActualModel(am); } catch (NotSerializableException ex) { String message = "Model is not serializable"; logger.error(message, ex); throw new JaqpotException(message, ex); } catch (final Exception ex) {// illegal options or could not build the classifier! String message = "MLR Model could not be trained"; logger.error(message, ex); throw new JaqpotException(message, ex); } m.getMeta().addPublisher("OpenTox").addComment("This is a Multiple Linear Regression Model"); //save the instances being predicted to abstract trainer for calculating DoA predictedInstances = orderedTrainingSet; excludeAttributesDoA.add(dependentFeature.getUri().toString()); return m; } catch (QSARException ex) { String message = "QSAR Exception: cannot train MLR model"; logger.error(message, ex); throw new JaqpotException(message, ex); } }