Java tutorial
/** * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.brandtg.stl; import org.apache.commons.math3.analysis.interpolation.LoessInterpolator; import org.apache.commons.math3.analysis.interpolation.NevilleInterpolator; import org.apache.commons.math3.analysis.polynomials.PolynomialFunctionLagrangeForm; import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction; import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; import java.io.BufferedReader; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.List; /** * This package contains an implementation of STL: * A Seasonal-Trend Decomposition Procedure based on Loess. * * <p> * Robert B. Cleveland et al., * "STL: A Seasonal-Trend Decomposition Procedure based on Loess," in Journal * of Official Statistics Vol. 6 No. 1, 1990, pp. 3-73 * </p> */ public class StlDecomposition { /** The configuration with which to run STL. */ private final StlConfig config; /** * Constructs an STL function that can de-trend data. * * <p> * n.b. The Java Loess implementation only does linear local polynomial * regression, but R supports linear (degree=1), quadratic (degree=2), and a * strange degree=0 option. * </p> * * <p> * Also, the Java Loess implementation accepts "bandwidth", the fraction of * source points closest to the current point, as opposed to integral values. * </p> * * @param numberOfObservations The number of observations in a season. */ public StlDecomposition(int numberOfObservations) { this.config = new StlConfig(numberOfObservations); } /** * @return The configuration used by this function for fine tuning. */ public StlConfig getConfig() { return config; } /** * A convenience method to use objects. * * @param times * A sequence of time values. * @param series * A dependent variable on times. * @return * The STL decomposition of the time series. */ public StlResult decompose(List<Number> times, List<Number> series) { double[] timesArray = new double[times.size()]; double[] seriesArray = new double[series.size()]; int idx = 0; for (Number time : times) { timesArray[idx++] = time.doubleValue(); } idx = 0; for (Number value : series) { seriesArray[idx++] = value.doubleValue(); } return decompose(timesArray, seriesArray); } /** * Computes the STL decomposition of a times series. * * @param times * A sequence of time values. * @param series * A dependent variable on times. * @return * The STL decomposition of the time series. */ public StlResult decompose(double[] times, double[] series) { if (times.length != series.length) { throw new IllegalArgumentException( "Times (" + times.length + ") and series (" + series.length + ") must be same size"); } int numberOfDataPoints = series.length; config.check(numberOfDataPoints); double[] trend = new double[numberOfDataPoints]; double[] seasonal = new double[numberOfDataPoints]; double[] remainder = new double[numberOfDataPoints]; double[] robustness = null; double[] detrend = new double[numberOfDataPoints]; double[] combinedSmoothed = new double[numberOfDataPoints + 2 * config.getNumberOfObservations()]; double[] combinedSmoothedTimes = new double[numberOfDataPoints + 2 * config.getNumberOfObservations()]; for (int i = 0; i < combinedSmoothedTimes.length; i++) { combinedSmoothedTimes[i] = i; } for (int l = 0; l < config.getNumberOfRobustnessIterations(); l++) { for (int k = 0; k < config.getNumberOfInnerLoopPasses(); k++) { // Step 1: De-trending for (int i = 0; i < numberOfDataPoints; i++) { detrend[i] = series[i] - trend[i]; } // Get cycle sub-series int numberOfObservations = config.getNumberOfObservations(); CycleSubSeries cycle = new CycleSubSeries(times, series, robustness, detrend, numberOfObservations); cycle.compute(); List<double[]> cycleSubseries = cycle.getCycleSubSeries(); List<double[]> cycleTimes = cycle.getCycleTimes(); List<double[]> cycleRobustnessWeights = cycle.getCycleRobustnessWeights(); // Step 2: Cycle-subseries Smoothing for (int i = 0; i < cycleSubseries.size(); i++) { // Pad times / values TimesAndValues padded = padEdges(cycleTimes.get(i), cycleSubseries.get(i)); double[] paddedTimes = padded.getTs(); double[] paddedSeries = padded.getXs(); // Pad weights double[] weights = cycleRobustnessWeights.get(i); double[] paddedWeights = null; if (weights != null) { paddedWeights = padEdges(cycleTimes.get(i), weights).getXs(); } // Loess smoothing double[] smoothed = loessSmooth(paddedTimes, paddedSeries, config.getSeasonalComponentBandwidth(), paddedWeights); cycleSubseries.set(i, smoothed); } // Combine smoothed series into one for (int i = 0; i < cycleSubseries.size(); i++) { double[] subseriesValues = cycleSubseries.get(i); for (int cycleIdx = 0; cycleIdx < subseriesValues.length; cycleIdx++) { combinedSmoothed[numberOfObservations * cycleIdx + i] = subseriesValues[cycleIdx]; } } // Step 3: Low-Pass Filtering of Smoothed Cycle-Subseries double[] filtered = lowPassFilter(combinedSmoothedTimes, combinedSmoothed, null); // Step 4: Detrending of Smoothed Cycle-Subseries int offset = config.getNumberOfObservations(); for (int i = 0; i < seasonal.length; i++) { seasonal[i] = combinedSmoothed[i + offset] - filtered[i + offset]; } // Step 5: Deseasonalizing for (int i = 0; i < numberOfDataPoints; i++) { trend[i] = series[i] - seasonal[i]; } // Step 6: Trend Smoothing trend = loessSmooth(times, trend, config.getTrendComponentBandwidth(), robustness); } // --- Now in outer loop --- // Calculate remainder for (int i = 0; i < numberOfDataPoints; i++) { remainder[i] = series[i] - trend[i] - seasonal[i]; } // Calculate robustness weights using remainder robustness = robustnessWeights(remainder); } if (config.isPeriodic()) { for (int i = 0; i < config.getNumberOfObservations(); i++) { // Compute weighted mean for one season double sum = 0.0; int count = 0; for (int j = i; j < numberOfDataPoints; j += config.getNumberOfObservations()) { sum += seasonal[j]; count++; } double mean = sum / count; // Copy this to rest of seasons for (int j = i; j < numberOfDataPoints; j += config.getNumberOfObservations()) { seasonal[j] = mean; } } // Recalculate remainder for (int i = 0; i < series.length; i++) { remainder[i] = series[i] - trend[i] - seasonal[i]; } } return new StlResult(times, series, trend, seasonal, remainder); } /** * The cycle subseries of a time series. * * <p> * The cycle subseries is a set of series whose members are of length * N, where N is the number of observations in a season. * </p> * * <p> * For example, if we have monthly data from 1990 to 2000, the cycle * subseries would be [[Jan_1990, Jan_1991, ...], ..., [Dec_1990, Dec_1991]]. * </p> */ private static class CycleSubSeries { /** Output: The list of cycle subseries series data. */ private final List<double[]> cycleSubSeries = new ArrayList<double[]>(); /** Output: The list of cycle subseries times. */ private final List<double[]> cycleTimes = new ArrayList<double[]>(); /** Output: The list of cycle subseries robustness weights. */ private final List<double[]> cycleRobustnessWeights = new ArrayList<double[]>(); /** Input: The number of observations in a season. */ private final int numberOfObservations; /** Input: The input times. */ private final double[] times; /** Input: The input series data. */ private final double[] series; /** Input: The robustness weights, from STL. */ private final double[] robustness; /** Input: The de-trended series, from STL. */ private final double[] detrend; /** * Constructs a cycle subseries computation. * * @param times * The input times. * @param series * A dependent variable on times. * @param robustness * The robustness weights from STL loop. * @param detrend * The de-trended data. * @param numberOfObservations * The number of observations in a season. */ CycleSubSeries(double[] times, double[] series, double[] robustness, double[] detrend, int numberOfObservations) { this.times = times; this.series = series; this.robustness = robustness; this.detrend = detrend; this.numberOfObservations = numberOfObservations; } /** * @return * A list of size numberOfObservations, whose elements are of length * times.length / numberOfObservations: the cycle subseries. */ List<double[]> getCycleSubSeries() { return cycleSubSeries; } /** * @return The times corresponding to getCycleSubSeries. */ List<double[]> getCycleTimes() { return cycleTimes; } /** * @return The robustness weights corresponding to getCycleSubSeries. */ List<double[]> getCycleRobustnessWeights() { return cycleRobustnessWeights; } /** * Computes the cycle subseries of the input. * * <p> * Must call this before getters return anything meaningful. * </p> */ void compute() { for (int i = 0; i < numberOfObservations; i++) { int subseriesLength = series.length / numberOfObservations; double[] subseriesValues = new double[subseriesLength]; double[] subseriesTimes = new double[subseriesLength]; double[] subseriesRobustnessWeights = null; if (robustness != null) { subseriesRobustnessWeights = new double[subseriesLength]; } for (int cycleIdx = 0; cycleIdx < subseriesLength; cycleIdx++) { subseriesValues[cycleIdx] = detrend[cycleIdx * numberOfObservations + i]; subseriesTimes[cycleIdx] = times[cycleIdx * numberOfObservations + i]; if (subseriesRobustnessWeights != null) { subseriesRobustnessWeights[cycleIdx] = robustness[cycleIdx * numberOfObservations + i]; // TODO: Hack to ensure no divide by zero if (subseriesRobustnessWeights[cycleIdx] < 0.001) { subseriesRobustnessWeights[cycleIdx] = 0.01; } } } cycleSubSeries.add(subseriesValues); cycleTimes.add(subseriesTimes); cycleRobustnessWeights.add(subseriesRobustnessWeights); } } } /** * Computes robustness weights using bisquare weight function. * * @param remainder * The remainder, series - trend - seasonal. * @return * A new array containing the robustness weights. */ private double[] robustnessWeights(double[] remainder) { // Compute "h" = 6 median(|R_v|) double[] absRemainder = new double[remainder.length]; for (int i = 0; i < remainder.length; i++) { absRemainder[i] = Math.abs(remainder[i]); } DescriptiveStatistics stats = new DescriptiveStatistics(absRemainder); double outlierThreshold = 6 * stats.getPercentile(50); // Compute robustness weights double[] robustness = new double[remainder.length]; for (int i = 0; i < remainder.length; i++) { robustness[i] = biSquareWeight(absRemainder[i] / outlierThreshold); } return robustness; } /** * The bisquare weight function. * * @param value * Any real number. * @return * <pre> * (1 - value^2)^2 for 0 <= value < 1 * 0 for value > 1 * </pre> */ private double biSquareWeight(double value) { if (value < 0) { throw new IllegalArgumentException("Invalid value, must be >= 0: " + value); } else if (value < 1) { return Math.pow(1 - Math.pow(value, 2), 2); } else { return 0; } } /** * A low pass filter used on combined smoothed cycle subseries. * * <p> * The filter consists of the following steps: * <ol> * <li>Moving average of length n_p, seasonal size</li> * <li>Moving average of length 3, (magic number from paper)</li> * <li>Loess smoothing</li> * </ol> * </p> * * @param times * The times. * @param series * The time series data. * @param weights * Weights to use in Loess stage. * @return * A smoother, less noisy series. */ private double[] lowPassFilter(double[] times, double[] series, double[] weights) { // Find the next odd integer >= n_p (see: section 3.4) double nextOdd = config.getNumberOfObservations() % 2 == 1 ? config.getNumberOfObservations() : config.getNumberOfObservations() + 1; // Determine bandwidth as a percentage of points double lowPassBandwidth = nextOdd / series.length; // Apply moving average of length n_p, twice series = movingAverage(series, config.getNumberOfObservations()); series = movingAverage(series, config.getNumberOfObservations()); // Apply moving average of length 3 series = movingAverage(series, 3); // Loess smoothing with d = 1, q = n_l series = loessSmooth(times, series, lowPassBandwidth, weights); return series; } /** * Performs weighted Loess smoothing on a series. * * <p> * Does not assume contiguous time. * </p> * * @param times * The times. * @param series * The time series data. * @param bandwidth * The amount of neighbor points to consider for each point in Loess. * @param weights * The weights to use for smoothing, if null, equal weights are assumed. * @return * Loess-smoothed series. */ private double[] loessSmooth(double[] times, double[] series, double bandwidth, double[] weights) { if (weights == null) { return new LoessInterpolator(bandwidth, config.getLoessRobustnessIterations()).smooth(times, series); } else { return new LoessInterpolator(bandwidth, config.getLoessRobustnessIterations()).smooth(times, series, weights); } } /** * Computes the moving average. * * <p> * The first "window" values are meaningless in the return value. * </p> * * @param series * An input series of data. * @param window * The moving average sliding window. * @return * A new series that contains moving average of series. */ private double[] movingAverage(double[] series, int window) { double[] movingAverage = new double[series.length]; double average = 0; for (int i = 0; i < window; i++) { average += series[i] / window; movingAverage[i] = average; } for (int i = window; i < series.length - window; i++) { average -= series[i - window] / window; average += series[i] / window; movingAverage[i] = average; } for (int i = series.length - window; i < series.length; i++) { average -= series[i] / window; movingAverage[i] = average; } return movingAverage; } private static class TimesAndValues { private final double[] ts; private final double[] xs; TimesAndValues(double[] ts, double[] xs) { this.ts = ts; this.xs = xs; } public double[] getTs() { return ts; } public double[] getXs() { return xs; } } private TimesAndValues padEdges(double[] ts, double[] xs) { // Find step between times double step = Math.abs(ts[1] - ts[0]); // Times (assuming uniform double[] paddedTimes = new double[ts.length + 2]; System.arraycopy(ts, 0, paddedTimes, 1, ts.length); paddedTimes[0] = paddedTimes[1] - step; paddedTimes[paddedTimes.length - 1] = paddedTimes[paddedTimes.length - 2] + step; // Series double[] paddedSeries = new double[xs.length + 2]; System.arraycopy(xs, 0, paddedSeries, 1, xs.length); // Use Loess at ends to pad // n.b. For some reason, this can result in NaN values - perhaps similar to // https://issues.apache.org/jira/browse/MATH-296. If we see NaN, just "extrapolate" by copying // the end points :( double left = paddedSeries[1]; double right = paddedSeries[paddedSeries.length - 2]; double bandwidth = 0.3; if (ts.length * bandwidth > 2) { PolynomialSplineFunction loess = new LoessInterpolator(bandwidth, 2).interpolate(ts, xs); double loessLeft = loess.value(ts[0]); if (!Double.isNaN(loessLeft)) { left = loessLeft; } double loessRight = loess.value(ts[ts.length - 1]); if (!Double.isNaN(loessRight)) { right = loessRight; } } paddedSeries[0] = left; paddedSeries[paddedSeries.length - 1] = right; return new TimesAndValues(paddedTimes, paddedSeries); } /** * Runs STL on a CSV of time,measure. * * <p> * Outputs a CSV of time,measure,trend,seasonal,remainder. * </p> * * @param args * args[0] = numberOfObservations * @throws Exception * If could not process data */ public static void main(String[] args) throws Exception { List<Number> times = new ArrayList<Number>(); List<Number> measures = new ArrayList<Number>(); // Read from STDIN String line; BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); while ((line = reader.readLine()) != null) { String[] tokens = line.split(","); times.add(Long.valueOf(tokens[0])); measures.add(Double.valueOf(tokens[1])); } // Compute STL StlDecomposition stl = new StlDecomposition(Integer.valueOf(args[0])); stl.getConfig().setSeasonalComponentBandwidth(Double.valueOf( System.getProperty("seasonal.bandwidth", String.valueOf(StlConfig.DEFAULT_SEASONAL_BANDWIDTH)))); stl.getConfig().setTrendComponentBandwidth(Double .valueOf(System.getProperty("trend.bandwidth", String.valueOf(StlConfig.DEFAULT_TREND_BANDWIDTH)))); stl.getConfig().setNumberOfInnerLoopPasses(Integer .valueOf(System.getProperty("inner.loop", String.valueOf(StlConfig.DEFAULT_INNER_LOOP_PASSES)))); StlResult res = stl.decompose(times, measures); // Output to STDOUT for (int i = 0; i < times.size(); i++) { System.out.println(String.format("%d,%02f,%02f,%02f,%02f", (long) res.getTimes()[i], res.getSeries()[i], res.getTrend()[i], res.getSeasonal()[i], res.getRemainder()[i])); } } }