com.opengamma.strata.pricer.sensitivity.RatesFiniteDifferenceSensitivityCalculator.java Source code

Java tutorial

Introduction

Here is the source code for com.opengamma.strata.pricer.sensitivity.RatesFiniteDifferenceSensitivityCalculator.java

Source

/**
 * Copyright (C) 2015 - present by OpenGamma Inc. and the OpenGamma group of companies
 *
 * Please see distribution for license.
 */
package com.opengamma.strata.pricer.sensitivity;

import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.function.BiFunction;
import java.util.function.Function;

import org.joda.beans.MetaProperty;

import com.google.common.collect.ImmutableMap;
import com.opengamma.strata.basics.currency.Currency;
import com.opengamma.strata.basics.currency.CurrencyAmount;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.tuple.Pair;
import com.opengamma.strata.market.curve.Curve;
import com.opengamma.strata.market.param.CurrencyParameterSensitivities;
import com.opengamma.strata.pricer.DiscountFactors;
import com.opengamma.strata.pricer.SimpleDiscountFactors;
import com.opengamma.strata.pricer.ZeroRateDiscountFactors;
import com.opengamma.strata.pricer.bond.ImmutableLegalEntityDiscountingProvider;
import com.opengamma.strata.pricer.bond.LegalEntityDiscountingProvider;
import com.opengamma.strata.pricer.rate.ImmutableRatesProvider;
import com.opengamma.strata.pricer.rate.RatesProvider;

/**
 * Computes the curve parameter sensitivity by finite difference.
 * <p>
 * This is based on an {@link ImmutableRatesProvider} or {@link ImmutableLegalEntityDiscountingProvider}, 
 * and calculates the sensitivity by finite difference.
 */
public class RatesFiniteDifferenceSensitivityCalculator {

    /**
     * Default implementation. The shift is one basis point (0.0001).
     */
    public static final RatesFiniteDifferenceSensitivityCalculator DEFAULT = new RatesFiniteDifferenceSensitivityCalculator(
            1.0E-4);

    /**
     * The shift used for finite difference.
     */
    private final double shift;

    /**
     * Create an instance of the finite difference calculator.
     * 
     * @param shift  the shift used in the finite difference computation
     */
    public RatesFiniteDifferenceSensitivityCalculator(double shift) {
        this.shift = shift;
    }

    //-------------------------------------------------------------------------
    /**
     * Computes the first order sensitivities of a function of a RatesProvider to a double by finite difference.
     * <p>
     * The finite difference is computed by forward type.
     * The function should return a value in the same currency for any rate provider.
     * 
     * @param provider  the rates provider
     * @param valueFn  the function from a rate provider to a currency amount for which the sensitivity should be computed
     * @return the curve sensitivity
     */
    public CurrencyParameterSensitivities sensitivity(RatesProvider provider,
            Function<ImmutableRatesProvider, CurrencyAmount> valueFn) {

        ImmutableRatesProvider immProv = provider.toImmutableRatesProvider();
        CurrencyAmount valueInit = valueFn.apply(immProv);
        CurrencyParameterSensitivities discounting = sensitivity(immProv, immProv.getDiscountCurves(),
                (base, bumped) -> base.toBuilder().discountCurves(bumped).build(), valueFn, valueInit);
        CurrencyParameterSensitivities forward = sensitivity(immProv, immProv.getIndexCurves(),
                (base, bumped) -> base.toBuilder().indexCurves(bumped).build(), valueFn, valueInit);
        return discounting.combinedWith(forward);
    }

    // computes the sensitivity with respect to the curves
    private <T> CurrencyParameterSensitivities sensitivity(ImmutableRatesProvider provider,
            Map<T, Curve> baseCurves,
            BiFunction<ImmutableRatesProvider, Map<T, Curve>, ImmutableRatesProvider> storeBumpedFn,
            Function<ImmutableRatesProvider, CurrencyAmount> valueFn, CurrencyAmount valueInit) {

        CurrencyParameterSensitivities result = CurrencyParameterSensitivities.empty();
        for (Entry<T, Curve> entry : baseCurves.entrySet()) {
            Curve curve = entry.getValue();
            DoubleArray sensitivity = DoubleArray.of(curve.getParameterCount(), i -> {
                Curve dscBumped = curve.withParameter(i, curve.getParameter(i) + shift);
                Map<T, Curve> mapBumped = new HashMap<>(baseCurves);
                mapBumped.put(entry.getKey(), dscBumped);
                ImmutableRatesProvider providerDscBumped = storeBumpedFn.apply(provider, mapBumped);
                return (valueFn.apply(providerDscBumped).getAmount() - valueInit.getAmount()) / shift;
            });
            result = result.combinedWith(curve.createParameterSensitivity(valueInit.getCurrency(), sensitivity));
        }
        return result;
    }

    //-------------------------------------------------------------------------
    /**
     * Computes the first order sensitivities of a function of a LegalEntityDiscountingProvider to a double by finite difference.
     * <p>
     * The finite difference is computed by forward type.
     * The function should return a value in the same currency for any rates provider of LegalEntityDiscountingProvider.
     * 
     * @param provider  the rates provider
     * @param valueFn  the function from a rate provider to a currency amount for which the sensitivity should be computed
     * @return the curve sensitivity
     */
    public CurrencyParameterSensitivities sensitivity(LegalEntityDiscountingProvider provider,
            Function<ImmutableLegalEntityDiscountingProvider, CurrencyAmount> valueFn) {

        ImmutableLegalEntityDiscountingProvider immProv = provider.toImmutableLegalEntityDiscountingProvider();
        CurrencyAmount valueInit = valueFn.apply(immProv);
        CurrencyParameterSensitivities discounting = sensitivity(immProv, valueFn,
                ImmutableLegalEntityDiscountingProvider.meta().repoCurves(), valueInit);
        CurrencyParameterSensitivities forward = sensitivity(immProv, valueFn,
                ImmutableLegalEntityDiscountingProvider.meta().issuerCurves(), valueInit);
        return discounting.combinedWith(forward);
    }

    private <T> CurrencyParameterSensitivities sensitivity(ImmutableLegalEntityDiscountingProvider provider,
            Function<ImmutableLegalEntityDiscountingProvider, CurrencyAmount> valueFn,
            MetaProperty<ImmutableMap<Pair<T, Currency>, DiscountFactors>> metaProperty, CurrencyAmount valueInit) {

        ImmutableMap<Pair<T, Currency>, DiscountFactors> baseCurves = metaProperty.get(provider);
        CurrencyParameterSensitivities result = CurrencyParameterSensitivities.empty();
        for (Pair<T, Currency> key : baseCurves.keySet()) {
            DiscountFactors discountFactors = baseCurves.get(key);
            Curve curve = checkDiscountFactors(discountFactors);
            int paramCount = curve.getParameterCount();
            double[] sensitivity = new double[paramCount];
            for (int i = 0; i < paramCount; i++) {
                Curve dscBumped = curve.withParameter(i, curve.getParameter(i) + shift);
                Map<Pair<T, Currency>, DiscountFactors> mapBumped = new HashMap<>(baseCurves);
                mapBumped.put(key, createDiscountFactors(discountFactors, dscBumped));
                ImmutableLegalEntityDiscountingProvider providerDscBumped = provider.toBuilder()
                        .set(metaProperty, mapBumped).build();
                sensitivity[i] = (valueFn.apply(providerDscBumped).getAmount() - valueInit.getAmount()) / shift;
            }
            result = result.combinedWith(
                    curve.createParameterSensitivity(valueInit.getCurrency(), DoubleArray.copyOf(sensitivity)));
        }
        return result;
    }

    //-------------------------------------------------------------------------
    // check that the discountFactors is ZeroRateDiscountFactors or SimpleDiscountFactors
    private Curve checkDiscountFactors(DiscountFactors discountFactors) {
        if (discountFactors instanceof ZeroRateDiscountFactors) {
            return ((ZeroRateDiscountFactors) discountFactors).getCurve();
        } else if (discountFactors instanceof SimpleDiscountFactors) {
            return ((SimpleDiscountFactors) discountFactors).getCurve();
        }
        throw new IllegalArgumentException("Not supported");
    }

    // return correct instance of DiscountFactors
    private DiscountFactors createDiscountFactors(DiscountFactors originalDsc, Curve bumpedCurve) {
        if (originalDsc instanceof ZeroRateDiscountFactors) {
            return ZeroRateDiscountFactors.of(originalDsc.getCurrency(), originalDsc.getValuationDate(),
                    bumpedCurve);
        } else if (originalDsc instanceof SimpleDiscountFactors) {
            return SimpleDiscountFactors.of(originalDsc.getCurrency(), originalDsc.getValuationDate(), bumpedCurve);
        }
        throw new IllegalArgumentException("Not supported");
    }

}