mulan.classifier.meta.thresholding.Meta.java Source code

Java tutorial

Introduction

Here is the source code for mulan.classifier.meta.thresholding.Meta.java

Source

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    Meta.java
 *    Copyright (C) 2009 Aristotle University of Thessaloniki, Thessaloniki, Greece
 */
package mulan.classifier.meta.thresholding;

import java.util.ArrayList;
import java.util.Collections;

import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.InvalidDataException;
import mulan.classifier.ModelInitializationException;
import mulan.classifier.meta.*;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.data.DataUtils;
import mulan.data.MultiLabelInstances;
import mulan.transformations.RemoveAllLabels;

import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;

/**
 * Base class for instance-based prediction of a bipartition from
 * the labels' scores
 *
 * @author Marios Ioannou
 * @author George Sakkas
 * @author Grigorios Tsoumakas
 * @version 2010.12.14
 */

public abstract class Meta extends MultiLabelMetaLearner {
    /** the classifier to learn the number of top labels or the threshold */
    protected Classifier classifier;

    /** the training instances for the single-label model */
    protected Instances classifierInstances;

    /** the type for constructing the meta dataset*/
    protected String metaDatasetChoice;

    /**the number of folds for cross validation*/
    protected int kFoldsCV;

    /** clean multi-label learner for cross-validation  */
    protected MultiLabelLearner foldLearner;

    /**
     * Constructor that initializes the learner 
     *
     * @param baseLearner the MultiLabelLearner
     * @param aClassifier the learner that will predict the number of relevant
     *        labels or a threshold
     * @param aMetaDatasetChoice what features to use for predicting the number
     *        of relevant labels or a threshold
     */
    public Meta(MultiLabelLearner baseLearner, Classifier aClassifier, String aMetaDatasetChoice) {
        super(baseLearner);
        metaDatasetChoice = aMetaDatasetChoice;
        classifier = aClassifier;
    }

    /**
     * Returns the classifier used to predict the number of labels/threshold
     *
     * @return the classifier used to predict the number of labels/threshold
     */
    public Classifier getClassifier() {
        return classifier;
    }

    /**
     * abstract method that transforms the training data to meta data
     *
     * @param trainingData the training data set
     * @return the meta data for training the predictor of labels/threshold
     * @throws Exception
     */
    protected abstract Instances transformData(MultiLabelInstances trainingData) throws Exception;

    /**
     * A method that modify an instance
     *
     * @param instance to modified
     * @param xBased the type for constructing the meta dataset
     * @return a transformed instance for the predictor of labels/threshold
     */
    protected Instance modifiedInstanceX(Instance instance, String xBased) {
        Instance modifiedIns = null;
        MultiLabelOutput mlo = null;
        if (xBased.compareTo("Content-Based") == 0) {
            Instance tempInstance = RemoveAllLabels.transformInstance(instance, labelIndices);
            modifiedIns = DataUtils.createInstance(tempInstance, tempInstance.weight(),
                    tempInstance.toDoubleArray());
        } else if (xBased.compareTo("Score-Based") == 0) {
            double[] arrayOfScores = new double[numLabels];
            try {
                mlo = baseLearner.makePrediction(instance);
            } catch (InvalidDataException ex) {
                Logger.getLogger(Meta.class.getName()).log(Level.SEVERE, null, ex);
            } catch (ModelInitializationException ex) {
                Logger.getLogger(Meta.class.getName()).log(Level.SEVERE, null, ex);
            } catch (Exception ex) {
                Logger.getLogger(Meta.class.getName()).log(Level.SEVERE, null, ex);
            }
            arrayOfScores = mlo.getConfidences();
            modifiedIns = DataUtils.createInstance(instance, numLabels);
            for (int i = 0; i < numLabels; i++) {
                modifiedIns.setValue(i, arrayOfScores[i]);
            }
        } else { //Rank-Based
            try {
                //Rank-Based
                double[] arrayOfScores = new double[numLabels];
                mlo = baseLearner.makePrediction(instance);
                arrayOfScores = mlo.getConfidences();
                ArrayList<Double> list = new ArrayList();
                for (int i = 0; i < numLabels; i++) {
                    list.add(arrayOfScores[i]);
                }
                Collections.sort(list);
                modifiedIns = DataUtils.createInstance(instance, numLabels);
                int j = numLabels - 1;
                for (Double x : list) {
                    modifiedIns.setValue(j, x);
                    j--;
                }
            } catch (InvalidDataException ex) {
                Logger.getLogger(Meta.class.getName()).log(Level.SEVERE, null, ex);
            } catch (ModelInitializationException ex) {
                Logger.getLogger(Meta.class.getName()).log(Level.SEVERE, null, ex);
            } catch (Exception ex) {
                Logger.getLogger(Meta.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        return modifiedIns;
    }

    /**
     * Prepares the instances for the predictor of labels/threshold
     *
     * @param data the training data
     * @return the prepared instances
     */
    protected Instances prepareClassifierInstances(MultiLabelInstances data) {
        Instances temp = null;
        if (metaDatasetChoice.compareTo("Content-Based") == 0) {
            try {
                temp = RemoveAllLabels.transformInstances(data);
                temp = new Instances(temp, 0);
            } catch (Exception ex) {
                Logger.getLogger(Meta.class.getName()).log(Level.SEVERE, null, ex);
            }
        } else {
            ArrayList<Attribute> atts = new ArrayList<Attribute>();
            for (int i = 0; i < numLabels; i++) {
                atts.add(new Attribute("Label" + i));
            }
            temp = new Instances("threshold", atts, 0);
        }
        return temp;
    }

    /**
     * A method that fill the array "newValues"
     *
     * @param learner the multi-label learner
     * @param instance the training instances
     * @param newValues the array to fill
     * @param xBased the type for constructing the meta dataset
     * @throws Exception
     */
    protected void valuesX(MultiLabelLearner learner, Instance instance, double[] newValues, String xBased)
            throws Exception {
        MultiLabelOutput mlo = null;
        if (metaDatasetChoice.compareTo("Content-Based") == 0) {
            double[] values = instance.toDoubleArray();
            for (int i = 0; i < featureIndices.length; i++)
                newValues[i] = values[featureIndices[i]];
        } else if (metaDatasetChoice.compareTo("Score-Based") == 0) {
            mlo = learner.makePrediction(instance);
            double[] values = mlo.getConfidences();
            System.arraycopy(values, 0, newValues, 0, values.length);
        } else if (metaDatasetChoice.compareTo("Rank-Based") == 0) {
            mlo = learner.makePrediction(instance);
            double[] values = mlo.getConfidences();
            ArrayList<Double> list = new ArrayList();
            for (int i = 0; i < numLabels; i++) {
                list.add(values[i]);
            }
            Collections.sort(list);
            int j = numLabels - 1;
            for (Double x : list) {
                newValues[j] = x;
                j--;
            }
        }
    }

    @Override
    protected void buildInternal(MultiLabelInstances trainingData) throws Exception {

        // build the base multilabel learner from the original training data
        baseLearner.build(trainingData);

        classifierInstances = transformData(trainingData);

        // build the prediction model
        classifier.buildClassifier(classifierInstances);

        // keep just the header information
        classifierInstances = new Instances(classifierInstances, 0);
    }

}