Example usage for weka.core Instances setClassIndex

List of usage examples for weka.core Instances setClassIndex

Introduction

In this page you can find the example usage for weka.core Instances setClassIndex.

Prototype

public void setClassIndex(int classIndex) 

Source Link

Document

Sets the class index of the set.

Usage

From source file:mulan.data.LabelPowersetStratification.java

License:Open Source License

public MultiLabelInstances[] stratify(MultiLabelInstances data, int folds) {
    try {/*from   ww w  .  java2s.c o m*/
        MultiLabelInstances[] segments = new MultiLabelInstances[folds];
        LabelPowersetTransformation transformation = new LabelPowersetTransformation();
        Instances transformed;

        // transform to single-label
        transformed = transformation.transformInstances(data);

        // add id 
        Add add = new Add();
        add.setAttributeIndex("first");
        add.setAttributeName("instanceID");
        add.setInputFormat(transformed);
        transformed = Filter.useFilter(transformed, add);
        for (int i = 0; i < transformed.numInstances(); i++) {
            transformed.instance(i).setValue(0, i);
        }
        transformed.setClassIndex(transformed.numAttributes() - 1);

        // stratify
        transformed.randomize(new Random(seed));
        transformed.stratify(folds);

        for (int i = 0; i < folds; i++) {
            //System.out.println("Fold " + (i + 1) + "/" + folds);
            Instances temp = transformed.testCV(folds, i);
            Instances test = new Instances(data.getDataSet(), 0);
            for (int j = 0; j < temp.numInstances(); j++) {
                test.add(data.getDataSet().instance((int) temp.instance(j).value(0)));
            }
            segments[i] = new MultiLabelInstances(test, data.getLabelsMetaData());
        }
        return segments;
    } catch (Exception ex) {
        Logger.getLogger(LabelPowersetStratification.class.getName()).log(Level.SEVERE, null, ex);
        return null;
    }
}

From source file:mulan.data.Statistics.java

License:Open Source License

/**
 * Calculates phi correlation//from w  ww.  j  a v  a  2  s.  c om
 *
 * @param dataSet a multi-label dataset
 * @return a matrix containing phi correlations
 * @throws java.lang.Exception
 */
public double[][] calculatePhi(MultiLabelInstances dataSet) throws Exception {

    numLabels = dataSet.getNumLabels();

    /** the indices of the label attributes */
    int[] labelIndices;

    labelIndices = dataSet.getLabelIndices();
    numLabels = dataSet.getNumLabels();
    phi = new double[numLabels][numLabels];

    Remove remove = new Remove();
    remove.setInvertSelection(true);
    remove.setAttributeIndicesArray(labelIndices);
    remove.setInputFormat(dataSet.getDataSet());
    Instances result = Filter.useFilter(dataSet.getDataSet(), remove);
    result.setClassIndex(result.numAttributes() - 1);

    for (int i = 0; i < numLabels; i++) {

        int a[] = new int[numLabels];
        int b[] = new int[numLabels];
        int c[] = new int[numLabels];
        int d[] = new int[numLabels];
        double e[] = new double[numLabels];
        double f[] = new double[numLabels];
        double g[] = new double[numLabels];
        double h[] = new double[numLabels];

        for (int j = 0; j < result.numInstances(); j++) {
            for (int l = 0; l < numLabels; l++) {
                if (result.instance(j).stringValue(i).equals("0")) {
                    if (result.instance(j).stringValue(l).equals("0")) {
                        a[l]++;
                    } else {
                        c[l]++;
                    }
                } else {
                    if (result.instance(j).stringValue(l).equals("0")) {
                        b[l]++;
                    } else {
                        d[l]++;
                    }
                }
            }
        }
        for (int l = 0; l < numLabels; l++) {
            e[l] = a[l] + b[l];
            f[l] = c[l] + d[l];
            g[l] = a[l] + c[l];
            h[l] = b[l] + d[l];

            double mult = e[l] * f[l] * g[l] * h[l];
            double denominator = Math.sqrt(mult);
            double nominator = a[l] * d[l] - b[l] * c[l];
            phi[i][l] = nominator / denominator;

        }
    }
    return phi;
}

From source file:mulan.regressor.transformation.RegressorChainSimple.java

License:Open Source License

protected void buildInternal(MultiLabelInstances train) throws Exception {
    // if no chain has been defined, create the default chain
    if (chain == null) {
        chain = new int[numLabels];
        for (int j = 0; j < numLabels; j++) {
            chain[j] = labelIndices[j];/*from w w  w. ja v a 2 s . c o  m*/
        }
    }

    if (chainSeed != 0) { // a random chain will be created by shuffling the existing chain
        Random rand = new Random(chainSeed);
        ArrayList<Integer> chainAsList = new ArrayList<Integer>(numLabels);
        for (int j = 0; j < numLabels; j++) {
            chainAsList.add(chain[j]);
        }
        Collections.shuffle(chainAsList, rand);
        for (int j = 0; j < numLabels; j++) {
            chain[j] = chainAsList.get(j);
        }
    }
    debug("Using chain: " + Arrays.toString(chain));

    chainRegressors = new FilteredClassifier[numLabels];
    Instances trainDataset = train.getDataSet();

    for (int i = 0; i < numLabels; i++) {
        chainRegressors[i] = new FilteredClassifier();
        chainRegressors[i].setClassifier(AbstractClassifier.makeCopy(baseRegressor));

        // Indices of attributes to remove.
        // First removes numLabels attributes, then numLabels - 1 attributes and so on.
        // The loop starts from the last attribute.
        int[] indicesToRemove = new int[numLabels - 1 - i];
        for (int counter1 = 0; counter1 < numLabels - i - 1; counter1++) {
            indicesToRemove[counter1] = chain[numLabels - 1 - counter1];
        }

        Remove remove = new Remove();
        remove.setAttributeIndicesArray(indicesToRemove);
        remove.setInvertSelection(false);
        remove.setInputFormat(trainDataset);
        chainRegressors[i].setFilter(remove);

        trainDataset.setClassIndex(chain[i]);
        debug("Bulding model " + (i + 1) + "/" + numLabels);
        chainRegressors[i].setDebug(true);
        chainRegressors[i].buildClassifier(trainDataset);
    }
}

From source file:mulan.regressor.transformation.RegressorChainSimple.java

License:Open Source License

protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
    double[] scores = new double[numLabels];

    // create a new temporary instance so that the passed instance is not altered
    Instances dataset = instance.dataset();
    Instance tempInstance = DataUtils.createInstance(instance, instance.weight(), instance.toDoubleArray());

    for (int counter = 0; counter < numLabels; counter++) {
        dataset.setClassIndex(chain[counter]);
        tempInstance.setDataset(dataset);
        // find the appropriate position for that score in the scores array
        // i.e. which is the corresponding target
        int pos = 0;
        for (int i = 0; i < numLabels; i++) {
            if (chain[counter] == labelIndices[i]) {
                pos = i;/*w  w w  . j a  va 2  s .co m*/
                break;
            }
        }
        scores[pos] = chainRegressors[counter].classifyInstance(tempInstance);
        tempInstance.setValue(chain[counter], scores[pos]);
    }

    MultiLabelOutput mlo = new MultiLabelOutput(scores, true);
    return mlo;
}

From source file:mulan.regressor.transformation.SingleTargetRegressor.java

License:Open Source License

protected void buildInternal(MultiLabelInstances mlTrainSet) throws Exception {
    stRegressors = new FilteredClassifier[numLabels];
    // any changes are applied to a copy of the original dataset
    Instances trainSet = new Instances(mlTrainSet.getDataSet());
    for (int i = 0; i < numLabels; i++) {
        stRegressors[i] = new FilteredClassifier();
        stRegressors[i].setClassifier(AbstractClassifier.makeCopy(baseRegressor));

        // Indices of attributes to remove. All labelIndices except for the current index
        int[] indicesToRemove = new int[numLabels - 1];
        int counter2 = 0;
        for (int counter1 = 0; counter1 < numLabels; counter1++) {
            if (labelIndices[counter1] != labelIndices[i]) {
                indicesToRemove[counter2] = labelIndices[counter1];
                counter2++;//from  w w w .  j  a v  a 2  s.c o  m
            }
        }

        Remove remove = new Remove();
        remove.setAttributeIndicesArray(indicesToRemove);
        remove.setInvertSelection(false);
        remove.setInputFormat(trainSet);
        stRegressors[i].setFilter(remove);

        trainSet.setClassIndex(labelIndices[i]);
        debug("Bulding model " + (i + 1) + "/" + numLabels);
        stRegressors[i].buildClassifier(trainSet);
    }
}

From source file:mulan.regressor.transformation.SingleTargetRegressor.java

License:Open Source License

protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
    double[] scores = new double[numLabels];
    Instances dataset = instance.dataset();

    for (int counter = 0; counter < numLabels; counter++) {
        dataset.setClassIndex(labelIndices[counter]);
        instance.setDataset(dataset);//  www  .  ja v  a2 s  .  c o m
        scores[counter] = stRegressors[counter].classifyInstance(instance);
    }

    MultiLabelOutput mlo = new MultiLabelOutput(scores, true);

    return mlo;
}

From source file:mulan.transformations.BinaryRelevanceTransformation.java

License:Open Source License

/**
 * Remove all label attributes except labelToKeep
 * @param train //from ww w. jav a  2  s  .c o  m
 * @param labelToKeep 
 * @return transformed Instances object
 * @throws Exception 
 */
public Instances transformInstances(Instances train, int labelToKeep) throws Exception {
    // Indices of attributes to remove
    int indices[] = new int[numOfLabels - 1];

    int k = 0;
    for (int labelIndex = 0; labelIndex < numOfLabels; labelIndex++) {
        if (labelIndex != labelToKeep) {
            indices[k] = train.numAttributes() - numOfLabels + labelIndex;
            k++;
        }
    }

    Remove remove = new Remove();
    remove.setAttributeIndicesArray(indices);
    remove.setInputFormat(train);
    remove.setInvertSelection(true);
    Instances result = Filter.useFilter(train, remove);
    result.setClassIndex(result.numAttributes() - 1);
    return result;
}

From source file:mulan.transformations.BinaryRelevanceTransformation.java

License:Open Source License

/**
 * Remove all label attributes except that at indexOfLabelToKeep
 * @param train //from  ww w .  ja v a 2  s .  c  o m
 * @param labelIndices 
 * @param indexToKeep 
 * @return transformed Instances object
 * @throws Exception 
 */
public static Instances transformInstances(Instances train, int[] labelIndices, int indexToKeep)
        throws Exception {
    int numLabels = labelIndices.length;

    train.setClassIndex(indexToKeep);

    // Indices of attributes to remove
    int[] indicesToRemove = new int[numLabels - 1];
    int counter2 = 0;
    for (int counter1 = 0; counter1 < numLabels; counter1++) {
        if (labelIndices[counter1] != indexToKeep) {
            indicesToRemove[counter2] = labelIndices[counter1];
            counter2++;
        }
    }

    Remove remove = new Remove();
    remove.setAttributeIndicesArray(indicesToRemove);
    remove.setInputFormat(train);
    remove.setInvertSelection(true);
    Instances result = Filter.useFilter(train, remove);
    return result;
}

From source file:mulan.transformations.IncludeLabelsTransformation.java

License:Open Source License

/**
 *
 * @param mlData multi-label data/*from   www.j  av  a  2 s  .  co  m*/
 * @return transformed instances
 * @throws Exception Potential exception thrown. To be handled in an upper level.
 */
public Instances transformInstances(MultiLabelInstances mlData) throws Exception {
    int numLabels = mlData.getNumLabels();
    labelIndices = mlData.getLabelIndices();

    // remove all labels
    Instances transformed = RemoveAllLabels.transformInstances(mlData);

    // add at the end an attribute with values the label names
    ArrayList<String> labelNames = new ArrayList<String>(numLabels);
    for (int counter = 0; counter < numLabels; counter++) {
        labelNames.add(mlData.getDataSet().attribute(labelIndices[counter]).name());
    }
    Attribute attrLabel = new Attribute("Label", labelNames);
    transformed.insertAttributeAt(attrLabel, transformed.numAttributes());

    // and at the end a binary attribute
    ArrayList<String> binaryValues = new ArrayList<String>(2);
    binaryValues.add("0");
    binaryValues.add("1");
    Attribute classAttr = new Attribute("Class", binaryValues);
    transformed.insertAttributeAt(classAttr, transformed.numAttributes());

    // add instances
    transformed = new Instances(transformed, 0);
    transformed.setClassIndex(transformed.numAttributes() - 1);
    Instances data = mlData.getDataSet();
    for (int instanceIndex = 0; instanceIndex < data.numInstances(); instanceIndex++) {
        for (int labelCounter = 0; labelCounter < numLabels; labelCounter++) {
            Instance temp;
            temp = RemoveAllLabels.transformInstance(data.instance(instanceIndex), labelIndices);
            temp.setDataset(null);
            temp.insertAttributeAt(temp.numAttributes());
            temp.insertAttributeAt(temp.numAttributes());
            temp.setDataset(transformed);
            temp.setValue(temp.numAttributes() - 2, (String) labelNames.get(labelCounter));
            if (data.attribute(labelIndices[labelCounter])
                    .value((int) data.instance(instanceIndex).value(labelIndices[labelCounter])).equals("1")) {
                temp.setValue(temp.numAttributes() - 1, "1");
            } else {
                temp.setValue(temp.numAttributes() - 1, "0");
            }
            transformed.add(temp);
        }
    }

    return transformed;
}