mulan.experiments.ICDM08EnsembleOfPrunedSets.java Source code

Java tutorial

Introduction

Here is the source code for mulan.experiments.ICDM08EnsembleOfPrunedSets.java

Source

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    ICDM08EnsembleOfPrunedSets.java
 *    Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
 */
package mulan.experiments;

/**
 * @author Emmanouela Stachtiari
 * @author Grigorios Tsoumakas
 * @version 2010.12.10
 */
import java.io.FileReader;
import java.util.HashMap;
import java.util.LinkedList;

import java.util.Random;
import mulan.classifier.meta.thresholding.OneThreshold;
import mulan.classifier.transformation.EnsembleOfPrunedSets;
import mulan.classifier.transformation.PrunedSets;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.Evaluation;
import mulan.evaluation.Evaluator;
import mulan.evaluation.MultipleEvaluation;
import mulan.evaluation.measure.BipartitionMeasureBase;
import mulan.evaluation.measure.ExampleBasedAccuracy;
import mulan.evaluation.measure.SubsetAccuracy;
import mulan.evaluation.measure.HammingLoss;
import mulan.evaluation.measure.OneError;
import mulan.evaluation.measure.AveragePrecision;
import mulan.evaluation.measure.Measure;
import weka.classifiers.functions.SMO;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.Utils;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;

/**
 * Class replicating an experiment from a published paper
 *
 * @author Grigorios Tsoumakas
 * @version 2010.12.10
 */
public class ICDM08EnsembleOfPrunedSets extends Experiment {

    /**
     * Main class
     *
     * @param args command line arguments
     */
    public static void main(String[] args) {

        try {

            String arffFilename = "SinaEmotion20140117.arff";
            String xmlFilename = "Emotion.xml";

            FileReader frData = new FileReader(arffFilename);
            Instances data = new Instances(frData);
            System.out.println("===Loading the data set 20140315===");
            MultiLabelInstances dataSet = new MultiLabelInstances(arffFilename, xmlFilename);

            // String path = Utils.getOption("path", args);
            //String filestem = Utils.getOption("filestem", args);

            // System.out.println("Loading the data set");
            // MultiLabelInstances dataSet = new MultiLabelInstances(path + filestem + ".arff", path + filestem + ".xml");

            Evaluator evaluator;

            Measure[] evaluationMeasures = new Measure[5];
            evaluationMeasures[0] = new ExampleBasedAccuracy(false);
            evaluationMeasures[1] = new HammingLoss();
            evaluationMeasures[2] = new SubsetAccuracy();
            evaluationMeasures[3] = new OneError();
            evaluationMeasures[4] = new AveragePrecision();

            HashMap<String, MultipleEvaluation> result = new HashMap<String, MultipleEvaluation>();
            for (Measure m : evaluationMeasures) {
                MultipleEvaluation me = new MultipleEvaluation();
                result.put(m.getName(), me);
            }

            Random random = new Random(1);

            for (int repetition = 0; repetition < 5; repetition++) {
                // perform 2-fold CV and add each to the current results
                dataSet.getDataSet().randomize(random);
                for (int fold = 0; fold < 2; fold++) {
                    System.out.println("Experiment " + (repetition * 2 + fold + 1));
                    Instances train = dataSet.getDataSet().trainCV(2, fold);
                    MultiLabelInstances multiTrain = new MultiLabelInstances(train, dataSet.getLabelsMetaData());
                    Instances test = dataSet.getDataSet().testCV(2, fold);
                    MultiLabelInstances multiTest = new MultiLabelInstances(test, dataSet.getLabelsMetaData());

                    HashMap<String, Integer> bestP = new HashMap<String, Integer>();
                    HashMap<String, Integer> bestB = new HashMap<String, Integer>();
                    HashMap<String, PrunedSets.Strategy> bestStrategy = new HashMap<String, PrunedSets.Strategy>();
                    HashMap<String, Double> bestDiff = new HashMap<String, Double>();
                    for (Measure m : evaluationMeasures) {
                        bestDiff.put(m.getName(), Double.MAX_VALUE);
                    }

                    System.out.println("Searching parameters");
                    for (int p = 5; p > 1; p--) {
                        for (int b = 1; b < 4; b++) {
                            MultipleEvaluation innerResult = null;
                            LinkedList<Measure> measures;
                            PrunedSets ps;
                            double diff;

                            evaluator = new Evaluator();
                            ps = new PrunedSets(new SMO(), p, PrunedSets.Strategy.A, b);
                            measures = new LinkedList<Measure>();
                            for (Measure m : evaluationMeasures) {
                                measures.add(m.makeCopy());
                            }
                            System.out.print("p=" + p + " b=" + b + " strategy=A ");
                            innerResult = evaluator.crossValidate(ps, multiTrain, measures, 5);
                            for (Measure m : evaluationMeasures) {
                                System.out.print(m.getName() + ": " + innerResult.getMean(m.getName()) + " ");
                                diff = Math.abs(m.getIdealValue() - innerResult.getMean(m.getName()));
                                if (diff <= bestDiff.get(m.getName())) {
                                    bestDiff.put(m.getName(), diff);
                                    bestP.put(m.getName(), p);
                                    bestB.put(m.getName(), b);
                                    bestStrategy.put(m.getName(), PrunedSets.Strategy.A);
                                }
                            }
                            System.out.println();

                            evaluator = new Evaluator();
                            ps = new PrunedSets(new SMO(), p, PrunedSets.Strategy.B, b);
                            measures = new LinkedList<Measure>();
                            for (Measure m : evaluationMeasures) {
                                measures.add(m.makeCopy());
                            }
                            System.out.print("p=" + p + " b=" + b + " strategy=B ");
                            innerResult = evaluator.crossValidate(ps, multiTrain, measures, 5);
                            for (Measure m : evaluationMeasures) {
                                System.out.print(m.getName() + ": " + innerResult.getMean(m.getName()) + " ");
                                diff = Math.abs(m.getIdealValue() - innerResult.getMean(m.getName()));
                                if (diff <= bestDiff.get(m.getName())) {
                                    bestDiff.put(m.getName(), diff);
                                    bestP.put(m.getName(), p);
                                    bestB.put(m.getName(), b);
                                    bestStrategy.put(m.getName(), PrunedSets.Strategy.B);
                                }
                            }
                            System.out.println();
                        }
                    }

                    for (Measure m : evaluationMeasures) {
                        System.out.println(m.getName());
                        System.out.println("Best p: " + bestP.get(m.getName()));
                        System.out.println("Best strategy: " + bestStrategy.get(m.getName()));
                        System.out.println("Best b: " + bestB.get(m.getName()));
                        EnsembleOfPrunedSets eps = new EnsembleOfPrunedSets(63, 10, 0.5, bestP.get(m.getName()),
                                bestStrategy.get(m.getName()), bestB.get(m.getName()), new SMO());
                        OneThreshold ot = new OneThreshold(eps, (BipartitionMeasureBase) m.makeCopy(), 5);
                        ot.build(multiTrain);
                        System.out.println("Best threshold: " + ot.getThreshold());
                        evaluator = new Evaluator();
                        Evaluation e = evaluator.evaluate(ot, multiTest);
                        System.out.println(e.toCSV());
                        result.get(m.getName()).addEvaluation(e);
                    }
                }
            }
            for (Measure m : evaluationMeasures) {
                System.out.println(m.getName());
                result.get(m.getName()).calculateStatistics();
                System.out.println(result.get(m.getName()));
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * Returns an instance of a TechnicalInformation object, containing detailed
     * information about the technical background of this class, e.g., paper
     * reference or book this class is based on.
     *
     * @return the technical information about this class
     */
    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(Type.CONFERENCE);
        result.setValue(Field.AUTHOR, "Read, Jesse");
        result.setValue(Field.TITLE, "Multi-label Classification using Ensembles of Pruned Sets");
        result.setValue(Field.PAGES, "995-1000");
        result.setValue(Field.BOOKTITLE, "ICDM'08: Eighth IEEE International Conference on Data Mining");
        result.setValue(Field.YEAR, "2008");
        return result;
    }
}