jsat.distributions.empirical.KernelDensityEstimatorButla.java Source code

Java tutorial

Introduction

Here is the source code for jsat.distributions.empirical.KernelDensityEstimatorButla.java

Source

/**
 * This file is part of SADL, a library for learning all sorts of (timed) automata and performing sequence-based anomaly detection.
 * Copyright (C) 2013-2016  the original author or authors.
 *
 * SADL is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
 *
 * SADL is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along with SADL.  If not, see <http://www.gnu.org/licenses/>.
 */
package jsat.distributions.empirical;

import java.util.Arrays;

import org.apache.commons.math3.util.Precision;

import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import jsat.distributions.ContinuousDistribution;
import jsat.distributions.empirical.kernelfunc.GaussKF;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.optimization.GoldenSearch;
import sadl.constants.KDEFormelVariant;

public class KernelDensityEstimatorButla {
    protected Vec dataPoints;
    protected double[] X;
    protected Function kernelPdfFunction;
    protected Function kernelDerivationFunction;

    protected double minSearchAccuracy;
    protected double minSearchStep;

    protected double startX;
    protected double endX;

    protected KDEFormelVariant kdeFormelVariant;

    public static final double DEFAULT_BANDWIDTH = 50d;
    public static final double DEFAULT_MIN_SEARCH_ACCURACY = 0.25d;

    public KernelDensityEstimatorButla(double[] dataPoints, KDEFormelVariant formelVariant) {
        this(new DenseVector(dataPoints), formelVariant);
    }

    public KernelDensityEstimatorButla(Vec dataPoints, KDEFormelVariant formelVariant) {
        this(dataPoints, formelVariant, MyKernelDensityEstimator.BandwithGuassEstimate(dataPoints));
    }

    public KernelDensityEstimatorButla(double[] dataPoints, KDEFormelVariant formelVariant, double bandwidth) {
        this(new DenseVector(dataPoints), formelVariant, bandwidth);
    }

    public KernelDensityEstimatorButla(Vec dataPoints, KDEFormelVariant formelVariant, double bandwidth) {
        this(dataPoints, formelVariant, bandwidth, bandwidth / 4.0, DEFAULT_MIN_SEARCH_ACCURACY);
    }

    public KernelDensityEstimatorButla(double[] dataPoints, KDEFormelVariant formelVariant, double bandwidth,
            double minSearchStep, double minSearchAccuracy) {
        this(new DenseVector(dataPoints), formelVariant, bandwidth, minSearchStep, minSearchAccuracy);
    }

    public KernelDensityEstimatorButla(Vec dataPoints, KDEFormelVariant formelVariant, double bandwidth,
            double minSearchStep, double minSearchAccuracy) {

        // TODO check

        this.dataPoints = dataPoints.sortedCopy();
        this.X = dataPoints.arrayCopy();
        this.minSearchStep = minSearchStep;
        this.minSearchAccuracy = minSearchAccuracy;
        this.kdeFormelVariant = formelVariant;

        if (Precision.equals(bandwidth, 0)) {
            bandwidth = MyKernelDensityEstimator.BandwithGuassEstimate(dataPoints);
            this.minSearchStep = bandwidth / 4.0;
        }

        if (this.minSearchStep < 0.0001) {
            this.minSearchStep = 0.25d;
        }

        if (formelVariant == KDEFormelVariant.OriginalKDE) {

            final MyKernelDensityEstimator kernelDensity = new MyKernelDensityEstimator(dataPoints,
                    GaussKF.getInstance(), bandwidth);
            kernelPdfFunction = ContinuousDistribution.getFunctionPDF(kernelDensity);
            kernelDerivationFunction = ContinuousDistribution.getFunctionPDF(
                    new MyKernelDensityEstimator(dataPoints, GaussKFDerivation.getInstance(), bandwidth));

            startX = kernelDensity.min() + bandwidth;
            endX = kernelDensity.max() - bandwidth;

        } else if (formelVariant == KDEFormelVariant.OriginalButlaVariableBandwidth) {

            kernelPdfFunction = new Function() {

                private static final long serialVersionUID = 337703545623146489L;

                @Override
                public double f(Vec x) {
                    return f(new double[] { x.get(0) });
                }

                @Override
                public double f(double... x) {

                    final double t = x[0];
                    double sum = 0.0d;

                    final double maxH = Math.pow(X[X.length - 1] * 0.05, 2);
                    int from = Arrays.binarySearch(X, t - maxH * 13);
                    int to = Arrays.binarySearch(X, t + maxH * 13);
                    from = from < 0 ? -from - 1 : from;
                    to = to < 0 ? -to - 1 : to;

                    for (int i = Math.max(0, from); i < Math.min(X.length, to + 1); i++) {
                        final double ti = dataPoints.get(i);
                        if (!Precision.equals(ti, 0)) {
                            sum += Math.exp(-Math.pow(t - ti, 2) / (2 * 0.05 * ti))
                                    / (Math.sqrt(2.0 * Math.PI) * 0.05 * ti);
                        }
                    }

                    return sum / dataPoints.length();
                }
            };

            kernelDerivationFunction = new Function() {
                private static final long serialVersionUID = 1896912471233540595L;

                @Override
                public double f(Vec x) {
                    return f(new double[] { x.get(0) });
                }

                @Override
                public double f(double... x) {

                    final double t = x[0];
                    double sum = 0.0d;

                    final double maxH = Math.pow(X[X.length - 1] * 0.05, 2);
                    int from = Arrays.binarySearch(X, t - maxH * 13);
                    int to = Arrays.binarySearch(X, t + maxH * 13);
                    from = from < 0 ? -from - 1 : from;
                    to = to < 0 ? -to - 1 : to;

                    for (int i = Math.max(0, from); i < Math.min(X.length, to + 1); i++) {
                        final double ti = dataPoints.get(i);
                        if (!Precision.equals(ti, 0)) {
                            sum += (-79.7885 * Math.exp(-10 * Math.pow(t - ti, 2) / ti)
                                    * (Math.pow(ti, 2) + 0.1 * ti - Math.pow(t, 2))) / Math.pow(ti, 3);
                        }
                    }

                    return sum / dataPoints.length();
                }
            };

            startX = Math.max(dataPoints.get(0), 1.0);
            endX = dataPoints.get(dataPoints.length() - 1);

        } else if (formelVariant == KDEFormelVariant.ButlaBandwidthNotSquared) {

            kernelPdfFunction = new Function() {
                private static final long serialVersionUID = -8200289641116502672L;

                @Override
                public double f(Vec x) {
                    return f(new double[] { x.get(0) });
                }

                @Override
                public double f(double... x) {

                    final double t = x[0];
                    double sum = 0.0d;

                    final double maxH = Math.pow(X[X.length - 1] * 0.05, 2);
                    int from = Arrays.binarySearch(X, t - maxH * 13);
                    int to = Arrays.binarySearch(X, t + maxH * 13);
                    from = from < 0 ? -from - 1 : from;
                    to = to < 0 ? -to - 1 : to;

                    for (int i = Math.max(0, from); i < Math.min(X.length, to + 1); i++) {
                        final double ti = X[i];
                        if (!Precision.equals(ti, 0)) {
                            sum += Math.exp(-Math.pow(t - ti, 2) / (2 * Math.pow(0.05 * ti, 2)))
                                    / (Math.sqrt(2.0 * Math.PI) * 0.05 * ti);
                        }
                    }

                    return sum / dataPoints.length();
                }
            };

            kernelDerivationFunction = new Function() {
                private static final long serialVersionUID = -2561020473687438986L;

                @Override
                public double f(Vec x) {
                    return f(new double[] { x.get(0) });
                }

                @Override
                public double f(double... x) {

                    final double t = x[0];
                    double sum = 0.0d;

                    final double maxH = Math.pow(X[X.length - 1] * 0.05, 2);
                    int from = Arrays.binarySearch(X, t - maxH * 13);
                    int to = Arrays.binarySearch(X, t + maxH * 13);
                    from = from < 0 ? -from - 1 : from;
                    to = to < 0 ? -to - 1 : to;

                    for (int i = Math.max(0, from); i < Math.min(X.length, to + 1); i++) {
                        final double ti = dataPoints.get(i);
                        if (!Precision.equals(ti, 0)) {
                            sum += ((-7.97885 * Math.pow(ti, 2) - 3191.54 * ti * t + 3191.54 * Math.pow(t, 2))
                                    * Math.exp(-200 * Math.pow(t - ti, 2) / Math.pow(ti, 2))) / Math.pow(ti, 4);
                        }
                    }

                    return sum / dataPoints.length();
                }
            };

            startX = Math.max(dataPoints.get(0), 1.0);
            endX = dataPoints.get(dataPoints.length() - 1);

        } else if (formelVariant == KDEFormelVariant.ButlaBandwidthSquared) {

            kernelPdfFunction = new Function() {
                private static final long serialVersionUID = 6749547413109881687L;

                @Override
                public double f(Vec x) {
                    return f(new double[] { x.get(0) });
                }

                @Override
                public double f(double... x) {

                    final double t = x[0];
                    double sum = 0.0d;

                    final double maxH = X[X.length - 1] * 0.05;
                    int from = Arrays.binarySearch(X, t - maxH * 13);
                    int to = Arrays.binarySearch(X, t + maxH * 13);
                    from = from < 0 ? -from - 1 : from;
                    to = to < 0 ? -to - 1 : to;

                    for (int i = Math.max(0, from); i < Math.min(X.length, to + 1); i++) {
                        final double ti = dataPoints.get(i);
                        if (!Precision.equals(ti, 0)) {
                            sum += Math.exp(-Math.pow(t - ti, 2) / (2 * 0.05 * ti))
                                    / (Math.sqrt(2.0 * Math.PI * 0.05 * ti));
                        }
                    }

                    return sum / dataPoints.length();
                }
            };

            kernelDerivationFunction = new Function() {
                private static final long serialVersionUID = 3612595828189571262L;

                @Override
                public double f(Vec x) {
                    return f(new double[] { x.get(0) });
                }

                @Override
                public double f(double... x) {

                    final double t = x[0];
                    double sum = 0.0d;

                    final double maxH = X[X.length - 1] * 0.05;
                    int from = Arrays.binarySearch(X, t - maxH * 13);
                    int to = Arrays.binarySearch(X, t + maxH * 13);
                    from = from < 0 ? -from - 1 : from;
                    to = to < 0 ? -to - 1 : to;

                    for (int i = Math.max(0, from); i < Math.min(X.length, to + 1); i++) {
                        final double ti = dataPoints.get(i);
                        if (!Precision.equals(ti, 0)) {
                            sum += (Math.exp(-10 * Math.pow(t - ti, 2) / ti)
                                    * (-17.8412 * Math.pow(ti, 2) - 0.892062 * ti + 17.8412 * Math.pow(t, 2)))
                                    / Math.sqrt(Math.pow(ti, 5));
                        }
                    }

                    return sum / dataPoints.length();
                }
            };

            startX = Math.max(dataPoints.get(0), 1.0);
            endX = dataPoints.get(dataPoints.length() - 1);
        }
    }

    public double pdf(double x) {

        return kernelPdfFunction.f(x);
    }

    public double prime(double x) {

        return kernelDerivationFunction.f(x);
    }

    /**
     * @param accuracy
     *            The accuracy of min-search in getMinima function.
     */
    public void setAccuracy(double accuracy) {

        if (Double.isNaN(accuracy)) {
            throw new IllegalArgumentException();
        }
        this.minSearchAccuracy = accuracy;
    }

    public double[] getMinima() {

        final TDoubleList pointList = new TDoubleArrayList();

        double lastX = startX;
        double lastValue = kernelDerivationFunction.f(lastX);

        double step = minSearchStep;

        if (kdeFormelVariant != KDEFormelVariant.OriginalKDE && Precision.equals(X[0], 0.0)) {
            pointList.add(0.5);
        }

        for (double x = lastX + step; x < endX; x = x + step) {
            final double newValue = kernelDerivationFunction.f(x);

            if (lastValue < 0 && newValue > 0) {
                pointList.add(GoldenSearch.minimize(minSearchAccuracy, 100, lastX, x, 0, kernelPdfFunction,
                        new double[1]));
            }

            lastX = x;
            lastValue = newValue;

            if (kdeFormelVariant == KDEFormelVariant.ButlaBandwidthNotSquared
                    || kdeFormelVariant == KDEFormelVariant.OriginalButlaVariableBandwidth) {
                step = x * 0.05d / 4.0;
            } else if (kdeFormelVariant == KDEFormelVariant.ButlaBandwidthSquared) {
                step = Math.pow(x * 0.05d, 2) / 4.0;
            }
        }

        return pointList.toArray();
    }

}