meka.classifiers.multilabel.PMCC.java Source code

Java tutorial

Introduction

Here is the source code for meka.classifiers.multilabel.PMCC.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package meka.classifiers.multilabel;

import meka.core.A;
import meka.core.MLUtils;
import meka.core.OptionUtils;
import weka.classifiers.AbstractClassifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;

import java.util.*;

/**
 * PMCC.java - Like MCC but creates a population of M chains at training time (from Is <i>candidate</i> chains, using Monte Carlo sampling), and uses this population for inference at test time; If you are looking for a 'more typical' majority-vote ensemble method, use something like EnsembleML or BaggingML with MCC.
 *
 * <p>
 * <b>NOTE:</b> this implementation was faster, because the chain was only rebuilt from the first node which was different -- this is no longer the case (due to updates to way classifier chains works, using the CNode class).
 * </p>
 *
 * @see meka.classifiers.multilabel.MCC
 * @author Jesse Read
 * @version   Sep 2014
 */
public class PMCC extends MCC {

    private static final long serialVersionUID = 1999206808758133267L;
    protected int m_M = 10;
    protected int m_O = 0;
    protected double m_Beta = 0.03;

    protected CC h[] = null;
    protected double[] w = null;

    public PMCC() {
        // a new default
        super.setChainIterations(50);
    }

    /**
     * MatchedUpTo - returns the index i of the first character which differs between two strings.
     * TODO this is a generic function, could go somewhere else into utils.
     */
    private static int matchedUpto(String s1, String s2, String DELIM) {
        String s_1[] = s1.split(DELIM);
        String s_2[] = s2.split(DELIM);
        int i = 0;
        while (i < s_1.length && i < s_2.length && s_1[i].equals(s_2[i]))
            i++;
        return i;
    }

    /**
     * GetClosest - returns the 'CC' in 'map' which is built on the sequence most matched to 'sequence'.
     */
    protected static CC getClosest(HashMap<String, CC> map, String sequence) {
        int score = -1;
        String best = sequence;
        for (String key : map.keySet()) {
            int score_ = matchedUpto(key, sequence, ",");
            if (score_ > score) {
                score = score_;
                best = key;
            }
        }
        return map.get(best);
    }

    /**
     * RebuildCC - rebuild a classifier chain 'h_old' to have a new sequence 's_new'.
     */
    protected CC rebuildCC(CC h_old, int s_new[], Instances D) throws Exception {

        // make a deep copy
        CC h = (CC) AbstractClassifier.makeCopy(h_old);

        // rebuild this chain
        h.rebuildClassifier(s_new, new Instances(D));
        return h;
    }

    /** 
     * BuildCC - Build a CC of chain-order 's' on dataset 'D'. 
     */
    protected CC buildCC(int s[], Instances D) throws Exception {

        // a new classifier chain
        CC h = new CC();

        // build this chain
        h.prepareChain(s);
        h.setClassifier(m_Classifier);
        h.buildClassifier(new Instances(D));
        return h;
    }

    /**
     * pi - proposal distribution; swap elements in s, depending on iteration t (temperature).
     * <br>
     * TODO - make faster!
     * @param   s      a chain sequence
     * @param   r        a random number generator
     * @param   t      the current iteration
     * @return    s' ~ p(s'|s)
     */
    public static int[] pi(int s[], Random r, int t, double beta) {
        int L = s.length;

        System.out.println("--- t = " + t + " , Beta = " + beta + "---");

        // select some entry j
        double p[] = new double[s.length];
        for (int j = 0; j < L; j++) {
            p[j] = Math.pow((1. / L), beta * t / (1 + j));
        }
        Utils.normalize(p);
        int j = A.samplePMF(p, r);
        System.out.println("elect j=" + j + " from pmf: " + A.toString(p));

        // blank out the j-th entry, and renormalize, now select k
        p[j] = 0.0;
        Utils.normalize(p);
        int k = A.samplePMF(p, r);
        System.out.println("elect k=" + k + " from pmf: " + A.toString(p));

        // swap j and k
        return A.swap(s, j, k);
    }

    @Override
    public void buildClassifier(Instances D) throws Exception {

        m_R = new Random(m_S);

        // Variables

        int L = D.classIndex();
        int N = D.numInstances();
        int d = D.numAttributes() - L;

        h = new CC[m_M];
        w = new double[m_M];
        //int s[][] = new int[m_M][L]; // for interest's sake

        if (m_Is >= m_M) {

            //HashMap<String,CC> id2cc = new HashMap<String,CC>();

            // Make CC
            int s[] = MLUtils.gen_indices(L);
            MLUtils.randomize(s, m_R);
            h[0] = buildCC(Arrays.copyOf(s, s.length), D); // @todo move into setChain(..)
            w[0] = payoff(h[0], D);
            //id2cc.put(Arrays.toString(s),h[0]);         // save a copy
            //s[0] = s_;
            if (getDebug())
                System.out.println("s[0] = " + Arrays.toString(s));

            for (int t = 0; t < m_Is; t++) {

                // propose a chain s' ~ pi(s'|s) 
                int s_[] = (m_O > 0) ? pi(Arrays.copyOf(s, s.length), m_R, t, m_Beta) : // default cond. option - with temperature
                        A.swap(Arrays.copyOf(s, s.length), m_R); // special simple option - swap two elements

                // build h' with sequence s'
                //CC h_ = rebuildCC(getClosest(id2cc,Arrays.toString(s_)),s_,D);
                CC h_ = buildCC(Arrays.copyOf(s_, s_.length), D);
                //id2cc.put(Arrays.toString(s_), h_);

                // rate h' (by its performance on the training data)
                double w_ = payoff(h_, D);

                // accept h' weighted more than the weakest h in the population
                int min = Utils.sort(w)[0]; // (min index)
                if (w_ > w[min]) {
                    w[min] = w_;
                    h[min] = h_;
                    if (getDebug())
                        System.out.println(" accepted h_ with score " + w_ + " > " + w[min]);
                    s = s_;
                } else if (getDebug())
                    System.out.println(" DENIED h_ with score " + w_ + " !> score " + w[min]);
            }
            if (getDebug())
                System.out.println("---");

            // normalise weights
            Utils.normalize(w);
        } else {
            throw new Exception(
                    "[Error] Number of chains evaluated (Is) should be at least as great as the population selected (M), and always greater than 0.");
        }

    }

    @Override
    public double[] distributionForInstance(Instance x) throws Exception {

        // Start with a good guess
        int max = Utils.maxIndex(w);
        double y[] = h[max].distributionForInstance(x);
        double wm = A.product(h[max].probabilityForInstance(x, y));

        for (int t = 0; t < m_Iy; t++) {
            // m ~ p(m|w) 
            int m = A.samplePMF(w, m_R);
            // y ~ p(y|x,m)
            double y_[] = h[m].sampleForInstance(x, m_R); // <-- TODO: can do this faster, see #MCC
            // w = prod_j p(y[j]|x,m)
            double w_ = A.product(h[m].getConfidences());
            // accept ? 
            if (w_ > wm) {
                wm = w_;
                y = y_;
            }
        }

        return y;
    }

    @Override
    public Enumeration listOptions() {
        Vector result = new Vector();
        result.addElement(new Option(
                "\tThe population size (of chains) -- should be smaller than the total number of chains evaluated (Is) \n\tdefault: 10",
                "M", 1, "-M <value>"));
        result.addElement(new Option(
                "\tUse temperature: cool the chain down over time (from the beginning of the chain) -- can be faster\n\tdefault: 0 (no temperature)",
                "O", 1, "-O <value>"));
        result.addElement(
                new Option("\tIf using O = 1 for temperature, this sets the Beta constant      \n\tdefault: 0.03",
                        "B", 1, "-B <value>"));
        OptionUtils.add(result, super.listOptions());
        return OptionUtils.toEnumeration(result);
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        setM(OptionUtils.parse(options, 'M', 10));
        setO(OptionUtils.parse(options, 'O', 0));
        setBeta(OptionUtils.parse(options, 'B', 0.03));
        super.setOptions(options);
    }

    @Override
    public String[] getOptions() {
        List<String> result = new ArrayList<>();
        OptionUtils.add(result, 'M', getM());
        OptionUtils.add(result, 'O', getO());
        OptionUtils.add(result, 'B', getBeta());
        OptionUtils.add(result, super.getOptions());
        return OptionUtils.toArray(result);
    }

    /** Set the temperature factor  */
    public void setBeta(double t) {
        m_Beta = t;
    }

    /** Get the temperature factor */
    public double getBeta() {
        return m_Beta;
    }

    public String betaTipText() {
        return "Sets the temperature factor.";
    }

    /** Set the temperature switch  */
    public void setO(int t) {
        m_O = t;
    }

    /** Get the temperature switch */
    public int getO() {
        return m_O;
    }

    public String oTipText() {
        return "Sets the temperature switch.";
    }

    /** Set the population size */
    public void setM(int M) {
        m_M = M;
    }

    /** Get the population size */
    public int getM() {
        return m_M;
    }

    public String mTipText() {
        return "Sets the population size.";
    }

    @Override
    public String globalInfo() {
        return "PMCC - Like MCC but selects the top M chains at training time, and uses all them at test time (using Monte Carlo sampling -- this is not a typical majority-vote ensemble method). For more information see:\n"
                + getTechnicalInformation().toString();
    }

    public static void main(String args[]) {
        ProblemTransformationMethod.evaluation(new PMCC(), args);
    }

}