com.opengamma.analytics.financial.model.volatility.smile.fitting.interpolation.SmileInterpolator.java Source code

Java tutorial

Introduction

Here is the source code for com.opengamma.analytics.financial.model.volatility.smile.fitting.interpolation.SmileInterpolator.java

Source

/**
 * Copyright (C) 2012 - present by OpenGamma Inc. and the OpenGamma group of companies
 *
 * Please see distribution for license.
 */
package com.opengamma.analytics.financial.model.volatility.smile.fitting.interpolation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.List;

import org.apache.commons.lang.ObjectUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import cern.jet.random.engine.MersenneTwister;
import cern.jet.random.engine.MersenneTwister64;
import cern.jet.random.engine.RandomEngine;

import com.opengamma.analytics.financial.model.option.pricing.analytic.formula.EuropeanVanillaOption;
import com.opengamma.analytics.financial.model.volatility.smile.fitting.SmileModelFitter;
import com.opengamma.analytics.financial.model.volatility.smile.function.SmileModelData;
import com.opengamma.analytics.financial.model.volatility.smile.function.VolatilityFunctionProvider;
import com.opengamma.analytics.math.MathException;
import com.opengamma.analytics.math.function.Function1D;
import com.opengamma.analytics.math.function.ParameterizedFunction;
import com.opengamma.analytics.math.matrix.DoubleMatrix1D;
import com.opengamma.analytics.math.statistics.leastsquare.LeastSquareResults;
import com.opengamma.analytics.math.statistics.leastsquare.LeastSquareResultsWithTransform;
import com.opengamma.analytics.math.statistics.leastsquare.NonLinearLeastSquare;
import com.opengamma.util.ArgumentChecker;

/**
 * Interpolate a smile, i.e. fit every data point (market volatility/price), by fitting a smile model (e.g. SABR) through consecutive sets of 3 strikes, so for N data points (prices) there will be N-2
 * 3-point fits. In the interior where smile fits overlap, a weighting between the two smiles is taken, which varies from giving 100% weight to the left smile at the mid point of that fit, down to 0%
 * at the mid point of the right fit.
 *
 * @param <T> The type of the smile model data
 */
public abstract class SmileInterpolator<T extends SmileModelData> implements GeneralSmileInterpolator {

    private static final double FIT_ERROR = 1e-4; //1bps
    private static final double LARGE_ERROR = 0.1;
    private static final WeightingFunction DEFAULT_WEIGHTING_FUNCTION = WeightingFunctionFactory.SINE_WEIGHTING_FUNCTION;

    /**
     * The logger
     */
    protected static final Logger s_logger = LoggerFactory.getLogger(SmileInterpolator.class);

    private final VolatilityFunctionProvider<T> _model;
    private final WeightingFunction _weightingFunction;
    private final RandomEngine _random;

    public SmileInterpolator(final VolatilityFunctionProvider<T> model) {
        this(MersenneTwister.DEFAULT_SEED, model);
    }

    public SmileInterpolator(final int seed, final VolatilityFunctionProvider<T> model) {
        this(seed, model, DEFAULT_WEIGHTING_FUNCTION);
    }

    public SmileInterpolator(final VolatilityFunctionProvider<T> model, final WeightingFunction weightFunction) {
        this(MersenneTwister.DEFAULT_SEED, model, weightFunction);
    }

    public SmileInterpolator(final int seed, final VolatilityFunctionProvider<T> model,
            final WeightingFunction weightFunction) {
        ArgumentChecker.notNull(model, "model");
        ArgumentChecker.notNull(weightFunction, "weightFunction");
        _random = new MersenneTwister64(seed);
        _model = model;
        _weightingFunction = weightFunction;
    }

    public List<T> getFittedModelParameters(final double forward, final double[] strikes, final double expiry,
            final double[] impliedVols) {
        ArgumentChecker.notNull(strikes, "strikes");
        ArgumentChecker.notNull(impliedVols, "implied volatilities");
        final int n = strikes.length;
        ArgumentChecker.isTrue(n > 2, "cannot fit less than three points; have {}", n);
        ArgumentChecker.isTrue(impliedVols.length == n, "#strikes != # vols; have {} and {}", impliedVols.length,
                n);
        validateStrikes(strikes);

        final List<T> modelParameters = new ArrayList<>(n);

        final double[] errors = new double[n];
        Arrays.fill(errors, FIT_ERROR);

        final SmileModelFitter<T> globalFitter = getFitter(forward, strikes, expiry, impliedVols, errors);
        final BitSet gFixed = getGlobalFixedValues();
        LeastSquareResultsWithTransform gBest = null;
        double chiSqr = Double.POSITIVE_INFINITY;

        //TODO set these in sub classes
        int tries = 0;
        int count = 0;
        while (chiSqr > 100.0 * n && count < 5) { //10bps average error
            final DoubleMatrix1D gStart = getGlobalStart(forward, strikes, expiry, impliedVols);
            try {
                final LeastSquareResultsWithTransform glsRes = globalFitter.solve(gStart, gFixed);
                if (glsRes.getChiSq() < chiSqr) {
                    gBest = glsRes;
                    chiSqr = gBest.getChiSq();
                }
                count++;
            } catch (final Exception e) {
            }
            tries++;
            if (tries > 20) {
                throw new MathException("Cannot fit data");
            }
        }
        if (gBest == null) {
            throw new IllegalStateException("Global estimate was null; should never happen");
        }
        if (n == 3) {
            if (gBest.getChiSq() / n > 1.0) {
                s_logger.debug("chi^2 on fit to ", +n + " points is " + gBest.getChiSq());
            }
            modelParameters.add(toSmileModelData(gBest.getModelParameters()));
        } else {
            final BitSet lFixed = getLocalFixedValues();
            DoubleMatrix1D lStart = gBest.getModelParameters();
            for (int i = 0; i < n - 2; i++) {
                final double[][] temp = getStrikesVolsAndErrors(i, strikes, impliedVols, errors);
                final double[] tStrikes = temp[0];
                final double[] tVols = temp[1];
                final double[] tErrors = temp[2];
                final SmileModelFitter<T> localFitter = getFitter(forward, tStrikes, expiry, tVols, tErrors);
                LeastSquareResultsWithTransform lRes = localFitter.solve(lStart, lFixed);
                LeastSquareResultsWithTransform best = lRes;

                count = 0;
                while (lRes.getChiSq() > 3.0 && count < 10) {
                    lStart = getGlobalStart(forward, strikes, expiry, impliedVols);
                    lRes = localFitter.solve(lStart, lFixed);
                    if (lRes.getChiSq() < best.getChiSq()) {
                        best = lRes;
                    }
                    count++;
                }

                if (best.getChiSq() > 3.0) {
                    s_logger.debug("chi^2 on 3-point fit #" + i + " is " + best.getChiSq());
                }
                modelParameters.add(toSmileModelData(best.getModelParameters()));
            }
        }
        return modelParameters;
    }

    /**
     * Returns the random number generator for seeding any interpolation algorithms.
     *
     * @return the random number generator, not null
     */
    protected RandomEngine getRandom() {
        return _random;
    }

    public VolatilityFunctionProvider<T> getModel() {
        return _model;
    }

    public WeightingFunction getWeightingFunction() {
        return _weightingFunction;
    }

    protected double[][] getStrikesVolsAndErrors(final int index, final double[] strikes,
            final double[] impliedVols, final double[] errors) {
        return getStrikesVolsAndErrorsForThreePoints(index, strikes, impliedVols, errors);
    }

    /**
     * Use this for models that can be expressed as having 3 parameters (e.g. SABR with beta fixed). It picks out 3 consecutive strike-volatility pairs for the 3 parameter fit (so the chi^2 should be
     * zero if the model is capable of fitting the data)
     *
     * @param index Index of first strike
     * @param strikes Array of all strikes
     * @param impliedVols Array of all vols
     * @param errors Array of all errors
     * @return array containing the 3 strikes, vols and errors
     */
    protected static double[][] getStrikesVolsAndErrorsForThreePoints(final int index, final double[] strikes,
            final double[] impliedVols, final double[] errors) {
        ArgumentChecker.notNull(strikes, "strikes");
        ArgumentChecker.notNull(impliedVols, "implied vols");
        ArgumentChecker.notNull(errors, "errors");
        double[] tStrikes = new double[3];
        double[] tVols = new double[3];
        double[] tErrors = new double[3];
        tStrikes = Arrays.copyOfRange(strikes, index, index + 3);
        tVols = Arrays.copyOfRange(impliedVols, index, index + 3);
        tErrors = Arrays.copyOfRange(errors, index, index + 3);
        final double[][] res = new double[][] { tStrikes, tVols, tErrors };
        return res;
    }

    /**
     * Use this for models that cannot be easily expressed as having 3 parameters (e.g. mixed log-normal). It picks out 3 consecutive strikes and gives them a small error (1bps by default), while the
     * rest of the data has a relatively large error (100bps by default). The fit is then made to all data (n > 3) which allows more than 3 parameters to be fitted (recall, the start position is set
     * from a true global fit). The chi^2 should be close to zero if the model is capable of fitting the data.
     *
     * @param index Index of first strike
     * @param strikes Array of all strikes
     * @param impliedVols Array of all vols
     * @param errors Array of all errors
     * @return array containing the 3 strikes, vols and errors
     */
    protected static double[][] getStrikesVolsAndErrorsForAllPoints(final int index, final double[] strikes,
            final double[] impliedVols, final double[] errors) {
        ArgumentChecker.notNull(strikes, "strikes");
        ArgumentChecker.notNull(impliedVols, "implied vols");
        ArgumentChecker.notNull(errors, "errors");
        final int n = errors.length;
        final double[] lErrors = new double[n];
        Arrays.fill(lErrors, LARGE_ERROR);
        System.arraycopy(errors, index, lErrors, index, 3); //copy the original errors for the points we really want to fit
        final double[][] res = new double[][] { strikes, impliedVols, lErrors };
        return res;
    }

    protected abstract DoubleMatrix1D getGlobalStart(final double forward, final double[] strikes,
            final double expiry, final double[] impliedVols);

    protected BitSet getGlobalFixedValues() {
        return new BitSet();
    }

    protected BitSet getLocalFixedValues() {
        return new BitSet();
    }

    protected abstract T toSmileModelData(final DoubleMatrix1D modelParameters); //TODO have the same thing in SmileModelFitter - could combine

    protected abstract SmileModelFitter<T> getFitter(final double forward, final double[] strikes,
            final double expiry, final double[] impliedVols, final double[] errors);

    @Override
    public Function1D<Double, Double> getVolatilityFunction(final double forward, final double[] strikes,
            final double expiry, final double[] impliedVols) {

        final List<T> modelParams = getFittedModelParameters(forward, strikes, expiry, impliedVols);
        final int n = strikes.length;

        return new Function1D<Double, Double>() {
            @SuppressWarnings("synthetic-access")
            @Override
            public Double evaluate(final Double strike) {
                final EuropeanVanillaOption option = new EuropeanVanillaOption(strike, expiry, true);
                final Function1D<T, Double> volFunc = _model.getVolatilityFunction(option, forward);
                final int index = SurfaceArrayUtils.getLowerBoundIndex(strikes, strike);
                if (index == 0) {
                    return volFunc.evaluate(modelParams.get(0));
                }
                if (index >= n - 2) {
                    return volFunc.evaluate(modelParams.get(n - 3));
                }
                final double w = _weightingFunction.getWeight(strikes, index, strike);
                if (w == 1) {
                    return volFunc.evaluate(modelParams.get(index - 1));
                } else if (w == 0) {
                    return volFunc.evaluate(modelParams.get(index));
                } else {
                    return w * volFunc.evaluate(modelParams.get(index - 1))
                            + (1 - w) * volFunc.evaluate(modelParams.get(index));
                }
            }
        };
    }

    protected DoubleMatrix1D getPolynomialFit(final double forward, final double[] strikes,
            final double[] impliedVols) {
        final int n = strikes.length;
        final double[] x = new double[n];
        for (int i = 0; i < n; i++) {
            x[i] = Math.log(strikes[i] / forward);
        }

        final ParameterizedFunction<Double, DoubleMatrix1D, Double> func = new ParameterizedFunction<Double, DoubleMatrix1D, Double>() {
            @Override
            public Double evaluate(final Double x1, final DoubleMatrix1D parameters) {
                final double a = parameters.getEntry(0);
                final double b = parameters.getEntry(1);
                final double c = parameters.getEntry(2);
                return a + b * x1 + c * x1 * x1;
            }
        };

        //TODO replace this with an explicit polynomial fitter
        final NonLinearLeastSquare ls = new NonLinearLeastSquare();
        final LeastSquareResults lsRes = ls.solve(new DoubleMatrix1D(x), new DoubleMatrix1D(impliedVols), func,
                new DoubleMatrix1D(0.1, 0.0, 0.0));
        final DoubleMatrix1D fitP = lsRes.getFitParameters();
        return fitP;
    }

    private void validateStrikes(final double[] strikes) {
        final int n = strikes.length;
        for (int i = 1; i < n; i++) {
            ArgumentChecker.isTrue(strikes[i] > strikes[i - 1],
                    "strikes must be in ascending order; have {} (element {}) and {} (element {})", strikes[i - 1],
                    i - 1, strikes[i], i);
        }
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result + _model.hashCode();
        result = prime * result + _weightingFunction.hashCode();
        return result;
    }

    @Override
    public boolean equals(final Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null) {
            return false;
        }
        if (getClass() != obj.getClass()) {
            return false;
        }
        final SmileInterpolator<?> other = (SmileInterpolator<?>) obj;
        if (!ObjectUtils.equals(_model, other._model)) {
            return false;
        }
        if (!ObjectUtils.equals(_weightingFunction, other._weightingFunction)) {
            return false;
        }
        return true;
    }

}