Java tutorial
/* * Copyright (c) 2012 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.analysis.fitting.functions; import java.io.Serializable; import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.Arrays; import org.apache.commons.math3.analysis.MultivariateFunction; import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionPenaltyAdapter; import org.eclipse.dawnsci.analysis.api.dataset.IDataset; import org.eclipse.dawnsci.analysis.api.fitting.functions.IFunction; import org.eclipse.dawnsci.analysis.api.fitting.functions.IOperator; import org.eclipse.dawnsci.analysis.api.fitting.functions.IParameter; import org.eclipse.dawnsci.analysis.api.monitor.IMonitor; import org.eclipse.dawnsci.analysis.dataset.impl.Comparisons; import org.eclipse.dawnsci.analysis.dataset.impl.Dataset; import org.eclipse.dawnsci.analysis.dataset.impl.DatasetUtils; import org.eclipse.dawnsci.analysis.dataset.impl.DoubleDataset; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Base abstract class for IFunction implementation. At a minimum, the fillWithValues() method needs * to be added. The fillWithPartialDerivativeValues() and/or calculatePartialDerivativeValues() * methods can be overridden if exact derivatives are needed. * * Note, if the implemented function can alter the number of parameters then it should call its * parent operator's update parameters method. */ public abstract class AFunction implements IFunction, Serializable { /** * Setup the logging facilities */ private static transient final Logger logger = LoggerFactory.getLogger(AFunction.class); /** * The array of parameters which specify all the variables in the minimisation problem */ protected IParameter[] parameters; /** * The name of the function, a description more than anything else. */ protected String name = "default"; /** * The description of the function */ protected String description = "default"; protected boolean dirty = true; protected IMonitor monitor = null; protected IOperator parent; /** * Constructor which simply generates the parameters but uninitialised * * @param numberOfParameters */ public AFunction(int numberOfParameters) { parameters = new Parameter[numberOfParameters]; for (int i = 0; i < numberOfParameters; i++) { parameters[i] = new Parameter(); } } /** * Constructor which takes a list of parameter values as its starting configuration * * @param params * An array of starting parameter values as doubles. */ public AFunction(double... params) { if (params != null) fillParameters(params); } /** * Constructor which is given a set of parameters to begin with. * * @param params * An array of parameters */ public AFunction(IParameter... params) { if (params != null) fillParameters(params); } protected void fillParameters(double... params) { if (parameters == null) parameters = new IParameter[params.length]; int n = Math.min(params.length, parameters.length); for (int i = 0; i < n; i++) { parameters[i] = new Parameter(params[i]); } } protected void fillParameters(IParameter... params) { if (parameters == null) parameters = new IParameter[params.length]; int n = Math.min(params.length, parameters.length); for (int i = 0; i < n; i++) { IParameter p = params[i]; parameters[i] = new Parameter(p); } } /** * @param function * @param parameter * @return index of parameter or -1 if parameter is not in function */ public static int indexOfParameter(IFunction function, IParameter parameter) { if (function == null || parameter == null) return -1; if (function instanceof AFunction) return ((AFunction) function).indexOfParameter(parameter); for (int j = 0, jmax = function.getNoOfParameters(); j < jmax; j++) { if (parameter == function.getParameter(j)) { return j; } } return -1; } /** * @param parameter * @return index of parameter or -1 if parameter is not in function */ protected int indexOfParameter(IParameter parameter) { for (int i = 0; i < parameters.length; i++) { if (parameter == parameters[i]) { return i; } } return -1; } @Override public String getName() { return name; } @Override public void setName(String newName) { name = newName; } @Override public String getDescription() { return description; } @Override public void setDescription(String newDescription) { description = newDescription; } @Override public IParameter getParameter(int index) { return parameters[index]; } @Override public IParameter[] getParameters() { IParameter[] params = new IParameter[parameters.length]; for (int i = 0; i < parameters.length; i++) { params[i] = parameters[i]; } return params; } @Override public int getNoOfParameters() { return parameters.length; } @Override public double getParameterValue(int index) { return parameters[index].getValue(); } @Override final public double[] getParameterValues() { int n = getNoOfParameters(); double[] result = new double[n]; for (int j = 0; j < n; j++) { result[j] = getParameterValue(j); } return result; } @Override public void setParameter(int index, IParameter parameter) { if (indexOfParameter(parameter) == index) return; parameters[index] = parameter; dirty = true; } @Override public void setParameterValues(double... params) { int nparams = Math.min(params.length, parameters.length); for (int j = 0; j < nparams; j++) { parameters[j].setValue(params[j]); } dirty = true; } @Override public String toString() { StringBuffer out = new StringBuffer(); int n = getNoOfParameters(); out.append(String.format("'%s' has %d parameters:\n", name, n)); for (int i = 0; i < n; i++) { IParameter p = getParameter(i); out.append(String.format("%d) %s = %g in range [%g, %g]\n", i, p.getName(), p.getValue(), p.getLowerLimit(), p.getUpperLimit())); } return out.toString(); } @Override @Deprecated public double partialDeriv(int index, double... values) { return partialDeriv(getParameter(index), values); } /** * This implementation is a numerical approximation. Overriding methods should check * for duplicated parameters before doing any calculation and either cope with this * or use this numerical approximation */ @Override public double partialDeriv(IParameter parameter, double... values) { if (indexOfParameter(parameter) < 0) return 0; return calcNumericalDerivative(A_TOLERANCE, R_TOLERANCE, parameter, values); } /** * @param param * @return true if there is more than one occurrence of given parameter in function */ protected boolean isDuplicated(IParameter param) { int c = 0; int n = getNoOfParameters(); for (int i = 0; i < n; i++) { if (getParameter(i) == param) { c++; return c > 1; } } return false; } private final static double DELTA = 1 / 256.; // initial value private final static double DELTA_FACTOR = 0.25; protected final static double A_TOLERANCE = 1e-9; // absolute tolerance protected final static double R_TOLERANCE = 1e-9; // relative tolerance /** * @param abs * @param rel * @param param * @param values * @return partial derivative up to tolerances */ protected double calcNumericalDerivative(double abs, double rel, IParameter param, double... values) { double delta = DELTA; double previous = numericalDerivative(delta, param, values); double aprevious = Math.abs(previous); double current = 0; double acurrent = 0; while (delta > Double.MIN_NORMAL) { delta *= DELTA_FACTOR; current = numericalDerivative(delta, param, values); acurrent = Math.abs(current); if (Math.abs(current - previous) <= abs + rel * Math.max(acurrent, aprevious)) break; previous = current; aprevious = acurrent; } return current; } /** * Calculate partial derivative. This is a numerical approximation. * @param param * @param values * @return partial derivative */ private double numericalDerivative(double delta, IParameter param, double... values) { double v = param.getValue(); double dv = delta * (v != 0 ? v : 1); param.setValue(v - dv); dirty = true; double minval = val(values); param.setValue(v + dv); dirty = true; double maxval = val(values); param.setValue(v); dirty = true; return (maxval - minval) / (2. * dv); } @Override public DoubleDataset makeDataset(IDataset... values) { return calculateValues(values); } /** * @param coords * @return a coordinate iterator */ final static public CoordinatesIterator createIterator(IDataset... coords) { if (coords == null || coords.length == 0) { logger.error("No coordinates given to evaluate function"); throw new IllegalArgumentException("No coordinates given to evaluate function"); } CoordinatesIterator it; int[] shape = coords[0].getShape(); if (coords.length == 1) { it = coords[0].getElementsPerItem() == 1 ? new DatasetsIterator(coords) : new CoordinateDatasetIterator(coords[0]); } else { boolean same = true; for (int i = 1; i < shape.length; i++) { if (!Arrays.equals(shape, coords[i].getShape())) { same = false; break; } } if (same && shape.length == 1) // override for 1D datasets same = false; it = same ? new DatasetsIterator(coords) : new HypergridIterator(coords); } return it; } final public CoordinatesIterator getIterator(IDataset... coords) { return createIterator(coords); } @Override public DoubleDataset calculateValues(IDataset... coords) { CoordinatesIterator it = getIterator(coords); DoubleDataset result = new DoubleDataset(it.getShape()); fillWithValues(result, it); result.setName(name); return result; } @Override public DoubleDataset calculatePartialDerivativeValues(IParameter parameter, IDataset... coords) { CoordinatesIterator it = getIterator(coords); DoubleDataset result = new DoubleDataset(it.getShape()); if (indexOfParameter(parameter) >= 0) internalFillWithPartialDerivativeValues(parameter, result, it); result.setName(name); return result; } private void internalFillWithPartialDerivativeValues(IParameter parameter, DoubleDataset data, CoordinatesIterator it) { if (isDuplicated(parameter)) { calcNumericalDerivativeDataset(A_TOLERANCE, R_TOLERANCE, parameter, data, it); } else { fillWithPartialDerivativeValues(parameter, data, it); } } /** * Fill dataset with values. Implementations should reset the iterator before use * @param data * @param it */ abstract public void fillWithValues(DoubleDataset data, CoordinatesIterator it); /** * Fill dataset with partial derivatives. Implementations should reset the iterator before use * <p> * This implementation is a numerical approximation. * <p> * Note that is called only if there are no duplicated parameters otherwise, * a numerical approximation is used. To change this behaviour, also override * {@link #calculatePartialDerivativeValues(IParameter, IDataset...)} * @param parameter * @param data * @param it */ public void fillWithPartialDerivativeValues(IParameter parameter, DoubleDataset data, CoordinatesIterator it) { calcNumericalDerivativeDataset(A_TOLERANCE, R_TOLERANCE, parameter, data, it); } private static final double SMALLEST_DELTA = Double.MIN_NORMAL * 1024 * 1024; /** * Calculate partial derivatives up to tolerances * @param abs * @param rel * @param param * @param data * @param it */ protected void calcNumericalDerivativeDataset(double abs, double rel, IParameter param, DoubleDataset data, CoordinatesIterator it) { DoubleDataset previous = new DoubleDataset(it.getShape()); double delta = DELTA; fillWithNumericalDerivativeDataset(delta, param, previous, it); DoubleDataset current = new DoubleDataset(it.getShape()); while (delta > SMALLEST_DELTA) { delta *= DELTA_FACTOR; fillWithNumericalDerivativeDataset(delta, param, current, it); if (Comparisons.allCloseTo(previous, current, rel, abs)) break; DoubleDataset temp = previous; previous = current; current = temp; } if (delta <= SMALLEST_DELTA) { logger.warn("Numerical derivative did not converge!"); } data.setSlice(current); } /** * Calculate partial derivative. This is a numerical approximation. * @param delta * @param param * @param data * @param it */ private void fillWithNumericalDerivativeDataset(double delta, IParameter param, DoubleDataset data, CoordinatesIterator it) { double v = param.getValue(); double dv = delta * (v != 0 ? v : 1); param.setValue(v + dv); dirty = true; fillWithValues(data, it); it.reset(); param.setValue(v - dv); dirty = true; DoubleDataset temp = new DoubleDataset(it.getShape()); fillWithValues(temp, it); data.isubtract(temp); data.imultiply(0.5 / dv); param.setValue(v); dirty = true; } /** * @return true if any parameters have changed */ public boolean isDirty() { return dirty; } @Override public void setDirty(boolean isDirty) { dirty = isDirty; } @Override public double residual(boolean allValues, IDataset data, IDataset weight, IDataset... coords) { double residual = 0; if (allValues) { DoubleDataset ddata = (DoubleDataset) DatasetUtils.convertToDataset(data).cast(Dataset.FLOAT64); residual = ddata.residual(calculateValues(coords), DatasetUtils.convertToDataset(weight), false); } else { // stochastic sampling of coords; // int NUMBER_OF_SAMPLES = 100; //TODO logger.error("Stochastic sampling has not been implemented yet"); throw new UnsupportedOperationException("Stochastic sampling has not been implemented yet"); } if (monitor != null) { monitor.worked(1); if (monitor.isCancelled()) { throw new IllegalMonitorStateException("Monitor cancelled"); } } return residual; } @Override @Deprecated public double residual(boolean allValues, IDataset data, IDataset... coords) { return residual(allValues, data, null, coords); } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + (dirty ? 1231 : 1237); result = prime * result + ((name == null) ? 0 : name.hashCode()); result = prime * result + Arrays.hashCode(parameters); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; AFunction other = (AFunction) obj; if (dirty != other.dirty) return false; if (name == null) { if (other.name != null) return false; } else if (!name.equals(other.name)) return false; if (!Arrays.equals(parameters, other.parameters)) return false; return true; } @Override public AFunction copy() throws Exception { Constructor<? extends AFunction> c = getClass().getConstructor(); IParameter[] localParameters = getParameters(); AFunction function = c.newInstance(); function.fillParameters(localParameters); return function; } @Override public IMonitor getMonitor() { return monitor; } @Override public void setMonitor(IMonitor monitor) { this.monitor = monitor; } @Override public boolean isValid() { return true; } @Override public IOperator getParentOperator() { return parent; } @Override public void setParentOperator(IOperator parent) { this.parent = parent; } /** * Generate a Apache MultivariateFunctionPenaltyAdapter from the function * @param inputValues A dataset containing the data values for the optimisation * @param inputCoords A dataset containing the coordinates for the optimization * @return the bounded MultivariateFunctionPenaltyAdapter */ public MultivariateFunctionPenaltyAdapter getApacheMultivariateFunction(IDataset inputValues, IDataset[] inputCoords) { final AFunction function = this; final IDataset values = inputValues; final IDataset[] coords = inputCoords; MultivariateFunction multivariateFunction = new MultivariateFunction() { @Override public double value(double[] arg0) { function.setParameterValuesNoFixed(arg0); double result = function.residual(true, values, null, coords); return result; } }; double offset = 1e12; double[] lowerb = getLowerBoundsNoFixed(); double[] upperb = getUpperBoundsNoFixed(); double[] scale = new double[lowerb.length]; for (int i = 0; i < scale.length; i++) { scale[i] = offset * 0.25; } MultivariateFunctionPenaltyAdapter multivariateFunctionPenaltyAdapter = new MultivariateFunctionPenaltyAdapter( multivariateFunction, lowerb, upperb, offset, scale); return multivariateFunctionPenaltyAdapter; } /** * Get the parameter values as an array, excluding parameters which are fixed * @return a double[] of non fixed parameter values */ public double[] getParameterValuesNoFixed() { ArrayList<Double> values = new ArrayList<Double>(); for (int i = 0; i < getNoOfParameters(); i++) { if (getParameter(i).isFixed() == false) { values.add(getParameter(i).getValue()); } } double[] start = new double[values.size()]; for (int i = 0; i < start.length; i++) { start[i] = values.get(i); } return start; } /** * Get the parameter upper bounds as an array, excluding parameters which are fixed * @return a double[] of non fixed parameter upper bounds */ public double[] getUpperBoundsNoFixed() { ArrayList<Double> values = new ArrayList<Double>(); for (int i = 0; i < getNoOfParameters(); i++) { if (getParameter(i).isFixed() == false) { values.add(getParameter(i).getUpperLimit()); } } double[] start = new double[values.size()]; for (int i = 0; i < start.length; i++) { start[i] = values.get(i); } return start; } /** * Get the parameter lower bounds as an array, excluding parameters which are fixed * @return a double[] of non fixed parameter lower bounds */ public double[] getLowerBoundsNoFixed() { ArrayList<Double> values = new ArrayList<Double>(); for (int i = 0; i < getNoOfParameters(); i++) { if (getParameter(i).isFixed() == false) { values.add(getParameter(i).getLowerLimit()); } } double[] start = new double[values.size()]; for (int i = 0; i < start.length; i++) { start[i] = values.get(i); } return start; } /** * Set the values of all non fixed parameters * @param values */ public void setParameterValuesNoFixed(double[] values) { int argpos = 0; for (int i = 0; i < getNoOfParameters(); i++) { if (getParameter(i).isFixed() == false) { getParameter(i).setValue(values[argpos]); argpos++; } } setDirty(true); } }