Pair.java Source code

Java tutorial

Introduction

Here is the source code for Pair.java

Source

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*  This code is based on:      
 *    TrAdaBoost.java
 *    Copyright (C) 2007 Raymond J. Mooney
 *  Modified by David Pardoe, 2012
 */

/*
*
* Valid options are:<p>
*
* -W classname <br>
* Specify the full class name of a classifier as the basis for 
* boosting (required).<p>
*
* -I num <br>
* Set the number of (second stage) boost iterations (default 10). <p>
*
* -II num <br>
* Set the number of first stage boost iterations; if not set, the value from -I will be used. <p>
*
* -S filename <br>
* Data file to use as source data for transfer. <p>
*
* -Ratio num <br>
* Ratio of source to target starting weight; default 2; if negative, then all instances are weighted equally <p>
*
* -F <br>
* Decrease source weights uniformly in stage 1. <p>
* 
* -B <br>
* Use bagging in stage 2, not AdaBoost.R2 <p>
* 
* -R <br>
* Resample training data at each iteration (necessary if base learner is not a WeightedInstanceHandler) <p>
*
* -U <br>
* Handle source weights as usual in stage 2 instead of fixing them. <p>
*
* -M <br>
* Use weighted median instead of mean <p>
* 
* -TT filename <br>
* The test file for which errors will be printed
* 
* -FB <br>
* Usa a fixed value of beta.  The theoretical justification for beta in AdaBoost.M1 doesn't remain true for AdaBoost.R2, and using a fixed value can actually improve performance and stability, especially if using cross validation to control the number of iterations.
*
* Options after -- are passed to the designated classifier.<p>
*
*/

import weka.classifiers.*;
import java.io.*;
import java.util.*;

import weka.core.*;

class Pair implements Comparable {
    double x, y;

    public int compareTo(Object o) {
        Pair p = (Pair) o;
        if (p.x > x)
            return -1;
        else if (p.x < x)
            return 1;
        else
            return 0;
    }
}

public class TwoStageTrAdaBoostR2 extends IteratedSingleClassifierEnhancer {

    private boolean resample = false;
    private double sourceRatio = 2;
    private boolean doFraction = false;
    private boolean doUpsource = false;
    private boolean doBagging = false;
    private boolean doSampleSize = false;
    private boolean fixedBeta = false;

    double sourceFraction;
    int numTargetInstances;

    boolean useMedian = false;
    private Instances testData;

    boolean allcvDone = false;
    double[][] outputs;
    int sourceIterations = -1;
    double targetWeight = 0;

    int finalStartIt = -1;
    int finalStopIt = -1;

    protected double betaSum;

    /** The source data file to transfer from */
    protected File m_SourceFile = new File("-");

    /** The Instances from the source domain to transfer from */
    protected Instances m_SourceInstances;

    /** Use source data as normal training data */
    protected boolean m_NormalSource;

    /** Array for storing the weights for the votes. */
    protected double[] m_Betas;

    /** The number of successfully generated base classifiers. */
    protected int m_NumIterationsPerformed;

    /** The number of classes */
    protected int m_NumClasses;

    /**
     * Constructor.
     */
    public TwoStageTrAdaBoostR2() {

        m_Classifier = new weka.classifiers.trees.DecisionStump();
    }

    public void setOptions(String[] options) throws Exception {

        String sourceFileName = Utils.getOption('S', options);
        if (sourceFileName.length() == 0) {
            throw new Exception("A filename must be specified with" + " the -S option.");
        } else {
            setSourceFile(new File(sourceFileName));
        }

        doFraction = (Utils.getFlag('F', options));
        doBagging = (Utils.getFlag('B', options));
        doUpsource = (Utils.getFlag('U', options));
        useMedian = (Utils.getFlag('M', options));
        resample = (Utils.getFlag('R', options));
        doSampleSize = Utils.getFlag("SS", options);
        fixedBeta = Utils.getFlag("FB", options);

        String optionString = Utils.getOption("TT", options);
        testData = new Instances(new BufferedReader(new FileReader(optionString)));
        testData.setClassIndex(testData.numAttributes() - 1);

        String r = Utils.getOption("Ratio", options);
        if (!r.equals(""))
            sourceRatio = Double.parseDouble(r);

        super.setOptions(options);

        r = Utils.getOption("II", options);
        if (!r.equals(""))
            sourceIterations = Integer.parseInt(r);
        else
            sourceIterations = m_NumIterations;

    }

    /**
     * Set the data file for the source data for transfer
     * and get the source Instances from the file
     */
    public void setSourceFile(File sourceFile) throws Exception {

        m_SourceFile = sourceFile;
        m_SourceInstances = new Instances(new BufferedReader(new FileReader(sourceFile)));
        m_SourceInstances.setClassIndex(m_SourceInstances.numAttributes() - 1);
    }

    /**
     * Boosting method.
     *
     * @param data the training data to be used for generating the
     * boosted classifier.
     * @exception Exception if the classifier could not be built successfully
     */

    public void buildClassifier(Instances data) throws Exception {

        super.buildClassifier(data);

        if (data.checkForStringAttributes()) {
            throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
        }
        data = new Instances(data);
        data.deleteWithMissingClass();
        if (data.numInstances() == 0) {
            throw new Exception("No train instances without class missing!");
        }
        if (!data.classAttribute().isNumeric()) {
            throw new UnsupportedClassTypeException("TrAdaBoostR2 can only handle a numeric class!");
        }
        if (m_SourceInstances == null) {
            throw new Exception("Source data has not been specified!");
        }

        m_NumClasses = data.numClasses();
        try {
            doCV(data);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void doCV(Instances targetData) throws Exception {
        System.out.println();
        System.out.flush();
        int numSourceInstances = m_SourceInstances.numInstances();
        int numInstances = targetData.numInstances() + numSourceInstances;
        numTargetInstances = numInstances - numSourceInstances;
        double weightSource, weightTarget;
        double initialSourceFraction;
        double[] weights = new double[numInstances];
        Random randomInstance = new Random(1);

        Instances data = new Instances(m_SourceInstances, 0, numSourceInstances);
        // Now add the target data, shallow copying the instances as they are added
        // so it doesn't mess up the weights for anyone else
        Enumeration enumer = targetData.enumerateInstances();
        while (enumer.hasMoreElements()) {
            Instance instance = (Instance) enumer.nextElement();
            data.add(instance);
        }

        if (sourceRatio < 0) { //weight all equally
            weightSource = weightTarget = 1.0/*/numInstances*/;
            initialSourceFraction = numSourceInstances / (double) numInstances;
        } else {
            double totalWeight = 1 + sourceRatio;
            weightSource = sourceRatio / totalWeight/*/numSourceInstances*/;
            weightTarget = 1.0 / totalWeight/*/numTargetInstances*/;
            initialSourceFraction = weightSource;
        }
        for (int j = 0; j < numInstances; j++) {
            Instance instance = data.instance(j);
            if (j < numSourceInstances)
                instance.setWeight(weightSource);
            else
                instance.setWeight(weightTarget);
        }

        if (doFraction) {
            for (int it = 0; it < sourceIterations/*m_NumIterations*/; it++) {

                sourceFraction = (1 - (it / (double) m_NumIterations)) * initialSourceFraction; //[same weights as regular]
                if (sourceFraction > .995)
                    sourceFraction = .995;
                //double sourceWeight = (sourceFraction * numInstances) / numSourceInstances;
                double sourceWeight = (sourceFraction * numTargetInstances)
                        / (numSourceInstances * (1 - sourceFraction));
                for (int j = 0; j < numInstances; j++) {
                    Instance instance = data.instance(j);
                    if (j < numSourceInstances)
                        instance.setWeight(sourceWeight);
                    else
                        instance.setWeight(1);
                }
                buildClassifierWithWeights(data);
                System.out.println("Iteration " + it + ":" + getTestError());
            }
        } else {

            for (int i = 0; i < numInstances; i++)
                weights[i] = data.instance(i).weight();
            buildClassifierWithWeights(data);
            System.out.println("Iteration -1:" + getTestError());
            for (int i = 0; i < numInstances; i++)
                data.instance(i).setWeight(weights[i]);

            for (int it = 0; it < sourceIterations; it++) {

                Instances sample = null;
                if (!resample || m_NumIterationsPerformed == 0) {
                    sample = data;
                } else {
                    double sum = data.sumOfWeights();
                    double[] sweights = new double[data.numInstances()];
                    for (int i = 0; i < sweights.length; i++) {
                        sweights[i] = data.instance(i).weight() / sum;
                    }
                    sample = data.resampleWithWeights(randomInstance, sweights);
                }

                try {
                    m_Classifiers[it].buildClassifier(sample);
                } catch (Exception e) {
                    e.printStackTrace();
                    System.out.println("E: " + e);
                }

                sourceFraction = initialSourceFraction * (1 - (it + 1) / (double) m_NumIterations);
                setWeights(data, m_Classifiers[it], sourceFraction, numSourceInstances, false);

                for (int i = 0; i < numInstances; i++)
                    weights[i] = data.instance(i).weight();

                buildClassifierWithWeights(data);

                System.out.println("Iteration " + it + ":" + getTestError());

                for (int i = 0; i < numInstances; i++)
                    data.instance(i).setWeight(weights[i]);

            }

        }

    }

    /**
     * Boosting method. Boosts any classifier that can handle weighted
     * instances.
     *
     * @param data the training data to be used for generating the
     * boosted classifier.
     * @exception Exception if the classifier could not be built successfully
     */
    protected void buildClassifierWithWeights(Instances data) throws Exception {

        Random randomInstance = new Random(0);
        double epsilon, reweight, beta = 0;
        Evaluation evaluation;
        Instances sample;
        // Initialize data
        m_Betas = new double[m_Classifiers.length];
        m_NumIterationsPerformed = 0;
        int numSourceInstances = m_SourceInstances.numInstances();

        // Do boostrap iterations
        for (m_NumIterationsPerformed = 0; m_NumIterationsPerformed < m_Classifiers.length; m_NumIterationsPerformed++) {
            // Build the classifier
            sample = null;
            if (!resample || m_NumIterationsPerformed == 0) {
                sample = data;
            } else {
                double sum = data.sumOfWeights();
                double[] weights = new double[data.numInstances()];
                for (int i = 0; i < weights.length; i++) {
                    weights[i] = data.instance(i).weight() / sum;
                }
                sample = data.resampleWithWeights(randomInstance, weights);

                if (doSampleSize) {
                    int effectiveInstances = (int) (sourceFraction * weights.length + numTargetInstances);
                    if (effectiveInstances > numSourceInstances + numTargetInstances)
                        effectiveInstances = numSourceInstances + numTargetInstances;
                    //System.out.println(effectiveInstances);               
                    sample.randomize(randomInstance);
                    Instances q = new Instances(sample, 0, effectiveInstances);
                    sample = q;
                }
            }
            try {
                m_Classifiers[m_NumIterationsPerformed].buildClassifier(sample);
            } catch (Exception e) {
                e.printStackTrace();
                System.out.println("E: " + e);
            }

            if (doBagging)
                beta = 0.4 / .6; //always same beta
            else
                beta = setWeights(data, m_Classifiers[m_NumIterationsPerformed], -1, numSourceInstances, true);

            // Stop if error too small or error too big and ignore this model
            if (beta < 0) { //setWeights indicates a problem with negative beta
                if (m_NumIterationsPerformed == 0) {
                    m_NumIterationsPerformed = 1; // If we're the first we have to to use it
                }
                break;
            }

            // Determine the weight to assign to this model

            m_Betas[m_NumIterationsPerformed] = Math.log(1 / beta);

        }

        betaSum = 0;

        for (int i = 0; i < m_NumIterationsPerformed; i++)
            betaSum += m_Betas[i];
    }

    private double getTestError() throws Exception {

        Evaluation evaluation;
        evaluation = new Evaluation(testData);
        evaluation.evaluateModel(this, testData);
        return evaluation.errorRate();

    }

    /**
     * Sets the weights for the next iteration.
     */
    protected double setWeights(Instances trainData, Classifier cls, double sourceFraction, int numSourceInstances,
            boolean isFinal) throws Exception {

        Enumeration enu = trainData.enumerateInstances();
        int instNum = 0;
        double[] errors = new double[trainData.numInstances()];
        double max = 0;
        int i = 0;
        while (enu.hasMoreElements()) {
            Instance instance = (Instance) enu.nextElement();
            errors[i] = Math.abs(cls.classifyInstance(instance) - instance.classValue());
            if (i >= numSourceInstances && errors[i] > max)
                max = errors[i];
            i++;
        }

        if (max == 0)
            return -1;

        //get avg loss
        double loss = 0;
        double initialTWeightSum = 0;
        double allWeightSum = 0;
        for (int j = 0; j < errors.length; j++) {
            errors[j] /= max;
            Instance instance = trainData.instance(j);
            loss += instance.weight() * errors[j];
            if (j >= numSourceInstances) {
                //loss += instance.weight() * errors[j];
                initialTWeightSum += instance.weight();
            }
            allWeightSum += instance.weight();
        }
        //loss /= weightSum;
        loss /= allWeightSum;

        targetWeight = initialTWeightSum / allWeightSum;
        /*
        if (!isFinal){
        System.out.println("Target weight: " + targetWeight);
        System.out.println("max: " + max);
        System.out.println("avg error: " + loss * max);
        System.out.println("Loss: " + loss);
        }
        */

        double beta;

        if (fixedBeta)
            beta = 0.4 / 0.6;
        else {
            if (isFinal && loss > 0.499)//bad, so quit
                //return -1;
                loss = 0.499; //since we're doing CV, no reason to quit

            beta = loss / (1 - loss); //or just use beta = .4/.6, since beta isn't as meaningful in AdaBoost.R2;
        }

        double tWeightSum = 0;
        if (!isFinal) {
            //need to find b so that weight of source be sourceFraction*num source
            //do binary search
            double goal = sourceFraction * errors.length;
            double bMin = .001;
            double bMax = .999;
            double b;
            double sourceSum = 0;
            while (bMax - bMin > .001) {
                b = (bMax + bMin) / 2;
                double sum = 0;
                for (int j = 0; j < numSourceInstances; j++) {
                    Instance instance = trainData.instance(j);
                    sum += Math.pow(b, errors[j]) * instance.weight();
                }
                if (sum > goal)
                    bMax = b;
                else
                    bMin = b;
            }
            b = (bMax + bMin) / 2;
            //System.out.println(b);         
            for (int j = 0; j < numSourceInstances; j++) {
                Instance instance = trainData.instance(j);
                instance.setWeight(instance.weight() * Math.pow(bMin, errors[j]));
                sourceSum += instance.weight();
            }

            //now adjust target weights
            goal = errors.length - sourceSum;
            double m = goal / initialTWeightSum;

            for (int j = numSourceInstances; j < errors.length; j++) {
                Instance instance = trainData.instance(j);
                instance.setWeight(instance.weight() * m);
            }
        } else {//final
            if (!doUpsource) { //modify only target weights
                for (int j = numSourceInstances; j < errors.length; j++) {
                    Instance instance = trainData.instance(j);
                    instance.setWeight(instance.weight() * Math.pow(beta, -errors[j]));
                    tWeightSum += instance.weight();
                }

                double weightSumInverse = initialTWeightSum / tWeightSum;
                for (int j = numSourceInstances; j < errors.length; j++) {
                    Instance instance = trainData.instance(j);
                    instance.setWeight(instance.weight() * weightSumInverse);
                }
            } else { //modify all weights
                for (int j = 0; j < errors.length; j++) {
                    Instance instance = trainData.instance(j);
                    instance.setWeight(instance.weight() * Math.pow(beta, -errors[j]));
                    tWeightSum += instance.weight();
                }

                double weightSumInverse = errors.length / tWeightSum;
                for (int j = 0; j < errors.length; j++) {
                    Instance instance = trainData.instance(j);
                    instance.setWeight(instance.weight() * weightSumInverse);
                }
            }

        }

        return beta;
    }

    public double classifyInstance(Instance inst) throws Exception {
        if (useMedian)
            return classifyInstanceMedian(inst);
        else
            return classifyInstanceMean(inst);
    }

    public double classifyInstanceMean(Instance inst) throws Exception {
        double sum = 0;
        for (int i = 0; i < m_NumIterationsPerformed; i++) {
            double q = m_Classifiers[i].classifyInstance(inst);
            sum += q * m_Betas[i];
        }

        return sum / betaSum;

    }

    public double classifyInstanceMedian(Instance inst) throws Exception {
        double invBetaSum = 1.0 / betaSum; // or try m_NumIterationsPerformed
        Pair[] p = new Pair[m_NumIterationsPerformed];
        for (int i = 0; i < m_NumIterationsPerformed; i++) {
            p[i] = new Pair();
            p[i].x = m_Classifiers[i].classifyInstance(inst);
            p[i].y = invBetaSum * m_Betas[i]; //or try 1
        }
        return getWeightedMedian(p);

    }

    public double getWeightedMedian(Pair[] p) {
        Arrays.sort(p);
        double sum = 0;
        int pos = -1;
        while (sum < 0.5) {
            pos++;
            sum += p[pos].y;
        }
        return p[pos].x;
    }

    /**
     * Returns description of the boosted classifier.
     *
     * @return description of the boosted classifier as a string
     */
    public String toString() {

        StringBuffer text = new StringBuffer();

        if (m_NumIterationsPerformed == 0) {
            text.append("TrAdaBoost: No model built yet.\n");
        } else if (m_NumIterationsPerformed == 1) {
            text.append("TrAdaBoost: No boosting possible, one classifier used!\n");
            text.append(m_Classifiers[0].toString() + "\n");
        } else {
            text.append("TrAdaBoost: Base classifiers and their weights: \n\n");
            for (int i = 0; i < m_NumIterationsPerformed; i++) {
                text.append(m_Classifiers[i].toString() + "\n\n");
                text.append("Weight: " + Utils.roundDouble(m_Betas[i], 2) + "\n\n");
            }
            text.append("Number of performed Iterations: " + m_NumIterationsPerformed + "\n");
        }

        return text.toString();
    }

    /**
     * Main method for testing this class.
     *
     * @param argv the options
     */
    public static void main(String[] argv) {

        try {
            System.out.println(Evaluation.evaluateModel(new TwoStageTrAdaBoostR2(), argv));
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
    }
}