edu.duke.cs.osprey.tupexp.IterativeCGTupleFitter.java Source code

Java tutorial

Introduction

Here is the source code for edu.duke.cs.osprey.tupexp.IterativeCGTupleFitter.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package edu.duke.cs.osprey.tupexp;

import java.util.ArrayList;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.ConjugateGradient;
import org.apache.commons.math3.linear.RealLinearOperator;
import org.apache.commons.math3.linear.RealVector;

/**
 *
 * Modified least-squares fitter for LUTE coefficients
 * Uses iterations of conjugate gradient
 * 
 * @author mhall44
 */
public class IterativeCGTupleFitter extends CGTupleFitter {

    /*RealLinearOperator AtA;
    RealVector Atb;
    int numSamp, numTup;
    TupleIndexMatrix tupIndMat;
    ArrayList<int[]> samples;*/
    ArrayList<double[]> goodRegionBounds;//bounds on the "good" region of fit values
    //for each sample

    double[] curFitVals = null;//value of fit at each sample before current iteration
    RealVector curCoeffs = null;
    double curResid = Double.POSITIVE_INFINITY;

    //Initial fit: A_samp,tup = 1_tupinsamp * 1_sampoverbounds
    //Hence AtA = sum_samp_in_bounds 1_tup_in_samp 1_tup2_in_samp
    //and Atb = sum_samp (tup_in_samp * samp_constr_broken)
    //indeed any fit is that way  hahahaha

    double damperLambda = 1e-4;//Some tuples will have no equality, only inequality
    //restraints, and thus may revert to 0 when moved to the 
    //inactive restraint set, preventing convergence.  A slight penalty on coeff changes
    //will prevent this

    public IterativeCGTupleFitter(TupleIndexMatrix tim, ArrayList<int[]> samp, int numTuples,
            ArrayList<double[]> goodRegionBounds) {
        //We'll fit the specified (sample,trueVal) pairs to an expansion in the tuples in tim

        samples = samp;
        numSamp = samples.size();
        numTup = numTuples;
        tupIndMat = tim;

        this.goodRegionBounds = goodRegionBounds;

        AtA = new RealLinearOperator() {

            @Override
            public int getRowDimension() {
                return numTup;
            }

            @Override
            public int getColumnDimension() {
                return numTup;
            }

            @Override
            public RealVector operate(RealVector rv) throws DimensionMismatchException {
                return applyAtA(rv);
            }
        };
    }

    double[] calcFitVals(RealVector rv) {
        double[] fitVals = new double[numSamp];
        for (int s = 0; s < numSamp; s++) {
            //if(isSampleRestrained(s)){//we need these for validation, if not for AtA aplication
            ArrayList<Integer> sampTup = tupIndMat.calcSampleTuples(samples.get(s));
            for (int t : sampTup)
                fitVals[s] += rv.getEntry(t);
            //}
        }

        return fitVals;
    }

    RealVector applyAtA(RealVector rv) {
        //first apply A
        double Arv[] = calcFitVals(rv);

        //then apply A^T to Arv
        double ans[] = new double[numTup];
        for (int s = 0; s < numSamp; s++) {
            if (isSampleRestrained(s)) {
                ArrayList<Integer> sampTup = tupIndMat.calcSampleTuples(samples.get(s));
                for (int t : sampTup)
                    ans[t] += Arv[s];
            }
        }

        //damping
        if (curCoeffs != null) {//not first iteration
            for (int t = 0; t < numTup; t++)
                ans[t] += damperLambda * rv.getEntry(t);
        }

        return new ArrayRealVector(ans, false);//make RealVector without copying ans
    }

    double getCurTarget(int s) {//get current "target" value (for iteration of modified lsq)
        //for the given sample.  NaN if not applicable (restraints inactive for sample)
        double bounds[] = goodRegionBounds.get(s);
        if (curFitVals == null) {//First iteration.  Assume
            //samples with any leeway don't have a target
            if (bounds[0] == bounds[1])
                return bounds[0];
            else
                return Double.NaN;
        } else {
            double curFitVal = curFitVals[s];
            if (curFitVal < bounds[0])//lower restraint active
                return bounds[0];
            else if (curFitVal > bounds[1])//upper
                return bounds[1];
            else//neither
                return Double.NaN;
        }
    }

    boolean isSampleRestrained(int s) {
        double bounds[] = goodRegionBounds.get(s);
        if (curFitVals == null) {//First iteration.  Assume
            //samples with any leeway don't have a target
            return (bounds[0] == bounds[1]);
        } else {
            double curFitVal = curFitVals[s];
            return (curFitVal < bounds[0] || curFitVal > bounds[1]);
        }
    }

    RealVector calcRHS() {//Calculate right-hand side vector of normal equations
        double atb[] = new double[numTup];
        //apply A^T to true vals
        for (int s = 0; s < numSamp; s++) {
            double curTarget = getCurTarget(s);
            if (!Double.isNaN(curTarget)) {//restraint active for sample
                ArrayList<Integer> sampTup = tupIndMat.calcSampleTuples(samples.get(s));
                for (int t : sampTup)
                    atb[t] += curTarget;
            }
        }

        //damping.  Slightly penalizes changes from curCoeffs
        if (curCoeffs != null) {
            for (int t = 0; t < numTup; t++)
                atb[t] += damperLambda * curCoeffs.getEntry(t);
        }

        Atb = new ArrayRealVector(atb);
        return Atb;
    }

    double calcResidual(double[] fitVals) {
        double resid = 0;

        for (int s = 0; s < numSamp; s++) {
            double dev = 0;
            double bounds[] = goodRegionBounds.get(s);

            if (fitVals[s] < bounds[0])
                dev = fitVals[s] - bounds[0];
            else if (fitVals[s] > bounds[1])
                dev = fitVals[s] - bounds[1];

            resid += dev * dev;
        }

        return resid / numSamp;
    }

    boolean checkDone(double[] oldFitVals, double[] newFitVals) {
        //see if we crossed any restraint boundaries between two consecutive iterations
        //(by a numerically significant amount)

        if (oldFitVals == null)//first iteration, can't be done (no comparison to do)
            return false;

        double tol = 1e-6;

        for (int s = 0; s < numSamp; s++) {
            double bounds[] = goodRegionBounds.get(s);
            if (bounds[0] < bounds[1]) {//there are boundaries for this sample
                //(i.e. not pure quadratic objective function term)
                if (oldFitVals[s] < bounds[0]) {//check if lower restraint disappeared
                    if (newFitVals[s] > bounds[0] + tol)
                        return false;
                } else if (oldFitVals[s] > bounds[1]) {//upper
                    if (newFitVals[s] < bounds[1] - tol)
                        return false;
                } else {//no restraint now...see if one appeared
                    if (newFitVals[s] < bounds[0] - tol || newFitVals[s] > bounds[1] + tol)
                        return false;
                }
            }
        }

        return true;
    }

    @Override
    double[] doFit() {
        //return fit tuple coefficients

        ConjugateGradient cg = new ConjugateGradient(100000, 1e-6, false);//max_iter; delta; whether to check pos def
        //delta is target ratio of residual norm to true vals norm

        long startTime = System.currentTimeMillis();

        while (true) {
            double iterStartTime = System.currentTimeMillis();

            Atb = calcRHS();
            RealVector ans = cg.solve(AtA, Atb);
            double[] newFitVals = calcFitVals(ans);

            System.out.println(
                    "Conjugate gradient fitting time (ms): " + (System.currentTimeMillis() - iterStartTime));

            //boolean done = checkDone(curFitVals, newFitVals);
            double resid = calcResidual(newFitVals);
            System.out.println("Step residual: " + resid);

            if (resid > curResid) {//gotten worse...use previous vals
                System.out.println("Iterative conjugate gradient fitting time (ms): "
                        + (System.currentTimeMillis() - startTime));
                return curCoeffs.toArray();
            } else if (resid > curResid - 1e-4) {//basically converged
                System.out.println("Iterative conjugate gradient fitting time (ms): "
                        + (System.currentTimeMillis() - startTime));
                return ans.toArray();
            } else {//keep going
                curCoeffs = ans;
                curFitVals = newFitVals;
                curResid = resid;
            }
        }
    }
}