mulan.classifier.meta.thresholding.ThresholdPrediction.java Source code

Java tutorial

Introduction

Here is the source code for mulan.classifier.meta.thresholding.ThresholdPrediction.java

Source

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    ThresholdPrediction.java
 *    Copyright (C) 2009 Aristotle University of Thessaloniki, Thessaloniki, Greece
 */
package mulan.classifier.meta.thresholding;

import java.util.ArrayList;
import java.util.Collections;
import java.util.logging.Level;
import java.util.logging.Logger;

import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.data.DataUtils;
import mulan.data.MultiLabelInstances;

import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;

/**
 *
 * @author Marios Ioannou
 * @author George Sakkas
 * @author Grigorios Tsoumakas
 * @version 2010.12.14
 */
public class ThresholdPrediction extends Meta {

    /**
     * Constructor that initializes the learner
     *
     * @param baseLearner the underlying multi-label learner
     * @param classifier the binary classification
     * @param metaDataChoice the type of meta-data
     * @param folds the number of internal cv folds
     */
    public ThresholdPrediction(MultiLabelLearner baseLearner, Classifier classifier, String metaDataChoice,
            int folds) {
        super(baseLearner, classifier, metaDataChoice);
        try {
            foldLearner = baseLearner.makeCopy();
        } catch (Exception ex) {
            Logger.getLogger(ThresholdPrediction.class.getName()).log(Level.SEVERE, null, ex);
        }
        kFoldsCV = folds;
    }

    @Override
    protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
        boolean[] predictedLabels = new boolean[numLabels];
        Instance modifiedIns = modifiedInstanceX(instance, metaDatasetChoice);

        modifiedIns.insertAttributeAt(modifiedIns.numAttributes());
        // set dataset to instance
        modifiedIns.setDataset(classifierInstances);
        double bipartition_key = classifier.classifyInstance(modifiedIns);

        MultiLabelOutput mlo = baseLearner.makePrediction(instance);
        double[] arrayOfScores = new double[numLabels];
        arrayOfScores = mlo.getConfidences();
        for (int i = 0; i < numLabels; i++) {
            if (arrayOfScores[i] >= bipartition_key) {
                predictedLabels[i] = true;
            } else {
                predictedLabels[i] = false;
            }
        }
        MultiLabelOutput final_mlo = new MultiLabelOutput(predictedLabels, mlo.getConfidences());
        return final_mlo;

    }

    public Instances transformData(MultiLabelInstances trainingData) throws Exception {
        // initialize  classifier instances

        classifierInstances = prepareClassifierInstances(trainingData);
        classifierInstances.insertAttributeAt(new Attribute("Threshold"), classifierInstances.numAttributes());
        classifierInstances.setClassIndex(classifierInstances.numAttributes() - 1);

        for (int k = 0; k < kFoldsCV; k++) {
            //Split data to train and test sets
            MultiLabelLearner tempLearner;
            MultiLabelInstances mlTest;
            if (kFoldsCV == 1) {
                tempLearner = baseLearner;
                mlTest = trainingData;
            } else {
                Instances train = trainingData.getDataSet().trainCV(kFoldsCV, k);
                Instances test = trainingData.getDataSet().testCV(kFoldsCV, k);
                MultiLabelInstances mlTrain = new MultiLabelInstances(train, trainingData.getLabelsMetaData());
                mlTest = new MultiLabelInstances(test, trainingData.getLabelsMetaData());
                tempLearner = foldLearner.makeCopy();
                tempLearner.build(mlTrain);
            }

            // copy features and labels, set metalabels
            for (int instanceIndex = 0; instanceIndex < mlTest.getDataSet().numInstances(); instanceIndex++) {
                Instance instance = mlTest.getDataSet().instance(instanceIndex);

                // initialize new class values
                double[] newValues = new double[classifierInstances.numAttributes()];

                // create features
                valuesX(tempLearner, instance, newValues, metaDatasetChoice);

                boolean[] trueLabels = new boolean[numLabels];
                // Indices of labels to take the truelabels for test instances
                for (int i = 0; i < numLabels; i++) {
                    int labelIndice = labelIndices[i];
                    String classValue = mlTest.getDataSet().attribute(labelIndice)
                            .value((int) mlTest.getDataSet().instance(instanceIndex).value(labelIndice));
                    trueLabels[i] = classValue.equals("1");
                }

                MultiLabelOutput mlo = baseLearner.makePrediction(mlTest.getDataSet().instance(instanceIndex));
                double[] arrayOfScores = mlo.getConfidences();
                ArrayList<Double> list = new ArrayList();
                for (int i = 0; i < numLabels; i++) {
                    list.add(arrayOfScores[i]);
                }
                Collections.sort(list);
                double tempThresshold = 0, threshold = 0;
                double prev = list.get(0);
                int t = numLabels, tempT = 0;
                for (Double x : list) {
                    tempThresshold = (x + prev) / 2;
                    for (int i = 0; i < numLabels; i++) {
                        if ((trueLabels[i] == true) && (arrayOfScores[i] <= tempThresshold)) {
                            tempT++;
                        } else if ((trueLabels[i] == false) && (arrayOfScores[i] >= tempThresshold)) {
                            tempT++;
                        }
                    }
                    if (tempT < t) {
                        t = tempT;
                        threshold = tempThresshold;
                    }
                    tempT = 0;
                    prev = x;
                }
                newValues[newValues.length - 1] = threshold;
                // add the new instance to  classifierInstances
                Instance newInstance = DataUtils.createInstance(mlTest.getDataSet().instance(instanceIndex),
                        mlTest.getDataSet().instance(instanceIndex).weight(), newValues);
                classifierInstances.add(newInstance);
            }
        }
        return classifierInstances;

    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(Type.INPROCEEDINGS);
        result.setValue(Field.AUTHOR, "Elisseeff, Andre and Weston, Jason");
        result.setValue(Field.TITLE, "A kernel method for multi-labelled classification");
        result.setValue(Field.BOOKTITLE, "Proceedings of NIPS 14");
        result.setValue(Field.YEAR, "2002");
        return result;
    }
}