Java tutorial
/* * CaseToCaseTreeLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.epidemiology.casetocase; import java.io.IOException; import java.io.PrintStream; import java.util.*; import dr.app.tools.NexusExporter; import dr.evolution.tree.FlexibleNode; import dr.evolution.tree.FlexibleTree; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.tree.TreeTrait; import dr.evolution.tree.TreeTraitProvider; import dr.evolution.util.Taxon; import dr.evolution.util.TaxonList; import dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution; import dr.evomodel.tree.TreeModel; import dr.oldevomodel.treelikelihood.AbstractTreeLikelihood; import dr.inference.loggers.LogColumn; import dr.inference.loggers.Loggable; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.util.Author; import dr.util.Citable; import dr.util.Citation; import org.apache.commons.math.stat.descriptive.moment.Mean; import org.apache.commons.math.stat.descriptive.moment.Variance; import org.apache.commons.math.stat.descriptive.rank.Median; /** * Handles manipulation of the tree partition, and likelihood of the infection times. * * @author Matthew Hall * @author Andrew Rambaut * @version $Id: $ */ public abstract class CaseToCaseTreeLikelihood extends AbstractTreeLikelihood implements Loggable, Citable, TreeTraitProvider { protected static final boolean DEBUG = false; protected static double tolerance = 1E-10; /* The phylogenetic tree. */ protected int noTips; protected int noCases; /* Mapping of outbreak to branches on the tree; old version is stored before operators are applied */ /* Matches outbreak to external nodes */ private double estimatedLastSampleTime; protected TreeTraitProvider.Helper treeTraits = new Helper(); /** * The set of cases */ protected AbstractOutbreak outbreak; protected double[] infectionTimes; private double[] storedInfectionTimes; protected double[] infectiousPeriods; private double[] storedInfectiousPeriods; protected double[] infectiousTimes; private double[] storedInfectiousTimes; protected double[] latentPeriods; private double[] storedLatentPeriods; protected boolean[] recalculateCaseFlags; protected HashMap<AbstractCase, Treelet> elementsAsTrees; protected HashMap<AbstractCase, Treelet> storedElementsAsTrees; //because of the way the former works, we need a maximum value of the time from first infection to root node. protected Parameter maxFirstInfToRoot; // latent periods protected boolean hasLatentPeriods; // PUBLIC STUFF // Name public static final String CASE_TO_CASE_TREE_LIKELIHOOD = "caseToCaseTreeLikelihood"; public static final String PARTITIONS_KEY = "partition"; // Basic constructor. public CaseToCaseTreeLikelihood(PartitionedTreeModel tree, AbstractOutbreak caseData, Parameter maxFirstInfToRoot) throws TaxonList.MissingTaxonException { this(CASE_TO_CASE_TREE_LIKELIHOOD, tree, caseData, maxFirstInfToRoot); } // Constructor for an instance with a non-default name public CaseToCaseTreeLikelihood(String name, PartitionedTreeModel tree, AbstractOutbreak caseData, Parameter maxFirstInfToRoot) { super(name, caseData, tree); if (stateCount != treeModel.getExternalNodeCount()) { throw new RuntimeException("There are duplicate tip outbreak."); } noTips = tree.getExternalNodeCount(); //subclasses should add outbreak as a model if it contains any information that ever changes outbreak = caseData; noCases = outbreak.getCases().size(); addModel(outbreak); estimatedLastSampleTime = getLatestTaxonTime(); //map outbreak to tips addModel(tree.getBranchMap()); hasLatentPeriods = outbreak.hasLatentPeriods(); infectionTimes = new double[outbreak.size()]; infectiousPeriods = new double[outbreak.size()]; if (hasLatentPeriods) { infectiousTimes = new double[outbreak.size()]; latentPeriods = new double[outbreak.size()]; } recalculateCaseFlags = new boolean[outbreak.size()]; Arrays.fill(recalculateCaseFlags, true); this.maxFirstInfToRoot = maxFirstInfToRoot; treeTraits.addTrait(PARTITIONS_KEY, new TreeTrait.S() { public String getTraitName() { return PARTITIONS_KEY; } public Intent getIntent() { return Intent.NODE; } public String getTrait(Tree tree, NodeRef node) { return getNodePartition(tree, node); } }); if (DEBUG) { treeTraits.addTrait("NodeNumber", new TreeTrait.S() { public String getTraitName() { return "NodeNumber"; } public Intent getIntent() { return Intent.NODE; } public String getTrait(Tree tree, NodeRef node) { return Integer.toString(node.getNumber()); } }); } likelihoodKnown = false; } public AbstractOutbreak getOutbreak() { return outbreak; } public boolean hasLatentPeriods() { return hasLatentPeriods; } /* Get the date of the last tip */ private double getLatestTaxonTime() { double latestTime = Double.NEGATIVE_INFINITY; for (int i = 0; i < treeModel.getExternalNodeCount(); i++) { Taxon taxon = treeModel.getNodeTaxon(treeModel.getExternalNode(i)); if (taxon.getDate().getTimeValue() > latestTime) { latestTime = taxon.getDate().getTimeValue(); } } return latestTime; } private NodeRef[] getChildren(NodeRef node) { NodeRef[] children = new NodeRef[treeModel.getChildCount(node)]; for (int i = 0; i < treeModel.getChildCount(node); i++) { children[i] = treeModel.getChild(node, i); } return children; } protected void explodeTree() { for (int i = 0; i < outbreak.size(); i++) { AbstractCase aCase = outbreak.getCase(i); if (aCase.wasEverInfected() && elementsAsTrees.get(aCase) == null) { NodeRef partitionRoot = ((PartitionedTreeModel) treeModel).getEarliestNodeInElement(aCase); double extraHeight; if (treeModel.isRoot(partitionRoot)) { extraHeight = maxFirstInfToRoot.getParameterValue(0) * aCase.getInfectionBranchPosition().getParameterValue(0); } else { extraHeight = treeModel.getBranchLength(partitionRoot) * aCase.getInfectionBranchPosition().getParameterValue(0); } FlexibleNode newRoot = new FlexibleNode(); FlexibleTree littleTree = new FlexibleTree(newRoot); littleTree.beginTreeEdit(); if (!treeModel.isExternal(partitionRoot)) { for (int j = 0; j < treeModel.getChildCount(partitionRoot); j++) { copyElementToTreelet(littleTree, treeModel.getChild(partitionRoot, j), newRoot, aCase); } } littleTree.endTreeEdit(); littleTree.resolveTree(); Treelet treelet = new Treelet(littleTree, littleTree.getRootHeight() + extraHeight); elementsAsTrees.put(aCase, treelet); } } } private void copyElementToTreelet(FlexibleTree littleTree, NodeRef oldNode, NodeRef newParent, AbstractCase element) { if (element.wasEverInfected()) { if (getBranchMap().get(oldNode.getNumber()) == element) { if (treeModel.isExternal(oldNode)) { NodeRef newTip = new FlexibleNode(new Taxon(treeModel.getNodeTaxon(oldNode).getId())); littleTree.addChild(newParent, newTip); littleTree.setBranchLength(newTip, treeModel.getBranchLength(oldNode)); } else { NodeRef newChild = new FlexibleNode(); littleTree.addChild(newParent, newChild); littleTree.setBranchLength(newChild, treeModel.getBranchLength(oldNode)); for (int i = 0; i < treeModel.getChildCount(oldNode); i++) { copyElementToTreelet(littleTree, treeModel.getChild(oldNode, i), newChild, element); } } } else { // we need a new tip NodeRef transmissionTip = new FlexibleNode( new Taxon("Transmission_" + getBranchMap().get(oldNode.getNumber()).getName())); double parentTime = getNodeTime(treeModel.getParent(oldNode)); double childTime = getInfectionTime(getBranchMap().get(oldNode.getNumber())); littleTree.addChild(newParent, transmissionTip); littleTree.setBranchLength(transmissionTip, childTime - parentTime); } } } protected class Treelet extends FlexibleTree { private double zeroHeight; protected Treelet(FlexibleTree tree, double zeroHeight) { super(tree); this.zeroHeight = zeroHeight; } protected double getZeroHeight() { return zeroHeight; } protected void setZeroHeight(double rootBranchLength) { this.zeroHeight = zeroHeight; } } // find all partitions of the descendant tips of the current node. If map is specified then it makes a map of node // number to possible partitions; map can be null. public HashSet<AbstractCase> descendantTipPartitions(NodeRef node, HashMap<Integer, HashSet<AbstractCase>> map) { HashSet<AbstractCase> out = new HashSet<AbstractCase>(); if (treeModel.isExternal(node)) { out.add(getBranchMap().get(node.getNumber())); if (map != null) { map.put(node.getNumber(), out); } return out; } else { for (int i = 0; i < treeModel.getChildCount(node); i++) { out.addAll(descendantTipPartitions(treeModel.getChild(node, i), map)); } if (map != null) { map.put(node.getNumber(), out); } return out; } } // change flags to indicate that something needs recalculation further down the tree protected static void flagForDescendantRecalculation(TreeModel tree, NodeRef node, boolean[] flags) { flags[node.getNumber()] = true; for (int i = 0; i < tree.getChildCount(node); i++) { flags[tree.getChild(node, i).getNumber()] = true; } NodeRef currentNode = node; while (!tree.isRoot(currentNode) && !flags[currentNode.getNumber()]) { currentNode = tree.getParent(currentNode); flags[currentNode.getNumber()] = true; } } public void flagForDescendantRecalculation(TreeModel tree, NodeRef node) { flagForDescendantRecalculation(tree, node, updateNode); } // ************************************************************** // ModelListener IMPLEMENTATION // ************************************************************** protected void handleModelChangedEvent(Model model, Object object, int index) { if (!(model instanceof AbstractPeriodPriorDistribution)) { if (model == treeModel) { if (object instanceof PartitionedTreeModel.PartitionsChangedEvent) { HashSet<AbstractCase> changedPartitions = ((PartitionedTreeModel.PartitionsChangedEvent) object) .getCasesToRecalculate(); for (AbstractCase aCase : changedPartitions) { recalculateCase(aCase); } } } else if (model == getBranchMap()) { if (object instanceof ArrayList) { for (int i = 0; i < ((ArrayList) object).size(); i++) { BranchMapModel.BranchMapChangedEvent event = (BranchMapModel.BranchMapChangedEvent) ((ArrayList) object) .get(i); recalculateCase(event.getOldCase()); recalculateCase(event.getNewCase()); NodeRef node = treeModel.getNode(event.getNodeToRecalculate()); NodeRef parent = treeModel.getParent(node); if (parent != null) { recalculateCase(getBranchMap().get(parent.getNumber())); } } } else { throw new RuntimeException("Unanticipated model changed event from BranchMapModel"); } } else if (model == outbreak) { if (object instanceof AbstractCase) { recalculateCase((AbstractCase) object); } else { for (AbstractCase aCase : outbreak.getCases()) { recalculateCase(aCase); } } } fireModelChanged(model); likelihoodKnown = false; } } protected void recalculateCase(int index) { recalculateCaseFlags[index] = true; } protected void recalculateCase(AbstractCase aCase) { if (aCase.wasEverInfected()) { recalculateCase(outbreak.getCaseIndex(aCase)); } } // ************************************************************** // VariableListener IMPLEMENTATION // ************************************************************** protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { fireModelChanged(); likelihoodKnown = false; } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** /** * Stores the precalculated state (in this case the node labels and subtree likelihoods) */ protected void storeState() { super.storeState(); storedInfectionTimes = Arrays.copyOf(infectionTimes, infectionTimes.length); storedInfectiousPeriods = Arrays.copyOf(infectiousPeriods, infectiousPeriods.length); if (hasLatentPeriods) { storedInfectiousTimes = Arrays.copyOf(infectiousTimes, infectionTimes.length); storedLatentPeriods = Arrays.copyOf(latentPeriods, latentPeriods.length); } } /** * Restores the precalculated state. */ protected void restoreState() { super.restoreState(); infectionTimes = storedInfectionTimes; infectiousPeriods = storedInfectiousPeriods; if (hasLatentPeriods) { infectiousTimes = storedInfectiousTimes; latentPeriods = storedLatentPeriods; } } protected final void acceptState() { } // nothing to do // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** public final BranchMapModel getBranchMap() { return ((PartitionedTreeModel) treeModel).getBranchMap(); } public final PartitionedTreeModel getTreeModel() { return (PartitionedTreeModel) treeModel; } public void makeDirty() { likelihoodKnown = false; Arrays.fill(recalculateCaseFlags, true); } protected void prepareTimings() { infectionTimes = getInfectionTimes(true); if (hasLatentPeriods) { infectiousTimes = getInfectiousTimes(true); } infectiousPeriods = getInfectiousPeriods(true); if (hasLatentPeriods) { latentPeriods = getLatentPeriods(true); } Arrays.fill(recalculateCaseFlags, false); } /** * Calculates the log likelihood of this set of node labels given the tree. */ protected abstract double calculateLogLikelihood(); // if no infectious models, just need to check whether any infections occur after the infector was no // longer infectious protected boolean isAllowed() { return isAllowed(treeModel.getRoot()); } private boolean isAllowed(NodeRef node) { if (!treeModel.isRoot(node)) { AbstractCase childCase = getBranchMap().get(node.getNumber()); AbstractCase parentCase = getBranchMap().get(treeModel.getParent(node).getNumber()); if (childCase != parentCase) { double infectionTime = infectionTimes[outbreak.getCaseIndex(childCase)]; if (infectionTime > parentCase.getEndTime() || (hasLatentPeriods && infectionTime < infectiousTimes[outbreak.getCaseIndex(parentCase)])) { return false; } } } return treeModel.isExternal(node) || (isAllowed(treeModel.getChild(node, 0)) && isAllowed(treeModel.getChild(node, 1))); } /* Return the double time at which the given node occurred */ public double getNodeTime(NodeRef node) { double nodeHeight = getHeight(node); return estimatedLastSampleTime - nodeHeight; } public double heightToTime(double height) { return estimatedLastSampleTime - height; } public double timeToHeight(double time) { return estimatedLastSampleTime - time; } private double getHeight(NodeRef node) { return treeModel.getNodeHeight(node); } public double getInfectionTime(AbstractCase thisCase) { if (!recalculateCaseFlags[outbreak.getCaseIndex(thisCase)]) { return infectionTimes[outbreak.getCaseIndex(thisCase)]; } else { if (thisCase.wasEverInfected()) { NodeRef child = ((PartitionedTreeModel) treeModel).getEarliestNodeInElement(thisCase); NodeRef parent = treeModel.getParent(child); if (parent != null) { double min = heightToTime(treeModel.getNodeHeight(parent)); // Let the likelihood evaluate to zero due to culling dates if it must... double max = heightToTime(treeModel.getNodeHeight(child)); return getInfectionTime(min, max, thisCase); } else { return getRootInfectionTime(getBranchMap()); } } else { return Double.POSITIVE_INFINITY; } } } private double getInfectionTime(double min, double max, AbstractCase infected) { final double branchLength = max - min; return min + branchLength * (1 - infected.getInfectionBranchPosition().getParameterValue(0)); } public double[] getInfectionTimes(boolean recalculate) { if (recalculate) { for (int i = 0; i < noCases; i++) { if (recalculateCaseFlags[i]) { infectionTimes[i] = getInfectionTime(outbreak.getCase(i)); } } } return infectionTimes; } public void setInfectionTime(AbstractCase thisCase, double time) { setInfectionHeight(thisCase, timeToHeight(time)); } public void setInfectionHeight(AbstractCase thisCase, double height) { if (thisCase.wasEverInfected()) { NodeRef child = ((PartitionedTreeModel) treeModel).getEarliestNodeInElement(thisCase); NodeRef parent = treeModel.getParent(child); double minHeight = treeModel.getNodeHeight(child); double maxHeight = parent != null ? treeModel.getNodeHeight(parent) : minHeight + maxFirstInfToRoot.getParameterValue(0); if (height < minHeight || height > maxHeight) { throw new RuntimeException( "Trying to set an infection time outside the branch on which it must occur"); } double branchPosition = (height - minHeight) / (maxHeight - minHeight); thisCase.setInfectionBranchPosition(branchPosition); } } public double getInfectiousTime(AbstractCase thisCase) { if (!hasLatentPeriods) { return getInfectionTime(thisCase); } else { if (recalculateCaseFlags[outbreak.getCaseIndex(thisCase)]) { if (thisCase.wasEverInfected()) { String latentCategory = ((CategoryOutbreak) outbreak).getLatentCategory(thisCase); Parameter latentPeriod = ((CategoryOutbreak) outbreak).getLatentPeriod(latentCategory); infectiousTimes[outbreak.getCaseIndex(thisCase)] = getInfectionTime(thisCase) + latentPeriod.getParameterValue(0); } else { infectiousTimes[outbreak.getCaseIndex(thisCase)] = Double.POSITIVE_INFINITY; } } return infectiousTimes[outbreak.getCaseIndex(thisCase)]; } } public double[] getInfectiousTimes(boolean recalculate) { if (recalculate) { for (int i = 0; i < noCases; i++) { if (recalculateCaseFlags[i]) { infectiousTimes[i] = getInfectiousTime(outbreak.getCase(i)); } } } return infectiousTimes; } public double getInfectiousPeriod(AbstractCase thisCase) { if (recalculateCaseFlags[outbreak.getCaseIndex(thisCase)]) { if (thisCase.wasEverInfected()) { if (!hasLatentPeriods) { double infectionTime = getInfectionTime(thisCase); double cullTime = thisCase.getEndTime(); infectiousPeriods[outbreak.getCaseIndex(thisCase)] = cullTime - infectionTime; } else { double infectiousTime = getInfectiousTime(thisCase); double cullTime = thisCase.getEndTime(); infectiousPeriods[outbreak.getCaseIndex(thisCase)] = cullTime - infectiousTime; } } else { infectiousPeriods[outbreak.getCaseIndex(thisCase)] = 0; } } return infectiousPeriods[outbreak.getCaseIndex(thisCase)]; } public double[] getInfectiousPeriods(boolean recalculate) { if (recalculate) { for (int i = 0; i < noCases; i++) { if (recalculateCaseFlags[i]) { infectiousPeriods[i] = getInfectiousPeriod(outbreak.getCase(i)); } } } return infectiousPeriods; } public Double[] getNonzeroInfectiousPeriods() { ArrayList<Double> out = new ArrayList<Double>(); for (int i = 0; i < noCases; i++) { AbstractCase thisCase = outbreak.getCase(i); if (thisCase.wasEverInfected()) { out.add(getInfectiousPeriod(thisCase)); } } return out.toArray(new Double[out.size()]); } public double getLatentPeriod(AbstractCase thisCase) { if (!hasLatentPeriods || !thisCase.wasEverInfected()) { return 0.0; } if (recalculateCaseFlags[outbreak.getCaseIndex(thisCase)]) { latentPeriods[outbreak.getCaseIndex(thisCase)] = getInfectiousTime(thisCase) - getInfectionTime(thisCase); } return latentPeriods[outbreak.getCaseIndex(thisCase)]; } public double[] getLatentPeriods(boolean recalculate) { if (recalculate) { for (int i = 0; i < noCases; i++) { if (recalculateCaseFlags[i]) { latentPeriods[i] = getLatentPeriod(outbreak.getCase(i)); } } } return latentPeriods; } public Double[] getNonzeroLatentPeriods() { ArrayList<Double> out = new ArrayList<Double>(); for (int i = 0; i < noCases; i++) { AbstractCase thisCase = outbreak.getCase(i); if (thisCase.wasEverInfected()) { out.add(getLatentPeriod(thisCase)); } } return out.toArray(new Double[out.size()]); } public double[] getInfectedPeriods(boolean recalculate) { if (!hasLatentPeriods) { return getInfectiousPeriods(recalculate); } else { double[] out = new double[noCases]; for (int i = 0; i < noCases; i++) { out[i] = getInfectedPeriod(outbreak.getCase(i)); } return out; } } public Double[] getNonzeroInfectedPeriods() { ArrayList<Double> out = new ArrayList<Double>(); for (int i = 0; i < noCases; i++) { AbstractCase thisCase = outbreak.getCase(i); if (thisCase.wasEverInfected()) { out.add(getInfectedPeriod(thisCase)); } } return out.toArray(new Double[out.size()]); } public double getInfectedPeriod(AbstractCase thisCase) { if (thisCase.wasEverInfected) { return thisCase.getEndTime() - getInfectionTime(thisCase); } return 0; } // return an array of the mean, median, variance and standard deviation of the given array // @todo this is pretty wasteful since it gets called so many times per log entry public static Double[] getSummaryStatistics(Double[] variable) { double[] primitiveVariable = new double[variable.length]; for (int i = 0; i < variable.length; i++) { primitiveVariable[i] = variable[i]; } Double[] out = new Double[4]; out[0] = (new Mean()).evaluate(primitiveVariable); out[1] = (new Median()).evaluate(primitiveVariable); out[2] = (new Variance()).evaluate(primitiveVariable); out[3] = Math.sqrt(out[2]); return out; } private double getRootInfectionTime(BranchMapModel branchMap) { NodeRef root = treeModel.getRoot(); AbstractCase rootCase = branchMap.get(root.getNumber()); final double branchLength = maxFirstInfToRoot.getParameterValue(0); return heightToTime(treeModel.getNodeHeight(root) + branchLength * rootCase.getInfectionBranchPosition().getParameterValue(0)); } protected double getRootInfectionTime() { AbstractCase rootCase = getBranchMap().get(treeModel.getRoot().getNumber()); return getInfectionTime(rootCase); } public void outputTreeToFile(String fileName, boolean includeTransmissionNodes) { outputTreeToFile(getBranchMap(), fileName, includeTransmissionNodes); } public void outputTreeToFile(BranchMapModel map, String fileName, boolean includeTransmissionNodes) { try { FlexibleTree treeCopy; if (!includeTransmissionNodes) { treeCopy = new FlexibleTree(treeModel); for (int j = 0; j < treeCopy.getNodeCount(); j++) { FlexibleNode node = (FlexibleNode) treeCopy.getNode(j); node.setAttribute("Number", node.getNumber()); node.setAttribute("Time", heightToTime(node.getHeight())); node.setAttribute(PARTITIONS_KEY, map.get(node.getNumber())); } } else { treeCopy = addTransmissionNodes(treeModel); } NexusExporter testTreesOut = new NexusExporter(new PrintStream(fileName)); testTreesOut.exportTree(treeCopy); } catch (IOException ignored) { System.out.println("IOException"); } } public FlexibleTree addTransmissionNodes(Tree tree) { prepareTimings(); FlexibleTree outTree = new FlexibleTree(tree, true); for (int j = 0; j < outTree.getNodeCount(); j++) { FlexibleNode node = (FlexibleNode) outTree.getNode(j); node.setAttribute("Number", node.getNumber()); node.setAttribute("Time", heightToTime(node.getHeight())); node.setAttribute(PARTITIONS_KEY, getBranchMap().get(node.getNumber())); } for (AbstractCase aCase : outbreak.getCases()) { if (aCase.wasEverInfected()) { NodeRef originalNode = ((PartitionedTreeModel) treeModel).getEarliestNodeInElement(aCase); int infectionNodeNo = originalNode.getNumber(); if (!treeModel.isRoot(originalNode)) { NodeRef originalParent = treeModel.getParent(originalNode); double nodeTime = getNodeTime(originalNode); double infectionTime = getInfectionTime(aCase); double heightToBreakBranch = getHeight(originalNode) + (nodeTime - infectionTime); FlexibleNode newNode = (FlexibleNode) outTree.getNode(infectionNodeNo); FlexibleNode oldParent = (FlexibleNode) outTree.getParent(newNode); outTree.beginTreeEdit(); outTree.removeChild(oldParent, newNode); FlexibleNode infectionNode = new FlexibleNode(); infectionNode.setHeight(heightToBreakBranch); infectionNode.setLength(oldParent.getHeight() - heightToBreakBranch); infectionNode.setAttribute(PARTITIONS_KEY, getNodePartition(treeModel, originalParent)); infectionNode.setAttribute("Time", heightToTime(heightToBreakBranch)); newNode.setLength(nodeTime - infectionTime); outTree.addChild(oldParent, infectionNode); outTree.addChild(infectionNode, newNode); outTree.endTreeEdit(); } else { double nodeTime = getNodeTime(originalNode); double infectionTime = getInfectionTime(aCase); double heightToInstallRoot = getHeight(originalNode) + (nodeTime - infectionTime); FlexibleNode newNode = (FlexibleNode) outTree.getNode(infectionNodeNo); outTree.beginTreeEdit(); FlexibleNode infectionNode = new FlexibleNode(); infectionNode.setHeight(heightToInstallRoot); infectionNode.setAttribute("Time", heightToTime(heightToInstallRoot)); infectionNode.setAttribute(PARTITIONS_KEY, "Origin"); outTree.addChild(infectionNode, newNode); newNode.setLength(heightToInstallRoot - getHeight(originalNode)); outTree.setRoot(infectionNode); outTree.endTreeEdit(); } } } outTree = new FlexibleTree((FlexibleNode) outTree.getRoot()); for (int i = 0; i < outTree.getNodeCount(); i++) { NodeRef node = outTree.getNode(i); NodeRef parent = outTree.getParent(node); if (parent != null && outTree.getNodeHeight(node) > outTree.getNodeHeight(parent)) { try { NexusExporter exporter = new NexusExporter(new PrintStream("fancyProblem.nex")); exporter.exportTree(outTree); } catch (IOException e) { e.printStackTrace(); } try { ((PartitionedTreeModel) treeModel).checkPartitions(); } catch (BadPartitionException e) { System.out.print("Rewiring messed up because of partition problem."); } throw new RuntimeException("Rewiring messed up; investigate"); } } return outTree; } //************************************************************************ // Loggable implementation //************************************************************************ public LogColumn[] getColumns() { LogColumn[] columns = new LogColumn[outbreak.infectedSize()]; int count = 0; for (int i = 0; i < outbreak.size(); i++) { final AbstractCase infected = outbreak.getCase(i); if (infected.wasEverInfected()) { columns[count] = new LogColumn.Abstract(infected.toString() + "_infector") { protected String getFormattedValue() { if (((PartitionedTreeModel) treeModel).getInfector(infected) == null) { return "Start"; } else { return ((PartitionedTreeModel) treeModel).getInfector(infected).toString(); } } }; count++; } } return columns; } public LogColumn[] passColumns() { ArrayList<LogColumn> columns = new ArrayList<LogColumn>(); for (int i = 0; i < outbreak.size(); i++) { final AbstractCase infected = outbreak.getCase(i); if (infected.wasEverInfected()) { columns.add(new LogColumn.Abstract(infected.toString() + "_infection_date") { protected String getFormattedValue() { return String.valueOf(getInfectionTime(infected)); } }); } } if (hasLatentPeriods) { for (int i = 0; i < outbreak.size(); i++) { final AbstractCase infected = outbreak.getCase(i); if (infected.wasEverInfected()) { columns.add(new LogColumn.Abstract(infected.toString() + "_infectious_date") { protected String getFormattedValue() { return String.valueOf(getInfectiousTime(infected)); } }); } } for (int i = 0; i < outbreak.size(); i++) { final AbstractCase infected = outbreak.getCase(i); if (infected.wasEverInfected()) { columns.add(new LogColumn.Abstract(infected.toString() + "_latent_period") { protected String getFormattedValue() { return String.valueOf(getLatentPeriod(infected)); } }); } } } for (int i = 0; i < outbreak.size(); i++) { final AbstractCase infected = outbreak.getCase(i); if (infected.wasEverInfected()) { columns.add(new LogColumn.Abstract(infected.toString() + "_infectious_period") { protected String getFormattedValue() { return String.valueOf(getInfectiousPeriod(infected)); } }); } } if (hasLatentPeriods) { for (int i = 0; i < outbreak.size(); i++) { final AbstractCase infected = outbreak.getCase(i); if (infected.wasEverInfected()) { columns.add(new LogColumn.Abstract(infected.toString() + "_infected_period") { protected String getFormattedValue() { return String.valueOf(getInfectiousPeriod(infected) + getLatentPeriod(infected)); } }); } } } columns.add(new LogColumn.Abstract("infectious_period.mean") { protected String getFormattedValue() { return String .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectiousPeriods())[0]); } }); columns.add(new LogColumn.Abstract("infectious_period.median") { protected String getFormattedValue() { return String .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectiousPeriods())[1]); } }); columns.add(new LogColumn.Abstract("infectious_period.var") { protected String getFormattedValue() { return String .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectiousPeriods())[2]); } }); columns.add(new LogColumn.Abstract("infectious_period.stdev") { protected String getFormattedValue() { return String .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectiousPeriods())[3]); } }); if (hasLatentPeriods) { columns.add(new LogColumn.Abstract("latent_period.mean") { protected String getFormattedValue() { return String .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroLatentPeriods())[0]); } }); columns.add(new LogColumn.Abstract("latent_period.median") { protected String getFormattedValue() { return String .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroLatentPeriods())[1]); } }); columns.add(new LogColumn.Abstract("latent_period.var") { protected String getFormattedValue() { return String .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroLatentPeriods())[2]); } }); columns.add(new LogColumn.Abstract("latent_period.stdev") { protected String getFormattedValue() { return String .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroLatentPeriods())[3]); } }); columns.add(new LogColumn.Abstract("infected_period.mean") { protected String getFormattedValue() { return String .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectedPeriods())[0]); } }); columns.add(new LogColumn.Abstract("infected_period.median") { protected String getFormattedValue() { return String .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectedPeriods())[1]); } }); columns.add(new LogColumn.Abstract("infected_period.var") { protected String getFormattedValue() { return String .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectedPeriods())[2]); } }); columns.add(new LogColumn.Abstract("infected_period.stdev") { protected String getFormattedValue() { return String .valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(getNonzeroInfectedPeriods())[3]); } }); for (int i = 0; i < outbreak.size(); i++) { final AbstractCase infected = outbreak.getCase(i); if (infected.wasEverInfected()) { columns.add(new LogColumn.Abstract(infected.toString() + "_ibp") { protected String getFormattedValue() { return String.valueOf(infected.getInfectionBranchPosition().getParameterValue(0)); } }); } } } return columns.toArray(new LogColumn[columns.size()]); } @Override public Citation.Category getCategory() { return Citation.Category.TREE_PRIORS; } @Override public String getDescription() { return "Case to Case Transmission Tree model"; } public List<Citation> getCitations() { return Arrays.asList(new Citation( new Author[] { new Author("M", "Hall"), new Author("M", "Woolhouse"), new Author("A", "Rambaut") }, "Epidemic Reconstruction in a Phylogenetics Framework: Transmission Trees as Partitions of the Node Set", 2016, "PLOS Comput Biol", 11, 0, 0, "10.1371/journal.pcbi.1004613", Citation.Status.PUBLISHED)); } // ************************************************************** // TreeTraitProvider IMPLEMENTATION // ************************************************************** public TreeTrait[] getTreeTraits() { return treeTraits.getTreeTraits(); } public TreeTrait getTreeTrait(String key) { return treeTraits.getTreeTrait(key); } public String getNodePartition(Tree tree, NodeRef node) { if (tree != treeModel) { // we're trying to annotate a partitioned tree, we hope try { NodeRef oldNode = treeModel.getNode((Integer) tree.getNodeAttribute(node, "Number")); if (treeModel.getNodeHeight(oldNode) != tree.getNodeHeight(node)) { throw new RuntimeException("Can only reconstruct states on treeModel given to constructor or a " + "partitioned tree derived from it"); } else { return getBranchMap().get(oldNode.getNumber()).toString(); } } catch (NullPointerException e) { if (tree.isRoot(node)) { return "Start"; } else { NodeRef parent = tree.getParent(node); int originalParentNumber = (Integer) tree.getNodeAttribute(parent, "Number"); return getBranchMap().get(originalParentNumber).toString(); } } } else { return getBranchMap().get(node.getNumber()).toString(); } } public Integer[] getParentsArray() { Integer[] out = new Integer[outbreak.size()]; for (AbstractCase thisCase : outbreak.getCases()) { if (thisCase.wasEverInfected()) { out[outbreak.getCaseIndex(thisCase)] = outbreak .getCaseIndex(((PartitionedTreeModel) treeModel).getInfector(thisCase)); } else { out[outbreak.getCaseIndex(thisCase)] = null; } } return out; } public AbstractCase getInfector(int i) { return ((PartitionedTreeModel) treeModel).getInfector(getOutbreak().getCase(i)); } }