weka.filters.supervised.instance.SpreadSubsample.java Source code

Java tutorial

Introduction

Here is the source code for weka.filters.supervised.instance.SpreadSubsample.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    SpreadSubsample.java
 *    Copyright (C) 2002-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.filters.supervised.instance;

import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Random;
import java.util.Vector;

import weka.core.*;
import weka.core.Capabilities.Capability;
import weka.filters.Filter;
import weka.filters.SupervisedFilter;
import weka.gui.ProgrammaticProperty;

/**
 * <!-- globalinfo-start --> Produces a random subsample of a dataset. The
 * original dataset must fit entirely in memory. This filter allows you to
 * specify the maximum "spread" between the rarest and most common class. For
 * example, you may specify that there be at most a 2:1 difference in class
 * frequencies. When used in batch mode, subsequent batches are NOT resampled.
 * <p/>
 * <!-- globalinfo-end -->
 * 
 * <!-- options-start --> Valid options are:
 * <p/>
 * 
 * <pre>
 * -S &lt;num&gt;
 *  Specify the random number seed (default 1)
 * </pre>
 * 
 * <pre>
 * -M &lt;num&gt;
 *  The maximum class distribution spread.
 *  0 = no maximum spread, 1 = uniform distribution, 10 = allow at most
 *  a 10:1 ratio between the classes (default 0)
 * </pre>
 * 
 * <pre>
 * -W
 *  Adjust weights so that total weight per class is maintained.
 *  Individual instance weighting is not preserved. (default no
 *  weights adjustment
 * </pre>
 * 
 * <pre>
 * -X &lt;num&gt;
 *  The maximum count for any class value (default 0 = unlimited).
 * </pre>
 * 
 * <!-- options-end -->
 * 
 * @author Stuart Inglis (stuart@reeltwo.com)
 * @version $Revision$
 **/
public class SpreadSubsample extends Filter
        implements SupervisedFilter, OptionHandler, Randomizable, WeightedAttributesHandler {

    /** for serialization */
    static final long serialVersionUID = -3947033795243930016L;

    /** The random number generator seed */
    private int m_RandomSeed = 1;

    /** The maximum count of any class */
    private int m_MaxCount;

    /** True if the first batch has been done */
    private double m_DistributionSpread = 0;

    /**
     * True if instance weights will be adjusted to maintain total weight per
     * class.
     */
    private boolean m_AdjustWeights = false;

    /**
     * Returns a string describing this filter
     * 
     * @return a description of the filter suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String globalInfo() {

        return "Produces a random subsample of a dataset. The original dataset must "
                + "fit entirely in memory. This filter allows you to specify the maximum "
                + "\"spread\" between the rarest and most common class. For example, you may "
                + "specify that there be at most a 2:1 difference in class frequencies. "
                + "When used in batch mode, subsequent batches are NOT resampled.";

    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String adjustWeightsTipText() {
        return "Wether instance weights will be adjusted to maintain total weight per " + "class.";
    }

    /**
     * Returns true if instance weights will be adjusted to maintain total weight
     * per class.
     * 
     * @return true if instance weights will be adjusted to maintain total weight
     *         per class.
     */
    public boolean getAdjustWeights() {

        return m_AdjustWeights;
    }

    /**
     * Sets whether the instance weights will be adjusted to maintain total weight
     * per class.
     * 
     * @param newAdjustWeights whether to adjust weights
     */
    public void setAdjustWeights(boolean newAdjustWeights) {

        m_AdjustWeights = newAdjustWeights;
    }

    /**
     * Returns an enumeration describing the available options.
     * 
     * @return an enumeration of all the available options.
     */
    @Override
    public Enumeration<Option> listOptions() {

        Vector<Option> newVector = new Vector<Option>(4);

        newVector.addElement(new Option("\tSpecify the random number seed (default 1)", "S", 1, "-S <num>"));
        newVector.addElement(new Option("\tThe maximum class distribution spread.\n"
                + "\t0 = no maximum spread, 1 = uniform distribution, 10 = allow at most\n"
                + "\ta 10:1 ratio between the classes (default 0)", "M", 1, "-M <num>"));
        newVector.addElement(new Option("\tAdjust weights so that total weight per class is maintained.\n"
                + "\tIndividual instance weighting is not preserved. (default no\n" + "\tweights adjustment", "W",
                0, "-W"));
        newVector.addElement(new Option("\tThe maximum count for any class value (default 0 = unlimited).\n", "X",
                0, "-X <num>"));

        return newVector.elements();
    }

    /**
     * Parses a given list of options.
     * <p/>
     * 
     * <!-- options-start --> Valid options are:
     * <p/>
     * 
     * <pre>
     * -S &lt;num&gt;
     *  Specify the random number seed (default 1)
     * </pre>
     * 
     * <pre>
     * -M &lt;num&gt;
     *  The maximum class distribution spread.
     *  0 = no maximum spread, 1 = uniform distribution, 10 = allow at most
     *  a 10:1 ratio between the classes (default 0)
     * </pre>
     * 
     * <pre>
     * -W
     *  Adjust weights so that total weight per class is maintained.
     *  Individual instance weighting is not preserved. (default no
     *  weights adjustment
     * </pre>
     * 
     * <pre>
     * -X &lt;num&gt;
     *  The maximum count for any class value (default 0 = unlimited).
     * </pre>
     * 
     * <!-- options-end -->
     * 
     * @param options the list of options as an array of strings
     * @throws Exception if an option is not supported
     */
    @Override
    public void setOptions(String[] options) throws Exception {

        String seedString = Utils.getOption('S', options);
        if (seedString.length() != 0) {
            setRandomSeed(Integer.parseInt(seedString));
        } else {
            setRandomSeed(1);
        }

        String maxString = Utils.getOption('M', options);
        if (maxString.length() != 0) {
            setDistributionSpread(Double.valueOf(maxString).doubleValue());
        } else {
            setDistributionSpread(0);
        }

        String maxCount = Utils.getOption('X', options);
        if (maxCount.length() != 0) {
            setMaxCount(Double.valueOf(maxCount).doubleValue());
        } else {
            setMaxCount(0);
        }

        setAdjustWeights(Utils.getFlag('W', options));

        if (getInputFormat() != null) {
            setInputFormat(getInputFormat());
        }

        Utils.checkForRemainingOptions(options);
    }

    /**
     * Gets the current settings of the filter.
     * 
     * @return an array of strings suitable for passing to setOptions
     */
    @Override
    public String[] getOptions() {

        Vector<String> options = new Vector<String>();

        options.add("-M");
        options.add("" + getDistributionSpread());

        options.add("-X");
        options.add("" + getMaxCount());

        options.add("-S");
        options.add("" + getRandomSeed());

        if (getAdjustWeights()) {
            options.add("-W");
        }

        return options.toArray(new String[0]);
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String distributionSpreadTipText() {
        return "The maximum class distribution spread. "
                + "(0 = no maximum spread, 1 = uniform distribution, 10 = allow at most a "
                + "10:1 ratio between the classes).";
    }

    /**
     * Sets the value for the distribution spread
     * 
     * @param spread the new distribution spread
     */
    public void setDistributionSpread(double spread) {

        m_DistributionSpread = spread;
    }

    /**
     * Gets the value for the distribution spread
     * 
     * @return the distribution spread
     */
    public double getDistributionSpread() {

        return m_DistributionSpread;
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String maxCountTipText() {
        return "The maximum count for any class value (0 = unlimited).";
    }

    /**
     * Sets the value for the max count
     * 
     * @param maxcount the new max count
     */
    public void setMaxCount(double maxcount) {

        m_MaxCount = (int) maxcount;
    }

    /**
     * Gets the value for the max count
     * 
     * @return the max count
     */
    public double getMaxCount() {

        return m_MaxCount;
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String randomSeedTipText() {
        return "Sets the random number seed for subsampling.";
    }

    /**
     * Gets the random number seed.
     * 
     * @return the random number seed.
     */
    public int getRandomSeed() {

        return m_RandomSeed;
    }

    /**
     * Sets the random number seed.
     * 
     * @param newSeed the new random number seed.
     */
    public void setRandomSeed(int newSeed) {

        m_RandomSeed = newSeed;
    }

    @ProgrammaticProperty
    public void setSeed(int seed) {
        setRandomSeed(seed);
    }

    @ProgrammaticProperty
    public int getSeed() {
        return getRandomSeed();
    }

    /**
     * Returns the Capabilities of this filter.
     * 
     * @return the capabilities of this object
     * @see Capabilities
     */
    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();

        // attributes
        result.enableAllAttributes();
        result.enable(Capability.MISSING_VALUES);

        // class
        result.enable(Capability.NOMINAL_CLASS);

        return result;
    }

    /**
     * Sets the format of the input instances.
     * 
     * @param instanceInfo an Instances object containing the input instance
     *          structure (any instances contained in the object are ignored -
     *          only the structure is required).
     * @return true if the outputFormat may be collected immediately
     * @throws UnassignedClassException if no class attribute has been set.
     * @throws UnsupportedClassTypeException if the class attribute is not
     *           nominal.
     */
    @Override
    public boolean setInputFormat(Instances instanceInfo) throws Exception {

        super.setInputFormat(instanceInfo);
        setOutputFormat(instanceInfo);
        return true;
    }

    /**
     * Input an instance for filtering. Filter requires all training instances be
     * read before producing output.
     * 
     * @param instance the input instance
     * @return true if the filtered instance may now be collected with output().
     * @throws IllegalStateException if no input structure has been defined
     */
    @Override
    public boolean input(Instance instance) {

        if (getInputFormat() == null) {
            throw new IllegalStateException("No input instance format defined");
        }
        if (m_NewBatch) {
            resetQueue();
            m_NewBatch = false;
        }
        if (isFirstBatchDone()) {
            push(instance);
            return true;
        } else {
            bufferInput(instance);
            return false;
        }
    }

    /**
     * Signify that this batch of input to the filter is finished. If the filter
     * requires all instances prior to filtering, output() may now be called to
     * retrieve the filtered instances.
     * 
     * @return true if there are instances pending output
     * @throws IllegalStateException if no input structure has been defined
     */
    @Override
    public boolean batchFinished() {

        if (getInputFormat() == null) {
            throw new IllegalStateException("No input instance format defined");
        }

        if (!isFirstBatchDone()) {
            // Do the subsample, and clear the input instances.
            createSubsample();
        }

        flushInput();
        m_NewBatch = true;
        m_FirstBatchDone = true;
        return (numPendingOutput() != 0);
    }

    /**
     * Creates a subsample of the current set of input instances. The output
     * instances are pushed onto the output queue for collection.
     */
    private void createSubsample() {

        int classI = getInputFormat().classIndex();
        // Sort according to class attribute.
        getInputFormat().sort(classI);
        // Determine where each class starts in the sorted dataset
        int[] classIndices = getClassIndices();

        // Get the existing class distribution
        int[] counts = new int[getInputFormat().numClasses()];
        double[] weights = new double[getInputFormat().numClasses()];
        int min = -1;
        for (int i = 0; i < getInputFormat().numInstances(); i++) {
            Instance current = getInputFormat().instance(i);
            if (current.classIsMissing() == false) {
                counts[(int) current.classValue()]++;
                weights[(int) current.classValue()] += current.weight();
            }
        }

        // Convert from total weight to average weight
        for (int i = 0; i < counts.length; i++) {
            if (counts[i] > 0) {
                weights[i] = weights[i] / counts[i];
            }
            /*
             * System.err.println("Class:" + i + " " +
             * getInputFormat().classAttribute().value(i) + " Count:" + counts[i] +
             * " Total:" + weights[i] * counts[i] + " Avg:" + weights[i]);
             */
        }

        // find the class with the minimum number of instances
        int minIndex = -1;
        for (int i = 0; i < counts.length; i++) {
            if ((min < 0) && (counts[i] > 0)) {
                min = counts[i];
                minIndex = i;
            } else if ((counts[i] < min) && (counts[i] > 0)) {
                min = counts[i];
                minIndex = i;
            }
        }

        if (min < 0) {
            System.err.println("SpreadSubsample: *warning* none of the classes have any values in them.");
            return;
        }

        // determine the new distribution
        int[] new_counts = new int[getInputFormat().numClasses()];
        for (int i = 0; i < counts.length; i++) {
            new_counts[i] = (int) Math.abs(Math.min(counts[i], min * m_DistributionSpread));
            if (i == minIndex) {
                if (m_DistributionSpread > 0 && m_DistributionSpread < 1.0) {
                    // don't undersample the minority class!
                    new_counts[i] = counts[i];
                }
            }
            if (m_DistributionSpread == 0) {
                new_counts[i] = counts[i];
            }

            if (m_MaxCount > 0) {
                new_counts[i] = Math.min(new_counts[i], m_MaxCount);
            }
        }

        // Sample without replacement
        Random random = new Random(m_RandomSeed);
        Hashtable<String, String> t = new Hashtable<String, String>();
        for (int j = 0; j < new_counts.length; j++) {
            double newWeight = 1.0;
            if (m_AdjustWeights && (new_counts[j] > 0)) {
                newWeight = weights[j] * counts[j] / new_counts[j];
                /*
                 * System.err.println("Class:" + j + " " +
                 * getInputFormat().classAttribute().value(j) + " Count:" + counts[j] +
                 * " Total:" + weights[j] * counts[j] + " Avg:" + weights[j] +
                 * " NewCount:" + new_counts[j] + " NewAvg:" + newWeight);
                 */
            }
            for (int k = 0; k < new_counts[j]; k++) {
                boolean ok = false;
                do {
                    int index = classIndices[j] + random.nextInt(classIndices[j + 1] - classIndices[j]);
                    // Have we used this instance before?
                    if (t.get("" + index) == null) {
                        // if not, add it to the hashtable and use it
                        t.put("" + index, "");
                        ok = true;
                        if (index >= 0) {
                            Instance newInst = (Instance) getInputFormat().instance(index).copy();
                            if (m_AdjustWeights) {
                                newInst.setWeight(newWeight);
                            }
                            push(newInst, false); // No need to copy instance
                        }
                    }
                } while (!ok);
            }
        }
    }

    /**
     * Creates an index containing the position where each class starts in the
     * getInputFormat(). m_InputFormat must be sorted on the class attribute.
     * 
     * @return the positions
     */
    private int[] getClassIndices() {

        // Create an index of where each class value starts
        int[] classIndices = new int[getInputFormat().numClasses() + 1];
        int currentClass = 0;
        classIndices[currentClass] = 0;
        for (int i = 0; i < getInputFormat().numInstances(); i++) {
            Instance current = getInputFormat().instance(i);
            if (current.classIsMissing()) {
                for (int j = currentClass + 1; j < classIndices.length; j++) {
                    classIndices[j] = i;
                }
                break;
            } else if (current.classValue() != currentClass) {
                for (int j = currentClass + 1; j <= current.classValue(); j++) {
                    classIndices[j] = i;
                }
                currentClass = (int) current.classValue();
            }
        }
        if (currentClass <= getInputFormat().numClasses()) {
            for (int j = currentClass + 1; j < classIndices.length; j++) {
                classIndices[j] = getInputFormat().numInstances();
            }
        }
        return classIndices;
    }

    /**
     * Returns the revision string.
     * 
     * @return the revision
     */
    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision$");
    }

    /**
     * Main method for testing this class.
     * 
     * @param argv should contain arguments to the filter: use -h for help
     */
    public static void main(String[] argv) {
        runFilter(new SpreadSubsample(), argv);
    }
}