List of usage examples for weka.core Instances setClass
public void setClass(Attribute att)
From source file:org.opentox.jaqpot3.qsar.trainer.PLSTrainer.java
License:Open Source License
@Override public Model train(Instances data) throws JaqpotException { Model model = new Model(Configuration.getBaseUri().augment("model", getUuid().toString())); data.setClass(data.attribute(targetUri.toString())); Boolean targetURIIncluded = false; for (Feature tempFeature : independentFeatures) { if (StringUtils.equals(tempFeature.getUri().toString(), targetUri.toString())) { targetURIIncluded = true;/*from w ww . j a v a2 s .com*/ break; } } if (!targetURIIncluded) { independentFeatures.add(new Feature(targetUri)); } model.setIndependentFeatures(independentFeatures); /* * Train the PLS filter */ PLSFilter pls = new PLSFilter(); try { pls.setInputFormat(data); pls.setOptions(new String[] { "-C", Integer.toString(numComponents), "-A", pls_algorithm, "-P", preprocessing, "-U", doUpdateClass }); PLSFilter.useFilter(data, pls); } catch (Exception ex) { Logger.getLogger(PLSTrainer.class.getName()).log(Level.SEVERE, null, ex); } PLSModel actualModel = new PLSModel(pls); try { PLSClassifier cls = new PLSClassifier(); cls.setFilter(pls); cls.buildClassifier(data); // evaluate classifier and print some statistics Evaluation eval = new Evaluation(data); eval.evaluateModel(cls, data); String stats = eval.toSummaryString("", false); ActualModel am = new ActualModel(actualModel); am.setStatistics(stats); model.setActualModel(am); } catch (NotSerializableException ex) { Logger.getLogger(PLSTrainer.class.getName()).log(Level.SEVERE, null, ex); throw new JaqpotException(ex); } catch (Exception ex) { Logger.getLogger(PLSTrainer.class.getName()).log(Level.SEVERE, null, ex); throw new JaqpotException(ex); } model.setDataset(datasetUri); model.setAlgorithm(Algorithms.plsFilter()); model.getMeta().addTitle("PLS Model for " + datasetUri); Set<Parameter> parameters = new HashSet<Parameter>(); Parameter targetPrm = new Parameter(Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()), "target", new LiteralValue(targetUri.toString(), XSDDatatype.XSDstring)) .setScope(Parameter.ParameterScope.MANDATORY); Parameter nComponentsPrm = new Parameter(Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()), "numComponents", new LiteralValue(numComponents, XSDDatatype.XSDpositiveInteger)) .setScope(Parameter.ParameterScope.MANDATORY); Parameter preprocessingPrm = new Parameter( Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()), "preprocessing", new LiteralValue(preprocessing, XSDDatatype.XSDstring)).setScope(Parameter.ParameterScope.OPTIONAL); Parameter algorithmPrm = new Parameter(Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()), "algorithm", new LiteralValue(pls_algorithm, XSDDatatype.XSDstring)) .setScope(Parameter.ParameterScope.OPTIONAL); Parameter doUpdatePrm = new Parameter(Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()), "doUpdateClass", new LiteralValue(doUpdateClass, XSDDatatype.XSDboolean)) .setScope(Parameter.ParameterScope.OPTIONAL); parameters.add(targetPrm); parameters.add(nComponentsPrm); parameters.add(preprocessingPrm); parameters.add(doUpdatePrm); parameters.add(algorithmPrm); model.setParameters(parameters); for (int i = 0; i < numComponents; i++) { Feature f = publishFeature(model, "", "PLS-" + i, datasetUri, featureService); model.addPredictedFeatures(f); } //save the instances being predicted to abstract trainer for calculating DoA predictedInstances = data; //in pls target is not excluded return model; }
From source file:org.opentox.jaqpot3.qsar.trainer.SvmRegression.java
License:Open Source License
@Override public Model train(Instances data) throws JaqpotException { try {//from w w w .ja v a2s.c om Attribute target = data.attribute(targetUri.toString()); if (target == null) { throw new QSARException("The prediction feature you provided was not found in the dataset"); } else { if (!target.isNumeric()) { throw new QSARException("The prediction feature you provided is not numeric."); } } data.setClass(target); //data.deleteAttributeAt(0);//remove the first attribute, i.e. 'compound_uri' or 'URI' /* Very important: place the target feature at the end! (target = last)*/ int numAttributes = data.numAttributes(); int classIndex = data.classIndex(); Instances orderedTrainingSet = null; List<String> properOrder = new ArrayList<String>(numAttributes); for (int j = 0; j < numAttributes; j++) { if (j != classIndex) { properOrder.add(data.attribute(j).name()); } } properOrder.add(data.attribute(classIndex).name()); try { orderedTrainingSet = InstancesUtil.sortByFeatureAttrList(properOrder, data, -1); } catch (JaqpotException ex) { logger.error(null, ex); } orderedTrainingSet.setClass(orderedTrainingSet.attribute(targetUri.toString())); getTask().getMeta() .addComment("Dataset successfully retrieved and converted into a weka.core.Instances object"); UpdateTask firstTaskUpdater = new UpdateTask(getTask()); firstTaskUpdater.setUpdateMeta(true); firstTaskUpdater.setUpdateTaskStatus(true);//TODO: Is this necessary? try { firstTaskUpdater.update(); } catch (DbException ex) { throw new JaqpotException(ex); } finally { try { firstTaskUpdater.close(); } catch (DbException ex) { throw new JaqpotException(ex); } } Model m = new Model(Configuration.getBaseUri().augment("model", getUuid().toString())); // INITIALIZE THE REGRESSOR regressor SVMreg regressor = new SVMreg(); final String[] regressorOptions = { "-P", Double.toString(epsilon), "-T", Double.toString(tolerance) }; Kernel svm_kernel = null; if (kernel.equalsIgnoreCase("rbf")) { RBFKernel rbf_kernel = new RBFKernel(); rbf_kernel.setGamma(Double.parseDouble(Double.toString(gamma))); rbf_kernel.setCacheSize(Integer.parseInt(Integer.toString(cacheSize))); svm_kernel = rbf_kernel; } else if (kernel.equalsIgnoreCase("polynomial")) { PolyKernel poly_kernel = new PolyKernel(); poly_kernel.setExponent(Double.parseDouble(Integer.toString(degree))); poly_kernel.setCacheSize(Integer.parseInt(Integer.toString(cacheSize))); poly_kernel.setUseLowerOrder(true); svm_kernel = poly_kernel; } else if (kernel.equalsIgnoreCase("linear")) { PolyKernel poly_kernel = new PolyKernel(); poly_kernel.setExponent((double) 1.0); poly_kernel.setCacheSize(Integer.parseInt(Integer.toString(cacheSize))); poly_kernel.setUseLowerOrder(true); svm_kernel = poly_kernel; } try { regressor.setOptions(regressorOptions); } catch (final Exception ex) { throw new QSARException("Bad options in SVM trainer for epsilon = {" + epsilon + "} or " + "tolerance = {" + tolerance + "}.", ex); } regressor.setKernel(svm_kernel); // START TRAINING & CREATE MODEL try { regressor.buildClassifier(orderedTrainingSet); // evaluate classifier and print some statistics Evaluation eval = new Evaluation(orderedTrainingSet); eval.evaluateModel(regressor, orderedTrainingSet); String stats = eval.toSummaryString("", false); ActualModel am = new ActualModel(regressor); am.setStatistics(stats); m.setActualModel(am); // m.setStatistics(stats); } catch (NotSerializableException ex) { String message = "Model is not serializable"; logger.error(message, ex); throw new JaqpotException(message, ex); } catch (final Exception ex) { throw new QSARException("Unexpected condition while trying to train " + "the model. Possible explanation : {" + ex.getMessage() + "}", ex); } m.setAlgorithm(getAlgorithm()); m.setCreatedBy(getTask().getCreatedBy()); m.setDataset(datasetUri); m.addDependentFeatures(dependentFeature); try { dependentFeature.loadFromRemote(); } catch (ServiceInvocationException ex) { java.util.logging.Logger.getLogger(SvmRegression.class.getName()).log(Level.SEVERE, null, ex); } m.addDependentFeatures(dependentFeature); m.setIndependentFeatures(independentFeatures); String predictionFeatureUri = null; Feature predictedFeature = publishFeature(m, dependentFeature.getUnits(), "Feature created as prediction feature for SVM model " + m.getUri(), datasetUri, featureService); m.addPredictedFeatures(predictedFeature); predictionFeatureUri = predictedFeature.getUri().toString(); getTask().getMeta().addComment("Prediction feature " + predictionFeatureUri + " was created."); /* SET PARAMETERS FOR THE TRAINED MODEL */ m.setParameters(new HashSet<Parameter>()); Parameter<String> kernelParam = new Parameter("kernel", new LiteralValue<String>(kernel)) .setScope(Parameter.ParameterScope.OPTIONAL); kernelParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong())); Parameter<Double> costParam = new Parameter("cost", new LiteralValue<Double>(cost)) .setScope(Parameter.ParameterScope.OPTIONAL); costParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong())); Parameter<Double> gammaParam = new Parameter("gamma", new LiteralValue<Double>(gamma)) .setScope(Parameter.ParameterScope.OPTIONAL); gammaParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong())); Parameter<Double> epsilonParam = new Parameter("espilon", new LiteralValue<Double>(epsilon)) .setScope(Parameter.ParameterScope.OPTIONAL); epsilonParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong())); Parameter<Integer> degreeParam = new Parameter("degree", new LiteralValue<Integer>(degree)) .setScope(Parameter.ParameterScope.OPTIONAL); degreeParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong())); Parameter<Double> toleranceParam = new Parameter("tolerance", new LiteralValue<Double>(tolerance)) .setScope(Parameter.ParameterScope.OPTIONAL); toleranceParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong())); m.getParameters().add(kernelParam); m.getParameters().add(costParam); m.getParameters().add(gammaParam); m.getParameters().add(epsilonParam); m.getParameters().add(degreeParam); m.getParameters().add(toleranceParam); //save the instances being predicted to abstract trainer for calculating DoA predictedInstances = orderedTrainingSet; excludeAttributesDoA.add(dependentFeature.getUri().toString()); return m; } catch (QSARException ex) { logger.debug(null, ex); throw new JaqpotException(ex); } }
From source file:org.opentox.qsar.processors.predictors.SimplePredictor.java
License:Open Source License
/** * Perform the prediction which is based on the serialized model file on the server. * @param data/*from w w w. j a v a 2 s .c o m*/ * Input data for with respect to which the predicitons are calculated * @return * A dataset containing the compounds submitted along with their predicted values. * @throws QSARException * In case the prediction (as a whole) is not feasible. If the prediction is not * feasible for a single instance, the prediction is set to <code>?</code> (unknown/undefined/missing). * If the prediction is not feasible for all instances, an exception (QSARException) is thrown. */ @Override public Instances predict(final Instances data) throws QSARException { Instances dataClone = new Instances(data); /** * IMPORTANT! * String attributes have to be removed from the dataset before * applying the prediciton */ dataClone = new AttributeCleanup(ATTRIBUTE_TYPE.string).filter(dataClone); /** * Set the class attribute of the incoming data to any arbitrary attribute * (Choose the last for instance). */ dataClone.setClass(dataClone.attribute(model.getDependentFeature().getURI())); /** * * Create the Instances that will host the predictions. This object contains * only two attributes: the compound_uri and the target feature of the model. */ Instances predictions = null; FastVector attributes = new FastVector(); final Attribute compoundAttribute = new Attribute("compound_uri", (FastVector) null); final Attribute targetAttribute = dataClone.classAttribute(); attributes.addElement(compoundAttribute); attributes.addElement(targetAttribute); predictions = new Instances("predictions", attributes, 0); predictions.setClassIndex(1); Instance predictionInstance = new Instance(2); try { final Classifier cls = (Classifier) SerializationHelper.read(filePath); for (int i = 0; i < data.numInstances(); i++) { try { String currentCompound = data.instance(i).stringValue(0); predictionInstance.setValue(compoundAttribute, currentCompound); if (targetAttribute.type() == Attribute.NUMERIC) { double clsLabel = cls.classifyInstance(dataClone.instance(i)); predictionInstance.setValue(targetAttribute, clsLabel); } else if (targetAttribute.type() == Attribute.NOMINAL) { double[] clsLable = cls.distributionForInstance(dataClone.instance(i)); int indexForNominalElement = maxInArray(clsLable).getPosition(); Enumeration nominalValues = targetAttribute.enumerateValues(); int counter = 0; String nomValue = ""; while (nominalValues.hasMoreElements()) { if (counter == indexForNominalElement) { nomValue = nominalValues.nextElement().toString(); break; } counter++; } predictionInstance.setValue(targetAttribute, nomValue); predictionInstance.setValue(targetAttribute, cls.classifyInstance(dataClone.instance(i))); } predictions.add(predictionInstance); } catch (Exception ex) { System.out.println(ex); } } } catch (Exception ex) { } return predictions; }
From source file:org.opentox.qsar.processors.trainers.classification.WekaClassifier.java
License:Open Source License
@Override public Instances preprocessData(Instances data) throws QSARException { /*//from ww w .j av a 2 s . com * TODO: In case a client choses a non-nominal feature for the classifier, * provide a list of some available nominal features. */ if (data == null) { throw new NullPointerException("Cannot train a classification model without data"); } /* The incoming dataset always has the first attribute set to 'compound_uri' which is of type "String". This is removed at the begining of the training procedure */ AttributeCleanup filter = new AttributeCleanup(ATTRIBUTE_TYPE.string); // NOTE: Removal of string attributes should be always performed prior to any kind of training! data = filter.filter(data); SimpleMVHFilter fil = new SimpleMVHFilter(); data = fil.filter(data); // CHECK IF THE GIVEN URI IS AN ATTRIBUTE OF THE DATASET Attribute classAttribute = data.attribute(predictionFeature); if (classAttribute == null) { throw new QSARException(Cause.XQReg202, "The prediction feature you provided is not a valid numeric attribute of the dataset :{" + predictionFeature + "}"); } // CHECK IF THE DATASET CONTAINS ANY NOMINAL ATTRIBUTES if (!data.checkForAttributeType(Attribute.NOMINAL)) { throw new QSARException(Cause.XQC4040, "Improper dataset! The dataset you provided has no " + "nominal features therefore classification models cannot be built."); } // CHECK WHETHER THE CLASS ATTRIBUTE IS NOMINAL if (!classAttribute.isNominal()) { StringBuilder list_of_nominal_features = new StringBuilder(); int j = 0; for (int i = 0; i < data.numAttributes() && j < 10; i++) { if (data.attribute(i).isNominal()) { j++; list_of_nominal_features.append(data.attribute(i).name() + "\n"); } } throw new QSARException(Cause.XQC4041, "The prediction feature you provided " + "is not a nominal. Here is a list of some nominal features in the dataset you might " + "be interested in :\n" + list_of_nominal_features.toString()); } // CHECK IF THE RANGE OF THE CLASS ATTRIBUTE IS NON-UNARY Enumeration nominalValues = classAttribute.enumerateValues(); String v = nominalValues.nextElement().toString(); if (!nominalValues.hasMoreElements()) { throw new QSARException(Cause.XQC4042, "This classifier cannot handle unary nominal classes, that is " + "nominal class attributes whose range includes only one value. Singleton value : {" + v + "}"); } // SET THE CLASS ATTRIBUTE OF THE DATASET data.setClass(classAttribute); return data; }
From source file:org.opentox.qsar.processors.trainers.regression.WekaRegressor.java
License:Open Source License
@Override public Instances preprocessData(Instances data) throws QSARException { /* Check if data == null*/ if (data == null) throw new NullPointerException("Trainers do not accept null datasets."); /* The incoming dataset always has the first attribute set to 'compound_uri' which is of type "String". This is removed at the begining of the training procedure */ AttributeCleanup filter = new AttributeCleanup(ATTRIBUTE_TYPE.string); // NOTE: Removal of string attributes should be always performed prior to any kind of training! data = filter.filter(data);/*from ww w .ja va 2 s .co m*/ SimpleMVHFilter fil = new SimpleMVHFilter(); data = fil.filter(data); /* * Do some checks for the prediction feature... */ // CHECK IF THE PREDICTION FEATURE EXISTS // IF IT DOESN'T PROVIDE A LIST OF SOME NUMERIC FEATURES IN THE DATASET Attribute classAttribute = data.attribute(predictionFeature); if (classAttribute == null) { String message = "The prediction feature you provided is is not included in the dataset :{" + predictionFeature + "}. " + attributeHint(data); throw new QSARException(Cause.XQReg202, message); } // CHECK IF THE PREDICTION FEATURE IS NUMERIC: // IF IT DOESN'T PROVIDE A LIST OF SOME NUMERIC FEATURES IN THE DATASET if (classAttribute.type() != Attribute.NUMERIC) { String message = "The prediction feature you provided is not numeric : " + "{" + predictionFeature + "}. " + attributeHint(data); throw new QSARException(Cause.XQReg203, message); } data.setClass(data.attribute(predictionFeature.toString())); return data; }
From source file:probcog.J48Reader.java
License:Open Source License
public static Instances readDB(String dbname) throws IOException, ClassNotFoundException, DDException, FileNotFoundException, Exception { Database db = Database.fromFile(new FileInputStream(dbname)); probcog.srldb.datadict.DataDictionary dd = db.getDataDictionary(); //the vector of attributes FastVector fvAttribs = new FastVector(); HashMap<String, Attribute> mapAttrs = new HashMap<String, Attribute>(); for (DDAttribute attribute : dd.getObject("object").getAttributes().values()) { if (attribute.isDiscarded() && !attribute.getName().equals("objectT")) { continue; }//from ww w .j a va 2 s . c o m FastVector attValues = new FastVector(); Domain dom = attribute.getDomain(); for (String s : dom.getValues()) attValues.addElement(s); Attribute attr = new Attribute(attribute.getName(), attValues); fvAttribs.addElement(attr); mapAttrs.put(attribute.getName(), attr); } Instances instances = new Instances("name", fvAttribs, 10000); instances.setClass(mapAttrs.get("objectT")); //for each object add an instance for (Object o : db.getObjects()) { if (o.hasAttribute("objectT")) { Instance instance = new Instance(fvAttribs.size()); for (Entry<String, String> e : o.getAttributes().entrySet()) { if (!dd.getAttribute(e.getKey()).isDiscarded()) { instance.setValue(mapAttrs.get(e.getKey()), e.getValue()); } } instances.add(instance); } } return instances; }