List of usage examples for weka.core Instances add
@Override public boolean add(Instance instance)
From source file:org.openml.webapplication.generatefolds.GenerateFolds.java
License:Open Source License
private Instances sample_splits_bootstrap(String name) throws Exception { Instances splits = new Instances(name, am.getArffHeader(), splits_size); for (int r = 0; r < evaluationMethod.getRepeats(); ++r) { Resample resample = new Resample(); String[] resampleOptions = { "-B", "0.0", "-Z", "100.0", "-S", r + "" }; resample.setOptions(resampleOptions); resample.setInputFormat(dataset); Instances trainingsset = Filter.useFilter(dataset, resample); // create training set, consisting of instances from for (int i = 0; i < trainingsset.numInstances(); ++i) { int rowid = (int) trainingsset.instance(i).value(0); splits.add(am.createInstance(true, rowid, r, 0)); }/*from w w w . j a v a2 s . co m*/ for (int i = 0; i < dataset.numInstances(); ++i) { int rowid = (int) dataset.instance(i).value(0); splits.add(am.createInstance(false, rowid, r, 0)); } } return splits; }
From source file:org.openml.webapplication.generatefolds.GenerateFolds.java
License:Open Source License
private Instances sample_splits_holdout_unlabeled(String name) { Instances splits = new Instances(name, am.getArffHeader(), splits_size); // do not randomize data set, as this method is based on user defined splits for (int i = 0; i < dataset.size(); ++i) { if (dataset.get(i).classIsMissing()) { splits.add(am.createInstance(false, i, 0, 0)); } else {// w w w . j a v a 2 s. c o m splits.add(am.createInstance(true, i, 0, 0)); } } return splits; }
From source file:org.openml.webapplication.generatefolds.GenerateFolds.java
License:Open Source License
private Instances sample_splits_holdout_userdefined(String name, List<List<List<Integer>>> testset) { Instances splits = new Instances(name, am.getArffHeader(), splits_size); if (testset == null) { throw new RuntimeException("Option -test not set correctly. "); }//from w ww . j a va 2s .c o m for (int r = 0; r < evaluationMethod.getRepeats(); ++r) { for (int f = 0; f < evaluationMethod.getFolds(); ++f) { Collections.sort(testset.get(r).get(f)); // do not randomize data set, as this method is based on user defined splits for (int i = 0; i < dataset.size(); ++i) { if (Collections.binarySearch(testset.get(r).get(f), i) >= 0) { splits.add(am.createInstance(false, i, r, f)); } else { splits.add(am.createInstance(true, i, r, f)); } } } } return splits; }
From source file:org.opentox.jaqpot3.qsar.InstancesUtil.java
License:Open Source License
/** * Accepts //from w w w . j av a 2 s .co m * @param features * @param data * @param compoundURIposition * Position where the compound URI should be placed. If set to <code>-1</code> * the compound URI will not be included in the created dataset. * @return * A subset of the provided dataset (parameter data in this method) with the * features specified in the provided list with that exact order. The compound * URI feature (string) is placed in the position specified by the parameter * compoundURIposition. * @throws JaqpotException * A JaqpotException is thrown with error code {@link ErrorCause#FeatureNotInDataset FeatureNotInDataset} * in case you provide a feature that is not found in the sumbitted Instances. */ public static Instances sortByFeatureAttrList(List<String> features, final Instances data, int compoundURIposition) throws JaqpotException { int position = compoundURIposition > features.size() ? features.size() : compoundURIposition; if (compoundURIposition != -1) { features.add(position, "compound_uri"); } FastVector vector = new FastVector(features.size()); for (int i = 0; i < features.size(); i++) { String feature = features.get(i); Attribute attribute = data.attribute(feature); if (attribute == null) { throw new JaqpotException("The Dataset you provided does not contain feature:" + feature); } vector.addElement(attribute.copy()); } Instances result = new Instances(data.relationName(), vector, 0); Enumeration instances = data.enumerateInstances(); while (instances.hasMoreElements()) { Instance instance = (Instance) instances.nextElement(); double[] vals = new double[features.size()]; for (int i = 0; i < features.size(); i++) { vals[i] = instance.value(data.attribute(result.attribute(i).name())); } Instance in = new Instance(1.0, vals); result.add(in); } return result; }
From source file:org.opentox.jaqpot3.qsar.InstancesUtil.java
License:Open Source License
public static Instances sortForPMMLModel(List<Feature> list, List<Integer> trFieldsAttrIndex, final Instances data, int compoundURIposition) throws JaqpotException { List<String> features = new ArrayList<String>(); for (Feature feature : list) { features.add(feature.getUri().toString()); }//from w w w. j a v a 2s . com int position = compoundURIposition > features.size() ? features.size() : compoundURIposition; if (compoundURIposition != -1) { features.add(position, "compound_uri"); } FastVector vector = new FastVector(features.size()); for (int i = 0; i < features.size(); i++) { String feature = features.get(i); Attribute attribute = data.attribute(feature); if (attribute == null) { throw new JaqpotException("The Dataset you provided does not contain feature:" + feature); } vector.addElement(attribute.copy()); } int attributeSize = features.size(); if (trFieldsAttrIndex.size() > 0) { for (int i = 0; i < trFieldsAttrIndex.size(); i++) { Attribute attribute = data.attribute(trFieldsAttrIndex.get(i)); if (attribute == null) { throw new JaqpotException("The Dataset you provided does not contain this pmml feature"); } vector.addElement(attribute.copy()); } attributeSize += trFieldsAttrIndex.size(); } Instances result = new Instances(data.relationName(), vector, 0); Enumeration instances = data.enumerateInstances(); while (instances.hasMoreElements()) { Instance instance = (Instance) instances.nextElement(); double[] vals = new double[attributeSize]; for (int i = 0; i < attributeSize; i++) { vals[i] = instance.value(data.attribute(result.attribute(i).name())); } Instance in = new Instance(1.0, vals); result.add(in); } return result; }
From source file:org.opentox.jaqpot3.qsar.trainer.FastRbfNnTrainer.java
License:Open Source License
@Override public Model train(Instances training) throws JaqpotException { /*/*from ww w .j ava 2 s .co m*/ * For this algorithm we need to remove all string and nominal attributes * and additionally we will remove the target attribute too. */ Instances cleanedTraining = training; Attribute targetAttribute = cleanedTraining.attribute(targetUri.toString()); if (targetAttribute == null) { throw new JaqpotException("The prediction feature you provided was not found in the dataset. " + "Prediction Feature provided by the client: " + targetUri.toString()); } else { if (!targetAttribute.isNumeric()) { throw new JaqpotException("The prediction feature you provided is not numeric."); } } double[] targetValues = new double[cleanedTraining.numInstances()]; for (int i = 0; i < cleanedTraining.numInstances(); i++) { targetValues[i] = cleanedTraining.instance(i).value(targetAttribute); } cleanedTraining.deleteAttributeAt(targetAttribute.index()); Instances rbfNnNodes = new Instances(cleanedTraining); rbfNnNodes.delete(); double[] potential = calculatePotential(cleanedTraining); int L = 1; int i_star = locationOfMax(potential); double potential_star = potential[i_star]; double potential_star_1 = potential_star; do { rbfNnNodes.add(cleanedTraining.instance(i_star)); potential = updatePotential(potential, i_star, cleanedTraining); i_star = locationOfMax(potential); double diff = potential[i_star] - e * potential_star_1; if (Double.isNaN(diff)) { throw new JaqpotException("Not converging"); } if (potential[i_star] <= e * potential_star_1) { break; } else { L = L + 1; potential_star = potential[i_star]; } } while (true); /* P-nearest neighbors */ double[] pNn = null; double[] sigma = new double[rbfNnNodes.numInstances()]; double s = 0; for (int i = 0; i < rbfNnNodes.numInstances(); i++) { pNn = new double[cleanedTraining.numInstances()]; s = 0; for (int j = 0; j < cleanedTraining.numInstances(); j++) { if (j != i) { pNn[j] = squaredNormDifference(rbfNnNodes.instance(i), cleanedTraining.instance(j)); } else { pNn[j] = 0; } } int[] minPoints = locationOfpMinimum(p, pNn); // indices refer to 'cleanedTraining' for (int q : minPoints) { s += squaredNormDifference(rbfNnNodes.instance(i), cleanedTraining.instance(q)); } sigma[i] = Math.sqrt(s / p); } /* Caclulate the matrix X = (l_{i,j})_{i,j} */ double[][] X = new double[cleanedTraining.numInstances()][rbfNnNodes.numInstances()]; for (int i = 0; i < cleanedTraining.numInstances(); i++) { //for DoA for (int j = 0; j < rbfNnNodes.numInstances(); j++) { X[i][j] = rbf(sigma[j], cleanedTraining.instance(i), rbfNnNodes.instance(j)); } } Jama.Matrix X_matr = new Matrix(X); Jama.Matrix Y_matr = new Matrix(targetValues, targetValues.length); Jama.Matrix coeffs = (X_matr.transpose().times(X_matr)).inverse().times(X_matr.transpose()).times(Y_matr); FastRbfNnModel actualModel = new FastRbfNnModel(); actualModel.setAlpha(a); actualModel.setBeta(b); actualModel.setEpsilon(e); actualModel.setNodes(rbfNnNodes); actualModel.setSigma(sigma); actualModel.setLrCoefficients(coeffs.getColumnPackedCopy()); Model m = new Model(Configuration.getBaseUri().augment("model", getUuid().toString())); m.setAlgorithm(getAlgorithm()); m.setCreatedBy(getTask().getCreatedBy()); m.setDataset(datasetUri); m.addDependentFeatures(dependentFeature); Feature predictedFeature = publishFeature(m, dependentFeature.getUnits(), "Created as prediction feature for the RBF NN model " + m.getUri(), datasetUri, featureService); m.addPredictedFeatures(predictedFeature); m.setIndependentFeatures(independentFeatures); try { m.setActualModel(new ActualModel(actualModel)); } catch (NotSerializableException ex) { logger.error("The provided instance of model cannot be serialized! Critical Error!", ex); } m.setParameters(new HashSet<Parameter>()); Parameter<Double> aParam = new Parameter("a", new LiteralValue<Double>(a)) .setScope(Parameter.ParameterScope.OPTIONAL); aParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong())); Parameter<Double> bParam = new Parameter("b", new LiteralValue<Double>(b)) .setScope(Parameter.ParameterScope.OPTIONAL); bParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong())); Parameter<Double> eParam = new Parameter("e", new LiteralValue<Double>(e)) .setScope(Parameter.ParameterScope.OPTIONAL); eParam.setUri(Services.anonymous().augment("parameter", RANDOM.nextLong())); m.getParameters().add(aParam); m.getParameters().add(bParam); m.getParameters().add(eParam); //save the instances being predicted to abstract trainer and set the features to be excluded for calculating DoA predictedInstances = training; excludeAttributesDoA.add(dependentFeature.getUri().toString()); return m; }
From source file:org.opentox.ontology.data.Dataset.java
License:Open Source License
/** * The dataset as <code>Instances</code>. These objects are used by weka as * input/output object to most algorithms (training, data preprocessing etc). * The Instances equivalent of the dataset may contain three different types of * <code>attributes</code>: numeric, nominal and/or string ones. The first attribute * is always a string one corresponding to the compound of the dataentry while * acting as an identifier for it. The name of this attribute is <code>compound_uri</code> * and is unique among all data entries. * @return Instances object for the dataset. * @throws YaqpOntException In case something goes wrong with the provided * representation (e.g. it does not correspond to a valid dataset). *///from w ww .jav a 2 s .c o m public Instances getInstances() throws YaqpOntException { // SOME INITIAL DEFINITIONS: Resource _DATAENTRY = OTClass.DataEntry.getOntClass(oo), _DATASET = OTClass.Dataset.getOntClass(oo), _FEATURE = OTClass.Feature.getOntClass(oo), _NUMERIC_FEATURE = OTClass.NumericFeature.getOntClass(oo), _NOMINAL_FEATURE = OTClass.NominalFeature.getOntClass(oo), _STRING_FEATURE = OTClass.StringFeature.getOntClass(oo); FastVector attributes = null; Instances data = null; StmtIterator dataSetIterator = null, featureIterator = null, valuesIterator = null, dataEntryIterator = null; String relationName = null; Map<Resource, WekaDataTypes> featureTypes = new HashMap<Resource, WekaDataTypes>(); Map<Resource, ArrayList<String>> featureNominalValues = new HashMap<Resource, ArrayList<String>>(); // CHECK IF THE RESOURCE IS A DATASET. IF YES, GET ITS IDENTIFIER AND SET // THE RELATION NAME ACCORDINGLY. IF NOT THROW AN ImproperEntityException. // ALSO CHECK IF THERE ARE MULTIPLE DATASETS AND IF YES THROW EXCEPTION. dataSetIterator = oo.listStatements(new SimpleSelector(null, RDF.type, _DATASET)); if (dataSetIterator.hasNext()) { relationName = dataSetIterator.next().getSubject().getURI(); if (dataSetIterator.hasNext()) { throw new YaqpOntException(Cause.XONT518, "More than one datasets found"); } } else { // this is not a dataset model throw new ImproperEntityException(Cause.XIE2, "Not a dataset"); } dataSetIterator.close(); // POPULATE THE MAP WHICH CORRELATES RESOURCES TO WEKA DATA TYPES ArrayList<String> nominalValues = new ArrayList<String>(); featureIterator = oo.listStatements(new SimpleSelector(null, RDF.type, _FEATURE)); while (featureIterator.hasNext()) { Resource feature = featureIterator.next().getSubject().as(Resource.class); StmtIterator featureTypeIterator = oo .listStatements(new SimpleSelector(feature, RDF.type, (RDFNode) null)); Set<Resource> featureTypesSet = new HashSet<Resource>(); while (featureTypeIterator.hasNext()) { Resource type = featureTypeIterator.next().getObject().as(Resource.class); featureTypesSet.add(type); } if (featureTypesSet.contains(_NUMERIC_FEATURE)) { featureTypes.put(feature, WekaDataTypes.numeric); } else if (featureTypesSet.contains(_STRING_FEATURE)) { featureTypes.put(feature, WekaDataTypes.string); } else if (featureTypesSet.contains(_NOMINAL_FEATURE)) { featureTypes.put(feature, WekaDataTypes.nominal); StmtIterator acceptValueIterator = oo.listStatements(new SimpleSelector(feature, OTDataTypeProperties.acceptValue.createProperty(oo), (RDFNode) null)); // GET THE RANGE OF THE FEATURE: while (acceptValueIterator.hasNext()) { nominalValues.add(acceptValueIterator.next().getObject().as(Literal.class).getString()); } featureNominalValues.put(feature, nominalValues); nominalValues = new ArrayList<String>(); } else { assert (featureTypesSet.contains(_FEATURE)); featureTypes.put(feature, WekaDataTypes.general); } } // GET THE ATTRIBUTES FOR THE DATASET: attributes = getAttributes(featureTypes, featureNominalValues); data = new Instances(relationName, attributes, 0); // ITERATE OVER ALL DATA ENTRIES IN THE DATASET: dataEntryIterator = oo.listStatements(new SimpleSelector(null, RDF.type, _DATAENTRY)); while (dataEntryIterator.hasNext()) { Statement dataEntry = dataEntryIterator.next(); /** * B2. For every dataEntry, iterate over all values nodes. */ Instance temp = null; valuesIterator = oo.listStatements(new SimpleSelector(dataEntry.getSubject(), OTObjectProperties.values.createProperty(oo), (Resource) null)); double[] vals = new double[data.numAttributes()]; for (int i = 0; i < data.numAttributes(); i++) { vals[i] = Instance.missingValue(); } StmtIterator compoundNamesIterator = oo.listStatements(new SimpleSelector(dataEntry.getSubject(), OTObjectProperties.compound.createProperty(oo), (Resource) null)); String compoundName = null; if (compoundNamesIterator.hasNext()) { compoundName = compoundNamesIterator.next().getObject().as(Resource.class).getURI(); } vals[data.attribute(compound_uri).index()] = data.attribute(compound_uri).addStringValue(compoundName); while (valuesIterator.hasNext()) { Statement values = valuesIterator.next(); /* * A pair of the form (AttributeName, AttributeValue) is created. * This will be registered in an Instance-type object which * is turn will be used to update the dataset. */ // atVal is the value of the attribute String atVal = values.getProperty(OTDataTypeProperties.value.createProperty(oo)).getObject() .as(Literal.class).getValue().toString(); // and atName is the name of the corresponding attribute. String atName = values.getProperty(OTObjectProperties.feature.createProperty(oo)).getObject() .as(Resource.class).getURI(); if (featureTypes.get(oo.createResource(atName)).equals(WekaDataTypes.numeric)) { try { vals[data.attribute(atName).index()] = Double.parseDouble(atVal); /** * The following catch rule, handles cases where some values are declared * as numeric (double, float etc) but their value cannot be cast as * double. */ } catch (NumberFormatException ex) { /* Just don't include this value in the dataset */ } } else if (featureTypes.get(oo.createResource(atName)).equals(WekaDataTypes.string)) { vals[data.attribute(atName).index()] = data.attribute(atName).addStringValue(atVal); } else if (XSDDatatype.XSDdate.getURI().equals(atName)) { try { vals[data.attribute(atName).index()] = data.attribute(atName).parseDate(atVal); } catch (ParseException ex) { System.out.println(ex); //Logger.getLogger(Dataset.class.getName()).log(Level.SEVERE, null, ex); } } } temp = new Instance(1.0, vals); // Add the Instance only if its compatible with the dataset! if (data.checkInstance(temp)) { data.add(temp); } else { System.err.println("Warning! The instance " + temp + " is not compatible with the dataset!"); } } dataEntryIterator.close(); return data; }
From source file:org.opentox.qsar.processors.predictors.SimplePredictor.java
License:Open Source License
/** * Perform the prediction which is based on the serialized model file on the server. * @param data//from w ww . j ava 2 s. co m * Input data for with respect to which the predicitons are calculated * @return * A dataset containing the compounds submitted along with their predicted values. * @throws QSARException * In case the prediction (as a whole) is not feasible. If the prediction is not * feasible for a single instance, the prediction is set to <code>?</code> (unknown/undefined/missing). * If the prediction is not feasible for all instances, an exception (QSARException) is thrown. */ @Override public Instances predict(final Instances data) throws QSARException { Instances dataClone = new Instances(data); /** * IMPORTANT! * String attributes have to be removed from the dataset before * applying the prediciton */ dataClone = new AttributeCleanup(ATTRIBUTE_TYPE.string).filter(dataClone); /** * Set the class attribute of the incoming data to any arbitrary attribute * (Choose the last for instance). */ dataClone.setClass(dataClone.attribute(model.getDependentFeature().getURI())); /** * * Create the Instances that will host the predictions. This object contains * only two attributes: the compound_uri and the target feature of the model. */ Instances predictions = null; FastVector attributes = new FastVector(); final Attribute compoundAttribute = new Attribute("compound_uri", (FastVector) null); final Attribute targetAttribute = dataClone.classAttribute(); attributes.addElement(compoundAttribute); attributes.addElement(targetAttribute); predictions = new Instances("predictions", attributes, 0); predictions.setClassIndex(1); Instance predictionInstance = new Instance(2); try { final Classifier cls = (Classifier) SerializationHelper.read(filePath); for (int i = 0; i < data.numInstances(); i++) { try { String currentCompound = data.instance(i).stringValue(0); predictionInstance.setValue(compoundAttribute, currentCompound); if (targetAttribute.type() == Attribute.NUMERIC) { double clsLabel = cls.classifyInstance(dataClone.instance(i)); predictionInstance.setValue(targetAttribute, clsLabel); } else if (targetAttribute.type() == Attribute.NOMINAL) { double[] clsLable = cls.distributionForInstance(dataClone.instance(i)); int indexForNominalElement = maxInArray(clsLable).getPosition(); Enumeration nominalValues = targetAttribute.enumerateValues(); int counter = 0; String nomValue = ""; while (nominalValues.hasMoreElements()) { if (counter == indexForNominalElement) { nomValue = nominalValues.nextElement().toString(); break; } counter++; } predictionInstance.setValue(targetAttribute, nomValue); predictionInstance.setValue(targetAttribute, cls.classifyInstance(dataClone.instance(i))); } predictions.add(predictionInstance); } catch (Exception ex) { System.out.println(ex); } } } catch (Exception ex) { } return predictions; }
From source file:org.opentox.toxotis.core.component.Dataset.java
License:Open Source License
/** * <p align="justify">Creates and returns a <code>weka.core.Instances</code> * object from the data contained in this Dataset. The Instances object created has the following specific structure: * The first element in each Instance is always the Compound's URI. It is * identified by the keyword <code>compound_uri</code>. Following that comes a sequence * of all Features contained the Dataset's DataEntries, described as * either <code>String</code>,<code>Numeric</code> or <code> Nominal</code>. * If a compound doesn't possess a value for a specific Feature, or the value is * unreadable or unacceptable (e.g. a String value is present when a Numeric is * expected), a missing value is placed instead. If a Feature is tagged as both * Numeric|String and Nominal, the Nominal property wins. If it is tagged as * both Numeric and String, the String property wins. * </p>//from w w w.j a v a 2 s . c o m * * @return * Weka Instances from the data contained in this Dataset. */ public Instances getInstances() { long timeFlag = System.currentTimeMillis(); // GET THE ATTRIBUTES FOR THE DATASET: FastVector attributes = new FastVector(); Set<Feature> features = getContainedFeatures(); // THE EXISTENCE OF THE (STRING) ATTRIBUTE 'COMPOUND_URI' IS MANDATORY FOR ALL // DATASETS. THIS IS ALWAYS THE FIRST ATTRIBUTE IN THE LIST. attributes.addElement(new Attribute(COMPOUND_URI, (FastVector) null)); // ADD NUMERIC AND STRING ATTRIBUTES INTO THE FASTVECTOR: for (Feature feature : features) { WekaDataTypes dataType = WekaDataTypes.getFromFeature(feature); if (dataType.equals(WekaDataTypes.numeric)) { attributes.addElement(new Attribute(feature.getUri().getStringNoQuery())); } else if (dataType.equals(WekaDataTypes.string) || dataType.equals(WekaDataTypes.general)) { attributes.addElement(new Attribute(feature.getUri().getStringNoQuery(), (FastVector) null)); } else if (dataType.equals(WekaDataTypes.nominal)) { // COPE WITH NOMINAL VALUES: FastVector nominalFVec = new FastVector(feature.getAdmissibleValues().size()); for (LiteralValue value : feature.getAdmissibleValues()) { nominalFVec.addElement(value.getValue()); } attributes.addElement(new Attribute(feature.getUri().getStringNoQuery(), nominalFVec)); } } Instances data = new Instances(this.getUri().getStringNoQuery(), attributes, 0); //POPULATE WITH VALUES: for (DataEntry dataEntry : this.getDataEntries()) { double[] vals = new double[data.numAttributes()]; for (int i = 0; i < data.numAttributes(); i++) { vals[i] = Instance.missingValue(); } Compound conformer = dataEntry.getConformer(); vals[data.attribute(COMPOUND_URI).index()] = data.attribute(COMPOUND_URI) .addStringValue(conformer.getUri().getStringNoQuery()); for (FeatureValue featureValue : dataEntry.getFeatureValues()) { Feature feature = featureValue.getFeature(); String featureName = feature.getUri().getStringNoQuery(); LiteralValue value = featureValue.getValue(); if (value != null) { if (WekaDataTypes.getFromFeature(feature).equals(WekaDataTypes.numeric)) { try { vals[data.attribute(featureName).index()] = Double .parseDouble(value.getValue().toString()); } catch (NumberFormatException ex) { logger.warn("NFE while trying to convert to double the value " + value.getValue(), ex); } } else if (WekaDataTypes.getFromFeature(feature).equals(WekaDataTypes.string)) { vals[data.attribute(featureName).index()] = data.attribute(featureName) .addStringValue((String) value.getValue().toString()); } else if (XSDDatatype.XSDdate.getURI().equals(featureName)) { try { vals[data.attribute(featureName).index()] = data.attribute(featureName) .parseDate((String) value.getValue()); } catch (ParseException ex) { logger.error("Parsing Exception for Date in Dataset", ex); } } else if (WekaDataTypes.getFromFeature(feature).equals(WekaDataTypes.nominal)) { //TODO: Nominals may not work, testing is needed. vals[data.attribute(featureName).index()] = data.attribute(featureName) .indexOfValue(value.getValue().toString()); } } } Instance valuesInstance = new Instance(1.0, vals); // Add the Instance only if its compatible with the dataset! if (data.checkInstance(valuesInstance)) { data.add(valuesInstance); } else { logger.warn("Warning! The instance " + valuesInstance + " is not compatible with the dataset!"); } } timeInstancesConversion = System.currentTimeMillis() - timeFlag; return data; }
From source file:org.packDataMining.SMOTE.java
License:Open Source License
/** * The procedure implementing the SMOTE algorithm. The output * instances are pushed onto the output queue for collection. * /*from w w w.j a va 2s .c o m*/ * @throws Exception if provided options cannot be executed * on input instances */ protected void doSMOTE() throws Exception { int minIndex = 0; int min = Integer.MAX_VALUE; if (m_DetectMinorityClass) { // find minority class int[] classCounts = getInputFormat().attributeStats(getInputFormat().classIndex()).nominalCounts; for (int i = 0; i < classCounts.length; i++) { if (classCounts[i] != 0 && classCounts[i] < min) { min = classCounts[i]; minIndex = i; } } } else { String classVal = getClassValue(); if (classVal.equalsIgnoreCase("first")) { minIndex = 1; } else if (classVal.equalsIgnoreCase("last")) { minIndex = getInputFormat().numClasses(); } else { minIndex = Integer.parseInt(classVal); } if (minIndex > getInputFormat().numClasses()) { throw new Exception("value index must be <= the number of classes"); } minIndex--; // make it an index } int nearestNeighbors; if (min <= getNearestNeighbors()) { nearestNeighbors = min - 1; } else { nearestNeighbors = getNearestNeighbors(); } if (nearestNeighbors < 1) throw new Exception("Cannot use 0 neighbors!"); // compose minority class dataset // also push all dataset instances Instances sample = getInputFormat().stringFreeStructure(); Enumeration instanceEnum = getInputFormat().enumerateInstances(); while (instanceEnum.hasMoreElements()) { Instance instance = (Instance) instanceEnum.nextElement(); push((Instance) instance.copy()); if ((int) instance.classValue() == minIndex) { sample.add(instance); } } // compute Value Distance Metric matrices for nominal features Map vdmMap = new HashMap(); Enumeration attrEnum = getInputFormat().enumerateAttributes(); while (attrEnum.hasMoreElements()) { Attribute attr = (Attribute) attrEnum.nextElement(); if (!attr.equals(getInputFormat().classAttribute())) { if (attr.isNominal() || attr.isString()) { double[][] vdm = new double[attr.numValues()][attr.numValues()]; vdmMap.put(attr, vdm); int[] featureValueCounts = new int[attr.numValues()]; int[][] featureValueCountsByClass = new int[getInputFormat().classAttribute().numValues()][attr .numValues()]; instanceEnum = getInputFormat().enumerateInstances(); while (instanceEnum.hasMoreElements()) { Instance instance = (Instance) instanceEnum.nextElement(); int value = (int) instance.value(attr); int classValue = (int) instance.classValue(); featureValueCounts[value]++; featureValueCountsByClass[classValue][value]++; } for (int valueIndex1 = 0; valueIndex1 < attr.numValues(); valueIndex1++) { for (int valueIndex2 = 0; valueIndex2 < attr.numValues(); valueIndex2++) { double sum = 0; for (int classValueIndex = 0; classValueIndex < getInputFormat() .numClasses(); classValueIndex++) { double c1i = (double) featureValueCountsByClass[classValueIndex][valueIndex1]; double c2i = (double) featureValueCountsByClass[classValueIndex][valueIndex2]; double c1 = (double) featureValueCounts[valueIndex1]; double c2 = (double) featureValueCounts[valueIndex2]; double term1 = c1i / c1; double term2 = c2i / c2; sum += Math.abs(term1 - term2); } vdm[valueIndex1][valueIndex2] = sum; } } } } } // use this random source for all required randomness Random rand = new Random(getRandomSeed()); // find the set of extra indices to use if the percentage is not evenly divisible by 100 List extraIndices = new LinkedList(); double percentageRemainder = (getPercentage() / 100) - Math.floor(getPercentage() / 100.0); int extraIndicesCount = (int) (percentageRemainder * sample.numInstances()); if (extraIndicesCount >= 1) { for (int i = 0; i < sample.numInstances(); i++) { extraIndices.add(i); } } Collections.shuffle(extraIndices, rand); extraIndices = extraIndices.subList(0, extraIndicesCount); Set extraIndexSet = new HashSet(extraIndices); // the main loop to handle computing nearest neighbors and generating SMOTE // examples from each instance in the original minority class data Instance[] nnArray = new Instance[nearestNeighbors]; for (int i = 0; i < sample.numInstances(); i++) { Instance instanceI = sample.instance(i); // find k nearest neighbors for each instance List distanceToInstance = new LinkedList(); for (int j = 0; j < sample.numInstances(); j++) { Instance instanceJ = sample.instance(j); if (i != j) { double distance = 0; attrEnum = getInputFormat().enumerateAttributes(); while (attrEnum.hasMoreElements()) { Attribute attr = (Attribute) attrEnum.nextElement(); if (!attr.equals(getInputFormat().classAttribute())) { double iVal = instanceI.value(attr); double jVal = instanceJ.value(attr); if (attr.isNumeric()) { distance += Math.pow(iVal - jVal, 2); } else { distance += ((double[][]) vdmMap.get(attr))[(int) iVal][(int) jVal]; } } } distance = Math.pow(distance, .5); distanceToInstance.add(new Object[] { distance, instanceJ }); } } // sort the neighbors according to distance Collections.sort(distanceToInstance, new Comparator() { public int compare(Object o1, Object o2) { double distance1 = (Double) ((Object[]) o1)[0]; double distance2 = (Double) ((Object[]) o2)[0]; return (int) Math.ceil(distance1 - distance2); } }); // populate the actual nearest neighbor instance array Iterator entryIterator = distanceToInstance.iterator(); int j = 0; while (entryIterator.hasNext() && j < nearestNeighbors) { nnArray[j] = (Instance) ((Object[]) entryIterator.next())[1]; j++; } // create synthetic examples int n = (int) Math.floor(getPercentage() / 100); while (n > 0 || extraIndexSet.remove(i)) { double[] values = new double[sample.numAttributes()]; int nn = rand.nextInt(nearestNeighbors); attrEnum = getInputFormat().enumerateAttributes(); while (attrEnum.hasMoreElements()) { Attribute attr = (Attribute) attrEnum.nextElement(); if (!attr.equals(getInputFormat().classAttribute())) { if (attr.isNumeric()) { double dif = nnArray[nn].value(attr) - instanceI.value(attr); double gap = rand.nextDouble(); values[attr.index()] = (double) (instanceI.value(attr) + gap * dif); } else if (attr.isDate()) { double dif = nnArray[nn].value(attr) - instanceI.value(attr); double gap = rand.nextDouble(); values[attr.index()] = (long) (instanceI.value(attr) + gap * dif); } else { int[] valueCounts = new int[attr.numValues()]; int iVal = (int) instanceI.value(attr); valueCounts[iVal]++; for (int nnEx = 0; nnEx < nearestNeighbors; nnEx++) { int val = (int) nnArray[nnEx].value(attr); valueCounts[val]++; } int maxIndex = 0; int max = Integer.MIN_VALUE; for (int index = 0; index < attr.numValues(); index++) { if (valueCounts[index] > max) { max = valueCounts[index]; maxIndex = index; } } values[attr.index()] = maxIndex; } } } values[sample.classIndex()] = minIndex; Instance synthetic = new Instance(1.0, values); push(synthetic); n--; } } }