beast.structuredCoalescent.distribution.Masco.java Source code

Java tutorial

Introduction

Here is the source code for beast.structuredCoalescent.distribution.Masco.java

Source

/*
 * Copyright (C) 2016 Nicola Felix Mueller (nicola.felix.mueller@gmail.com)
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package beast.structuredCoalescent.distribution;

import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math3.ode.FirstOrderDifferentialEquations;
import org.apache.commons.math3.ode.FirstOrderIntegrator;
import org.apache.commons.math3.ode.nonstiff.ClassicalRungeKuttaIntegrator;
import org.jblas.DoubleMatrix;

import beast.core.Description;
import beast.core.Input;
import beast.core.Loggable;
import beast.core.parameter.BooleanParameter;
import beast.core.parameter.IntegerParameter;
import beast.core.parameter.RealParameter;
import beast.evolution.tree.Node;
import beast.evolution.tree.TraitSet;
import beast.evolution.tree.coalescent.IntervalType;
import beast.structuredCoalescent.math.ode_masco;

/**
 * @author Nicola Felix Mueller
 */

@Description("Calculate the probability of a tree under the structured coalescent assuming lineage independence with constant rates"
        + " as described in Mueller et. al.,2016. The Input rates are backwards in time migration rates"
        + " and pairwise coalescent rates translating to 1/Ne for m=1 in the Wright Fisher model")
public class Masco extends StructuredTreeDistribution implements Loggable {

    public Input<TraitSet> typeTraitInput = new Input<>("typeTrait", "Type trait set.");

    public Input<RealParameter> timeStepInput = new Input<RealParameter>("timeStep",
            "how strong the seasonal patterns are, higher means less seasonal");

    public Input<RealParameter> coalescentRatesInput = new Input<RealParameter>("coalescentRates",
            "how strong the seasonal patterns are, higher means less seasonal", Input.Validate.REQUIRED);

    public Input<RealParameter> migrationRatesInput = new Input<RealParameter>("migrationRates",
            "how strong the seasonal patterns are, higher means less seasonal", Input.Validate.REQUIRED);

    public Input<IntegerParameter> dim = new Input<IntegerParameter>("dim",
            "how strong the seasonal patterns are, higher means less seasonal", Input.Validate.REQUIRED);

    public Input<BooleanParameter> propTimeStepInput = new Input<BooleanParameter>("proportionalTimeStep",
            "how strong the seasonal patterns are, higher means less seasonal");

    public int samples;
    public int nrSamples;
    public DoubleMatrix[] stateProbabilities;
    public DoubleMatrix[] stateFlow;

    private double max_posterior = Double.NEGATIVE_INFINITY;
    private double[] max_mig;
    private double[] max_coal;

    public int states;

    private boolean dependentHistory = true;
    private boolean traitInput = false;
    private boolean normalize = true;
    private boolean propTimeStep = false;
    private int nr_lineages;
    private double independent_likelihood = 0.0;

    // Standard time Step of the RungeKutta if not specified different
    private double timeStep = 0.000001;

    // Set up for lineage state probabilities
    ArrayList<Integer> activeLineages;
    ArrayList<Double> lineStateProbs;

    @Override
    public void initAndValidate() {
        // Check if the time steps should be done in such a way that they are proportional to the number
        // of lineages present in a time interval
        if (propTimeStepInput.get() != null)
            propTimeStep = propTimeStepInput.get().getValue();

        // Calculate the tree intervals (time between events, which nodes participate at a event etc.)
        treeIntervalsInput.get().calculateIntervals();

        stateProbabilities = new DoubleMatrix[treeIntervalsInput.get().getSampleCount()];
        stateFlow = new DoubleMatrix[2 * treeIntervalsInput.get().getSampleCount() + 1];
        nrSamples = treeIntervalsInput.get().getSampleCount() + 1;

        double states_tmp = dim.get().getValue();
        states = (int) states_tmp;

        if (typeTraitInput.get() != null)
            traitInput = true;
        // Check all the boolean input parameters
        if (timeStepInput.get() != null)
            timeStep = timeStepInput.get().getValue();

        if (!dependentHistory)
            normalize = false;

        calculateLogP();
    }

    public double calculateLogP() {
        // newly calculate tree intervals
        treeIntervalsInput.get().calculateIntervals();
        // correctly calculate the daughter nodes at coalescent intervals in the case of
        // bifurcation or in case two nodes are at the same height
        treeIntervalsInput.get().swap();

        // Set up ArrayLists for the indices of active lineages and the lineage state probabilities
        activeLineages = new ArrayList<Integer>();
        lineStateProbs = new ArrayList<Double>();

        // Compute likelihood at each integration time and tree event starting at final sampling time and moving backwards
        logP = 0;

        // set the current time
        double currTime = 0.0;
        // total number of intervals
        final int intervalCount = treeIntervalsInput.get().getIntervalCount();
        // interval time counter
        int t = 0;
        // initialize the number of lineages
        nr_lineages = 0;
        // Captures the probabilities of lineages being in a state
        double[] p;

        // Initialize the migration rates matrix
        double[][] migration_rates = new double[states][states];
        int c = 0;

        for (int k = 0; k < states; k++) {
            for (int l = 0; l < states; l++) {
                if (k != l) {
                    migration_rates[k][l] = migrationRatesInput.get().getArrayValue(c);
                    c++;
                } else { // diagonal
                    migration_rates[k][l] = 0.0;
                }

            }
        }

        // Initialize the coalescent rates
        double[] coalescent_rates = new double[states];
        for (int k = 0; k < states; k++) {
            coalescent_rates[k] = coalescentRatesInput.get().getArrayValue(k) / 2;//(epiModelInput.get().getF(t,k,k) / (Y.get(k)*Y.get(k)));
        }

        // integrate until there are no more tree intervals
        do {
            double nextIntervalTime = treeIntervalsInput.get().getInterval(t);

            // Length of the current interval
            final double duration = nextIntervalTime;// - currTime;
            // if the current interval has a length greater than 0, integrate
            if (duration > 0) {
                if (dependentHistory)
                    p = new double[lineStateProbs.size()]; // Captures the probabilities of lineages being in a state
                else
                    p = new double[lineStateProbs.size() + 1]; // Captures the probabilities of lineages being in a state, last one keeps track of the probability

                // convert the array list to double[]
                for (int i = 0; i < lineStateProbs.size(); i++)
                    p[i] = lineStateProbs.get(i);

                // not needed
                if (!dependentHistory)
                    p[lineStateProbs.size()] = 1;

                double[] p_for_ode = new double[p.length];
                double ts = 0.0;

                // If proportial time step is true, set the integration time for the given interval 
                // inverse proportional to the number of lineages
                if (propTimeStep)
                    ts = timeStep / lineStateProbs.size();
                else
                    ts = timeStep;

                // Never choose a longer time step than the integration window
                if (duration < (ts / 2))
                    ts = duration / 2;

                FirstOrderIntegrator integrator = new ClassicalRungeKuttaIntegrator(ts);
                // set the odes
                FirstOrderDifferentialEquations ode = new ode_masco(migration_rates, coalescent_rates, nr_lineages,
                        states);
                // integrate                   
                integrator.integrate(ode, 0, p, duration, p_for_ode);

                // If the Dimension is larger than the maximum integer, at least one state prob is below 0 and the step is rejected
                if (ode.getDimension() == Integer.MAX_VALUE) {
                    System.out.println(lineStateProbs.size());
                    System.out.println("lalalallal");
                    return Double.NEGATIVE_INFINITY;
                }

                for (int i = 0; i < lineStateProbs.size(); i++)
                    lineStateProbs.set(i, p_for_ode[i]);
            }

            // update the time
            currTime = nextIntervalTime;
            // event is coalescent event
            if (treeIntervalsInput.get().getIntervalType(t) == IntervalType.COALESCENT) {
                logP += coalesce(t);
                nr_lineages--;
            }

            // event is sampling event
            if (treeIntervalsInput.get().getIntervalType(t) == IntervalType.SAMPLE) {
                logP += normalizeLineages();
                addLineages(t);
                nr_lineages++;
            }

            // update the interval number
            t++;
        } while (t < intervalCount);

        //Compute likelihood of remaining tree intervals (coal events occuring before origin)
        if (Double.isInfinite(logP))
            logP = Double.NEGATIVE_INFINITY;
        if (max_posterior < logP && logP < 0) {
            max_posterior = logP;
            max_mig = new double[states * (states - 1)];
            max_coal = new double[states];
            for (int i = 0; i < 1; i++)
                max_mig[i] = migrationRatesInput.get().getArrayValue(i);
            for (int i = 0; i < 1; i++)
                max_coal[i] = coalescentRatesInput.get().getArrayValue(i);
        }
        //        System.exit(0);

        return logP;

    }

    private void addLineages(int currTreeInterval) {
        List<Node> incomingLines = treeIntervalsInput.get().getLineagesAdded(currTreeInterval);
        if (traitInput) {
            /*
             * If there is a typeTrait given as Input the model will take this
             * trait as states for the taxons
             */
            for (Node l : incomingLines) {
                activeLineages.add(l.getNr());
                int sampleState = (int) typeTraitInput.get().getValue(l.getID());
                for (int i = 0; i < states; i++)
                    if (i == sampleState)
                        lineStateProbs.add(1.0);
                    else
                        lineStateProbs.add(0.0);
            }
        } else {
            /*
             * If there is no trait given as Input, the model will simply assume that
             * the last value of the taxon name, the last value after a _, is an integer
             * that gives the type of that taxon
             */
            for (Node l : incomingLines) {
                activeLineages.add(l.getNr());
                String sampleID = l.getID();
                String[] splits = sampleID.split("_");
                int sampleState = Integer.parseInt(splits[splits.length - 1]); //samples states (or priors) should eventually be specified in the XML
                for (int i = 0; i < states; i++)
                    if (i == sampleState)
                        lineStateProbs.add(1.0);
                    else
                        lineStateProbs.add(0.0);
            }
        }
    }

    private double normalizeLineages() {
        double interval = 0.0;
        //       System.out.println("");
        for (int i = 0; i < nr_lineages; i++) {
            double lineProbs = 0.0;
            for (int j = 0; j < states; j++)
                if (lineStateProbs.get(i * states + j) >= 0.0) {
                    lineProbs += lineStateProbs.get(i * states + j);
                } else {
                    System.out.println(lineStateProbs);
                    System.err.println("state probability smaller than 0 before normalizing");
                    System.exit(1);
                }
            for (int j = 0; j < states; j++)
                lineStateProbs.set(i * states + j, lineStateProbs.get(states * i + j) / lineProbs);

            interval += lineProbs;
        }
        if (nr_lineages > 0)
            return Math.log(interval / nr_lineages);
        else
            return 0;

    }

    private double coalesce(int currTreeInterval) {
        double interval = 0.0;

        // normalize all lineage states in order to reduce the chance of too small numbers
        // after multiplication of lineage states
        //       if (normalize){       
        //          for (int i = 0; i < nr_lineages; i++){
        //             double lineProbs = 0.0;
        //             for (int j = 0; j < states; j++)
        //                if (lineStateProbs.get(i*states+j)>=0.0){
        //                   lineProbs += lineStateProbs.get(i*states+j);
        //                }else{
        //                   System.out.println(lineStateProbs);
        //                   System.err.println("state probability smaller than 0 before normalizing");
        //                   System.exit(1);
        //                }
        //             for (int j = 0; j < states; j++)
        //                lineStateProbs.set(i*states+j, lineStateProbs.get(states*i+j)/lineProbs);   
        //             
        //             interval += Math.log(lineProbs);
        //          }
        //       }

        interval += normalizeLineages();

        // Get the daughter lineages
        List<Node> coalLines = treeIntervalsInput.get().getLineagesRemoved(currTreeInterval);
        if (coalLines.size() > 2) {
            System.err.println("Unsupported coalescent at non-binary node");
            System.exit(0);
        }
        if (coalLines.size() < 2) {
            System.out.println();
            System.out.println("WARNING: Less than two lineages found at coalescent event!");
            System.out.println();
            return Double.NaN;
        }

        //  get the indices of the daughter lineages
        final int daughterIndex1 = activeLineages.indexOf(coalLines.get(0).getNr());
        final int daughterIndex2 = activeLineages.indexOf(coalLines.get(1).getNr());
        if (daughterIndex1 == -1 || daughterIndex2 == -1) {
            System.out.println("daughter lineages at coalescent event not found");
            return Double.NaN;
        }
        DoubleMatrix lambda = DoubleMatrix.zeros(states);

        /*
         * Calculate the overall probability for two strains to coalesce 
         * independent of the state at which this coalescent event is 
         * supposed to happen
         */
        for (int k = 0; k < states; k++) {
            double coal = 0.0;
            coal = coalescentRatesInput.get().getArrayValue(k) / 2;
            // calculate the coalscent rate. The 2 originates from lineage i coalescing with lineage j
            // AND vice versa.
            final double pairCoalRate = coal * 2 * (lineStateProbs.get(daughterIndex1 * states + k)
                    * lineStateProbs.get(daughterIndex2 * states + k));

            if (!Double.isNaN(pairCoalRate))
                lambda.put(k, pairCoalRate);
            else
                System.err.println("pairwise coalescent rate is nan");
        }

        activeLineages.add(coalLines.get(0).getParent().getNr());

        DoubleMatrix pVec = new DoubleMatrix();
        pVec.copy(lambda);
        if (normalize)
            pVec = pVec.div(pVec.sum());

        for (int i = 0; i < states; i++)
            lineStateProbs.add(pVec.get(i));

        stateProbabilities[coalLines.get(0).getParent().getNr() - nrSamples] = (pVec);

        //Remove daughter lineages
        if (daughterIndex1 > daughterIndex2) {
            activeLineages.remove(daughterIndex1);
            activeLineages.remove(daughterIndex2);
            for (int i = (states - 1); i >= 0; i--)
                lineStateProbs.remove(daughterIndex1 * states + i);
            for (int i = (states - 1); i >= 0; i--)
                lineStateProbs.remove(daughterIndex2 * states + i);
        } else {
            activeLineages.remove(daughterIndex2);
            activeLineages.remove(daughterIndex1);
            for (int i = (states - 1); i >= 0; i--)
                lineStateProbs.remove(daughterIndex2 * states + i);
            for (int i = (states - 1); i >= 0; i--)
                lineStateProbs.remove(daughterIndex1 * states + i);
        }

        if (lambda.min() < 0.0)
            System.err.println("Coalescent probability smaller than 0");

        if (!dependentHistory)
            interval = Math.log(independent_likelihood);

        return Math.log(lambda.sum()) + interval;
    }

    public DoubleMatrix getStateProb(int nr) {
        return stateProbabilities[nr - nrSamples];
    }

    public DoubleMatrix[] getStateProbabilities() {
        return stateProbabilities;
    }

    public String getType() {
        if (typeTraitInput.get() != null) {
            return typeTraitInput.get().getTraitName();
        } else {
            return "state";
        }
    }

    @Override
    protected boolean requiresRecalculation() {
        return true;
    }

    /*
    * Loggable interface
    */
    @Override
    public void init(PrintStream out) {
        for (int i = 0; i < states; i++)
            for (int j = 0; j < states; j++)
                if (i != j)
                    out.print("backward_rate_" + i + "to" + j + "\t");
        for (int i = 0; i < states; i++)
            for (int j = 0; j < states; j++)
                if (i != j)
                    out.print("forward_rate_" + j + "to" + i + "\t");
        for (int i = 0; i < states; i++)
            out.print("coalescent_rate_deme_" + i + "\t");
        out.print("max_posterior\t");
        for (int i = 0; i < states * (states - 1); i++)
            out.print("max_mig_rate" + i + "\t");
        for (int i = 0; i < states; i++)
            out.print("max_coal_rate" + i + "\t");

    }

    @Override
    public void log(int nSample, PrintStream out) {
        int c = 0;
        for (int i = 0; i < states; i++)
            for (int j = 0; j < states; j++)
                if (i != j) {
                    out.print(migrationRatesInput.get().getArrayValue(c) + "\t");
                    c++;
                }

        c = 0;
        for (int i = 0; i < states; i++)
            for (int j = 0; j < states; j++)
                if (i != j) {
                    out.print(
                            migrationRatesInput.get().getArrayValue(c) / coalescentRatesInput.get().getArrayValue(j)
                                    * coalescentRatesInput.get().getArrayValue(i) + "\t");
                    c++;
                }
        for (int i = 0; i < states; i++)
            out.print(coalescentRatesInput.get().getArrayValue(i) + "\t");

        out.print(max_posterior + "\t");
        for (int i = 0; i < states * (states - 1); i++)
            out.print(max_mig[i] + "\t");
        for (int i = 0; i < states; i++)
            out.print(max_coal[i] + "\t");

    }

    @Override
    public void close(PrintStream out) {
    }

}