org.broadinstitute.gatk.engine.recalibration.RecalDatumNode.java Source code

Java tutorial

Introduction

Here is the source code for org.broadinstitute.gatk.engine.recalibration.RecalDatumNode.java

Source

/*
* Copyright 2012-2016 Broad Institute, Inc.
* 
* Permission is hereby granted, free of charge, to any person
* obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without
* restriction, including without limitation the rights to use,
* copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following
* conditions:
* 
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
* 
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR
* THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

package org.broadinstitute.gatk.engine.recalibration;

import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import org.apache.commons.math.MathException;
import org.apache.commons.math.stat.inference.ChiSquareTestImpl;
import org.apache.log4j.Logger;
import org.broadinstitute.gatk.utils.collections.Pair;
import org.broadinstitute.gatk.utils.exceptions.ReviewedGATKException;

import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Set;

/**
 * A tree of recal datum, where each contains a set of sub datum representing sub-states of the higher level one
 *
 * @author Mark DePristo
 * @since 07/27/12
 */
public class RecalDatumNode<T extends RecalDatum> {
    private final static double SMALLEST_CHI2_PVALUE = 1e-300;
    protected static final Logger logger = Logger.getLogger(RecalDatumNode.class);

    /**
     * fixedPenalty is this value if it's considered fixed
     */
    private final static double UNINITIALIZED = Double.NEGATIVE_INFINITY;

    private final T recalDatum;
    private double fixedPenalty = UNINITIALIZED;
    private final Set<RecalDatumNode<T>> subnodes;

    @Requires({ "recalDatum != null" })
    public RecalDatumNode(final T recalDatum) {
        this(recalDatum, new HashSet<RecalDatumNode<T>>());
    }

    @Override
    public String toString() {
        return recalDatum.toString();
    }

    @Requires({ "recalDatum != null", "subnodes != null" })
    public RecalDatumNode(final T recalDatum, final Set<RecalDatumNode<T>> subnodes) {
        this(recalDatum, UNINITIALIZED, subnodes);
    }

    @Requires({ "recalDatum != null" })
    protected RecalDatumNode(final T recalDatum, final double fixedPenalty) {
        this(recalDatum, fixedPenalty, new HashSet<RecalDatumNode<T>>());
    }

    @Requires({ "recalDatum != null", "subnodes != null" })
    protected RecalDatumNode(final T recalDatum, final double fixedPenalty, final Set<RecalDatumNode<T>> subnodes) {
        this.recalDatum = recalDatum;
        this.fixedPenalty = fixedPenalty;
        this.subnodes = new HashSet<RecalDatumNode<T>>(subnodes);
    }

    /**
     * Get the recal data associated with this node
     * @return
     */
    @Ensures("result != null")
    public T getRecalDatum() {
        return recalDatum;
    }

    /**
     * The set of all subnodes of this tree.  May be modified.
     * @return
     */
    @Ensures("result != null")
    public Set<RecalDatumNode<T>> getSubnodes() {
        return subnodes;
    }

    /**
     * Return the fixed penalty, if set, or else the the calculated penalty for this node
     * @return
     */
    public double getPenalty() {
        if (fixedPenalty != UNINITIALIZED)
            return fixedPenalty;
        else
            return calcPenalty();
    }

    /**
     * Set the fixed penalty for this node to a fresh calculation from calcPenalty
     *
     * This is important in the case where you want to compute the penalty from a full
     * tree and then chop the tree up afterwards while considering the previous penalties.
     * If you don't call this function then manipulating the tree may result in the
     * penalty functions changing with changes in the tree.
     *
     * @param doEntireTree recurse into all subnodes?
     * @return the fixed penalty for this node
     */
    public double calcAndSetFixedPenalty(final boolean doEntireTree) {
        fixedPenalty = calcPenalty();
        if (doEntireTree)
            for (final RecalDatumNode<T> sub : subnodes)
                sub.calcAndSetFixedPenalty(doEntireTree);
        return fixedPenalty;
    }

    /**
     * Add node to the set of subnodes of this node
     * @param sub
     */
    @Requires("sub != null")
    public void addSubnode(final RecalDatumNode<T> sub) {
        subnodes.add(sub);
    }

    /**
     * Is this a leaf node (i.e., has no subnodes)?
     * @return
     */
    public boolean isLeaf() {
        return subnodes.isEmpty();
    }

    /**
     * Is this node immediately above only leaf nodes?
     *
     * @return
     */
    public boolean isAboveOnlyLeaves() {
        for (final RecalDatumNode<T> sub : subnodes)
            if (!sub.isLeaf())
                return false;
        return true;
    }

    /**
     * What's the immediate number of subnodes from this node?
     * @return
     */
    @Ensures("result >= 0")
    public int getNumSubnodes() {
        return subnodes.size();
    }

    /**
     * Total penalty is the sum of leaf node penalties
     *
     * This algorithm assumes that penalties have been fixed before pruning, as leaf nodes by
     * definition have 0 penalty unless they represent a pruned tree with underlying -- but now
     * pruned -- subtrees
     *
     * @return
     */
    public double totalPenalty() {
        if (isLeaf())
            return getPenalty();
        else {
            double sum = 0.0;
            for (final RecalDatumNode<T> sub : subnodes)
                sum += sub.totalPenalty();
            return sum;
        }
    }

    /**
     * The maximum penalty among all nodes
     * @return
     */
    public double maxPenalty(final boolean leafOnly) {
        double max = !leafOnly || isLeaf() ? getPenalty() : Double.MIN_VALUE;
        for (final RecalDatumNode<T> sub : subnodes)
            max = Math.max(max, sub.maxPenalty(leafOnly));
        return max;
    }

    /**
     * The minimum penalty among all nodes
     * @return
     */
    public double minPenalty(final boolean leafOnly) {
        double min = !leafOnly || isLeaf() ? getPenalty() : Double.MAX_VALUE;
        for (final RecalDatumNode<T> sub : subnodes)
            min = Math.min(min, sub.minPenalty(leafOnly));
        return min;
    }

    /**
     * What's the longest branch from this node to any leaf?
     * @return
     */
    public int maxDepth() {
        int subMax = 0;
        for (final RecalDatumNode<T> sub : subnodes)
            subMax = Math.max(subMax, sub.maxDepth());
        return subMax + 1;
    }

    /**
     * What's the shortest branch from this node to any leaf?  Includes this node
     * @return
     */
    @Ensures("result > 0")
    public int minDepth() {
        if (isLeaf())
            return 1;
        else {
            int subMin = Integer.MAX_VALUE;
            for (final RecalDatumNode<T> sub : subnodes)
                subMin = Math.min(subMin, sub.minDepth());
            return subMin + 1;
        }
    }

    /**
     * Return the number of nodes, including this one, reachable from this node
     * @return
     */
    @Ensures("result > 0")
    public int size() {
        int size = 1;
        for (final RecalDatumNode<T> sub : subnodes)
            size += sub.size();
        return size;
    }

    /**
     * Count the number of leaf nodes reachable from this node
     *
     * @return
     */
    @Ensures("result >= 0")
    public int numLeaves() {
        if (isLeaf())
            return 1;
        else {
            int size = 0;
            for (final RecalDatumNode<T> sub : subnodes)
                size += sub.numLeaves();
            return size;
        }
    }

    /**
     * Calculate the phred-scaled p-value for a chi^2 test for independent among subnodes of this node.
     *
     * The chi^2 value indicates the degree of independence of the implied error rates among the
     * immediate subnodes
     *
     * @return the phred-scaled p-value for chi2 penalty, or 0.0 if it cannot be calculated
     */
    private double calcPenalty() {
        if (isLeaf() || freeToMerge())
            return 0.0;
        else if (subnodes.size() == 1)
            // only one value, so its free to merge away
            return 0.0;
        else {
            final long[][] counts = new long[subnodes.size()][2];

            int i = 0;
            for (final RecalDatumNode<T> subnode : subnodes) {
                // use the yates correction to help avoid all zeros => NaN
                counts[i][0] = Math.round(subnode.getRecalDatum().getNumMismatches()) + 1L;
                counts[i][1] = subnode.getRecalDatum().getNumObservations() + 2L;
                i++;
            }

            try {
                final double chi2PValue = new ChiSquareTestImpl().chiSquareTest(counts);
                final double penalty = -10.0 * Math.log10(Math.max(chi2PValue, SMALLEST_CHI2_PVALUE));

                // make sure things are reasonable and fail early if not
                if (Double.isInfinite(penalty) || Double.isNaN(penalty))
                    throw new ReviewedGATKException("chi2 value is " + chi2PValue + " at " + getRecalDatum());

                return penalty;
            } catch (MathException e) {
                throw new ReviewedGATKException("Failed in calculating chi2 value", e);
            }
        }
    }

    /**
     * Is this node free to merge because its rounded Q score is the same as all nodes below
     * @return
     */
    private boolean freeToMerge() {
        if (isLeaf()) // leaves are free to merge
            return true;
        else {
            final byte myQual = getRecalDatum().getEmpiricalQualityAsByte();
            for (final RecalDatumNode<T> sub : subnodes)
                if (sub.getRecalDatum().getEmpiricalQualityAsByte() != myQual)
                    return false;
            return true;
        }
    }

    /**
     * Calculate the penalty of this interval, given the overall error rate for the interval
     *
     * If the globalErrorRate is e, this value is:
     *
     * sum_i |log10(e_i) - log10(e)| * nObservations_i
     *
     * each the index i applies to all leaves of the tree accessible from this interval
     * (found recursively from subnodes as necessary)
     *
     * @param globalErrorRate overall error rate in real space against which we calculate the penalty
     * @return the cost of approximating the bins in this interval with the globalErrorRate
     */
    @Requires("globalErrorRate >= 0.0")
    @Ensures("result >= 0.0")
    private double calcPenaltyLog10(final double globalErrorRate) {
        if (globalErrorRate == 0.0) // there were no observations, so there's no penalty
            return 0.0;

        if (isLeaf()) {
            // this is leave node
            return (Math.abs(Math.log10(recalDatum.getEmpiricalErrorRate()) - Math.log10(globalErrorRate)))
                    * (double) recalDatum.getNumObservations();
            // TODO -- how we can generalize this calculation?
            //            if ( this.qEnd <= minInterestingQual )
            //                // It's free to merge up quality scores below the smallest interesting one
            //                return 0;
            //            else {
            //                return (Math.abs(Math.log10(getEmpiricalErrorRate()) - Math.log10(globalErrorRate))) * getNumObservations();
            //            }
        } else {
            double sum = 0;
            for (final RecalDatumNode<T> hrd : subnodes)
                sum += hrd.calcPenaltyLog10(globalErrorRate);
            return sum;
        }
    }

    /**
     * Return a freshly allocated tree prunes to have no more than maxDepth from the root to any leaf
     *
     * @param maxDepth
     * @return
     */
    public RecalDatumNode<T> pruneToDepth(final int maxDepth) {
        if (maxDepth < 1)
            throw new IllegalArgumentException("maxDepth < 1");
        else {
            final Set<RecalDatumNode<T>> subPruned = new HashSet<RecalDatumNode<T>>(getNumSubnodes());
            if (maxDepth > 1)
                for (final RecalDatumNode<T> sub : subnodes)
                    subPruned.add(sub.pruneToDepth(maxDepth - 1));
            return new RecalDatumNode<T>(getRecalDatum(), fixedPenalty, subPruned);
        }
    }

    /**
     * Return a freshly allocated tree with to no more than maxElements in order of penalty
     *
     * Note that nodes must have fixed penalties to this algorithm will fail.
     *
     * @param maxElements
     * @return
     */
    public RecalDatumNode<T> pruneByPenalty(final int maxElements) {
        RecalDatumNode<T> root = this;

        while (root.size() > maxElements) {
            // remove the lowest penalty element, and continue
            root = root.removeLowestPenaltyNode();
        }

        // our size is below the target, so we are good, return
        return root;
    }

    /**
     * Return a freshly allocated tree where all mergable nodes with < maxPenalty are merged
     *
     * Note that nodes must have fixed penalties to this algorithm will fail.
     *
     * @param maxPenaltyIn the maximum penalty we are allowed to incur for a merge
     * @param applyBonferroniCorrection if true, we will adjust penalty by the phred-scaled bonferroni correction
     *                                  for the size of the initial tree.  That is, if there are 10 nodes in the
     *                                  tree and maxPenalty is 20 we will actually enforce 10^-2 / 10 = 10^-3 = 30
     *                                  penalty for multiple testing
     * @return
     */
    public RecalDatumNode<T> pruneToNoMoreThanPenalty(final double maxPenaltyIn,
            final boolean applyBonferroniCorrection) {
        RecalDatumNode<T> root = this;

        final double bonferroniCorrection = 10 * Math.log10(this.size());
        final double maxPenalty = applyBonferroniCorrection ? maxPenaltyIn + bonferroniCorrection : maxPenaltyIn;

        if (applyBonferroniCorrection)
            logger.info(String.format(
                    "Applying Bonferroni correction for %d nodes = %.2f to initial penalty %.2f for total "
                            + "corrected max penalty of %.2f",
                    this.size(), bonferroniCorrection, maxPenaltyIn, maxPenalty));

        while (true) {
            final Pair<RecalDatumNode<T>, Double> minPenaltyNode = root.getMinPenaltyAboveLeafNode();

            if (minPenaltyNode == null || minPenaltyNode.getSecond() > maxPenalty) {
                // nothing to merge, or the best candidate is above our max allowed
                if (minPenaltyNode == null) {
                    if (logger.isDebugEnabled())
                        logger.debug("Stopping because no candidates could be found");
                } else {
                    if (logger.isDebugEnabled())
                        logger.debug("Stopping because node " + minPenaltyNode.getFirst() + " has penalty "
                                + minPenaltyNode.getSecond() + " > max " + maxPenalty);
                }
                break;
            } else {
                // remove the lowest penalty element, and continue
                if (logger.isDebugEnabled())
                    logger.debug("Removing node " + minPenaltyNode.getFirst() + " with penalty "
                            + minPenaltyNode.getSecond());
                root = root.removeLowestPenaltyNode();
            }
        }

        // no more candidates exist with penalty < maxPenalty
        return root;
    }

    /**
     * Find the lowest penalty above leaf node in the tree, and return a tree without it
     *
     * Note this excludes the current (root) node
     *
     * @return
     */
    private RecalDatumNode<T> removeLowestPenaltyNode() {
        final Pair<RecalDatumNode<T>, Double> nodeToRemove = getMinPenaltyAboveLeafNode();
        if (logger.isDebugEnabled())
            logger.debug("Removing " + nodeToRemove.getFirst() + " with penalty " + nodeToRemove.getSecond());

        final Pair<RecalDatumNode<T>, Boolean> result = removeNode(nodeToRemove.getFirst());

        if (!result.getSecond())
            throw new IllegalStateException("Never removed any node!");

        final RecalDatumNode<T> oneRemoved = result.getFirst();
        if (oneRemoved == null)
            throw new IllegalStateException("Removed our root node, wow, didn't expect that");
        return oneRemoved;
    }

    /**
     * Finds in the tree the node with the lowest penalty whose subnodes are all leaves
     *
     * @return the node and its penalty, or null if no such node exists
     */
    private Pair<RecalDatumNode<T>, Double> getMinPenaltyAboveLeafNode() {
        if (isLeaf())
            // not allowed to remove leafs directly
            return null;
        if (isAboveOnlyLeaves())
            // we only consider removing nodes above all leaves
            return new Pair<RecalDatumNode<T>, Double>(this, getPenalty());
        else {
            // just recurse, taking the result with the min penalty of all subnodes
            Pair<RecalDatumNode<T>, Double> minNode = null;
            for (final RecalDatumNode<T> sub : subnodes) {
                final Pair<RecalDatumNode<T>, Double> subFind = sub.getMinPenaltyAboveLeafNode();
                if (subFind != null && (minNode == null || subFind.getSecond() < minNode.getSecond())) {
                    minNode = subFind;
                }
            }
            return minNode;
        }
    }

    /**
     * Return a freshly allocated tree without the node nodeToRemove
     *
     * @param nodeToRemove
     * @return
     */
    private Pair<RecalDatumNode<T>, Boolean> removeNode(final RecalDatumNode<T> nodeToRemove) {
        if (this == nodeToRemove) {
            if (isLeaf())
                throw new IllegalStateException(
                        "Trying to remove a leaf node from the tree! " + this + " " + nodeToRemove);
            // node is the thing we are going to remove, but without any subnodes
            final RecalDatumNode<T> node = new RecalDatumNode<T>(getRecalDatum(), fixedPenalty);
            return new Pair<RecalDatumNode<T>, Boolean>(node, true);
        } else {
            // did we remove something in a sub branch?
            boolean removedSomething = false;

            // our sub nodes with the penalty node removed
            final Set<RecalDatumNode<T>> sub = new HashSet<RecalDatumNode<T>>(getNumSubnodes());

            for (final RecalDatumNode<T> sub1 : subnodes) {
                if (removedSomething) {
                    // already removed something, just add sub1 back to sub
                    sub.add(sub1);
                } else {
                    // haven't removed anything yet, so try
                    final Pair<RecalDatumNode<T>, Boolean> maybeRemoved = sub1.removeNode(nodeToRemove);
                    removedSomething = maybeRemoved.getSecond();
                    sub.add(maybeRemoved.getFirst());
                }
            }

            final RecalDatumNode<T> node = new RecalDatumNode<T>(getRecalDatum(), fixedPenalty, sub);
            return new Pair<RecalDatumNode<T>, Boolean>(node, removedSomething);
        }
    }

    /**
     * Return a collection of all of the data in the leaf nodes of this tree
     *
     * @return
     */
    public Collection<T> getAllLeaves() {
        final LinkedList<T> list = new LinkedList<T>();
        getAllLeavesRec(list);
        return list;
    }

    /**
     * Helpful recursive function for getAllLeaves()
     *
     * @param list the destination for the list of leaves
     */
    private void getAllLeavesRec(final LinkedList<T> list) {
        if (isLeaf())
            list.add(getRecalDatum());
        else {
            for (final RecalDatumNode<T> sub : subnodes)
                sub.getAllLeavesRec(list);
        }
    }
}