Example usage for weka.core Instances setClass

List of usage examples for weka.core Instances setClass

Introduction

In this page you can find the example usage for weka.core Instances setClass.

Prototype

public void setClass(Attribute att) 

Source Link

Document

Sets the class attribute.

Usage

From source file:org.opentox.jaqpot3.qsar.trainer.PLSTrainer.java

License:Open Source License

@Override
public Model train(Instances data) throws JaqpotException {
    Model model = new Model(Configuration.getBaseUri().augment("model", getUuid().toString()));

    data.setClass(data.attribute(targetUri.toString()));

    Boolean targetURIIncluded = false;
    for (Feature tempFeature : independentFeatures) {
        if (StringUtils.equals(tempFeature.getUri().toString(), targetUri.toString())) {
            targetURIIncluded = true;/*from   w ww  .  j a v  a2 s .com*/
            break;
        }
    }
    if (!targetURIIncluded) {
        independentFeatures.add(new Feature(targetUri));
    }
    model.setIndependentFeatures(independentFeatures);

    /*
     * Train the PLS filter
     */
    PLSFilter pls = new PLSFilter();
    try {
        pls.setInputFormat(data);
        pls.setOptions(new String[] { "-C", Integer.toString(numComponents), "-A", pls_algorithm, "-P",
                preprocessing, "-U", doUpdateClass });
        PLSFilter.useFilter(data, pls);
    } catch (Exception ex) {
        Logger.getLogger(PLSTrainer.class.getName()).log(Level.SEVERE, null, ex);
    }

    PLSModel actualModel = new PLSModel(pls);
    try {

        PLSClassifier cls = new PLSClassifier();
        cls.setFilter(pls);
        cls.buildClassifier(data);

        // evaluate classifier and print some statistics
        Evaluation eval = new Evaluation(data);
        eval.evaluateModel(cls, data);
        String stats = eval.toSummaryString("", false);

        ActualModel am = new ActualModel(actualModel);
        am.setStatistics(stats);

        model.setActualModel(am);
    } catch (NotSerializableException ex) {
        Logger.getLogger(PLSTrainer.class.getName()).log(Level.SEVERE, null, ex);
        throw new JaqpotException(ex);
    } catch (Exception ex) {
        Logger.getLogger(PLSTrainer.class.getName()).log(Level.SEVERE, null, ex);
        throw new JaqpotException(ex);
    }

    model.setDataset(datasetUri);
    model.setAlgorithm(Algorithms.plsFilter());
    model.getMeta().addTitle("PLS Model for " + datasetUri);

    Set<Parameter> parameters = new HashSet<Parameter>();
    Parameter targetPrm = new Parameter(Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()),
            "target", new LiteralValue(targetUri.toString(), XSDDatatype.XSDstring))
                    .setScope(Parameter.ParameterScope.MANDATORY);
    Parameter nComponentsPrm = new Parameter(Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()),
            "numComponents", new LiteralValue(numComponents, XSDDatatype.XSDpositiveInteger))
                    .setScope(Parameter.ParameterScope.MANDATORY);
    Parameter preprocessingPrm = new Parameter(
            Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()), "preprocessing",
            new LiteralValue(preprocessing, XSDDatatype.XSDstring)).setScope(Parameter.ParameterScope.OPTIONAL);
    Parameter algorithmPrm = new Parameter(Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()),
            "algorithm", new LiteralValue(pls_algorithm, XSDDatatype.XSDstring))
                    .setScope(Parameter.ParameterScope.OPTIONAL);
    Parameter doUpdatePrm = new Parameter(Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()),
            "doUpdateClass", new LiteralValue(doUpdateClass, XSDDatatype.XSDboolean))
                    .setScope(Parameter.ParameterScope.OPTIONAL);

    parameters.add(targetPrm);
    parameters.add(nComponentsPrm);
    parameters.add(preprocessingPrm);
    parameters.add(doUpdatePrm);
    parameters.add(algorithmPrm);
    model.setParameters(parameters);

    for (int i = 0; i < numComponents; i++) {
        Feature f = publishFeature(model, "", "PLS-" + i, datasetUri, featureService);
        model.addPredictedFeatures(f);
    }

    //save the instances being predicted to abstract trainer for calculating DoA
    predictedInstances = data;
    //in pls target is not excluded

    return model;
}

From source file:org.opentox.jaqpot3.qsar.trainer.SvmRegression.java

License:Open Source License

@Override
public Model train(Instances data) throws JaqpotException {
    try {//from  w  w  w .ja  v a2s.c  om
        Attribute target = data.attribute(targetUri.toString());
        if (target == null) {
            throw new QSARException("The prediction feature you provided was not found in the dataset");
        } else {
            if (!target.isNumeric()) {
                throw new QSARException("The prediction feature you provided is not numeric.");
            }
        }
        data.setClass(target);
        //data.deleteAttributeAt(0);//remove the first attribute, i.e. 'compound_uri' or 'URI'
        /* Very important: place the target feature at the end! (target = last)*/
        int numAttributes = data.numAttributes();
        int classIndex = data.classIndex();
        Instances orderedTrainingSet = null;
        List<String> properOrder = new ArrayList<String>(numAttributes);
        for (int j = 0; j < numAttributes; j++) {
            if (j != classIndex) {
                properOrder.add(data.attribute(j).name());
            }
        }
        properOrder.add(data.attribute(classIndex).name());
        try {
            orderedTrainingSet = InstancesUtil.sortByFeatureAttrList(properOrder, data, -1);
        } catch (JaqpotException ex) {
            logger.error(null, ex);
        }
        orderedTrainingSet.setClass(orderedTrainingSet.attribute(targetUri.toString()));

        getTask().getMeta()
                .addComment("Dataset successfully retrieved and converted into a weka.core.Instances object");
        UpdateTask firstTaskUpdater = new UpdateTask(getTask());
        firstTaskUpdater.setUpdateMeta(true);
        firstTaskUpdater.setUpdateTaskStatus(true);//TODO: Is this necessary?
        try {
            firstTaskUpdater.update();
        } catch (DbException ex) {
            throw new JaqpotException(ex);
        } finally {
            try {
                firstTaskUpdater.close();
            } catch (DbException ex) {
                throw new JaqpotException(ex);
            }
        }

        Model m = new Model(Configuration.getBaseUri().augment("model", getUuid().toString()));

        // INITIALIZE THE REGRESSOR regressor
        SVMreg regressor = new SVMreg();
        final String[] regressorOptions = { "-P", Double.toString(epsilon), "-T", Double.toString(tolerance) };
        Kernel svm_kernel = null;
        if (kernel.equalsIgnoreCase("rbf")) {
            RBFKernel rbf_kernel = new RBFKernel();
            rbf_kernel.setGamma(Double.parseDouble(Double.toString(gamma)));
            rbf_kernel.setCacheSize(Integer.parseInt(Integer.toString(cacheSize)));
            svm_kernel = rbf_kernel;
        } else if (kernel.equalsIgnoreCase("polynomial")) {
            PolyKernel poly_kernel = new PolyKernel();
            poly_kernel.setExponent(Double.parseDouble(Integer.toString(degree)));
            poly_kernel.setCacheSize(Integer.parseInt(Integer.toString(cacheSize)));
            poly_kernel.setUseLowerOrder(true);
            svm_kernel = poly_kernel;
        } else if (kernel.equalsIgnoreCase("linear")) {
            PolyKernel poly_kernel = new PolyKernel();
            poly_kernel.setExponent((double) 1.0);
            poly_kernel.setCacheSize(Integer.parseInt(Integer.toString(cacheSize)));
            poly_kernel.setUseLowerOrder(true);
            svm_kernel = poly_kernel;
        }

        try {
            regressor.setOptions(regressorOptions);
        } catch (final Exception ex) {
            throw new QSARException("Bad options in SVM trainer for epsilon = {" + epsilon + "} or "
                    + "tolerance = {" + tolerance + "}.", ex);
        }
        regressor.setKernel(svm_kernel);
        // START TRAINING & CREATE MODEL
        try {
            regressor.buildClassifier(orderedTrainingSet);

            // evaluate classifier and print some statistics
            Evaluation eval = new Evaluation(orderedTrainingSet);
            eval.evaluateModel(regressor, orderedTrainingSet);
            String stats = eval.toSummaryString("", false);

            ActualModel am = new ActualModel(regressor);
            am.setStatistics(stats);
            m.setActualModel(am);
            // m.setStatistics(stats);
        } catch (NotSerializableException ex) {
            String message = "Model is not serializable";
            logger.error(message, ex);
            throw new JaqpotException(message, ex);
        } catch (final Exception ex) {
            throw new QSARException("Unexpected condition while trying to train "
                    + "the model. Possible explanation : {" + ex.getMessage() + "}", ex);
        }

        m.setAlgorithm(getAlgorithm());
        m.setCreatedBy(getTask().getCreatedBy());
        m.setDataset(datasetUri);
        m.addDependentFeatures(dependentFeature);
        try {
            dependentFeature.loadFromRemote();
        } catch (ServiceInvocationException ex) {
            java.util.logging.Logger.getLogger(SvmRegression.class.getName()).log(Level.SEVERE, null, ex);
        }
        m.addDependentFeatures(dependentFeature);

        m.setIndependentFeatures(independentFeatures);

        String predictionFeatureUri = null;
        Feature predictedFeature = publishFeature(m, dependentFeature.getUnits(),
                "Feature created as prediction feature for SVM model " + m.getUri(), datasetUri,
                featureService);
        m.addPredictedFeatures(predictedFeature);
        predictionFeatureUri = predictedFeature.getUri().toString();

        getTask().getMeta().addComment("Prediction feature " + predictionFeatureUri + " was created.");

        /* SET PARAMETERS FOR THE TRAINED MODEL */
        m.setParameters(new HashSet<Parameter>());
        Parameter<String> kernelParam = new Parameter("kernel", new LiteralValue<String>(kernel))
                .setScope(Parameter.ParameterScope.OPTIONAL);
        kernelParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong()));
        Parameter<Double> costParam = new Parameter("cost", new LiteralValue<Double>(cost))
                .setScope(Parameter.ParameterScope.OPTIONAL);
        costParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong()));
        Parameter<Double> gammaParam = new Parameter("gamma", new LiteralValue<Double>(gamma))
                .setScope(Parameter.ParameterScope.OPTIONAL);
        gammaParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong()));
        Parameter<Double> epsilonParam = new Parameter("espilon", new LiteralValue<Double>(epsilon))
                .setScope(Parameter.ParameterScope.OPTIONAL);
        epsilonParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong()));
        Parameter<Integer> degreeParam = new Parameter("degree", new LiteralValue<Integer>(degree))
                .setScope(Parameter.ParameterScope.OPTIONAL);
        degreeParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong()));
        Parameter<Double> toleranceParam = new Parameter("tolerance", new LiteralValue<Double>(tolerance))
                .setScope(Parameter.ParameterScope.OPTIONAL);
        toleranceParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong()));

        m.getParameters().add(kernelParam);
        m.getParameters().add(costParam);
        m.getParameters().add(gammaParam);
        m.getParameters().add(epsilonParam);
        m.getParameters().add(degreeParam);
        m.getParameters().add(toleranceParam);

        //save the instances being predicted to abstract trainer for calculating DoA
        predictedInstances = orderedTrainingSet;
        excludeAttributesDoA.add(dependentFeature.getUri().toString());

        return m;
    } catch (QSARException ex) {
        logger.debug(null, ex);
        throw new JaqpotException(ex);
    }
}

From source file:org.opentox.qsar.processors.predictors.SimplePredictor.java

License:Open Source License

/**
 * Perform the prediction which is based on the serialized model file on the server.
 * @param data/*from  w  w w.  j  a  v a 2 s  .c o  m*/
 *      Input data for with respect to which the predicitons are calculated
 * @return
 *      A dataset containing the compounds submitted along with their predicted values.
 * @throws QSARException
 *      In case the prediction (as a whole) is not feasible. If the prediction is not
 *      feasible for a single instance, the prediction is set to <code>?</code> (unknown/undefined/missing).
 *      If the prediction is not feasible for all instances, an exception (QSARException) is thrown.
 */
@Override
public Instances predict(final Instances data) throws QSARException {

    Instances dataClone = new Instances(data);
    /**
     * IMPORTANT!
     * String attributes have to be removed from the dataset before
     * applying the prediciton
     */
    dataClone = new AttributeCleanup(ATTRIBUTE_TYPE.string).filter(dataClone);

    /**
     * Set the class attribute of the incoming data to any arbitrary attribute
     * (Choose the last for instance).
     */
    dataClone.setClass(dataClone.attribute(model.getDependentFeature().getURI()));

    /**
     *
     * Create the Instances that will host the predictions. This object contains
     * only two attributes: the compound_uri and the target feature of the model.
     */
    Instances predictions = null;
    FastVector attributes = new FastVector();
    final Attribute compoundAttribute = new Attribute("compound_uri", (FastVector) null);
    final Attribute targetAttribute = dataClone.classAttribute();
    attributes.addElement(compoundAttribute);
    attributes.addElement(targetAttribute);

    predictions = new Instances("predictions", attributes, 0);
    predictions.setClassIndex(1);

    Instance predictionInstance = new Instance(2);
    try {
        final Classifier cls = (Classifier) SerializationHelper.read(filePath);

        for (int i = 0; i < data.numInstances(); i++) {
            try {
                String currentCompound = data.instance(i).stringValue(0);
                predictionInstance.setValue(compoundAttribute, currentCompound);

                if (targetAttribute.type() == Attribute.NUMERIC) {
                    double clsLabel = cls.classifyInstance(dataClone.instance(i));
                    predictionInstance.setValue(targetAttribute, clsLabel);
                } else if (targetAttribute.type() == Attribute.NOMINAL) {
                    double[] clsLable = cls.distributionForInstance(dataClone.instance(i));
                    int indexForNominalElement = maxInArray(clsLable).getPosition();
                    Enumeration nominalValues = targetAttribute.enumerateValues();
                    int counter = 0;
                    String nomValue = "";
                    while (nominalValues.hasMoreElements()) {
                        if (counter == indexForNominalElement) {
                            nomValue = nominalValues.nextElement().toString();
                            break;
                        }
                        counter++;
                    }
                    predictionInstance.setValue(targetAttribute, nomValue);

                    predictionInstance.setValue(targetAttribute, cls.classifyInstance(dataClone.instance(i)));
                }

                predictions.add(predictionInstance);
            } catch (Exception ex) {
                System.out.println(ex);
            }
        }

    } catch (Exception ex) {
    }

    return predictions;
}

From source file:org.opentox.qsar.processors.trainers.classification.WekaClassifier.java

License:Open Source License

@Override
public Instances preprocessData(Instances data) throws QSARException {
    /*//from  ww  w .j  av  a 2 s .  com
     * TODO: In case a client choses a non-nominal feature for the classifier,
     * provide a list of some available nominal features.
     */

    if (data == null) {
        throw new NullPointerException("Cannot train a classification model without data");
    }

    /* The incoming dataset always has the first attribute set to
    'compound_uri' which is of type "String". This is removed at the
    begining of the training procedure */
    AttributeCleanup filter = new AttributeCleanup(ATTRIBUTE_TYPE.string);
    // NOTE: Removal of string attributes should be always performed prior to any kind of training!
    data = filter.filter(data);

    SimpleMVHFilter fil = new SimpleMVHFilter();
    data = fil.filter(data);

    // CHECK IF THE GIVEN URI IS AN ATTRIBUTE OF THE DATASET
    Attribute classAttribute = data.attribute(predictionFeature);
    if (classAttribute == null) {
        throw new QSARException(Cause.XQReg202,
                "The prediction feature you provided is not a valid numeric attribute of the dataset :{"
                        + predictionFeature + "}");
    }

    // CHECK IF THE DATASET CONTAINS ANY NOMINAL ATTRIBUTES
    if (!data.checkForAttributeType(Attribute.NOMINAL)) {
        throw new QSARException(Cause.XQC4040, "Improper dataset! The dataset you provided has no "
                + "nominal features therefore classification models cannot be built.");
    }

    // CHECK WHETHER THE CLASS ATTRIBUTE IS NOMINAL
    if (!classAttribute.isNominal()) {
        StringBuilder list_of_nominal_features = new StringBuilder();

        int j = 0;
        for (int i = 0; i < data.numAttributes() && j < 10; i++) {
            if (data.attribute(i).isNominal()) {
                j++;
                list_of_nominal_features.append(data.attribute(i).name() + "\n");
            }
        }

        throw new QSARException(Cause.XQC4041,
                "The prediction feature you provided "
                        + "is not a nominal. Here is a list of some nominal features in the dataset you might "
                        + "be interested in :\n" + list_of_nominal_features.toString());
    }

    // CHECK IF THE RANGE OF THE CLASS ATTRIBUTE IS NON-UNARY
    Enumeration nominalValues = classAttribute.enumerateValues();
    String v = nominalValues.nextElement().toString();
    if (!nominalValues.hasMoreElements()) {
        throw new QSARException(Cause.XQC4042,
                "This classifier cannot handle unary nominal classes, that is "
                        + "nominal class attributes whose range includes only one value. Singleton value : {"
                        + v + "}");
    }

    // SET THE CLASS ATTRIBUTE OF THE DATASET
    data.setClass(classAttribute);

    return data;
}

From source file:org.opentox.qsar.processors.trainers.regression.WekaRegressor.java

License:Open Source License

@Override
public Instances preprocessData(Instances data) throws QSARException {
    /* Check if data == null*/
    if (data == null)
        throw new NullPointerException("Trainers do not accept null datasets.");

    /* The incoming dataset always has the first attribute set to
    'compound_uri' which is of type "String". This is removed at the
    begining of the training procedure */
    AttributeCleanup filter = new AttributeCleanup(ATTRIBUTE_TYPE.string);
    // NOTE: Removal of string attributes should be always performed prior to any kind of training!
    data = filter.filter(data);/*from   ww w .ja va  2  s  .co m*/
    SimpleMVHFilter fil = new SimpleMVHFilter();
    data = fil.filter(data);

    /*
     * Do some checks for the prediction feature...
     */
    // CHECK IF THE PREDICTION FEATURE EXISTS
    // IF IT DOESN'T PROVIDE A LIST OF SOME NUMERIC FEATURES IN THE DATASET
    Attribute classAttribute = data.attribute(predictionFeature);
    if (classAttribute == null) {
        String message = "The prediction feature you provided is is not included in the  dataset :{"
                + predictionFeature + "}. " + attributeHint(data);

        throw new QSARException(Cause.XQReg202, message);
    }

    // CHECK IF THE PREDICTION FEATURE IS NUMERIC:
    // IF IT DOESN'T PROVIDE A LIST OF SOME NUMERIC FEATURES IN THE DATASET
    if (classAttribute.type() != Attribute.NUMERIC) {
        String message = "The prediction feature you provided is not numeric : " + "{" + predictionFeature
                + "}. " + attributeHint(data);
        throw new QSARException(Cause.XQReg203, message);
    }

    data.setClass(data.attribute(predictionFeature.toString()));

    return data;

}

From source file:probcog.J48Reader.java

License:Open Source License

public static Instances readDB(String dbname)
        throws IOException, ClassNotFoundException, DDException, FileNotFoundException, Exception {
    Database db = Database.fromFile(new FileInputStream(dbname));
    probcog.srldb.datadict.DataDictionary dd = db.getDataDictionary();
    //the vector of attributes
    FastVector fvAttribs = new FastVector();
    HashMap<String, Attribute> mapAttrs = new HashMap<String, Attribute>();
    for (DDAttribute attribute : dd.getObject("object").getAttributes().values()) {
        if (attribute.isDiscarded() && !attribute.getName().equals("objectT")) {
            continue;
        }//from   ww  w .j a va  2 s  .  c  o  m
        FastVector attValues = new FastVector();
        Domain dom = attribute.getDomain();
        for (String s : dom.getValues())
            attValues.addElement(s);
        Attribute attr = new Attribute(attribute.getName(), attValues);
        fvAttribs.addElement(attr);
        mapAttrs.put(attribute.getName(), attr);
    }

    Instances instances = new Instances("name", fvAttribs, 10000);
    instances.setClass(mapAttrs.get("objectT"));
    //for each object add an instance
    for (Object o : db.getObjects()) {
        if (o.hasAttribute("objectT")) {
            Instance instance = new Instance(fvAttribs.size());
            for (Entry<String, String> e : o.getAttributes().entrySet()) {
                if (!dd.getAttribute(e.getKey()).isDiscarded()) {
                    instance.setValue(mapAttrs.get(e.getKey()), e.getValue());
                }
            }
            instances.add(instance);
        }
    }
    return instances;
}