weka.estimators.NormalEstimator.java Source code

Java tutorial

Introduction

Here is the source code for weka.estimators.NormalEstimator.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    NormalEstimator.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.estimators;

import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Aggregateable;
import weka.core.RevisionUtils;
import weka.core.Statistics;
import weka.core.Utils;

/**
 * Simple probability estimator that places a single normal distribution over
 * the observed values.
 * 
 * @author Len Trigg (trigg@cs.waikato.ac.nz)
 * @version $Revision$
 */
public class NormalEstimator extends Estimator implements IncrementalEstimator, Aggregateable<NormalEstimator> {

    /** for serialization */
    private static final long serialVersionUID = 93584379632315841L;

    /** The sum of the weights */
    private double m_SumOfWeights;

    /** The sum of the values seen */
    private double m_SumOfValues;

    /** The sum of the values squared */
    private double m_SumOfValuesSq;

    /** The current mean */
    private double m_Mean;

    /** The current standard deviation */
    private double m_StandardDev;

    /** The precision of numeric values ( = minimum std dev permitted) */
    private double m_Precision;

    /**
     * Round a data value using the defined precision for this estimator
     * 
     * @param data the value to round
     * @return the rounded data value
     */
    private double round(double data) {

        return Math.rint(data / m_Precision) * m_Precision;
    }

    // ===============
    // Public methods.
    // ===============

    /**
     * Constructor that takes a precision argument.
     * 
     * @param precision the precision to which numeric values are given. For
     *          example, if the precision is stated to be 0.1, the values in the
     *          interval (0.25,0.35] are all treated as 0.3.
     */
    public NormalEstimator(double precision) {

        m_Precision = precision;

        // Allow at most 3 sd's within one interval
        m_StandardDev = m_Precision / (2 * 3);
    }

    /**
     * Add a new data value to the current estimator.
     * 
     * @param data the new data value
     * @param weight the weight assigned to the data value
     */
    @Override
    public void addValue(double data, double weight) {

        if (weight == 0) {
            return;
        }
        data = round(data);
        m_SumOfWeights += weight;
        m_SumOfValues += data * weight;
        m_SumOfValuesSq += data * data * weight;

        computeParameters();
    }

    /**
     * Compute the parameters of the distribution
     */
    protected void computeParameters() {
        if (m_SumOfWeights > 0) {
            m_Mean = m_SumOfValues / m_SumOfWeights;
            double stdDev = Math.sqrt(Math.abs(m_SumOfValuesSq - m_Mean * m_SumOfValues) / m_SumOfWeights);
            // If the stdDev ~= 0, we really have no idea of scale yet,
            // so stick with the default. Otherwise...
            if (stdDev > 1e-10) {
                m_StandardDev = Math.max(m_Precision / (2 * 3),
                        // allow at most 3sd's within one interval
                        stdDev);
            }
        }
    }

    /**
     * Get a probability estimate for a value
     * 
     * @param data the value to estimate the probability of
     * @return the estimated probability of the supplied value
     */
    @Override
    public double getProbability(double data) {

        data = round(data);
        double zLower = (data - m_Mean - (m_Precision / 2)) / m_StandardDev;
        double zUpper = (data - m_Mean + (m_Precision / 2)) / m_StandardDev;

        double pLower = Statistics.normalProbability(zLower);
        double pUpper = Statistics.normalProbability(zUpper);
        return pUpper - pLower;
    }

    /**
     * Display a representation of this estimator
     */
    @Override
    public String toString() {

        return "Normal Distribution. Mean = " + Utils.doubleToString(m_Mean, 4) + " StandardDev = "
                + Utils.doubleToString(m_StandardDev, 4) + " WeightSum = " + Utils.doubleToString(m_SumOfWeights, 4)
                + " Precision = " + m_Precision + "\n";
    }

    /**
     * Returns default capabilities of the classifier.
     * 
     * @return the capabilities of this classifier
     */
    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();

        // class
        if (!m_noClass) {
            result.enable(Capability.NOMINAL_CLASS);
            result.enable(Capability.MISSING_CLASS_VALUES);
        } else {
            result.enable(Capability.NO_CLASS);
        }

        // attributes
        result.enable(Capability.NUMERIC_ATTRIBUTES);
        return result;
    }

    /**
     * Return the value of the mean of this normal estimator.
     * 
     * @return the mean
     */
    public double getMean() {
        return m_Mean;
    }

    /**
     * Return the value of the standard deviation of this normal estimator.
     * 
     * @return the standard deviation
     */
    public double getStdDev() {
        return m_StandardDev;
    }

    /**
     * Return the value of the precision of this normal estimator.
     * 
     * @return the precision
     */
    public double getPrecision() {
        return m_Precision;
    }

    /**
     * Return the sum of the weights for this normal estimator.
     * 
     * @return the sum of the weights
     */
    public double getSumOfWeights() {
        return m_SumOfWeights;
    }

    /**
     * Returns the revision string.
     * 
     * @return the revision
     */
    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision$");
    }

    @Override
    public NormalEstimator aggregate(NormalEstimator toAggregate) throws Exception {

        m_SumOfWeights += toAggregate.m_SumOfWeights;
        m_SumOfValues += toAggregate.m_SumOfValues;
        m_SumOfValuesSq += toAggregate.m_SumOfValuesSq;

        if (toAggregate.m_Precision < m_Precision) {
            m_Precision = toAggregate.m_Precision;
        }

        computeParameters();

        return this;
    }

    @Override
    public void finalizeAggregation() throws Exception {
        // nothing to do
    }

    public static void testAggregation() {
        NormalEstimator ne = new NormalEstimator(0.01);
        NormalEstimator one = new NormalEstimator(0.01);
        NormalEstimator two = new NormalEstimator(0.01);

        java.util.Random r = new java.util.Random(1);

        for (int i = 0; i < 100; i++) {
            double z = r.nextDouble();

            ne.addValue(z, 1);
            if (i < 50) {
                one.addValue(z, 1);
            } else {
                two.addValue(z, 1);
            }
        }

        try {
            System.out.println("\n\nFull\n");
            System.out.println(ne.toString());
            System.out.println("Prob (0): " + ne.getProbability(0));

            System.out.println("\nOne\n" + one.toString());
            System.out.println("Prob (0): " + one.getProbability(0));

            System.out.println("\nTwo\n" + two.toString());
            System.out.println("Prob (0): " + two.getProbability(0));

            one = one.aggregate(two);

            System.out.println("\nAggregated\n");
            System.out.println(one.toString());
            System.out.println("Prob (0): " + one.getProbability(0));
        } catch (Exception ex) {
            ex.printStackTrace();
        }
    }

    /**
     * Main method for testing this class.
     * 
     * @param argv should contain a sequence of numeric values
     */
    public static void main(String[] argv) {

        try {

            if (argv.length == 0) {
                System.out.println("Please specify a set of instances.");
                return;
            }
            NormalEstimator newEst = new NormalEstimator(0.01);
            for (int i = 0; i < argv.length; i++) {
                double current = Double.valueOf(argv[i]).doubleValue();
                System.out.println(newEst);
                System.out.println("Prediction for " + current + " = " + newEst.getProbability(current));
                newEst.addValue(current, 1);
            }

            NormalEstimator.testAggregation();
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
    }
}