weka.estimators.KernelEstimator.java Source code

Java tutorial

Introduction

Here is the source code for weka.estimators.KernelEstimator.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    KernelEstimator.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.estimators;

import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Aggregateable;
import weka.core.RevisionUtils;
import weka.core.Statistics;
import weka.core.Utils;

/**
 * Simple kernel density estimator. Uses one gaussian kernel per observed data
 * value.
 * 
 * @author Len Trigg (trigg@cs.waikato.ac.nz)
 * @version $Revision$
 */
public class KernelEstimator extends Estimator implements IncrementalEstimator, Aggregateable<KernelEstimator> {

    /** for serialization */
    private static final long serialVersionUID = 3646923563367683925L;

    /** Vector containing all of the values seen */
    private double[] m_Values;

    /** Vector containing the associated weights */
    private double[] m_Weights;

    /** Number of values stored in m_Weights and m_Values so far */
    private int m_NumValues;

    /** The sum of the weights so far */
    private double m_SumOfWeights;

    /** The standard deviation */
    private double m_StandardDev;

    /** The precision of data values */
    private double m_Precision;

    /** Whether we can optimise the kernel summation */
    private boolean m_AllWeightsOne;

    /** Maximum percentage error permitted in probability calculations */
    private static double MAX_ERROR = 0.01;

    /**
     * Execute a binary search to locate the nearest data value
     * 
     * @param the data value to locate
     * @return the index of the nearest data value
     */
    private int findNearestValue(double key) {

        int low = 0;
        int high = m_NumValues;
        int middle = 0;
        while (low < high) {
            middle = (low + high) / 2;
            double current = m_Values[middle];
            if (current == key) {
                return middle;
            }
            if (current > key) {
                high = middle;
            } else if (current < key) {
                low = middle + 1;
            }
        }
        return low;
    }

    /**
     * Round a data value using the defined precision for this estimator
     * 
     * @param data the value to round
     * @return the rounded data value
     */
    private double round(double data) {

        return Math.rint(data / m_Precision) * m_Precision;
    }

    // ===============
    // Public methods.
    // ===============

    /**
     * Constructor that takes a precision argument.
     * 
     * @param precision the precision to which numeric values are given. For
     *          example, if the precision is stated to be 0.1, the values in the
     *          interval (0.25,0.35] are all treated as 0.3.
     */
    public KernelEstimator(double precision) {

        m_Values = new double[50];
        m_Weights = new double[50];
        m_NumValues = 0;
        m_SumOfWeights = 0;
        m_AllWeightsOne = true;
        m_Precision = precision;
        // precision cannot be zero
        if (m_Precision < Utils.SMALL)
            m_Precision = Utils.SMALL;
        // m_StandardDev = 1e10 * m_Precision; // Set the standard deviation
        // initially very wide
        m_StandardDev = m_Precision / (2 * 3);
    }

    /**
     * Add a new data value to the current estimator.
     * 
     * @param data the new data value
     * @param weight the weight assigned to the data value
     */
    @Override
    public void addValue(double data, double weight) {

        if (weight == 0) {
            return;
        }
        data = round(data);
        int insertIndex = findNearestValue(data);
        if ((m_NumValues <= insertIndex) || (m_Values[insertIndex] != data)) {
            if (m_NumValues < m_Values.length) {
                int left = m_NumValues - insertIndex;
                System.arraycopy(m_Values, insertIndex, m_Values, insertIndex + 1, left);
                System.arraycopy(m_Weights, insertIndex, m_Weights, insertIndex + 1, left);

                m_Values[insertIndex] = data;
                m_Weights[insertIndex] = weight;
                m_NumValues++;
            } else {
                double[] newValues = new double[m_Values.length * 2];
                double[] newWeights = new double[m_Values.length * 2];
                int left = m_NumValues - insertIndex;
                System.arraycopy(m_Values, 0, newValues, 0, insertIndex);
                System.arraycopy(m_Weights, 0, newWeights, 0, insertIndex);
                newValues[insertIndex] = data;
                newWeights[insertIndex] = weight;
                System.arraycopy(m_Values, insertIndex, newValues, insertIndex + 1, left);
                System.arraycopy(m_Weights, insertIndex, newWeights, insertIndex + 1, left);
                m_NumValues++;
                m_Values = newValues;
                m_Weights = newWeights;
            }
            if (weight != 1) {
                m_AllWeightsOne = false;
            }
        } else {
            m_Weights[insertIndex] += weight;
            m_AllWeightsOne = false;
        }
        m_SumOfWeights += weight;
        double range = m_Values[m_NumValues - 1] - m_Values[0];
        if (range > 0) {
            m_StandardDev = Math.max(range / Math.sqrt(m_SumOfWeights),
                    // allow at most 3 sds within one interval
                    m_Precision / (2 * 3));
        }
    }

    /**
     * Get a probability estimate for a value.
     * 
     * @param data the value to estimate the probability of
     * @return the estimated probability of the supplied value
     */
    @Override
    public double getProbability(double data) {

        double delta = 0, sum = 0, currentProb = 0;
        double zLower = 0, zUpper = 0;
        if (m_NumValues == 0) {
            zLower = (data - (m_Precision / 2)) / m_StandardDev;
            zUpper = (data + (m_Precision / 2)) / m_StandardDev;
            return (Statistics.normalProbability(zUpper) - Statistics.normalProbability(zLower));
        }
        double weightSum = 0;
        int start = findNearestValue(data);
        for (int i = start; i < m_NumValues; i++) {
            delta = m_Values[i] - data;
            zLower = (delta - (m_Precision / 2)) / m_StandardDev;
            zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
            currentProb = (Statistics.normalProbability(zUpper) - Statistics.normalProbability(zLower));
            sum += currentProb * m_Weights[i];
            /*
             * System.out.print("zL" + (i + 1) + ": " + zLower + " ");
             * System.out.print("zU" + (i + 1) + ": " + zUpper + " ");
             * System.out.print("P" + (i + 1) + ": " + currentProb + " ");
             * System.out.println("total: " + (currentProb * m_Weights[i]) + " ");
             */
            weightSum += m_Weights[i];
            if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
                break;
            }
        }
        for (int i = start - 1; i >= 0; i--) {
            delta = m_Values[i] - data;
            zLower = (delta - (m_Precision / 2)) / m_StandardDev;
            zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
            currentProb = (Statistics.normalProbability(zUpper) - Statistics.normalProbability(zLower));
            sum += currentProb * m_Weights[i];
            weightSum += m_Weights[i];
            if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
                break;
            }
        }
        return sum / m_SumOfWeights;
    }

    /** Display a representation of this estimator */
    @Override
    public String toString() {

        String result = m_NumValues + " Normal Kernels. \nStandardDev = "
                + Utils.doubleToString(m_StandardDev, 6, 4) + " Precision = " + m_Precision;
        if (m_NumValues == 0) {
            result += "  \nMean = 0";
        } else {
            result += "  \nMeans =";
            for (int i = 0; i < m_NumValues; i++) {
                result += " " + m_Values[i];
            }
            if (!m_AllWeightsOne) {
                result += "\nWeights = ";
                for (int i = 0; i < m_NumValues; i++) {
                    result += " " + m_Weights[i];
                }
            }
        }
        return result + "\n";
    }

    /**
     * Return the number of kernels in this kernel estimator
     * 
     * @return the number of kernels
     */
    public int getNumKernels() {
        return m_NumValues;
    }

    /**
     * Return the means of the kernels.
     * 
     * @return the means of the kernels
     */
    public double[] getMeans() {
        return m_Values;
    }

    /**
     * Return the weights of the kernels.
     * 
     * @return the weights of the kernels
     */
    public double[] getWeights() {
        return m_Weights;
    }

    /**
     * Return the precision of this kernel estimator.
     * 
     * @return the precision
     */
    public double getPrecision() {
        return m_Precision;
    }

    /**
     * Return the standard deviation of this kernel estimator.
     * 
     * @return the standard deviation
     */
    public double getStdDev() {
        return m_StandardDev;
    }

    /**
     * Returns default capabilities of the classifier.
     * 
     * @return the capabilities of this classifier
     */
    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        // class
        if (!m_noClass) {
            result.enable(Capability.NOMINAL_CLASS);
            result.enable(Capability.MISSING_CLASS_VALUES);
        } else {
            result.enable(Capability.NO_CLASS);
        }

        // attributes
        result.enable(Capability.NUMERIC_ATTRIBUTES);
        return result;
    }

    /**
     * Returns the revision string.
     * 
     * @return the revision
     */
    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision$");
    }

    @Override
    public KernelEstimator aggregate(KernelEstimator toAggregate) throws Exception {

        for (int i = 0; i < toAggregate.m_NumValues; i++) {
            addValue(toAggregate.m_Values[i], toAggregate.m_Weights[i]);
        }

        return this;
    }

    @Override
    public void finalizeAggregation() throws Exception {
        // nothing to do
    }

    public static void testAggregation() {
        KernelEstimator ke = new KernelEstimator(0.01);
        KernelEstimator one = new KernelEstimator(0.01);
        KernelEstimator two = new KernelEstimator(0.01);

        java.util.Random r = new java.util.Random(1);

        for (int i = 0; i < 100; i++) {
            double z = r.nextDouble();

            ke.addValue(z, 1);
            if (i < 50) {
                one.addValue(z, 1);
            } else {
                two.addValue(z, 1);
            }
        }

        try {

            System.out.println("\n\nFull\n");
            System.out.println(ke.toString());
            System.out.println("Prob (0): " + ke.getProbability(0));

            System.out.println("\nOne\n" + one.toString());
            System.out.println("Prob (0): " + one.getProbability(0));

            System.out.println("\nTwo\n" + two.toString());
            System.out.println("Prob (0): " + two.getProbability(0));

            one = one.aggregate(two);

            System.out.println("Aggregated\n");
            System.out.println(one.toString());
            System.out.println("Prob (0): " + one.getProbability(0));
        } catch (Exception ex) {
            ex.printStackTrace();
        }
    }

    /**
     * Main method for testing this class.
     * 
     * @param argv should contain a sequence of numeric values
     */
    public static void main(String[] argv) {

        try {
            if (argv.length < 2) {
                System.out.println("Please specify a set of instances.");
                return;
            }
            KernelEstimator newEst = new KernelEstimator(0.01);
            for (int i = 0; i < argv.length - 3; i += 2) {
                newEst.addValue(Double.valueOf(argv[i]).doubleValue(), Double.valueOf(argv[i + 1]).doubleValue());
            }
            System.out.println(newEst);

            double start = Double.valueOf(argv[argv.length - 2]).doubleValue();
            double finish = Double.valueOf(argv[argv.length - 1]).doubleValue();
            for (double current = start; current < finish; current += (finish - start) / 50) {
                System.out.println("Data: " + current + " " + newEst.getProbability(current));
            }

            KernelEstimator.testAggregation();
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
    }
}