uk.ac.diamond.scisoft.analysis.fitting.functions.Polynomial.java Source code

Java tutorial

Introduction

Here is the source code for uk.ac.diamond.scisoft.analysis.fitting.functions.Polynomial.java

Source

/*-
 * Copyright (c) 2012 Diamond Light Source Ltd.
 *
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/legal/epl-v10.html
 */

package uk.ac.diamond.scisoft.analysis.fitting.functions;

import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.math3.complex.Complex;
import org.ddogleg.solver.PolynomialOps;
import org.ddogleg.solver.PolynomialRoots;
import org.ddogleg.solver.RootFinderType;
import org.eclipse.dawnsci.analysis.api.fitting.functions.IParameter;
import org.eclipse.dawnsci.analysis.dataset.impl.Dataset;
import org.eclipse.dawnsci.analysis.dataset.impl.DatasetUtils;
import org.eclipse.dawnsci.analysis.dataset.impl.DoubleDataset;
import org.eclipse.dawnsci.analysis.dataset.impl.Maths;
import org.ejml.data.Complex64F;

/**
 * Class that wrappers the equation <br>
 * y(x) = a_0 x^n + a_1 x^(n-1) + a_2 x^(n-2) + ... + a_(n-1) x + a_n
 */
public class Polynomial extends AFunction {
    private static final String NAME = "Polynomial";
    private static final String DESC = "A polynomial of degree n."
            + "\n    y(x) = a_0 x^n + a_1 x^(n-1) + a_2 x^(n-2) + ... + a_(n-1) x + a_n";
    private transient double[] a;
    private transient int nparams; // actually degree + 1

    /**
     * Basic constructor, not advisable to use
     */
    public Polynomial() {
        this(0);
    }

    /**
     * Make a polynomial of given degree (0 - constant, 1 - linear, 2 - quadratic, etc)
     * @param degree
     */
    public Polynomial(final int degree) {
        super(degree + 1);
    }

    /**
     * Make a polynomial with given parameters
     * @param params
     */
    public Polynomial(double[] params) {
        super(params);
    }

    /**
     * Make a polynomial with given parameters
     * @param params
     */
    public Polynomial(IParameter... params) {
        super(params);
    }

    /**
     * Constructor that allows for the positioning of all the parameter bounds
     * 
     * @param min
     *            minimum boundaries
     * @param max
     *            maximum boundaries
     */
    public Polynomial(double[] min, double[] max) {
        super(0);
        if (min.length != max.length) {
            throw new IllegalArgumentException("Bound arrays must be of equal length");
        }
        nparams = min.length;
        parameters = new Parameter[nparams];
        a = new double[nparams];

        for (int i = 0; i < nparams; i++) {
            a[i] = 0.5 * (min[i] + max[i]);
            parameters[i] = new Parameter(a[i], min[i], max[i]);
        }

        setNames();
    }

    @Override
    protected void setNames() {
        if (isDirty() && nparams < getNoOfParameters()) {
            nparams = getNoOfParameters();
        }
        String[] paramNames = new String[nparams];
        for (int i = 0; i < nparams; i++) {
            paramNames[i] = "a_" + i;
        }

        setNames(NAME, DESC, paramNames);
    }

    private void calcCachedParameters() {
        if (a == null || a.length != nparams) {
            a = new double[nparams];
        }
        for (int i = 0; i < nparams; i++) {
            a[i] = getParameterValue(i);
        }

        setDirty(false);
    }

    @Override
    public double val(double... values) {
        if (isDirty()) {
            calcCachedParameters();
        }
        final double position = values[0];

        double v = a[0];
        for (int i = 1; i < nparams; i++) {
            v = v * position + a[i];
        }
        return v;
    }

    @Override
    public void fillWithValues(DoubleDataset data, CoordinatesIterator it) {
        if (isDirty())
            calcCachedParameters();

        it.reset();
        double[] coords = it.getCoordinates();
        int i = 0;
        double[] buffer = data.getData();
        while (it.hasNext()) {
            double v = a[0];
            double p = coords[0];
            for (int j = 1; j < nparams; j++) {
                v = v * p + a[j];
            }
            buffer[i++] = v;
        }
    }

    @Override
    public double partialDeriv(IParameter parameter, double... position) {
        if (isDuplicated(parameter))
            return super.partialDeriv(parameter, position);

        int i = indexOfParameter(parameter);
        if (i < 0)
            return 0;

        final double pos = position[0];
        final int n = nparams - 1 - i;
        switch (n) {
        case 0:
            return 1.0;
        case 1:
            return pos;
        case 2:
            return pos * pos;
        default:
            return Math.pow(pos, n);
        }
    }

    @Override
    public void fillWithPartialDerivativeValues(IParameter parameter, DoubleDataset data, CoordinatesIterator it) {
        Dataset pos = DatasetUtils.convertToDataset(it.getValues()[0]);

        final int n = nparams - 1 - indexOfParameter(parameter);
        switch (n) {
        case 0:
            data.fill(1);
            break;
        case 1:
            data.setSlice(pos);
            break;
        case 2:
            Maths.square(pos, data);
            break;
        default:
            Maths.power(pos, n, data);
            break;
        }
    }

    /**
     * Create a 2D dataset which contains in each row a coordinate raised to n-th powers.
     * <p>
     * This is for solving the linear least squares problem 
     * @param coords
     * @return matrix
     */
    public DoubleDataset makeMatrix(Dataset coords) {
        final int rows = coords.getSize();
        DoubleDataset matrix = new DoubleDataset(rows, nparams);

        for (int i = 0; i < rows; i++) {
            final double x = coords.getDouble(i);
            double v = 1.0;
            for (int j = nparams - 1; j >= 0; j--) {
                matrix.setItem(v, i, j);
                v *= x;
            }
        }

        return matrix;
    }

    /**
     * Set the degree after a class instantiation
     * @param degree
     */
    public void setDegree(int degree) {
        nparams = degree + 1;
        parameters = createParameters(nparams);
        dirty = true;

        setNames();

        if (parent != null) {
            parent.updateParameters();
        }
    }

    public String getStringEquation() {

        StringBuilder out = new StringBuilder();

        DecimalFormat df = new DecimalFormat("0.#####E0");

        for (int i = nparams - 1; i >= 2; i--) {
            out.append(df.format(parameters[nparams - 1 - i].getValue()));
            out.append(String.format("x^%d + ", i));
        }

        if (nparams >= 2)
            out.append(df.format(parameters[nparams - 2].getValue()) + "x + ");
        if (nparams >= 1)
            out.append(df.format(parameters[nparams - 1].getValue()));

        return out.toString();
    }

    /**
     * Find all roots
     * @return all roots or null if there is any problem finding the roots
     */
    public Complex[] findRoots() {
        if (isDirty()) {
            calcCachedParameters();
        }
        return findRoots(a);
    }

    /**
     * Find all roots
     * @param coeffs
     * @return all roots or null if there is any problem finding the roots
     */
    public static Complex[] findRoots(double... coeffs) {
        double[] reverse = coeffs.clone();
        ArrayUtils.reverse(reverse);
        double max = Double.NEGATIVE_INFINITY;
        for (double r : reverse) {
            max = Math.max(max, Math.abs(r));
        }
        for (int i = 0; i < reverse.length; i++) {
            reverse[i] /= max;
        }

        org.ddogleg.solver.Polynomial p = org.ddogleg.solver.Polynomial.wrap(reverse);
        PolynomialRoots rf = PolynomialOps.createRootFinder(p.computeDegree(), RootFinderType.EVD);
        if (rf.process(p)) {
            // reorder to NumPy's roots output
            List<Complex64F> rts = rf.getRoots();
            Complex[] out = new Complex[rts.size()];
            int i = 0;
            for (Complex64F r : rts) {
                out[i++] = new Complex(r.getReal(), r.getImaginary());
            }
            return sort(out);
        }

        return null;
    }

    private static Complex[] sort(Complex[] values) {
        // reorder to NumPy's roots output
        List<Complex> rts = Arrays.asList(values);
        Collections.sort(rts, new Comparator<Complex>() {
            @Override
            public int compare(Complex o1, Complex o2) {
                double a = o1.getReal();
                double b = o2.getReal();

                double u = 10 * Math.ulp(Math.max(Math.abs(a), Math.abs(b)));
                if (Math.abs(a - b) > u)
                    return a < b ? -1 : 1;

                a = o1.getImaginary();
                b = o2.getImaginary();
                if (a == b)
                    return 0;
                return a < b ? 1 : -1;
            }
        });

        return rts.toArray(new Complex[0]);
    }
}