beast.evolution.tree.ConstrainedClusterTree.java Source code

Java tutorial

Introduction

Here is the source code for beast.evolution.tree.ConstrainedClusterTree.java

Source

/*
* File ClusterTree.java
*
* Copyright (C) 2010 Remco Bouckaert remco@cs.auckland.ac.nz
*
* This file is part of BEAST2.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
*  BEAST is distributed in the hope that it will be useful,
*  but WITHOUT ANY WARRANTY; without even the implied warranty of
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*  GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA  02110-1301  USA
*/
package beast.evolution.tree;

import beast.core.Description;
import beast.core.Input;
import beast.core.StateNode;
import beast.core.StateNodeInitialiser;
import beast.core.parameter.RealParameter;
import beast.core.util.Log;
import beast.evolution.alignment.Alignment;
import beast.evolution.alignment.distance.Distance;
import beast.evolution.alignment.distance.JukesCantorDistance;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.math.distributions.MRCAPrior;
import beast.math.distributions.MultiMonophyleticConstraint;
import beast.math.distributions.ParametricDistribution;

import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.*;

import org.apache.commons.math.MathException;

/**
 * Adapted from Weka's HierarchicalClustering class *
 */
@Description("Create initial beast.tree by hierarchical clustering, either through one of the classic link methods "
        + "or by neighbor joining. The following link methods are supported: " + "<br/>o single link, "
        + "<br/>o complete link, " + "<br/>o UPGMA=average link, " + "<br/>o mean link, " + "<br/>o centroid, "
        + "<br/>o Ward and " + "<br/>o adjusted complete link " + "<br/>o neighborjoining "
        + "<br/>o neighborjoining2 - corrects tree for tip data, unlike plain neighborjoining")
public class ConstrainedClusterTree extends Tree implements StateNodeInitialiser {
    // The tree in the XML should have a taxon set, since it is not fully initialised at this stage
    public Input<List<MRCAPrior>> calibrationsInput = new Input<List<MRCAPrior>>("constraint",
            "specifies (monophyletic or height distribution) constraints on internal nodes",
            new ArrayList<MRCAPrior>());
    public final Input<MultiMonophyleticConstraint> allConstraints = new Input<>("constraints",
            "all constraints as encoded by one unresolved tree.", Input.Validate.REQUIRED);
    public final Input<Double> epsilonInput = new Input<Double>("minBranchLength",
            "lower bound on lengths used for creating branches", 1e-10);

    enum Type {
        single, average, complete, upgma, mean, centroid, ward, adjcomplete, neighborjoining, neighborjoining2
    }

    //  minimum branch length
    double EPSILON = 1e-10;

    public Input<Type> clusterTypeInput = new Input<Type>(
            "clusterType", "type of clustering algorithm used for generating initial beast.tree. "
                    + "Should be one of " + Type.values() + " (default " + Type.average + ")",
            Type.average, Type.values());
    public Input<Alignment> dataInput = new Input<Alignment>("taxa",
            "alignment data used for calculating distances for clustering");

    public Input<Distance> distanceInput = new Input<Distance>("distance",
            "method for calculating distance between two sequences (default Jukes Cantor)");

    public Input<RealParameter> clockRateInput = new Input<RealParameter>("clock.rate",
            "the clock rate parameter, used to divide all divergence times by, to convert from substitutions to times. (default 1.0)",
            new RealParameter(new Double[] { 1.0 }));

    /**
     * Whether the distance represent node height (if false) or branch length (if true).
     */
    protected boolean distanceIsBranchLength = false;
    Distance distance;
    List<String> taxaNames;

    /**
     * Holds the Link type used calculate distance between clusters
     */
    Type nLinkType = Type.single;

    List<boolean[]> constraints;
    List<Integer> constraintsize;

    @Override
    public void initAndValidate() {
        EPSILON = epsilonInput.get();
        RealParameter clockRate = clockRateInput.get();

        if (dataInput.get() != null) {
            taxaNames = dataInput.get().getTaxaNames();
        } else {
            if (m_taxonset.get() == null) {
                throw new IllegalArgumentException("At least one of taxa and taxonset input needs to be specified");
            }
            taxaNames = m_taxonset.get().asStringList();
        }

        constraints = new ArrayList<boolean[]>();
        constraintsize = new ArrayList<Integer>();

        List<MRCAPrior> calibrations = collectCalibrations(taxaNames, m_initial.get(), allConstraints.get(),
                calibrationsInput.get(), constraints, constraintsize);

        if (Boolean.valueOf(System.getProperty("beast.resume")) && (isEstimatedInput.get()
                || (m_initial.get() != null && m_initial.get().isEstimatedInput.get()))) {
            // don't bother creating a cluster tree to save some time, if it is read from file anyway
            // make a caterpillar
            Node left = newNode();
            left.setNr(0);
            left.setID(taxaNames.get(0));
            left.setHeight(0);
            for (int i = 1; i < taxaNames.size(); i++) {
                final Node right = newNode();
                right.setNr(i);
                right.setID(taxaNames.get(i));
                right.setHeight(0);
                final Node parent = newNode();
                parent.setNr(taxaNames.size() + i - 1);
                parent.setHeight(i);
                left.setParent(parent);
                parent.setLeft(left);
                right.setParent(parent);
                parent.setRight(right);
                left = parent;
            }
            root = left;
            leafNodeCount = taxaNames.size();
            nodeCount = leafNodeCount * 2 - 1;
            internalNodeCount = leafNodeCount - 1;
            super.initAndValidate();
            return;
        }

        distance = distanceInput.get();
        if (distance == null) {
            distance = new JukesCantorDistance();
        }
        if (distance instanceof Distance.Base) {
            if (dataInput.get() == null) {
                // Distance requires an alignment?
            }
            ((Distance.Base) distance).setPatterns(dataInput.get());
        }

        nLinkType = clusterTypeInput.get();

        if (nLinkType == Type.upgma)
            nLinkType = Type.average;

        if (nLinkType == Type.neighborjoining || nLinkType == Type.neighborjoining2) {
            distanceIsBranchLength = true;
        }
        final Node root = buildClusterer();
        setRoot(root);
        root.labelInternalNodes((getNodeCount() + 1) / 2);
        super.initAndValidate();
        if (nLinkType == Type.neighborjoining2) {
            // set tip dates to zero
            final Node[] nodes = getNodesAsArray();
            for (int i = 0; i < getLeafNodeCount(); i++) {
                nodes[i].setHeight(0);
            }
            super.initAndValidate();
        }

        if (m_initial.get() != null)
            processTraits(m_initial.get().m_traitList.get());
        else
            processTraits(m_traitList.get());

        if (timeTraitSet != null)
            adjustTreeNodeHeights(root);
        else {
            // all nodes should be at zero height if no date-trait is available
            for (int i = 0; i < getLeafNodeCount(); i++) {
                getNode(i).setHeight(0);
            }
        }

        //divide all node heights by clock rate to convert from substitutions to time.
        for (Node node : getInternalNodes()) {
            double height = node.getHeight();
            node.setHeight(height / clockRate.getValue());
        }

        Map<Node, MRCAPrior> nodeToBoundMap = new HashMap<>();
        // calculate nodeToBoundMap
        findConstrainedNodes(calibrations, getRoot(), nodeToBoundMap);
        if (nodeToBoundMap.size() > 0) {
            // adjust node heights to MRCAPriors
            try {
                handlebounds(getRoot(), nodeToBoundMap, EPSILON);
            } catch (MathException e) {
                Log.warning.println("Bounds could not be set");
            }
        }

        initStateNodes();
    }

    /**
     * Collect all monophyletic constraints, through MRCAPRiors, MultiMonoPhyleticConstraints and MRCAPrior outputs of a Tree
     * 
     * @param taxaNames list of taxa names to use
     * @param tree (optional) if present, output MRCAPriors are collected as well
     * @param multiMonophyleticConstraint
     * @param MRCAPriors
     * @param constraints returns constraints encoded as boolean arrays
     * @param constraintsize returns constraint sizes corresponding to 'constraints' array
     * @return list of MRCAPriors so time calibrations can be taken care of
     * @throws Exception when taxa in constraint cannot be found in taxaNames  
     */
    public static List<MRCAPrior> collectCalibrations(List<String> taxaNames, Tree tree,
            MultiMonophyleticConstraint multiMonophyleticConstraint, List<MRCAPrior> MRCAPriors,
            List<boolean[]> constraints, List<Integer> constraintsize) {
        List<MRCAPrior> calibrations = new ArrayList<>();

        int nrOfTaxa = taxaNames.size();

        // collect all calibrations
        calibrations.addAll(MRCAPriors);

        //  pick up constraints from calibrations on m_initial input
        if (tree != null) {
            for (final Object plugin : tree.getOutputs()) {
                if (plugin instanceof MRCAPrior && !calibrations.contains(plugin)) {
                    calibrations.add((MRCAPrior) plugin);
                }
            }
        }

        for (MRCAPrior prior : calibrations) {
            final boolean[] bTaxa = new boolean[nrOfTaxa];
            List<String> taxa = prior.taxonsetInput.get().asStringList();
            if (taxa == null) {
                prior.taxonsetInput.get().initAndValidate();
                taxa = prior.taxonsetInput.get().asStringList();
            }
            int size = 0;
            for (final String sTaxonID : taxa) {
                final int iID = taxaNames.indexOf(sTaxonID);
                if (iID < 0) {
                    throw new IllegalArgumentException(
                            "Taxon <" + sTaxonID + "> could not be found in list of taxa. Choose one of "
                                    + taxaNames.toArray(new String[0]));
                }
                bTaxa[iID] = true;
                size++;
            }
            if (prior.isMonophyleticInput.get() && size > 1) {
                // add any monophyletic constraint
                constraints.add(bTaxa);
                constraintsize.add(size);
            }
        }

        final MultiMonophyleticConstraint mul = multiMonophyleticConstraint;
        List<List<String>> allc = mul.getConstraints();

        for (List<String> c : allc) {
            final boolean[] bTaxa = new boolean[nrOfTaxa];
            for (String sTaxonID : c) {
                final int iID = taxaNames.indexOf(sTaxonID);
                if (iID < 0) {
                    throw new IllegalArgumentException(
                            "Taxon <" + sTaxonID + "> could not be found in list of taxa. Choose one of "
                                    + taxaNames.toArray(new String[0]));
                }
                bTaxa[iID] = true;
            }
            constraints.add(bTaxa);
            constraintsize.add(c.size());
        }

        return calibrations;
    }

    /** calculate nodeToBoundMap; for every MRCAPrior, a node will be added to the map **/
    static public void findConstrainedNodes(List<MRCAPrior> calibrations, Node root,
            Map<Node, MRCAPrior> nodeToBoundMap) {
        for (MRCAPrior calibration : calibrations) {
            int nrOfTaxa = calibration.taxonsetInput.get().getTaxonCount();
            findConstrainedNode(calibration, calibration.taxonsetInput.get().asStringList(), nodeToBoundMap, root,
                    new int[1], nrOfTaxa);
        }
    }

    /** process a specific MRCAPrior for the nodeToBoundMap **/
    static int findConstrainedNode(MRCAPrior calibration, List<String> taxa, Map<Node, MRCAPrior> nodeToBoundMap,
            final Node node, final int[] nTaxonCount, int nrOfTaxa) {
        if (node.isLeaf()) {
            nTaxonCount[0]++;

            if (taxa.contains(node.getID())) {
                return 1;
            } else {
                return 0;
            }
        } else {
            int taxonCount = findConstrainedNode(calibration, taxa, nodeToBoundMap, node.getLeft(), nTaxonCount,
                    nrOfTaxa);
            int nMatchingTaxa = nTaxonCount[0];
            nTaxonCount[0] = 0;
            for (int i = 1; i < node.getChildCount(); i++) {
                Node child = node.getChild(i);
                taxonCount += findConstrainedNode(calibration, taxa, nodeToBoundMap, child, nTaxonCount, nrOfTaxa);
                nMatchingTaxa += nTaxonCount[0];
            }
            nTaxonCount[0] = nMatchingTaxa;
            if (taxonCount == nrOfTaxa) {
                if (nrOfTaxa == 1 && calibration.useOriginateInput.get()) {
                    nodeToBoundMap.put(node, calibration);
                    return taxonCount + 1;
                }
                // we are at the MRCA, so record the height
                if (calibration.useOriginateInput.get()) {
                    Node parent = node.getParent();
                    if (parent != null) {
                        nodeToBoundMap.put(parent, calibration);
                    }
                } else {
                    nodeToBoundMap.put(node, calibration);
                }
                return taxonCount + 1;
            }

            //            if (node.getRight() != null) {
            //                taxonCount += findConstrainedNode(calibration, taxa, nodeToBoundMap, node.getRight(), nTaxonCount, nrOfTaxa);
            //                final int nRightTaxa = nTaxonCount[0];
            //                nTaxonCount[0] = nLeftTaxa + nRightTaxa;
            //                if (taxonCount == nrOfTaxa) {
            //                   if (nrOfTaxa == 1 && calibration.useOriginateInput.get()) {
            //                      nodeToBoundMap.put(node, calibration);
            //                        return taxonCount + 1;
            //                   }
            //                    // we are at the MRCA, so record the height
            //                   if (calibration.useOriginateInput.get()) {
            //                      Node parent = node.getParent();
            //                      if (parent != null) {
            //                          nodeToBoundMap.put(parent, calibration);
            //                      }
            //                   } else {
            //                      nodeToBoundMap.put(node, calibration);
            //                   }
            //                    return taxonCount + 1;
            //                }
            //            }
            return taxonCount;
        }
    }

    /** go through MRCAPriors
     * Since we can easily scale a clade, start with the highest MRCAPrior, then process the nested ones 
     * @throws MathException **/
    static public void handlebounds(Node node, Map<Node, MRCAPrior> nodeToBoundMap, double EPSILON)
            throws MathException {
        if (!node.isLeaf()) {
            if (nodeToBoundMap.containsKey(node)) {
                MRCAPrior calibration = nodeToBoundMap.get(node);
                if (calibration.distInput.get() != null) {
                    ParametricDistribution distr = calibration.distInput.get();
                    distr.initAndValidate();
                    double lower = distr.inverseCumulativeProbability(0.0) + distr.offsetInput.get();
                    double upper = distr.inverseCumulativeProbability(1.0) + distr.offsetInput.get();

                    // make sure the timing fits the constraint
                    double height = node.getHeight();
                    double newHeight = Double.NEGATIVE_INFINITY;
                    if (height < lower) {
                        if (Double.isFinite(upper)) {
                            newHeight = (lower + upper) / 2.0;
                        } else {
                            newHeight = lower;
                        }
                    }
                    if (height > upper) {
                        if (Double.isFinite(lower)) {
                            newHeight = (lower + upper) / 2.0;
                        } else {
                            newHeight = upper;
                        }
                    }
                    if (Double.isFinite(newHeight)) {
                        double scale = newHeight / height;

                        // scale clade
                        node.scale(scale);

                        // adjust parents if necessary
                        Node node2 = node;
                        Node parent = node2.getParent();
                        while (parent != null && parent.getHeight() < node2.getHeight()) {
                            parent.setHeight(node2.getHeight() + EPSILON);
                            node2 = node2.getParent();
                            parent = node2.getParent();
                        }
                    }

                }
            }
            for (Node child : node.getChildren()) {
                handlebounds(child, nodeToBoundMap, EPSILON);
            }
        }
    }

    public ConstrainedClusterTree() {
    } // c'tor

    /**
     * class representing node in cluster hierarchy *
     */
    class NodeX {
        NodeX m_left;
        NodeX m_right;
        NodeX m_parent;
        int m_iLeftInstance;
        int m_iRightInstance;
        double m_fLeftLength = 0;
        double m_fRightLength = 0;
        double m_fHeight = 0;

        void setHeight(double fHeight1, double fHeight2) {
            if (fHeight1 < EPSILON) {
                fHeight1 = EPSILON;
            }
            if (fHeight2 < EPSILON) {
                fHeight2 = EPSILON;
            }
            m_fHeight = fHeight1;
            if (m_left == null) {
                m_fLeftLength = fHeight1;
            } else {
                m_fLeftLength = fHeight1 - m_left.m_fHeight;
            }
            if (m_right == null) {
                m_fRightLength = fHeight2;
            } else {
                m_fRightLength = fHeight2 - m_right.m_fHeight;
            }
        }

        void setLength(double fLength1, double fLength2) {
            if (fLength1 < EPSILON) {
                fLength1 = EPSILON;
            }
            if (fLength2 < EPSILON) {
                fLength2 = EPSILON;
            }
            m_fLeftLength = fLength1;
            m_fRightLength = fLength2;
            m_fHeight = fLength1;
            if (m_left != null) {
                m_fHeight += m_left.m_fHeight;
            }
        }

        public String toString() {
            final DecimalFormat myFormatter = new DecimalFormat("#.#####", new DecimalFormatSymbols(Locale.US));

            if (m_left == null) {
                if (m_right == null) {
                    return "(" + taxaNames.get(m_iLeftInstance) + ":" + myFormatter.format(m_fLeftLength) + ","
                            + taxaNames.get(m_iRightInstance) + ":" + myFormatter.format(m_fRightLength) + ")";
                } else {
                    return "(" + taxaNames.get(m_iLeftInstance) + ":" + myFormatter.format(m_fLeftLength) + ","
                            + m_right.toString() + ":" + myFormatter.format(m_fRightLength) + ")";
                }
            } else {
                if (m_right == null) {
                    return "(" + m_left.toString() + ":" + myFormatter.format(m_fLeftLength) + ","
                            + taxaNames.get(m_iRightInstance) + ":" + myFormatter.format(m_fRightLength) + ")";
                } else {
                    return "(" + m_left.toString() + ":" + myFormatter.format(m_fLeftLength) + ","
                            + m_right.toString() + ":" + myFormatter.format(m_fRightLength) + ")";
                }
            }
        }

        Node toNode() {
            final Node node = newNode();
            node.setHeight(m_fHeight);
            if (m_left == null) {
                node.setLeft(newNode());
                node.getLeft().setNr(m_iLeftInstance);
                node.getLeft().setID(taxaNames.get(m_iLeftInstance));
                node.getLeft().setHeight(m_fHeight - m_fLeftLength);
                if (m_right == null) {
                    node.setRight(newNode());
                    node.getRight().setNr(m_iRightInstance);
                    node.getRight().setID(taxaNames.get(m_iRightInstance));
                    node.getRight().setHeight(m_fHeight - m_fRightLength);
                } else {
                    node.setRight(m_right.toNode());
                }
            } else {
                node.setLeft(m_left.toNode());
                if (m_right == null) {
                    node.setRight(newNode());
                    node.getRight().setNr(m_iRightInstance);
                    node.getRight().setID(taxaNames.get(m_iRightInstance));
                    node.getRight().setHeight(m_fHeight - m_fRightLength);
                } else {
                    node.setRight(m_right.toNode());
                }
            }
            if (node.getHeight() < node.getLeft().getHeight() + EPSILON) {
                node.setHeight(node.getLeft().getHeight() + EPSILON);
            }
            if (node.getHeight() < node.getRight().getHeight() + EPSILON) {
                node.setHeight(node.getRight().getHeight() + EPSILON);
            }

            node.getRight().setParent(node);
            node.getLeft().setParent(node);
            return node;
        }

    } // class NodeX

    /**
     * used for priority queue for efficient retrieval of pair of clusters to merge*
     */
    class Tuple {
        public Tuple(final double d, final int i, final int j, final int nSize1, final int nSize2) {
            m_fDist = d;
            m_iCluster1 = i;
            m_iCluster2 = j;
            m_nClusterSize1 = nSize1;
            m_nClusterSize2 = nSize2;
        }

        double m_fDist;
        int m_iCluster1;
        int m_iCluster2;
        int m_nClusterSize1;
        int m_nClusterSize2;
    }

    /**
     * comparator used by priority queue*
     */
    class TupleComparator implements Comparator<Tuple> {
        public int compare(final Tuple o1, final Tuple o2) {
            if (o1.m_fDist < o2.m_fDist) {
                return -1;
            } else if (o1.m_fDist == o2.m_fDist) {
                return 0;
            }
            return 1;
        }
    }

    // return distance according to distance metric
    double distance(final int iTaxon1, final int iTaxon2) {
        return distance.pairwiseDistance(iTaxon1, iTaxon2);
    } // distance

    // 1-norm
    double distance(final double[] nPattern1, final double[] nPattern2) {
        double fDist = 0;
        for (int i = 0; i < dataInput.get().getPatternCount(); i++) {
            fDist += dataInput.get().getPatternWeight(i) * Math.abs(nPattern1[i] - nPattern2[i]);
        }
        return fDist / dataInput.get().getSiteCount();
    }

    @SuppressWarnings("unchecked")
    public Node buildClusterer() {
        final int nTaxa = taxaNames.size();
        if (nTaxa == 1) {
            // pathological case
            final Node node = newNode();
            node.setHeight(1);
            node.setNr(0);
            return node;
        }

        // use array of integer vectors to store cluster indices,
        // starting with one cluster per instance
        final List<Integer>[] nClusterID = new ArrayList[nTaxa];
        for (int i = 0; i < nTaxa; i++) {
            nClusterID[i] = new ArrayList<Integer>();
            nClusterID[i].add(i);
        }
        // calculate distance matrix
        final int nClusters = nTaxa;

        // used for keeping track of hierarchy
        final NodeX[] clusterNodes = new NodeX[nTaxa];
        if (nLinkType == Type.neighborjoining || nLinkType == Type.neighborjoining2) {
            neighborJoining(nClusters, nClusterID, clusterNodes);
        } else {
            doLinkClustering(nClusters, nClusterID, clusterNodes);
        }

        // move all clusters in m_nClusterID array
        // & collect hierarchy
        for (int i = 0; i < nTaxa; i++) {
            if (nClusterID[i].size() > 0) {
                return clusterNodes[i].toNode();
            }
        }
        return null;
    } // buildClusterer

    /**
     * use neighbor joining algorithm for clustering
     * This is roughly based on the RapidNJ simple implementation and runs at O(n^3)
     * More efficient implementations exist, see RapidNJ (or my GPU implementation :-))
     *
     * @param nClusters
     * @param nClusterID
     * @param clusterNodes
     */
    void neighborJoining(int nClusters, final List<Integer>[] nClusterID, final NodeX[] clusterNodes) {
        final int n = taxaNames.size();
        Log.warning.print("ClusterTree: calc distances, ");
        final double[][] fDist = new double[nClusters][nClusters];
        for (int i = 0; i < nClusters; i++) {
            fDist[i][i] = 0;
            for (int j = i + 1; j < nClusters; j++) {
                fDist[i][j] = getDistance0(nClusterID[i], nClusterID[j]);
                fDist[j][i] = fDist[i][j];
            }
        }

        final double[] fSeparationSums = new double[n];
        final double[] fSeparations = new double[n];
        final int[] nNextActive = new int[n];

        //calculate initial separation rows
        for (int i = 0; i < n; i++) {
            double fSum = 0;
            for (int j = 0; j < n; j++) {
                fSum += fDist[i][j];
            }
            fSeparationSums[i] = fSum;
            fSeparations[i] = fSum / (nClusters - 2);
            nNextActive[i] = i + 1;
        }

        while (nClusters > 2) {
            // find minimum
            int iMin1 = -1;
            int iMin2 = -1;
            double fMin = Double.MAX_VALUE;
            {
                int i = 0;
                while (i < n) {
                    final double fSep1 = fSeparations[i];
                    final double[] fRow = fDist[i];
                    int j = nNextActive[i];
                    while (j < n) {
                        final double fSep2 = fSeparations[j];
                        final double fVal = fRow[j] - fSep1 - fSep2;
                        if (fVal < fMin && isCompatible(i, j, nClusterID)) {
                            // new minimum
                            iMin1 = i;
                            iMin2 = j;
                            fMin = fVal;
                        }
                        j = nNextActive[j];
                    }
                    i = nNextActive[i];
                }
            }
            // record distance
            final double fMinDistance = fDist[iMin1][iMin2];
            nClusters--;
            final double fSep1 = fSeparations[iMin1];
            final double fSep2 = fSeparations[iMin2];
            final double fDist1 = (0.5 * fMinDistance) + (0.5 * (fSep1 - fSep2));
            final double fDist2 = (0.5 * fMinDistance) + (0.5 * (fSep2 - fSep1));
            if (nClusters > 2) {
                // update separations  & distance
                double fNewSeparationSum = 0;
                final double fMutualDistance = fDist[iMin1][iMin2];
                final double[] fRow1 = fDist[iMin1];
                final double[] fRow2 = fDist[iMin2];
                for (int i = 0; i < n; i++) {
                    if (i == iMin1 || i == iMin2 || nClusterID[i].size() == 0) {
                        fRow1[i] = 0;
                    } else {
                        final double fVal1 = fRow1[i];
                        final double fVal2 = fRow2[i];
                        final double fDistance = (fVal1 + fVal2 - fMutualDistance) / 2.0;
                        fNewSeparationSum += fDistance;
                        // update the separationsum of cluster i.
                        fSeparationSums[i] += (fDistance - fVal1 - fVal2);
                        fSeparations[i] = fSeparationSums[i] / (nClusters - 2);
                        fRow1[i] = fDistance;
                        fDist[i][iMin1] = fDistance;
                    }
                }
                fSeparationSums[iMin1] = fNewSeparationSum;
                fSeparations[iMin1] = fNewSeparationSum / (nClusters - 2);
                fSeparationSums[iMin2] = 0;
                merge(iMin1, iMin2, fDist1, fDist2, nClusterID, clusterNodes);
                updateConstraints(nClusterID[iMin1]);
                int iPrev = iMin2;
                // since iMin1 < iMin2 we havenActiveRows[0] >= 0, so the next loop should be save
                while (nClusterID[iPrev].size() == 0) {
                    iPrev--;
                }
                nNextActive[iPrev] = nNextActive[iMin2];
            } else {
                merge(iMin1, iMin2, fDist1, fDist2, nClusterID, clusterNodes);
                updateConstraints(nClusterID[iMin1]);
                break;
            }

            // feedback on progress
            if (nClusters % 100 == 0) {
                if (nClusters % 1000 == 0) {
                    Log.warning.print('|');
                } else {
                    Log.warning.print('.');
                }
            }
        }

        Log.warning.print(" merge");
        for (int i = 0; i < n; i++) {
            if (nClusterID[i].size() > 0) {
                for (int j = i + 1; j < n; j++) {
                    if (nClusterID[j].size() > 0) {
                        final double fDist1 = fDist[i][j];
                        if (nClusterID[i].size() == 1) {
                            merge(i, j, fDist1, 0, nClusterID, clusterNodes);
                        } else if (nClusterID[j].size() == 1) {
                            merge(i, j, 0, fDist1, nClusterID, clusterNodes);
                        } else {
                            merge(i, j, fDist1 / 2.0, fDist1 / 2.0, nClusterID, clusterNodes);
                        }
                        break;
                    }
                }
            }
        }
        Log.warning.println(" done.");
    } // neighborJoining

    /**
     * Perform clustering using a link method
     * This implementation uses a priority queue resulting in a O(n^2 log(n)) algorithm
     *
     * @param nClusters    number of clusters
     * @param nClusterID
     * @param clusterNodes
     */
    void doLinkClustering(int nClusters, final List<Integer>[] nClusterID, final NodeX[] clusterNodes) {
        Log.warning.print("Calculating distance");
        final int nInstances = taxaNames.size();
        final PriorityQueue<Tuple> queue = new PriorityQueue<Tuple>(nClusters * nClusters / 2,
                new TupleComparator());
        final double[][] fDistance0 = new double[nClusters][nClusters];
        for (int i = 0; i < nClusters; i++) {
            fDistance0[i][i] = 0;
            for (int j = i + 1; j < nClusters; j++) {
                fDistance0[i][j] = getDistance0(nClusterID[i], nClusterID[j]);
                fDistance0[j][i] = fDistance0[i][j];
                if (isCompatible(i, j, nClusterID)) {
                    queue.add(new Tuple(fDistance0[i][j], i, j, 1, 1));
                }
            }
            // feedback on progress
            if ((i + 1) % 100 == 0) {
                if ((i + 1) % 1000 == 0) {
                    Log.warning.print('|');
                } else {
                    Log.warning.print('.');
                }
            }
        }
        Log.warning.print("\nClustering: ");
        while (nClusters > 1) {
            int iMin1 = -1;
            int iMin2 = -1;
            // use priority queue to find next best pair to cluster
            Tuple t;
            do {
                t = queue.poll();
            } while (t != null && (nClusterID[t.m_iCluster1].size() != t.m_nClusterSize1
                    || nClusterID[t.m_iCluster2].size() != t.m_nClusterSize2));
            iMin1 = t.m_iCluster1;
            iMin2 = t.m_iCluster2;
            merge(iMin1, iMin2, t.m_fDist / 2.0, t.m_fDist / 2.0, nClusterID, clusterNodes);
            updateConstraints(nClusterID[iMin1]);
            // merge  clusters

            // update distances & queue
            for (int i = 0; i < nInstances; i++) {
                if (i != iMin1 && nClusterID[i].size() != 0) {
                    final int i1 = Math.min(iMin1, i);
                    final int i2 = Math.max(iMin1, i);
                    if (isCompatible(i1, i2, nClusterID)) {
                        final double fDistance = getDistance(fDistance0, nClusterID[i1], nClusterID[i2]);
                        queue.add(new Tuple(fDistance, i1, i2, nClusterID[i1].size(), nClusterID[i2].size()));
                    }
                }
            }

            nClusters--;

            // feedback on progress
            if (nClusters % 100 == 0) {
                if (nClusters % 1000 == 0) {
                    Log.warning.print('|');
                } else {
                    Log.warning.print('.');
                }
            }
        }
        Log.warning.println(" done.");
    } // doLinkClustering

    private boolean isCompatible(int i1, int i2, final List<Integer>[] nClusterID) {
        for (boolean[] bits : constraints) {
            boolean value = bits[nClusterID[i1].get(0)];
            for (int i : nClusterID[i1]) {
                if (value != bits[i]) {
                    return false;
                }
            }
            for (int i : nClusterID[i2]) {
                if (value != bits[i]) {
                    return false;
                }
            }
        }
        return true;
    }

    private void updateConstraints(List<Integer> nClusterIDs) {
        for (int i = constraints.size() - 1; i >= 0; i--) {
            if (constraintsize.get(i) == nClusterIDs.size()) {
                boolean[] bits = constraints.get(i);
                boolean match = true;
                for (int j : nClusterIDs) {
                    if (!bits[j]) {
                        match = false;
                        break;
                    }
                }
                if (match) {
                    constraints.remove(i);
                    constraintsize.remove(i);
                }
            }
        }
    }

    NodeX merge(int iMin1, int iMin2, double fDist1, double fDist2, final List<Integer>[] nClusterID,
            final NodeX[] clusterNodes) {
        if (iMin1 > iMin2) {
            final int h = iMin1;
            iMin1 = iMin2;
            iMin2 = h;
            final double f = fDist1;
            fDist1 = fDist2;
            fDist2 = f;
        }
        nClusterID[iMin1].addAll(nClusterID[iMin2]);
        //nClusterID[iMin2].removeAllElements();
        nClusterID[iMin2].removeAll(nClusterID[iMin2]);

        // track hierarchy
        final NodeX node = new NodeX();
        if (clusterNodes[iMin1] == null) {
            node.m_iLeftInstance = iMin1;
        } else {
            node.m_left = clusterNodes[iMin1];
            clusterNodes[iMin1].m_parent = node;
        }
        if (clusterNodes[iMin2] == null) {
            node.m_iRightInstance = iMin2;
        } else {
            node.m_right = clusterNodes[iMin2];
            clusterNodes[iMin2].m_parent = node;
        }
        if (distanceIsBranchLength) {
            node.setLength(fDist1, fDist2);
        } else {
            node.setHeight(fDist1, fDist2);
        }
        clusterNodes[iMin1] = node;
        return node;
    } // merge

    /**
     * calculate distance the first time when setting up the distance matrix *
     */
    double getDistance0(final List<Integer> cluster1, final List<Integer> cluster2) {
        double fBestDist = Double.MAX_VALUE;
        switch (nLinkType) {
        case single:
        case neighborjoining:
        case neighborjoining2:
        case centroid:
        case complete:
        case adjcomplete:
        case average:
        case mean:
            // set up two instances for distance function
            fBestDist = distance(cluster1.get(0), cluster2.get(0));
            break;
        case ward: {
            // finds the distance of the change in caused by merging the cluster.
            // The information of a cluster is calculated as the error sum of squares of the
            // centroids of the cluster and its members.
            final double ESS1 = calcESS(cluster1);
            final double ESS2 = calcESS(cluster2);
            final List<Integer> merged = new ArrayList<Integer>();
            merged.addAll(cluster1);
            merged.addAll(cluster2);
            final double ESS = calcESS(merged);
            fBestDist = ESS * merged.size() - ESS1 * cluster1.size() - ESS2 * cluster2.size();
        }
            break;
        }
        return fBestDist;
    } // getDistance0

    /**
     * calculate the distance between two clusters
     *
     * @param cluster1 list of indices of instances in the first cluster
     * @param cluster2 dito for second cluster
     * @return distance between clusters based on link type
     */
    double getDistance(final double[][] fDistance, final List<Integer> cluster1, final List<Integer> cluster2) {
        double fBestDist = Double.MAX_VALUE;
        switch (nLinkType) {
        case single:
            // find single link distance aka minimum link, which is the closest distance between
            // any item in cluster1 and any item in cluster2
            fBestDist = Double.MAX_VALUE;
            for (int i = 0; i < cluster1.size(); i++) {
                final int i1 = cluster1.get(i);
                for (int j = 0; j < cluster2.size(); j++) {
                    final int i2 = cluster2.get(j);
                    final double fDist = fDistance[i1][i2];
                    if (fBestDist > fDist) {
                        fBestDist = fDist;
                    }
                }
            }
            break;
        case complete:
        case adjcomplete:
            // find complete link distance aka maximum link, which is the largest distance between
            // any item in cluster1 and any item in cluster2
            fBestDist = 0;
            for (int i = 0; i < cluster1.size(); i++) {
                final int i1 = cluster1.get(i);
                for (int j = 0; j < cluster2.size(); j++) {
                    final int i2 = cluster2.get(j);
                    final double fDist = fDistance[i1][i2];
                    if (fBestDist < fDist) {
                        fBestDist = fDist;
                    }
                }
            }
            if (nLinkType == Type.complete) {
                break;
            }
            // calculate adjustment, which is the largest within cluster distance
            double fMaxDist = 0;
            for (int i = 0; i < cluster1.size(); i++) {
                final int i1 = cluster1.get(i);
                for (int j = i + 1; j < cluster1.size(); j++) {
                    final int i2 = cluster1.get(j);
                    final double fDist = fDistance[i1][i2];
                    if (fMaxDist < fDist) {
                        fMaxDist = fDist;
                    }
                }
            }
            for (int i = 0; i < cluster2.size(); i++) {
                final int i1 = cluster2.get(i);
                for (int j = i + 1; j < cluster2.size(); j++) {
                    final int i2 = cluster2.get(j);
                    final double fDist = fDistance[i1][i2];
                    if (fMaxDist < fDist) {
                        fMaxDist = fDist;
                    }
                }
            }
            fBestDist -= fMaxDist;
            break;
        case average:
            // finds average distance between the elements of the two clusters
            fBestDist = 0;
            for (int i = 0; i < cluster1.size(); i++) {
                final int i1 = cluster1.get(i);
                for (int j = 0; j < cluster2.size(); j++) {
                    final int i2 = cluster2.get(j);
                    fBestDist += fDistance[i1][i2];
                }
            }
            fBestDist /= (cluster1.size() * cluster2.size());
            break;
        case mean: {
            // calculates the mean distance of a merged cluster (akak Group-average agglomerative clustering)
            final List<Integer> merged = new ArrayList<Integer>();
            merged.addAll(cluster1);
            merged.addAll(cluster2);
            fBestDist = 0;
            for (int i = 0; i < merged.size(); i++) {
                final int i1 = merged.get(i);
                for (int j = i + 1; j < merged.size(); j++) {
                    final int i2 = merged.get(j);
                    fBestDist += fDistance[i1][i2];
                }
            }
            final int n = merged.size();
            fBestDist /= (n * (n - 1.0) / 2.0);
        }
            break;
        case centroid:
            // finds the distance of the centroids of the clusters
            final int nPatterns = dataInput.get().getPatternCount();
            final double[] centroid1 = new double[nPatterns];
            for (int i = 0; i < cluster1.size(); i++) {
                final int iTaxon = cluster1.get(i);
                for (int j = 0; j < nPatterns; j++) {
                    centroid1[j] += dataInput.get().getPattern(iTaxon, j);
                }
            }
            final double[] centroid2 = new double[nPatterns];
            for (int i = 0; i < cluster2.size(); i++) {
                final int iTaxon = cluster2.get(i);
                for (int j = 0; j < nPatterns; j++) {
                    centroid2[j] += dataInput.get().getPattern(iTaxon, j);
                }
            }
            for (int j = 0; j < nPatterns; j++) {
                centroid1[j] /= cluster1.size();
                centroid2[j] /= cluster2.size();
            }
            fBestDist = distance(centroid1, centroid2);
            break;
        case ward: {
            // finds the distance of the change in caused by merging the cluster.
            // The information of a cluster is calculated as the error sum of squares of the
            // centroids of the cluster and its members.
            final double ESS1 = calcESS(cluster1);
            final double ESS2 = calcESS(cluster2);
            final List<Integer> merged = new ArrayList<Integer>();
            merged.addAll(cluster1);
            merged.addAll(cluster2);
            final double ESS = calcESS(merged);
            fBestDist = ESS * merged.size() - ESS1 * cluster1.size() - ESS2 * cluster2.size();
        }
            break;
        }
        return fBestDist;
    } // getDistance

    /**
     * calculated error sum-of-squares for instances wrt centroid *
     */
    double calcESS(final List<Integer> cluster) {
        final int nPatterns = dataInput.get().getPatternCount();
        final double[] centroid = new double[nPatterns];
        for (int i = 0; i < cluster.size(); i++) {
            final int iTaxon = cluster.get(i);
            for (int j = 0; j < nPatterns; j++) {
                centroid[j] += dataInput.get().getPattern(iTaxon, j);
            }
        }
        for (int j = 0; j < nPatterns; j++) {
            centroid[j] /= cluster.size();
        }
        // set up two instances for distance function
        double fESS = 0;
        for (int i = 0; i < cluster.size(); i++) {
            final double[] instance = new double[nPatterns];
            final int iTaxon = cluster.get(i);
            for (int j = 0; j < nPatterns; j++) {
                instance[j] += dataInput.get().getPattern(iTaxon, j);
            }
            fESS += distance(centroid, instance);
        }
        return fESS / cluster.size();
    } // calcESS

    @Override
    public void initStateNodes() {
        if (m_initial.get() != null) {
            m_initial.get().assignFromWithoutID(this);
        }
    }

    @Override
    public void getInitialisedStateNodes(final List<StateNode> stateNodes) {
        if (m_initial.get() != null) {
            stateNodes.add(m_initial.get());
        }
    }

} // class ClusterTree