List of usage examples for weka.core Instances setClassIndex
public void setClassIndex(int classIndex)
From source file:mulan.data.LabelPowersetStratification.java
License:Open Source License
public MultiLabelInstances[] stratify(MultiLabelInstances data, int folds) { try {/*from ww w . java2s.c o m*/ MultiLabelInstances[] segments = new MultiLabelInstances[folds]; LabelPowersetTransformation transformation = new LabelPowersetTransformation(); Instances transformed; // transform to single-label transformed = transformation.transformInstances(data); // add id Add add = new Add(); add.setAttributeIndex("first"); add.setAttributeName("instanceID"); add.setInputFormat(transformed); transformed = Filter.useFilter(transformed, add); for (int i = 0; i < transformed.numInstances(); i++) { transformed.instance(i).setValue(0, i); } transformed.setClassIndex(transformed.numAttributes() - 1); // stratify transformed.randomize(new Random(seed)); transformed.stratify(folds); for (int i = 0; i < folds; i++) { //System.out.println("Fold " + (i + 1) + "/" + folds); Instances temp = transformed.testCV(folds, i); Instances test = new Instances(data.getDataSet(), 0); for (int j = 0; j < temp.numInstances(); j++) { test.add(data.getDataSet().instance((int) temp.instance(j).value(0))); } segments[i] = new MultiLabelInstances(test, data.getLabelsMetaData()); } return segments; } catch (Exception ex) { Logger.getLogger(LabelPowersetStratification.class.getName()).log(Level.SEVERE, null, ex); return null; } }
From source file:mulan.data.Statistics.java
License:Open Source License
/** * Calculates phi correlation//from w ww. j a v a 2 s. c om * * @param dataSet a multi-label dataset * @return a matrix containing phi correlations * @throws java.lang.Exception */ public double[][] calculatePhi(MultiLabelInstances dataSet) throws Exception { numLabels = dataSet.getNumLabels(); /** the indices of the label attributes */ int[] labelIndices; labelIndices = dataSet.getLabelIndices(); numLabels = dataSet.getNumLabels(); phi = new double[numLabels][numLabels]; Remove remove = new Remove(); remove.setInvertSelection(true); remove.setAttributeIndicesArray(labelIndices); remove.setInputFormat(dataSet.getDataSet()); Instances result = Filter.useFilter(dataSet.getDataSet(), remove); result.setClassIndex(result.numAttributes() - 1); for (int i = 0; i < numLabels; i++) { int a[] = new int[numLabels]; int b[] = new int[numLabels]; int c[] = new int[numLabels]; int d[] = new int[numLabels]; double e[] = new double[numLabels]; double f[] = new double[numLabels]; double g[] = new double[numLabels]; double h[] = new double[numLabels]; for (int j = 0; j < result.numInstances(); j++) { for (int l = 0; l < numLabels; l++) { if (result.instance(j).stringValue(i).equals("0")) { if (result.instance(j).stringValue(l).equals("0")) { a[l]++; } else { c[l]++; } } else { if (result.instance(j).stringValue(l).equals("0")) { b[l]++; } else { d[l]++; } } } } for (int l = 0; l < numLabels; l++) { e[l] = a[l] + b[l]; f[l] = c[l] + d[l]; g[l] = a[l] + c[l]; h[l] = b[l] + d[l]; double mult = e[l] * f[l] * g[l] * h[l]; double denominator = Math.sqrt(mult); double nominator = a[l] * d[l] - b[l] * c[l]; phi[i][l] = nominator / denominator; } } return phi; }
From source file:mulan.regressor.transformation.RegressorChainSimple.java
License:Open Source License
protected void buildInternal(MultiLabelInstances train) throws Exception { // if no chain has been defined, create the default chain if (chain == null) { chain = new int[numLabels]; for (int j = 0; j < numLabels; j++) { chain[j] = labelIndices[j];/*from w w w. ja v a 2 s . c o m*/ } } if (chainSeed != 0) { // a random chain will be created by shuffling the existing chain Random rand = new Random(chainSeed); ArrayList<Integer> chainAsList = new ArrayList<Integer>(numLabels); for (int j = 0; j < numLabels; j++) { chainAsList.add(chain[j]); } Collections.shuffle(chainAsList, rand); for (int j = 0; j < numLabels; j++) { chain[j] = chainAsList.get(j); } } debug("Using chain: " + Arrays.toString(chain)); chainRegressors = new FilteredClassifier[numLabels]; Instances trainDataset = train.getDataSet(); for (int i = 0; i < numLabels; i++) { chainRegressors[i] = new FilteredClassifier(); chainRegressors[i].setClassifier(AbstractClassifier.makeCopy(baseRegressor)); // Indices of attributes to remove. // First removes numLabels attributes, then numLabels - 1 attributes and so on. // The loop starts from the last attribute. int[] indicesToRemove = new int[numLabels - 1 - i]; for (int counter1 = 0; counter1 < numLabels - i - 1; counter1++) { indicesToRemove[counter1] = chain[numLabels - 1 - counter1]; } Remove remove = new Remove(); remove.setAttributeIndicesArray(indicesToRemove); remove.setInvertSelection(false); remove.setInputFormat(trainDataset); chainRegressors[i].setFilter(remove); trainDataset.setClassIndex(chain[i]); debug("Bulding model " + (i + 1) + "/" + numLabels); chainRegressors[i].setDebug(true); chainRegressors[i].buildClassifier(trainDataset); } }
From source file:mulan.regressor.transformation.RegressorChainSimple.java
License:Open Source License
protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception { double[] scores = new double[numLabels]; // create a new temporary instance so that the passed instance is not altered Instances dataset = instance.dataset(); Instance tempInstance = DataUtils.createInstance(instance, instance.weight(), instance.toDoubleArray()); for (int counter = 0; counter < numLabels; counter++) { dataset.setClassIndex(chain[counter]); tempInstance.setDataset(dataset); // find the appropriate position for that score in the scores array // i.e. which is the corresponding target int pos = 0; for (int i = 0; i < numLabels; i++) { if (chain[counter] == labelIndices[i]) { pos = i;/*w w w . j a va 2 s .co m*/ break; } } scores[pos] = chainRegressors[counter].classifyInstance(tempInstance); tempInstance.setValue(chain[counter], scores[pos]); } MultiLabelOutput mlo = new MultiLabelOutput(scores, true); return mlo; }
From source file:mulan.regressor.transformation.SingleTargetRegressor.java
License:Open Source License
protected void buildInternal(MultiLabelInstances mlTrainSet) throws Exception { stRegressors = new FilteredClassifier[numLabels]; // any changes are applied to a copy of the original dataset Instances trainSet = new Instances(mlTrainSet.getDataSet()); for (int i = 0; i < numLabels; i++) { stRegressors[i] = new FilteredClassifier(); stRegressors[i].setClassifier(AbstractClassifier.makeCopy(baseRegressor)); // Indices of attributes to remove. All labelIndices except for the current index int[] indicesToRemove = new int[numLabels - 1]; int counter2 = 0; for (int counter1 = 0; counter1 < numLabels; counter1++) { if (labelIndices[counter1] != labelIndices[i]) { indicesToRemove[counter2] = labelIndices[counter1]; counter2++;//from w w w . j a v a 2 s.c o m } } Remove remove = new Remove(); remove.setAttributeIndicesArray(indicesToRemove); remove.setInvertSelection(false); remove.setInputFormat(trainSet); stRegressors[i].setFilter(remove); trainSet.setClassIndex(labelIndices[i]); debug("Bulding model " + (i + 1) + "/" + numLabels); stRegressors[i].buildClassifier(trainSet); } }
From source file:mulan.regressor.transformation.SingleTargetRegressor.java
License:Open Source License
protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception { double[] scores = new double[numLabels]; Instances dataset = instance.dataset(); for (int counter = 0; counter < numLabels; counter++) { dataset.setClassIndex(labelIndices[counter]); instance.setDataset(dataset);// www . ja v a2 s . c o m scores[counter] = stRegressors[counter].classifyInstance(instance); } MultiLabelOutput mlo = new MultiLabelOutput(scores, true); return mlo; }
From source file:mulan.transformations.BinaryRelevanceTransformation.java
License:Open Source License
/** * Remove all label attributes except labelToKeep * @param train //from ww w. jav a 2 s .c o m * @param labelToKeep * @return transformed Instances object * @throws Exception */ public Instances transformInstances(Instances train, int labelToKeep) throws Exception { // Indices of attributes to remove int indices[] = new int[numOfLabels - 1]; int k = 0; for (int labelIndex = 0; labelIndex < numOfLabels; labelIndex++) { if (labelIndex != labelToKeep) { indices[k] = train.numAttributes() - numOfLabels + labelIndex; k++; } } Remove remove = new Remove(); remove.setAttributeIndicesArray(indices); remove.setInputFormat(train); remove.setInvertSelection(true); Instances result = Filter.useFilter(train, remove); result.setClassIndex(result.numAttributes() - 1); return result; }
From source file:mulan.transformations.BinaryRelevanceTransformation.java
License:Open Source License
/** * Remove all label attributes except that at indexOfLabelToKeep * @param train //from ww w . ja v a 2 s . c o m * @param labelIndices * @param indexToKeep * @return transformed Instances object * @throws Exception */ public static Instances transformInstances(Instances train, int[] labelIndices, int indexToKeep) throws Exception { int numLabels = labelIndices.length; train.setClassIndex(indexToKeep); // Indices of attributes to remove int[] indicesToRemove = new int[numLabels - 1]; int counter2 = 0; for (int counter1 = 0; counter1 < numLabels; counter1++) { if (labelIndices[counter1] != indexToKeep) { indicesToRemove[counter2] = labelIndices[counter1]; counter2++; } } Remove remove = new Remove(); remove.setAttributeIndicesArray(indicesToRemove); remove.setInputFormat(train); remove.setInvertSelection(true); Instances result = Filter.useFilter(train, remove); return result; }
From source file:mulan.transformations.IncludeLabelsTransformation.java
License:Open Source License
/** * * @param mlData multi-label data/*from www.j av a 2 s . co m*/ * @return transformed instances * @throws Exception Potential exception thrown. To be handled in an upper level. */ public Instances transformInstances(MultiLabelInstances mlData) throws Exception { int numLabels = mlData.getNumLabels(); labelIndices = mlData.getLabelIndices(); // remove all labels Instances transformed = RemoveAllLabels.transformInstances(mlData); // add at the end an attribute with values the label names ArrayList<String> labelNames = new ArrayList<String>(numLabels); for (int counter = 0; counter < numLabels; counter++) { labelNames.add(mlData.getDataSet().attribute(labelIndices[counter]).name()); } Attribute attrLabel = new Attribute("Label", labelNames); transformed.insertAttributeAt(attrLabel, transformed.numAttributes()); // and at the end a binary attribute ArrayList<String> binaryValues = new ArrayList<String>(2); binaryValues.add("0"); binaryValues.add("1"); Attribute classAttr = new Attribute("Class", binaryValues); transformed.insertAttributeAt(classAttr, transformed.numAttributes()); // add instances transformed = new Instances(transformed, 0); transformed.setClassIndex(transformed.numAttributes() - 1); Instances data = mlData.getDataSet(); for (int instanceIndex = 0; instanceIndex < data.numInstances(); instanceIndex++) { for (int labelCounter = 0; labelCounter < numLabels; labelCounter++) { Instance temp; temp = RemoveAllLabels.transformInstance(data.instance(instanceIndex), labelIndices); temp.setDataset(null); temp.insertAttributeAt(temp.numAttributes()); temp.insertAttributeAt(temp.numAttributes()); temp.setDataset(transformed); temp.setValue(temp.numAttributes() - 2, (String) labelNames.get(labelCounter)); if (data.attribute(labelIndices[labelCounter]) .value((int) data.instance(instanceIndex).value(labelIndices[labelCounter])).equals("1")) { temp.setValue(temp.numAttributes() - 1, "1"); } else { temp.setValue(temp.numAttributes() - 1, "0"); } transformed.add(temp); } } return transformed; }