Example usage for weka.core Instances setClass

List of usage examples for weka.core Instances setClass

Introduction

In this page you can find the example usage for weka.core Instances setClass.

Prototype

public void setClass(Attribute att) 

Source Link

Document

Sets the class attribute.

Usage

From source file:org.jaqpot.algorithms.resource.WekaSVM.java

License:Open Source License

@POST
@Path("prediction")
public Response prediction(PredictionRequest request) {
    try {//from w w w  . jav a 2s .c o m
        if (request.getDataset().getDataEntry().isEmpty()
                || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
            return Response.status(Response.Status.BAD_REQUEST)
                    .entity("Dataset is empty. Cannot make predictions on empty dataset.").build();
        }

        String base64Model = (String) request.getRawModel();
        byte[] modelBytes = Base64.getDecoder().decode(base64Model);
        ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes);
        ObjectInput in = new ObjectInputStream(bais);
        WekaModel model = (WekaModel) in.readObject();

        Classifier classifier = model.getClassifier();
        Instances data = InstanceUtils.createFromDataset(request.getDataset());
        String dependentFeature = (String) request.getAdditionalInfo();
        data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes());
        data.setClass(data.attribute(dependentFeature));
        List<LinkedHashMap<String, Object>> predictions = new ArrayList<>();
        //            data.stream().forEach(instance -> {
        //                try {
        //                    double prediction = classifier.classifyInstance(instance);
        //                    Map<String, Object> predictionMap = new HashMap<>();
        //                    predictionMap.put("Weka SVM prediction of " + dependentFeature, prediction);
        //                    predictions.add(predictionMap);
        //                } catch (Exception ex) {
        //                    Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex);
        //                }
        //            });

        for (int i = 0; i < data.numInstances(); i++) {
            Instance instance = data.instance(i);
            try {
                double prediction = classifier.classifyInstance(instance);
                LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>();
                predictionMap.put("Weka SVM prediction of " + dependentFeature, prediction);
                predictions.add(predictionMap);
            } catch (Exception ex) {
                Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
                return Response.status(Response.Status.BAD_REQUEST)
                        .entity("Error while gettting predictions. " + ex.getMessage()).build();
            }
        }

        PredictionResponse response = new PredictionResponse();
        response.setPredictions(predictions);
        return Response.ok(response).build();
    } catch (Exception ex) {
        Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex);
        return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
    }
}

From source file:org.knime.knip.suise.node.boundarymodel.contourdata.WekaContourDataClassifier.java

License:Open Source License

private Instances initDataset(int numFeatures, int numClasses, int capacity) {
    ArrayList<Attribute> attributes = new ArrayList<Attribute>(numFeatures + 5);
    for (int i = 0; i < numFeatures; i++) {
        attributes.add(new Attribute("att" + i));
    }//  ww  w.  j a  v a  2 s . c o  m

    ArrayList<String> classNames = new ArrayList<String>();
    for (int i = 0; i < numClasses; i++) {
        classNames.add("class" + i);
    }

    Attribute classAtt = new Attribute("class", classNames);
    attributes.add(classAtt);
    Instances res = new Instances("trainingData", attributes, capacity);
    res.setClass(classAtt);

    return res;

}

From source file:org.mcennis.graphrat.algorithm.machinelearning.WekaClassifierMultiAttribute.java

License:Open Source License

@Override
public void execute(Graph g) {
    Actor[] source = g.getActor((String) parameter[1].getValue());
    if (source != null) {

        // create the atributes for each artist
        FastVector sourceTypes = new FastVector();
        Actor[] dest = g.getActor((String) parameter[3].getValue());
        if (dest != null) {
            // create the Instances set backing this object
            Instances masterSet = null;
            Instance[] trainingData = new Instance[source.length];
            for (int i = 0; i < source.length; ++i) {
                // First, acquire the instance objects for each actor
                Property p = null;//from  w  w w  . j  a  va 2  s .  com
                if ((Boolean) parameter[10].getValue()) {
                    p = source[i].getProperty((String) parameter[2].getValue() + g.getID());
                } else {
                    p = source[i].getProperty((String) parameter[2].getValue());
                }
                if (p != null) {
                    Object[] values = p.getValue();
                    if (values.length > 0) {
                        sourceTypes.addElement(source[i].getID());
                        trainingData[i] = (Instance) ((Instance) values[0]).copy();
                        // assume that this Instance has a backing dataset 
                        // that contains all Instance objects to be tested
                        if (masterSet == null) {
                            masterSet = new Instances(trainingData[i].dataset(), source.length);
                        }
                        masterSet.add(trainingData[i]);
                        sourceTypes.addElement(source[i].getID());
                    } else {
                        trainingData[i] = null;
                        Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.WARNING,
                                "Actor " + source[i].getType() + ":" + source[i].getID()
                                        + " does not have an Instance value of property ID " + p.getType());
                    }
                } else {
                    trainingData[i] = null;
                    Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.WARNING,
                            "Actor " + source[i].getType() + ":" + source[i].getID()
                                    + " does not have a property of ID " + p.getType());
                }
            }

            Vector<Attribute> destVector = new Vector<Attribute>();
            for (int i = 0; i < dest.length; ++i) {
                FastVector type = new FastVector();
                type.addElement("false");
                type.addElement("true");
                Attribute tmp = new Attribute(dest[i].getID(), type);
                destVector.add(tmp);
                masterSet.insertAttributeAt(tmp, masterSet.numAttributes());
            }
            Attribute sourceID = new Attribute("sourceID", sourceTypes);
            masterSet.insertAttributeAt(sourceID, masterSet.numAttributes());

            //set ground truth for evaluation
            for (int i = 0; i < masterSet.numInstances(); ++i) {
                Instance inst = masterSet.instance(i);
                Actor user = g.getActor((String) parameter[i].getValue(),
                        sourceID.value((int) inst.value(sourceID)));
                if (user != null) {
                    for (int j = 0; j < dest.length; ++j) {
                        if (g.getLink((String) parameter[4].getValue(), user, dest[j]) != null) {
                            inst.setValue(sourceID, "true");
                        } else {
                            if ((Boolean) parameter[9].getValue()) {
                                inst.setValue(sourceID, "false");
                            } else {
                                inst.setValue(sourceID, Double.NaN);
                            }
                        }
                    }
                } else {
                    Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.SEVERE,
                            "Actor " + sourceID.value((int) inst.value(sourceID)) + " does not exist in graph");
                }
            }

            // perform cross fold evaluation of each classifier in turn
            String[] opts = ((String) parameter[9].getValue()).split("\\s+");
            Properties props = new Properties();
            if ((Boolean) parameter[11].getValue()) {
                props.setProperty("LinkType", (String) parameter[5].getValue() + g.getID());
            } else {
                props.setProperty("LinkType", (String) parameter[5].getValue());
            }
            props.setProperty("LinkClass", "Basic");
            try {
                for (int destCount = 0; destCount < dest.length; ++destCount) {
                    masterSet.setClass(destVector.get(destCount));
                    for (int i = 0; i < (Integer) parameter[8].getValue(); ++i) {
                        Instances test = masterSet.testCV((Integer) parameter[8].getValue(), i);
                        Instances train = masterSet.testCV((Integer) parameter[8].getValue(), i);
                        Classifier classifier = (Classifier) ((Class) parameter[7].getValue()).newInstance();
                        classifier.setOptions(opts);
                        classifier.buildClassifier(train);
                        for (int j = 0; j < test.numInstances(); ++j) {
                            String sourceName = sourceID.value((int) test.instance(j).value(sourceID));
                            double result = classifier.classifyInstance(test.instance(j));
                            String predicted = masterSet.classAttribute().value((int) result);
                            Link derived = LinkFactory.newInstance().create(props);
                            derived.set(g.getActor((String) parameter[2].getValue(), sourceName), 1.0,
                                    g.getActor((String) parameter[3].getValue(), predicted));
                            g.add(derived);
                        }
                    }
                }
            } catch (InstantiationException ex) {
                Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.SEVERE, null, ex);
            } catch (IllegalAccessException ex) {
                Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.SEVERE, null, ex);
            } catch (Exception ex) {
                Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.SEVERE, null, ex);
            }

        } else { // dest==null
            Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.WARNING,
                    "Ground truth mode '" + (String) parameter[3].getValue() + "' has no actors");
        }
    } else { // source==null
        Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.WARNING,
                "Source mode '" + (String) parameter[2].getValue() + "' has no actors");
    }
}

From source file:org.mcennis.graphrat.algorithm.machinelearning.WekaClassifierOneAttribute.java

License:Open Source License

@Override
public void execute(Graph g) {
    Actor[] source = g.getActor((String) parameter[1].getValue());
    if (source != null) {

        // create the Instance sets for each ac
        FastVector classTypes = new FastVector();
        FastVector sourceTypes = new FastVector();
        Actor[] dest = g.getActor((String) parameter[3].getValue());
        if (dest != null) {
            for (int i = 0; i < dest.length; ++i) {
                classTypes.addElement(dest[i].getID());
            }/*  ww w  .j ava  2  s.com*/
            Attribute classAttribute = new Attribute((String) parameter[5].getValue(), classTypes);

            Instance[] trainingData = new Instance[source.length];
            Instances masterSet = null;
            for (int i = 0; i < source.length; ++i) {

                // First, acquire the instance objects for each actor
                Property p = null;
                if ((Boolean) parameter[9].getValue()) {
                    p = source[i].getProperty((String) parameter[2].getValue() + g.getID());
                } else {
                    p = source[i].getProperty((String) parameter[2].getValue());
                }
                if (p != null) {
                    Object[] values = p.getValue();
                    if (values.length > 0) {
                        sourceTypes.addElement(source[i].getID());
                        trainingData[i] = (Instance) ((Instance) values[0]).copy();
                        // assume that this Instance has a backing dataset 
                        // that contains all Instance objects to be tested
                        if (masterSet == null) {
                            masterSet = new Instances(trainingData[i].dataset(), source.length);
                        }
                        masterSet.add(trainingData[i]);
                    } else {
                        trainingData[i] = null;
                        Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.WARNING,
                                "Actor " + source[i].getType() + ":" + source[i].getID()
                                        + " does not have an Instance value of property ID " + p.getType());
                    }
                } else {
                    trainingData[i] = null;
                    Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.WARNING,
                            "Actor " + source[i].getType() + ":" + source[i].getID()
                                    + " does not have a property of ID " + p.getType());
                }

            } // for every actor, fix the instance
            Attribute sourceID = new Attribute("sourceID", sourceTypes);
            masterSet.insertAttributeAt(sourceID, masterSet.numAttributes());
            masterSet.insertAttributeAt(classAttribute, masterSet.numAttributes());
            masterSet.setClass(classAttribute);
            for (int i = 0; i < source.length; ++i) {
                if (trainingData[i] != null) {
                    trainingData[i].setValue(sourceID, source[i].getID());
                    Link[] link = g.getLinkBySource((String) parameter[4].getValue(), source[i]);
                    if (link == null) {
                        trainingData[i].setClassValue(Double.NaN);
                    } else {
                        trainingData[i].setClassValue(link[0].getDestination().getID());
                    }
                }
            }

            String[] opts = ((String) parameter[7].getValue()).split("\\s+");
            Properties props = new Properties();
            if ((Boolean) parameter[10].getValue()) {
                props.setProperty("LinkType", (String) parameter[5].getValue() + g.getID());
            } else {
                props.setProperty("LinkType", (String) parameter[5].getValue());
            }
            props.setProperty("LinkClass", "Basic");
            try {
                for (int i = 0; i < (Integer) parameter[8].getValue(); ++i) {
                    Instances test = masterSet.testCV((Integer) parameter[8].getValue(), i);
                    Instances train = masterSet.testCV((Integer) parameter[8].getValue(), i);
                    Classifier classifier = (Classifier) ((Class) parameter[6].getValue()).newInstance();
                    classifier.setOptions(opts);
                    classifier.buildClassifier(train);
                    for (int j = 0; j < test.numInstances(); ++j) {
                        String sourceName = sourceID.value((int) test.instance(j).value(sourceID));
                        double result = classifier.classifyInstance(test.instance(j));
                        String predicted = masterSet.classAttribute().value((int) result);
                        Link derived = LinkFactory.newInstance().create(props);
                        derived.set(g.getActor((String) parameter[2].getValue(), sourceName), 1.0,
                                g.getActor((String) parameter[3].getValue(), predicted));
                        g.add(derived);
                    }
                }
            } catch (InstantiationException ex) {
                Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.SEVERE, null, ex);
            } catch (IllegalAccessException ex) {
                Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.SEVERE, null, ex);
            } catch (Exception ex) {
                Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.SEVERE, null, ex);
            }

        } else { // dest==null
            Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.WARNING,
                    "Ground truth mode '" + (String) parameter[3].getValue() + "' has no actors");
        }
    } else { // source==null
        Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.WARNING,
                "Source mode '" + (String) parameter[2].getValue() + "' has no actors");
    }
}

From source file:org.openml.webapplication.features.ExtractFeatures.java

License:Open Source License

public static List<Feature> getFeatures(Instances dataset, String defaultClass) {
    if (defaultClass != null) {
        dataset.setClass(dataset.attribute(defaultClass));
    } else {/*from www. j  a  va2 s .  co m*/
        dataset.setClassIndex(dataset.numAttributes() - 1);
    }

    final ArrayList<Feature> resultFeatures = new ArrayList<Feature>();

    for (int i = 0; i < dataset.numAttributes(); i++) {
        Attribute att = dataset.attribute(i);
        int numValues = dataset.classAttribute().isNominal() ? dataset.classAttribute().numValues() : 0;
        AttributeStatistics attributeStats = new AttributeStatistics(dataset.attribute(i), numValues);

        for (int j = 0; j < dataset.numInstances(); ++j) {
            attributeStats.addValue(dataset.get(j).value(i), dataset.get(j).classValue());
        }

        String data_type = null;

        Integer numberOfDistinctValues = null;
        Integer numberOfUniqueValues = null;
        Integer numberOfMissingValues = null;
        Integer numberOfIntegerValues = null;
        Integer numberOfRealValues = null;
        Integer numberOfNominalValues = null;
        Integer numberOfValues = null;

        Double maximumValue = null;
        Double minimumValue = null;
        Double meanValue = null;
        Double standardDeviation = null;

        AttributeStats as = dataset.attributeStats(i);

        numberOfDistinctValues = as.distinctCount;
        numberOfUniqueValues = as.uniqueCount;
        numberOfMissingValues = as.missingCount;
        numberOfIntegerValues = as.intCount;
        numberOfRealValues = as.realCount;
        numberOfMissingValues = as.missingCount;

        if (att.isNominal()) {
            numberOfNominalValues = att.numValues();
        }
        numberOfValues = attributeStats.getTotalObservations();

        if (att.isNumeric()) {
            maximumValue = attributeStats.getMaximum();
            minimumValue = attributeStats.getMinimum();
            meanValue = attributeStats.getMean();
            standardDeviation = 0.0;
            try {
                standardDeviation = attributeStats.getStandardDeviation();
            } catch (Exception e) {
                Conversion.log("WARNING", "StdDev", "Could not compute standard deviation of feature "
                        + att.name() + ": " + e.getMessage());
            }
        }

        if (att.type() == 0) {
            data_type = "numeric";
        } else if (att.type() == 1) {
            data_type = "nominal";
        } else if (att.type() == 2) {
            data_type = "string";
        } else {
            data_type = "unknown";
        }

        resultFeatures.add(new Feature(att.index(), att.name(), data_type, att.index() == dataset.classIndex(),
                numberOfDistinctValues, numberOfUniqueValues, numberOfMissingValues, numberOfIntegerValues,
                numberOfRealValues, numberOfNominalValues, numberOfValues, maximumValue, minimumValue,
                meanValue, standardDeviation, attributeStats.getClassDistribution()));
    }
    return resultFeatures;
}

From source file:org.openml.webapplication.features.ExtractFeatures.java

License:Open Source License

public static List<Quality> getQualities(Instances dataset, String defaultClass) {
    if (defaultClass != null) {
        dataset.setClass(dataset.attribute(defaultClass));
    } else {/*from www .  ja v a2  s  .co  m*/
        dataset.setClassIndex(dataset.numAttributes() - 1);
    }
    List<Quality> result = new ArrayList<Quality>();
    Characterizer simpleQualities = new SimpleMetaFeatures();
    Map<String, Double> qualities = simpleQualities.characterize(dataset);
    for (String quality : qualities.keySet()) {
        result.add(new Quality(quality, qualities.get(quality) + ""));
    }
    return result;
}

From source file:org.openml.webapplication.features.FantailConnector.java

License:Open Source License

private boolean extractFeatures(Integer did, Integer interval_size) throws Exception {
    Conversion.log("OK", "Extract Features", "Start extracting features for dataset: " + did);

    List<String> qualitiesAvailable = Arrays.asList(apiconnector.dataQualities(did).getQualityNames());

    // TODO: initialize this properly!!!!!!
    streamCharacterizers = new StreamCharacterizer[1];
    streamCharacterizers[0] = new ChangeDetectors(interval_size);

    DataSetDescription dsd = apiconnector.dataGet(did);

    Conversion.log("OK", "Extract Features", "Start downloading dataset: " + did);

    Instances dataset = new Instances(new FileReader(dsd.getDataset(apiconnector.getApiKey())));

    dataset.setClass(dataset.attribute(dsd.getDefault_target_attribute()));
    if (dsd.getRow_id_attribute() != null) {
        if (dataset.attribute(dsd.getRow_id_attribute()) != null) {
            dataset.deleteAttributeAt(dataset.attribute(dsd.getRow_id_attribute()).index());
        }/*from   w  w w  . ja  v  a  2 s .  c  o m*/
    }
    if (dsd.getIgnore_attribute() != null) {
        for (String att : dsd.getIgnore_attribute()) {
            if (dataset.attribute(att) != null) {
                dataset.deleteAttributeAt(dataset.attribute(att).index());
            }
        }
    }

    // first run stream characterizers
    for (StreamCharacterizer sc : streamCharacterizers) {

        if (qualitiesAvailable.containsAll(Arrays.asList(sc.getIDs())) == false) {
            Conversion.log("OK", "Extract Features", "Running Stream Characterizers (full data)");
            sc.characterize(dataset);
        } else {
            Conversion.log("OK", "Extract Features",
                    "Skipping Stream Characterizers (full data) - already in database");
        }
    }

    List<Quality> qualities = new ArrayList<DataQuality.Quality>();
    if (interval_size != null) {
        Conversion.log("OK", "Extract Features", "Running Batch Characterizers (partial data)");

        for (int i = 0; i < dataset.numInstances(); i += interval_size) {
            if (apiconnector.getVerboselevel() >= Constants.VERBOSE_LEVEL_ARFF) {
                Conversion.log("OK", "FantailConnector", "Starting window [" + i + "," + (i + interval_size)
                        + "> (did = " + did + ",total size = " + dataset.numInstances() + ")");
            }
            qualities.addAll(datasetCharacteristics(dataset, i, interval_size, null));

            for (StreamCharacterizer sc : streamCharacterizers) {
                qualities.addAll(hashMaptoList(sc.interval(i), i, interval_size));
            }
        }

    } else {
        Conversion.log("OK", "Extract Features",
                "Running Batch Characterizers (full data, might take a while)");
        qualities.addAll(datasetCharacteristics(dataset, null, null, qualitiesAvailable));
        for (StreamCharacterizer sc : streamCharacterizers) {
            Map<String, Double> streamqualities = sc.global();
            if (streamqualities != null) {
                qualities.addAll(hashMaptoList(streamqualities, null, null));
            }
        }
    }
    Conversion.log("OK", "Extract Features", "Done generating features, start wrapping up");
    DataQuality dq = new DataQuality(did, qualities.toArray(new Quality[qualities.size()]));
    String strQualities = xstream.toXML(dq);

    DataQualityUpload dqu = apiconnector
            .dataQualitiesUpload(Conversion.stringToTempFile(strQualities, "qualities_did_" + did, "xml"));
    Conversion.log("OK", "Extract Features", "DONE: " + dqu.getDid());

    return true;
}

From source file:org.opentox.jaqpot3.qsar.predictor.FastRbfNnPredictor.java

License:Open Source License

@Override
public Instances predict(Instances inputSet) throws JaqpotException {
    FastRbfNnModel actualModel = (FastRbfNnModel) model.getActualModel().getSerializableActualModel();
    Instances orderedDataset = null;//from w  ww  .j  ava 2  s  . c  o  m
    try {
        orderedDataset = InstancesUtil.sortForPMMLModel(model.getIndependentFeatures(), trFieldsAttrIndex,
                inputSet, -1);
    } catch (JaqpotException ex) {
        logger.error(null, ex);
    }

    Instances predictions = new Instances(orderedDataset);
    Add attributeAdder = new Add();
    attributeAdder.setAttributeIndex("last");
    attributeAdder.setAttributeName(model.getPredictedFeatures().iterator().next().getUri().toString());
    try {
        attributeAdder.setInputFormat(predictions);
        predictions = Filter.useFilter(predictions, attributeAdder);
        predictions.setClass(
                predictions.attribute(model.getPredictedFeatures().iterator().next().getUri().toString()));
    } catch (Exception ex) {
        String message = "Exception while trying to add prediction feature to Instances";
        logger.debug(message, ex);
        throw new JaqpotException(message, ex);
    }

    Instances nodes = actualModel.getNodes();
    double[] sigma = actualModel.getSigma();
    double[] coeffs = actualModel.getLrCoefficients();
    double sum;
    for (int i = 0; i < orderedDataset.numInstances(); i++) {
        sum = 0;
        for (int j = 0; j < nodes.numInstances(); j++) {
            sum += rbf(sigma[j], orderedDataset.instance(i), nodes.instance(j)) * coeffs[j];
        }
        predictions.instance(i).setClassValue(sum);
    }

    List<Integer> trFieldsIndex = WekaInstancesProcess.getTransformationFieldsAttrIndex(predictions,
            pmmlObject);
    predictions = WekaInstancesProcess.removeInstancesAttributes(predictions, trFieldsIndex);
    Instances resultSet = Instances.mergeInstances(justCompounds, predictions);

    return resultSet;

}

From source file:org.opentox.jaqpot3.qsar.predictor.WekaPredictor.java

License:Open Source License

@Override
public Instances predict(Instances inputSet) throws JaqpotException {

    /* THE OBJECT newData WILL HOST THE PREDICTIONS... */
    Instances newData = InstancesUtil.sortForPMMLModel(model.getIndependentFeatures(), trFieldsAttrIndex,
            inputSet, -1);/*from  ww w. ja  va 2  s. c  om*/
    /* ADD TO THE NEW DATA THE PREDICTION FEATURE*/
    Add attributeAdder = new Add();
    attributeAdder.setAttributeIndex("last");
    attributeAdder.setAttributeName(model.getPredictedFeatures().iterator().next().getUri().toString());
    Instances predictions = null;
    try {
        attributeAdder.setInputFormat(newData);
        predictions = Filter.useFilter(newData, attributeAdder);
        predictions.setClass(
                predictions.attribute(model.getPredictedFeatures().iterator().next().getUri().toString()));
    } catch (Exception ex) {
        String message = "Exception while trying to add prediction feature to Instances";
        logger.debug(message, ex);
        throw new JaqpotException(message, ex);
    }

    if (predictions != null) {
        Classifier classifier = (Classifier) model.getActualModel().getSerializableActualModel();

        int numInstances = predictions.numInstances();
        for (int i = 0; i < numInstances; i++) {
            try {
                double predictionValue = classifier.distributionForInstance(predictions.instance(i))[0];
                predictions.instance(i).setClassValue(predictionValue);
            } catch (Exception ex) {
                logger.warn("Prediction failed :-(", ex);
            }
        }
    }

    List<Integer> trFieldsIndex = WekaInstancesProcess.getTransformationFieldsAttrIndex(predictions,
            pmmlObject);
    predictions = WekaInstancesProcess.removeInstancesAttributes(predictions, trFieldsIndex);
    Instances result = Instances.mergeInstances(justCompounds, predictions);

    return result;
}

From source file:org.opentox.jaqpot3.qsar.trainer.MlrRegression.java

License:Open Source License

@Override
public Model train(Instances data) throws JaqpotException {
    try {//  w w w.  j  a v  a2s .co  m

        getTask().getMeta().addComment(
                "Dataset successfully retrieved and converted " + "into a weka.core.Instances object");
        UpdateTask firstTaskUpdater = new UpdateTask(getTask());
        firstTaskUpdater.setUpdateMeta(true);
        firstTaskUpdater.setUpdateTaskStatus(true);//TODO: Is this necessary?
        try {
            firstTaskUpdater.update();
        } catch (DbException ex) {
            throw new JaqpotException(ex);
        } finally {
            try {
                firstTaskUpdater.close();
            } catch (DbException ex) {
                throw new JaqpotException(ex);
            }
        }

        Instances trainingSet = data;
        getTask().getMeta().addComment("The downloaded dataset is now preprocessed");
        firstTaskUpdater = new UpdateTask(getTask());
        firstTaskUpdater.setUpdateMeta(true);
        firstTaskUpdater.setUpdateTaskStatus(true);//TODO: Is this necessary?
        try {
            firstTaskUpdater.update();
        } catch (DbException ex) {
            throw new JaqpotException(ex);
        } finally {
            try {
                firstTaskUpdater.close();
            } catch (DbException ex) {
                throw new JaqpotException(ex);
            }
        }

        /* SET CLASS ATTRIBUTE */
        Attribute target = trainingSet.attribute(targetUri.toString());
        if (target == null) {
            throw new BadParameterException("The prediction feature you provided was not found in the dataset");
        } else {
            if (!target.isNumeric()) {
                throw new QSARException("The prediction feature you provided is not numeric.");
            }
        }
        trainingSet.setClass(target);
        /* Very important: place the target feature at the end! (target = last)*/
        int numAttributes = trainingSet.numAttributes();
        int classIndex = trainingSet.classIndex();
        Instances orderedTrainingSet = null;
        List<String> properOrder = new ArrayList<String>(numAttributes);
        for (int j = 0; j < numAttributes; j++) {
            if (j != classIndex) {
                properOrder.add(trainingSet.attribute(j).name());
            }
        }
        properOrder.add(trainingSet.attribute(classIndex).name());
        try {
            orderedTrainingSet = InstancesUtil.sortByFeatureAttrList(properOrder, trainingSet, -1);
        } catch (JaqpotException ex) {
            logger.error("Improper dataset - training will stop", ex);
            throw ex;
        }
        orderedTrainingSet.setClass(orderedTrainingSet.attribute(targetUri.toString()));

        /* START CONSTRUCTION OF MODEL */
        Model m = new Model(Configuration.getBaseUri().augment("model", getUuid().toString()));
        m.setAlgorithm(getAlgorithm());
        m.setCreatedBy(getTask().getCreatedBy());
        m.setDataset(datasetUri);
        m.addDependentFeatures(dependentFeature);
        try {
            dependentFeature.loadFromRemote();
        } catch (ServiceInvocationException ex) {
            Logger.getLogger(MlrRegression.class.getName()).log(Level.SEVERE, null, ex);
        }

        Set<LiteralValue> depFeatTitles = null;
        if (dependentFeature.getMeta() != null) {
            depFeatTitles = dependentFeature.getMeta().getTitles();
        }

        String depFeatTitle = dependentFeature.getUri().toString();
        if (depFeatTitles != null) {
            depFeatTitle = depFeatTitles.iterator().next().getValueAsString();
            m.getMeta().addTitle("MLR model for " + depFeatTitle)
                    .addDescription("MLR model for the prediction of " + depFeatTitle + " (uri: "
                            + dependentFeature.getUri() + " ).");
        } else {
            m.getMeta().addTitle("MLR model for the prediction of the feature with URI " + depFeatTitle)
                    .addComment("No name was found for the feature " + depFeatTitle);
        }

        /*
         * COMPILE THE LIST OF INDEPENDENT FEATURES with the exact order in which
         * these appear in the Instances object (training set).
         */
        m.setIndependentFeatures(independentFeatures);

        /* CREATE PREDICTED FEATURE AND POST IT TO REMOTE SERVER */
        String predictionFeatureUri = null;
        Feature predictedFeature = publishFeature(m, dependentFeature.getUnits(),
                "Predicted " + depFeatTitle + " by MLR model", datasetUri, featureService);
        m.addPredictedFeatures(predictedFeature);
        predictionFeatureUri = predictedFeature.getUri().toString();

        getTask().getMeta().addComment("Prediction feature " + predictionFeatureUri + " was created.");

        firstTaskUpdater = new UpdateTask(getTask());
        firstTaskUpdater.setUpdateMeta(true);
        firstTaskUpdater.setUpdateTaskStatus(true);//TODO: Is this necessary?
        try {
            firstTaskUpdater.update();
        } catch (DbException ex) {
            throw new JaqpotException(ex);
        } finally {
            try {
                firstTaskUpdater.close();
            } catch (DbException ex) {
                throw new JaqpotException(ex);
            }
        }

        /* ACTUAL TRAINING OF THE MODEL USING WEKA */
        LinearRegression linreg = new LinearRegression();
        String[] linRegOptions = { "-S", "1", "-C" };

        try {
            linreg.setOptions(linRegOptions);
            linreg.buildClassifier(orderedTrainingSet);

        } catch (final Exception ex) {// illegal options or could not build the classifier!
            String message = "MLR Model could not be trained";
            logger.error(message, ex);
            throw new JaqpotException(message, ex);
        }

        try {
            // evaluate classifier and print some statistics
            Evaluation eval = new Evaluation(orderedTrainingSet);
            eval.evaluateModel(linreg, orderedTrainingSet);
            String stats = eval.toSummaryString("\nResults\n======\n", false);

            ActualModel am = new ActualModel(linreg);
            am.setStatistics(stats);
            m.setActualModel(am);
        } catch (NotSerializableException ex) {
            String message = "Model is not serializable";
            logger.error(message, ex);
            throw new JaqpotException(message, ex);
        } catch (final Exception ex) {// illegal options or could not build the classifier!
            String message = "MLR Model could not be trained";
            logger.error(message, ex);
            throw new JaqpotException(message, ex);
        }

        m.getMeta().addPublisher("OpenTox").addComment("This is a Multiple Linear Regression Model");

        //save the instances being predicted to abstract trainer for calculating DoA
        predictedInstances = orderedTrainingSet;
        excludeAttributesDoA.add(dependentFeature.getUri().toString());

        return m;
    } catch (QSARException ex) {
        String message = "QSAR Exception: cannot train MLR model";
        logger.error(message, ex);
        throw new JaqpotException(message, ex);
    }
}