Java tutorial
/////////////////////////////////////////////////////////////////////////////// // For information as to what this class does, see the Javadoc, below. // // Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // // 2007, 2008, 2009, 2010, 2014, 2015 by Peter Spirtes, Richard Scheines, Joseph // // Ramsey, and Clark Glymour. // // // // This program is free software; you can redistribute it and/or modify // // it under the terms of the GNU General Public License as published by // // the Free Software Foundation; either version 2 of the License, or // // (at your option) any later version. // // // // This program is distributed in the hope that it will be useful, // // but WITHOUT ANY WARRANTY; without even the implied warranty of // // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // // GNU General Public License for more details. // // // // You should have received a copy of the GNU General Public License // // along with this program; if not, write to the Free Software // // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // /////////////////////////////////////////////////////////////////////////////// package edu.cmu.tetrad.bayes; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.data.DataUtils; import edu.cmu.tetrad.data.DiscreteVariable; import edu.cmu.tetrad.graph.Dag; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.graph.TimeLagGraph; import edu.cmu.tetrad.util.NumberFormatUtil; import edu.cmu.tetrad.util.RandomUtil; import edu.cmu.tetrad.data.BoxDataSet; import edu.cmu.tetrad.data.VerticalIntDataBox; import org.apache.commons.math3.distribution.ChiSquaredDistribution; import java.io.IOException; import java.io.ObjectInputStream; import java.text.NumberFormat; import java.util.*; import static java.lang.Math.abs; import static java.lang.Math.log; import static java.lang.Math.pow; /** * Stores a table of probabilities for a Bayes net and, together with BayesPm * and Dag, provides methods to manipulate this table. The division of labor is * as follows. The Dag is responsible for manipulating the basic graphical * structure of the Bayes net. Dag also stores and manipulates the names of the * nodes in the graph; there are no method in either BayesPm or BayesIm to do * this. BayesPm stores and manipulates the *categories* of each node in a DAG, * considered as a variable in a Bayes net. The number of categories for a * variable can be changed there as well as the names for those categories. This * class, BayesIm, stores the actual probability tables which are implied by the * structures in the other two classes. The implied parameters take the form of * conditional probabilities--e.g., P(N=v0|P1=v1, P2=v2, ...), for all nodes and * all combinations of their parent categories. The set of all such * probabilities is organized in this class as a three-dimensional table of * double values. The first dimension corresponds to the nodes in the Bayes * net. For each such node, the second dimension corresponds to a flat list of * combinations of parent categories for that node. The third dimension * corresponds to the list of categories for that node itself. Two methods * allow these values to be set and retrieved: <ul> <li>getWordRatio(int * nodeIndex, int rowIndex, int colIndex); and, <li>setProbability(int * nodeIndex, int rowIndex, int colIndex, int probability). </ul> To determine * the index of the node in question, use the method <ul> <li> getNodeIndex(Node * node). </ul> To determine the index of the row in question, use the method * <ul> <li>getRowIndex(int[] parentVals). </ul> To determine the order of the * parent values for a given node so that you can build the parentVals[] array, * use the method <ul> <li> getParents(int nodeIndex) </ul> To determine the * index of a category, use the method <ul> <li> getCategoryIndex(Node node) * </ul> in BayesPm. The rest of the methods in this class are easily * understood as variants of the methods above. </p> <p>Thanks to Pucktada * Treeratpituk, Frank Wimberly, and Willie Wheeler for advice and earlier * versions.</p> * * @author Joseph Ramsey jdramsey@andrew.cmu.edu */ public final class MlBayesIm implements BayesIm { static final long serialVersionUID = 23L; private static final double ALLOWABLE_DIFFERENCE = 1.0e-10; /** * Inidicates that new rows in this BayesIm should be initialized as * unknowns, forcing them to be specified manually. This is the default. */ public static final int MANUAL = 0; /** * Indicates that new rows in this BayesIm should be initialized randomly. */ public static final int RANDOM = 1; /** * The associated Bayes PM model. * * @serial */ private BayesPm bayesPm; /** * The array of nodes from the graph. Order is important. * * @serial */ private Node[] nodes; /** * The list of parents for each node from the graph. Order or nodes * corresponds to the order of nodes in 'nodes', and order in subarrays is * important. * * @serial */ private int[][] parents; /** * The array of dimensionality (number of categories for each node) for each * of the subarrays of 'parents'. * * @serial */ private int[][] parentDims; /** * The main data structure; stores the values of all of the conditional * probabilities for the Bayes net of the form P(N=v0 | P1=v1, P2=v2,...). * The first dimension is the node N, in the order of 'nodes'. The second * dimension is the row index for the table of parameters associated with * node N; the third dimension is the column index. The row index is * calculated by the function getRowIndex(int[] values) where 'values' is an * array of numerical indices for each of the parent values; the order of * the values in this array is the same as the order of node in 'parents'; * the value indices are obtained from the Bayes PM for each node. The * column is the index of the value of N, where this index is obtained from * the Bayes PM. * * @serial */ private double[][][] probs; //===============================CONSTRUCTORS=========================// /** * Constructs a new BayesIm from the given BayesPm, initializing all values * as Double.NaN ("?"). * * @param bayesPm the given Bayes PM. Carries with it the underlying graph * model. * @throws IllegalArgumentException if the array of nodes provided is not a * permutation of the nodes contained in * the bayes parametric model provided. */ public MlBayesIm(BayesPm bayesPm) throws IllegalArgumentException { this(bayesPm, null, MANUAL); } /** * Constructs a new BayesIm from the given BayesPm, initializing values * either as MANUAL or RANDOM. If initialized manually, all values will be * set to Double.NaN ("?") in each row; if initialized randomly, all values * will distributed randomly in each row. * * @param bayesPm the given Bayes PM. Carries with it the * underlying graph model. * @param initializationMethod either MANUAL or RANDOM. * @throws IllegalArgumentException if the array of nodes provided is not a * permutation of the nodes contained in * the bayes parametric model provided. */ public MlBayesIm(BayesPm bayesPm, int initializationMethod) throws IllegalArgumentException { this(bayesPm, null, initializationMethod); } /** * Constructs a new BayesIm from the given BayesPm, initializing values * either as MANUAL or RANDOM, but using values from the old BayesIm * provided where posssible. If initialized manually, all values that cannot * be retrieved from oldBayesIm will be set to Double.NaN ("?") in each such * row; if initialized randomly, all values that cannot be retrieved from * oldBayesIm will distributed randomly in each such row. * * @param bayesPm the given Bayes PM. Carries with it the * underlying graph model. * @param oldBayesIm an already-constructed BayesIm whose values * may be used where possible to initialize this * BayesIm. May be null. * @param initializationMethod either MANUAL or RANDOM. * @throws IllegalArgumentException if the array of nodes provided is not a * permutation of the nodes contained in * the bayes parametric model provided. */ public MlBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, int initializationMethod) throws IllegalArgumentException { if (bayesPm == null) { throw new NullPointerException("BayesPm must not be null."); } this.bayesPm = new BayesPm(bayesPm); // Get the nodes from the BayesPm. This fixes the order of the nodes // in the BayesIm, independently of any change to the BayesPm. // (This order must be maintained.) Graph graph = bayesPm.getDag(); this.nodes = graph.getNodes().toArray(new Node[0]); // Initialize. initialize(oldBayesIm, initializationMethod); } /** * Copy constructor. */ public MlBayesIm(BayesIm bayesIm) throws IllegalArgumentException { if (bayesIm == null) { throw new NullPointerException("BayesIm must not be null."); } this.bayesPm = bayesIm.getBayesPm(); // Get the nodes from the BayesPm, fixing on an order. (This is // important; the nodes must always be in the same order for this // BayesIm.) this.nodes = new Node[bayesIm.getNumNodes()]; for (int i = 0; i < bayesIm.getNumNodes(); i++) { this.nodes[i] = bayesIm.getNode(i); } // Copy all the old values over. initialize(bayesIm, MlBayesIm.MANUAL); } /** * Generates a simple exemplar of this class to test serialization. * * @see edu.cmu.TestSerialization * @see edu.cmu.tetradapp.util.TetradSerializableUtils */ public static MlBayesIm serializableInstance() { return new MlBayesIm(BayesPm.serializableInstance()); } //===============================PUBLIC METHODS========================// /** * Returns the underlying Bayes PM. * * @return this PM. */ public BayesPm getBayesPm() { return bayesPm; } /** * Returns the DAG as a Graph. * * @return the DAG. */ public Dag getDag() { return bayesPm.getDag(); } /** * Returns the number of nodes in the model. */ public int getNumNodes() { return nodes.length; } /** * Returns the node corresponding to the given node index. * * @param nodeIndex * @return this node. */ public Node getNode(int nodeIndex) { return nodes[nodeIndex]; } /** * Returns the node with the given name in the associated graph. * * @param name the name of the node. * @return the node. */ public Node getNode(String name) { return getDag().getNode(name); } /** * Returns the node index for the given node. * * @param node the given node. * @return the index for that node, or -1 if the node is not in the * BayesIm. */ public int getNodeIndex(Node node) { for (int i = 0; i < nodes.length; i++) { if (node == nodes[i]) { return i; } } return -1; } public List<Node> getVariables() { List<Node> variables = new LinkedList<Node>(); for (int i = 0; i < getNumNodes(); i++) { Node node = getNode(i); variables.add(bayesPm.getVariable(node)); } return variables; } /** * Returns the list of measured variableNodes. */ public List<Node> getMeasuredNodes() { return bayesPm.getMeasuredNodes(); } public List<String> getVariableNames() { List<String> variableNames = new LinkedList<String>(); for (int i = 0; i < getNumNodes(); i++) { Node node = getNode(i); variableNames.add(bayesPm.getVariable(node).getName()); } return variableNames; } /** * Returns the number of columns in the table of the given node N with index * 'nodeIndex'--that is, the number of possible values that N can take on. * That is, if P(N=v0 | P1=v1, P2=v2, ... Pn=vn) is a conditional * probability stored in 'probs', then the maximum number of rows in the * table for N is #vals(N). * * @param nodeIndex * @return this number. * @see #getNumRows */ public int getNumColumns(int nodeIndex) { return probs[nodeIndex][0].length; } /** * Returns the number of rows in the table of the given node, which would be * the total number of possible combinations of parent values for a given * node. That is, if P(N=v0 | P1=v1, P2=v2, ... Pn=vn) is a conditional * probability stored in 'probs', then the maximum number of rows in the * table for N is #vals(P1) x #vals(P2) x ... x #vals(Pn). * * @param nodeIndex * @return this number. * @see #getRowIndex * @see #getNumColumns */ public int getNumRows(int nodeIndex) { return probs[nodeIndex].length; } /** * Returns the number of parents of the given node. * * @param nodeIndex the given node. * @return the number of parents for this node. */ public int getNumParents(int nodeIndex) { return parents[nodeIndex].length; } /** * Returns the given parent of the given node. */ public int getParent(int nodeIndex, int parentIndex) { return parents[nodeIndex][parentIndex]; } /** * Returns the dimension of the given parent for the given node. */ public int getParentDim(int nodeIndex, int parentIndex) { return parentDims[nodeIndex][parentIndex]; } /** * Returns (a defensive copy of) the array representing the dimensionality * of each parent of a node, that is, the number of values which that node * can take on. The order of entries in this array is the same as the order * of entries of nodes returned by getParents() for that node. * * @return this array of parent dimensions. * @see #getParents */ public int[] getParentDims(int nodeIndex) { int[] dims = parentDims[nodeIndex]; int[] copy = new int[dims.length]; System.arraycopy(dims, 0, copy, 0, dims.length); return copy; } /** * Returns (a defensive copy of) the array containing all of the parents of * a given node in the order in which they are stored internally. * * @see #getParentDims */ public int[] getParents(int nodeIndex) { int[] nodeParents = parents[nodeIndex]; int[] copy = new int[nodeParents.length]; System.arraycopy(nodeParents, 0, copy, 0, nodeParents.length); return copy; } /** * Returns an array containing the combination of parent values for a given * node and given row in the probability table for that node. To get the * combination of parent values from the row number, the row number is * represented using a variable-base place value system, where the bases for * each place value are the dimensions of the parents in the order in which * they are given by getParentDims(). For instance, if the row number (base * 10) is 103 and the parent dimension array is [3 5 7], we calculate the * first value as 103 / 7 = 14 with a remainder of 5. We then divide 14 / 5 * = 2 with a remainder of 4. We then divide 2 / 3 = 0 with a remainder of * 2. The variable place value representation is [2 4 5], which is the * combination of parent values. This is the inverse function of * getRowIndex(). * * @param nodeIndex the index of the node. * @param rowIndex the index of the row in question. * @return the array representing the combination of parent values for this * row. * @see #getNodeIndex * @see #getRowIndex */ public int[] getParentValues(int nodeIndex, int rowIndex) { int[] dims = getParentDims(nodeIndex); int[] values = new int[dims.length]; for (int i = dims.length - 1; i >= 0; i--) { values[i] = rowIndex % dims[i]; rowIndex /= dims[i]; } return values; } /** * Returns the value in the probability table for the given node, at the * given row and column. */ public int getParentValue(int nodeIndex, int rowIndex, int colIndex) { return getParentValues(nodeIndex, rowIndex)[colIndex]; } /** * Returns the probability for the given node at the given row and column in * the table for that node. To get the node index, use getNodeIndex(). To * get the row index, use getRowIndex(). To get the column index, use * getCategoryIndex() from the underlying BayesPm(). The value returned * will represent a conditional probability of the form P(N=v0 | P1=v1, * P2=v2, ... , Pn=vn), where N is the node referenced by nodeIndex, v0 is * the value referenced by colIndex, and the combination of parent values * indicated is the combination indicated by rowIndex. * * @param nodeIndex the index of the node in question. * @param rowIndex the row in the table for this for node which represents * the combination of parent values in question. * @param colIndex the column in the table for this node which represents * the value of the node in question. * @return the probability stored for this parameter. * @see #getNodeIndex * @see #getRowIndex */ public double getProbability(int nodeIndex, int rowIndex, int colIndex) { return probs[nodeIndex][rowIndex][colIndex]; } /** * Returns the row in the table at which the given combination of parent * values is represented for the given node. The row is calculated as a * variable-base place-value number. For instance, if the array of parent * dimensions is [3, 5, 7] and the parent value combination is [2, 4, 5], * then the row number is (7 * (5 * (3 * 0 + 2) + 4)) + 5 = 103. This is the * inverse function to getVariableValues(). <p> Note: If the node has n * values, the length of 'values' must be >= the number of parents. Only the * first n values are used. * * @param nodeIndex * @param values * @return the row in the table for the given node and combination of parent * values. * @see #getParentValues */ public int getRowIndex(int nodeIndex, int[] values) { int[] dim = getParentDims(nodeIndex); int rowIndex = 0; for (int i = 0; i < dim.length; i++) { rowIndex *= dim[i]; rowIndex += values[i]; } return rowIndex; } /** * Normalizes all rows in the tables associated with each of node in turn. */ public void normalizeAll() { for (int nodeIndex = 0; nodeIndex < nodes.length; nodeIndex++) { normalizeNode(nodeIndex); } } /** * Normalizes all rows in the table associated with a given node. */ public void normalizeNode(int nodeIndex) { for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { normalizeRow(nodeIndex, rowIndex); } } /** * Normalizes the given row. */ public void normalizeRow(int nodeIndex, final int rowIndex) { final int numColumns = getNumColumns(nodeIndex); double total = 0.0; for (int colIndex = 0; colIndex < numColumns; colIndex++) { total += getProbability(nodeIndex, rowIndex, colIndex); } if (total != 0.0) { for (int colIndex = 0; colIndex < numColumns; colIndex++) { double probability = getProbability(nodeIndex, rowIndex, colIndex); double prob = probability / total; setProbability(nodeIndex, rowIndex, colIndex, prob); } } else { double prob = 1.0 / numColumns; for (int colIndex = 0; colIndex < numColumns; colIndex++) { setProbability(nodeIndex, rowIndex, colIndex, prob); } } } /** * Sets the probability for the given node at a given row and column in the * table for that node. To get the node index, use getNodeIndex(). To get * the row index, use getRowIndex(). To get the column index, use * getCategoryIndex() from the underlying BayesPm(). The value returned * will represent a conditional probability of the form P(N=v0 | P1=v1, * P2=v2, ... , Pn=vn), where N is the node referenced by nodeIndex, v0 is * the value referenced by colIndex, and the combination of parent values * indicated is the combination indicated by rowIndex. * * @param nodeIndex the index of the node in question. * @param rowIndex the row in the table for this for node which represents * the combination of parent values in question. * @param colIndex the column in the table for this node which represents * the value of the node in question. * @param value the desired probability to be set. * @see #getProbability */ public void setProbability(int nodeIndex, int rowIndex, int colIndex, double value) { if (colIndex >= getNumColumns(nodeIndex)) { throw new IllegalArgumentException( "Column out of range: " + colIndex + " >= " + getNumColumns(nodeIndex)); } if (!(0.0 <= value && value <= 1.0) && !Double.isNaN(value)) { throw new IllegalArgumentException("Probability value must be " + "between 0.0 and 1.0 or Double.NaN."); } probs[nodeIndex][rowIndex][colIndex] = value; } /** * Returns the index of the node with the given name in the specified * BayesIm. */ public int getCorrespondingNodeIndex(int nodeIndex, BayesIm otherBayesIm) { String nodeName = getNode(nodeIndex).getName(); Node oldNode = otherBayesIm.getNode(nodeName); return otherBayesIm.getNodeIndex(oldNode); } /** * Assigns random probability values to the child values of this row that * add to 1. * * @param nodeIndex the node for the table that this row belongs to. * @param rowIndex the index of the row. */ public void clearRow(int nodeIndex, int rowIndex) { for (int colIndex = 0; colIndex < getNumColumns(nodeIndex); colIndex++) { setProbability(nodeIndex, rowIndex, colIndex, Double.NaN); } } /** * Assigns random probability values to the child values of this row that * add to 1. * * @param nodeIndex the node for the table that this row belongs to. * @param rowIndex the index of the row. */ public void randomizeRow(int nodeIndex, int rowIndex) { final int size = getNumColumns(nodeIndex); probs[nodeIndex][rowIndex] = getRandomWeights(size); } public void randomizeRow2(int nodeIndex, int rowIndex, double[] biases) { final int size = getNumColumns(nodeIndex); probs[nodeIndex][rowIndex] = getRandomWeights2(size, biases); } private static double[] getRandomWeights2(int size, double[] biases) { assert size >= 0; double[] row = new double[size]; double sum = 0.0; for (int i = 0; i < size; i++) { // row[i] = RandomUtil.getInstance().nextDouble() + biases[i]; row[i] = RandomUtil.getInstance().nextUniform(0, biases[i]); sum += row[i]; } for (int i = 0; i < size; i++) { row[i] /= sum; } return row; } /** * Randomizes any row in the table for the given node index that has a * Double.NaN value in it. * * @param nodeIndex the node for the table whose incomplete rows are to be * randomized. */ public void randomizeIncompleteRows(int nodeIndex) { for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { if (isIncomplete(nodeIndex, rowIndex)) { randomizeRow(nodeIndex, rowIndex); } } } /** * Randomizes every row in the table for the given node index. * * @param nodeIndex the node for the table to be randomized. */ public void randomizeTable(int nodeIndex) { // for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { // randomizeRow(nodeIndex, rowIndex); // } randomizeTable4(nodeIndex); } private void randomizeTable2(int nodeIndex) { boolean existsIncomplete = true; for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { if (isIncomplete(nodeIndex, rowIndex)) { existsIncomplete = true; break; } } if (!existsIncomplete) return; // Trying for some more power ..jdramsey 5/7/10 List<Integer> rowIndices = new ArrayList<Integer>(); for (int i = 0; i < getNumRows(nodeIndex); i++) { rowIndices.add(i); } Collections.shuffle(rowIndices); randomizeRow(nodeIndex, rowIndices.get(0)); double[][] values = new double[getNumRows(nodeIndex)][getNumColumns(nodeIndex)]; for (int row = 0; row < getNumRows(nodeIndex); row++) { double bestNorm = 0.0; for (int trial = 0; trial < 100; trial++) { randomizeRow(nodeIndex, rowIndices.get(row)); double totalNorm = 0.0; for (int _row = row - 1; _row < row; _row++) { double norm = norm(nodeIndex, rowIndices.get(row), rowIndices.get(_row)); totalNorm += norm; } if (totalNorm > bestNorm) { bestNorm = totalNorm; for (int _row = 0; _row < getNumRows(nodeIndex); _row++) { for (int col = 0; col < getNumColumns(nodeIndex); col++) { values[_row][col] = getProbability(nodeIndex, _row, col); } } } } for (int _row = 0; _row < getNumRows(nodeIndex); _row++) { for (int col = 0; col < getNumColumns(nodeIndex); col++) { setProbability(nodeIndex, _row, col, values[_row][col]); } } } } private void randomizeTable3(int nodeIndex) { for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { randomizeRow(nodeIndex, rowIndex); } double[][] saved = new double[getNumRows(nodeIndex)][getNumColumns(nodeIndex)]; double[][] stored = probs[nodeIndex]; copy(stored, saved); double maxSumSpread = 0.0; for (int i = 0; i < 100; i++) { for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { randomizeRow(nodeIndex, rowIndex); } double sumSpread = 0.0; for (int c = 0; c < getNumColumns(nodeIndex); c++) { double min = 1.0, max = 0.0; for (int r = 0; r < getNumRows(nodeIndex); r++) { double p = getProbability(nodeIndex, r, c); if (p <= min) min = p; if (p >= max) max = p; } sumSpread += abs(max - min); } if (sumSpread > maxSumSpread) { copy(stored, saved); maxSumSpread = sumSpread; } } for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { copy(saved, stored); } } private void randomizeTable4(int nodeIndex) { for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { randomizeRow(nodeIndex, rowIndex); } double[][] saved = new double[getNumRows(nodeIndex)][getNumColumns(nodeIndex)]; double max = Double.NEGATIVE_INFINITY; for (int i = 0; i < 1000; i++) { for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { // randomizeRow(nodeIndex, rowIndex); randomizeRow2(nodeIndex, rowIndex, probs[nodeIndex][rowIndex]); } double score = score(nodeIndex); if (score > max) { max = score; copy(probs[nodeIndex], saved); } // if (score == getNumParents(nodeIndex)) { // break; // } } for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { copy(saved, probs[nodeIndex]); } } private double score(int nodeIndex) { double[][] p = new double[getNumRows(nodeIndex)][getNumColumns(nodeIndex)]; copy(probs[nodeIndex], p); double score = 0.0; int num = 0; int numRows = getNumRows(nodeIndex); for (int r = 0; r < p.length; r++) { for (int c = 0; c < p[0].length; c++) { p[r][c] /= numRows; } } int[] parents = getParents(nodeIndex); for (int t = 0; t < parents.length; t++) { int numParentValues = getParentDim(nodeIndex, t); int numColumns = getNumColumns(nodeIndex); double[][] table = new double[numParentValues][numColumns]; for (int childCol = 0; childCol < numColumns; childCol++) { for (int parentValue = 0; parentValue < numParentValues; parentValue++) { for (int row = 0; row < numRows; row++) { if (getParentValues(nodeIndex, row)[t] == parentValue) { table[parentValue][childCol] += p[row][childCol]; } } } } double N = 1000.0; for (int r = 0; r < table.length; r++) { for (int c = 0; c < table[0].length; c++) { table[r][c] *= N; } } double chisq = 0.0; for (int r = 0; r < table.length; r++) { for (int c = 0; c < table[0].length; c++) { double _sumRow = sumRow(table, r); double _sumCol = sumCol(table, c); double exp = (_sumRow / N) * (_sumCol / N) * N; double obs = table[r][c]; chisq += pow(obs - exp, 2) / exp; } } int dof = (table.length - 1) * (table[0].length - 1); ChiSquaredDistribution distribution = new ChiSquaredDistribution(dof); double prob = 1 - distribution.cumulativeProbability(chisq); num += prob < 0.0001 ? 1 : 0; score += log(prob); } // return num == parents.length ? -score : 0; return num; } private double sumCol(double[][] marginals, int j) { double sum = 0.0; for (int h = 0; h < marginals.length; h++) { sum += marginals[h][j]; } return sum; } private double sumRow(double[][] marginals, int i) { double sum = 0.0; for (int h = 0; h < marginals[i].length; h++) { sum += marginals[i][h]; } return sum; } private void copy(double[][] a, double[][] b) { for (int r = 0; r < a.length; r++) { System.arraycopy(a[r], 0, b[r], 0, a[r].length); } } private double totalNorm(int nodeIndex, int parent, int cat1, int cat2) { double[] sumProbs1 = new double[getNumColumns(nodeIndex)]; double[] sumProbs2 = new double[getNumColumns(nodeIndex)]; for (int row = 0; row < getNumRows(nodeIndex); row++) { for (int col = 0; col < getNumColumns(nodeIndex); col++) { if (getParentValues(nodeIndex, row)[parent] == cat1) { sumProbs1[col] += getProbability(nodeIndex, row, col); } } } for (int row = 0; row < getNumRows(nodeIndex); row++) { for (int col = 0; col < getNumColumns(nodeIndex); col++) { if (getParentValues(nodeIndex, row)[parent] == cat2) { sumProbs2[col] += getProbability(nodeIndex, row, col); } } } double norm = 0.0; for (int col = 0; col < getNumColumns(nodeIndex); col++) { double value1 = sumProbs1[col]; double value2 = sumProbs2[col]; double diff = value1 - value2; double absNorm = abs(diff); norm += absNorm; } return norm; } private double norm(int nodeIndex, int row1, int row2) { double norm = 0.0; for (int col = 0; col < getNumColumns(nodeIndex); col++) { double value1 = getProbability(nodeIndex, row1, col); double value2 = getProbability(nodeIndex, row2, col); double diff = value1 - value2; double absNorm = abs(diff); // norm += diff * diff; norm += absNorm; } return norm; } /** * Randomizes every row in the table for the given node index. * * @param nodeIndex the node for the table to be randomized. */ public void clearTable(int nodeIndex) { for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { clearRow(nodeIndex, rowIndex); } } /** * Returns true iff one of the values in the given row is Double.NaN. */ public boolean isIncomplete(int nodeIndex, int rowIndex) { for (int colIndex = 0; colIndex < getNumColumns(nodeIndex); colIndex++) { double p = getProbability(nodeIndex, rowIndex, colIndex); if (Double.isNaN(p)) { return true; } } return false; } /** * Returns true iff any value in the table for the given node is * Double.NaN. */ public boolean isIncomplete(int nodeIndex) { for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { if (isIncomplete(nodeIndex, rowIndex)) { return true; } } return false; } /** * Simulates a sample with the given sample size. * * @param sampleSize the sample size. * @param latentDataSaved * @return the simulated sample as a DataSet. */ public DataSet simulateData(int sampleSize, boolean latentDataSaved) { if (getBayesPm().getDag().isTimeLagModel()) { return simulateTimeSeries(sampleSize); } return simulateDataHelper(sampleSize, latentDataSaved); } public DataSet simulateData(DataSet dataSet, boolean latentDataSaved) { return simulateDataHelper(dataSet, latentDataSaved); } private DataSet simulateTimeSeries(int sampleSize) { TimeLagGraph timeSeriesGraph = getBayesPm().getDag().getTimeLagGraph(); List<Node> variables = new ArrayList<Node>(); for (Node node : timeSeriesGraph.getLag0Nodes()) { final DiscreteVariable e = new DiscreteVariable(timeSeriesGraph.getNodeId(node).getName()); e.setNodeType(node.getNodeType()); variables.add(e); } List<Node> lag0Nodes = timeSeriesGraph.getLag0Nodes(); // DataSet fullData = new ColtDataSet(sampleSize, variables); DataSet fullData = new BoxDataSet(new VerticalIntDataBox(sampleSize, variables.size()), variables); Map<Node, Integer> nodeIndices = new HashMap<Node, Integer>(); for (int i = 0; i < lag0Nodes.size(); i++) { nodeIndices.put(lag0Nodes.get(i), i); } Graph contemporaneousDag = timeSeriesGraph.subgraph(lag0Nodes); List<Node> tierOrdering = contemporaneousDag.getCausalOrdering(); int[] tiers = new int[tierOrdering.size()]; for (int i = 0; i < tierOrdering.size(); i++) { tiers[i] = getNodeIndex(tierOrdering.get(i)); } // Construct the sample. int[] combination = new int[tierOrdering.size()]; for (int i = 0; i < sampleSize; i++) { int[] point = new int[nodes.length]; for (int nodeIndex : tiers) { double cutoff = RandomUtil.getInstance().nextDouble(); for (int k = 0; k < getNumParents(nodeIndex); k++) { combination[k] = point[getParent(nodeIndex, k)]; } int rowIndex = getRowIndex(nodeIndex, combination); double sum = 0.0; for (int k = 0; k < getNumColumns(nodeIndex); k++) { double probability = getProbability(nodeIndex, rowIndex, k); if (Double.isNaN(probability)) { throw new IllegalStateException("Some probability " + "values in the BayesIm are not filled in; " + "cannot simulate data."); } sum += probability; if (sum >= cutoff) { point[nodeIndex] = k; break; } } } } return fullData; } /** * Simulates a sample with the given sample size. * * @param sampleSize the sample size. * @param seed the random number generator seed allows you * recreate the simulated data by passing in the same * seed (so you don't have to store the sample data * @param latentDataSaved * @return the simulated sample as a DataSet. */ public DataSet simulateData(int sampleSize, long seed, boolean latentDataSaved) { RandomUtil.getInstance().setSeed(seed); return simulateDataHelper(sampleSize, latentDataSaved); } public DataSet simulateData(DataSet dataSet, long seed, boolean latentDataSaved) { RandomUtil.getInstance().setSeed(seed); return simulateDataHelper(dataSet, latentDataSaved); } /** * Simulates a sample with the given sample size. * * @param sampleSize the sample size. * @param latentDataSaved * @return the simulated sample as a DataSet. */ private DataSet simulateDataHelper(int sampleSize, boolean latentDataSaved) { int numMeasured = 0; int[] map = new int[nodes.length]; List<Node> variables = new LinkedList<Node>(); for (int j = 0; j < nodes.length; j++) { // if (!latentDataSaved && nodes[j].getNodeType() != NodeType.MEASURED) { // continue; // } int numCategories = bayesPm.getNumCategories(nodes[j]); List<String> categories = new LinkedList<String>(); for (int k = 0; k < numCategories; k++) { categories.add(bayesPm.getCategory(nodes[j], k)); } DiscreteVariable var = new DiscreteVariable(nodes[j].getName(), categories); var.setNodeType(nodes[j].getNodeType()); variables.add(var); int index = ++numMeasured - 1; map[index] = j; } DataSet dataSet = new BoxDataSet(new VerticalIntDataBox(sampleSize, variables.size()), variables); constructSample(sampleSize, dataSet, map); if (!latentDataSaved) { dataSet = DataUtils.restrictToMeasured(dataSet); } return dataSet; } /** * Constructs a random sample using the given already allocated data set, to * avoid allocating more memory. */ private DataSet simulateDataHelper(DataSet dataSet, boolean latentDataSaved) { if (dataSet.getNumColumns() != nodes.length) { throw new IllegalArgumentException("When rewriting the old data set, " + "number of variables in data set must equal number of variables " + "in Bayes net."); } int sampleSize = dataSet.getNumRows(); int numVars = 0; int[] map = new int[nodes.length]; List<Node> variables = new LinkedList<Node>(); for (int j = 0; j < nodes.length; j++) { // if (!latentDataSaved && nodes[j].getNodeType() != NodeType.MEASURED) { // continue; // } int numCategories = bayesPm.getNumCategories(nodes[j]); List<String> categories = new LinkedList<String>(); for (int k = 0; k < numCategories; k++) { categories.add(bayesPm.getCategory(nodes[j], k)); } DiscreteVariable var = new DiscreteVariable(nodes[j].getName(), categories); var.setNodeType(nodes[j].getNodeType()); variables.add(var); int index = ++numVars - 1; map[index] = j; } for (int i = 0; i < variables.size(); i++) { Node node = dataSet.getVariable(i); Node _node = variables.get(i); dataSet.changeVariable(node, _node); } constructSample(sampleSize, dataSet, map); if (latentDataSaved) { return dataSet; } else { return DataUtils.restrictToMeasured(dataSet); } } private void constructSample(int sampleSize, DataSet dataSet, int[] map) { // Get a tier ordering and convert it to an int array. Graph graph = getBayesPm().getDag(); Dag dag = (Dag) graph; List<Node> tierOrdering = dag.getCausalOrdering(); int[] tiers = new int[tierOrdering.size()]; for (int i = 0; i < tierOrdering.size(); i++) { tiers[i] = getNodeIndex(tierOrdering.get(i)); } long t1 = System.currentTimeMillis(); // Construct the sample. for (int i = 0; i < sampleSize; i++) { for (int t : tiers) { int[] parentValues = new int[parents[t].length]; for (int k = 0; k < parentValues.length; k++) { parentValues[k] = dataSet.getInt(i, parents[t][k]); } int rowIndex = getRowIndex(t, parentValues); double sum = 0.0; double r = RandomUtil.getInstance().nextDouble(); for (int k = 0; k < getNumColumns(t); k++) { double probability = getProbability(t, rowIndex, k); sum += probability; if (sum >= r) { dataSet.setInt(i, map[t], k); break; } } } } long t2 = System.currentTimeMillis(); System.out.println("Elapsed Sim = " + (t2 - t1) + "ms"); } public boolean equals(Object o) { if (o == this) { return true; } if (!(o instanceof BayesIm)) { return false; } BayesIm otherIm = (BayesIm) o; if (getNumNodes() != otherIm.getNumNodes()) { return false; } for (int i = 0; i < getNumNodes(); i++) { int otherIndex = otherIm.getCorrespondingNodeIndex(i, otherIm); if (otherIndex == -1) { return false; } if (getNumColumns(i) != otherIm.getNumColumns(otherIndex)) { return false; } if (getNumRows(i) != otherIm.getNumRows(otherIndex)) { return false; } for (int j = 0; j < getNumRows(i); j++) { for (int k = 0; k < getNumColumns(i); k++) { double prob = getProbability(i, j, k); double otherProb = otherIm.getProbability(i, j, k); if (Double.isNaN(prob) && Double.isNaN(otherProb)) { continue; } if (abs(prob - otherProb) > ALLOWABLE_DIFFERENCE) { return false; } } } } return true; } /** * Prints out the probability table for each variable. */ public String toString() { StringBuilder buf = new StringBuilder(); NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat(); for (int i = 0; i < getNumNodes(); i++) { buf.append("\n\nNode: ").append(getNode(i)); if (getNumParents(i) == 0) { buf.append("\n"); } else { buf.append("\n\n"); for (int k = 0; k < getNumParents(i); k++) { buf.append(getNode(getParent(i, k))).append("\t"); } } for (int j = 0; j < getNumRows(i); j++) { buf.append("\n"); for (int k = 0; k < getNumParents(i); k++) { buf.append(getParentValue(i, j, k)); if (k < getNumParents(i) - 1) { buf.append("\t"); } } if (getNumParents(i) > 0) { buf.append("\t"); } for (int k = 0; k < getNumColumns(i); k++) { buf.append(nf.format(getProbability(i, j, k))).append("\t"); } } } buf.append("\n"); return buf.toString(); } //=============================PRIVATE METHODS=======================// /** * This method initializes the probability tables for all of the nodes in * the Bayes net. * * @see #initializeNode * @see #randomizeRow */ private void initialize(BayesIm oldBayesIm, int initializationMethod) { parents = new int[this.nodes.length][]; parentDims = new int[this.nodes.length][]; probs = new double[this.nodes.length][][]; for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { initializeNode(nodeIndex, oldBayesIm, initializationMethod); } } /** * This method initializes the node indicated. */ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, int initializationMethod) { Node node = nodes[nodeIndex]; // Set up parents array. Should store the parents of // each node as ints in a particular order. Graph graph = getBayesPm().getDag(); List<Node> parentList = graph.getParents(node); int[] parentArray = new int[parentList.size()]; for (int i = 0; i < parentList.size(); i++) { parentArray[i] = getNodeIndex(parentList.get(i)); } // Sort parent array. Arrays.sort(parentArray); parents[nodeIndex] = parentArray; // Setup dimensions array for parents. int[] dims = new int[parentArray.length]; for (int i = 0; i < dims.length; i++) { Node parNode = nodes[parentArray[i]]; dims[i] = getBayesPm().getNumCategories(parNode); } // Calculate dimensions of table. int numRows = 1; for (int dim : dims) { if (numRows > 1000000 /* Integer.MAX_VALUE / dim*/) { throw new IllegalArgumentException( "The number of rows in the " + "conditional probability table for " + nodes[nodeIndex] + " is greater than 1,000,000 and cannot be " + "represented."); } numRows *= dim; } int numCols = getBayesPm().getNumCategories(node); parentDims[nodeIndex] = dims; probs[nodeIndex] = new double[numRows][numCols]; // Initialize each row. if (initializationMethod == RANDOM) { randomizeTable(nodeIndex); } else { for (int rowIndex = 0; rowIndex < numRows; rowIndex++) { if (oldBayesIm == null) { overwriteRow(nodeIndex, rowIndex, initializationMethod); } else { retainOldRowIfPossible(nodeIndex, rowIndex, oldBayesIm, initializationMethod); } } } } private void overwriteRow(int nodeIndex, int rowIndex, int initializationMethod) { if (initializationMethod == RANDOM) { randomizeRow(nodeIndex, rowIndex); } else if (initializationMethod == MANUAL) { initializeRowAsUnknowns(nodeIndex, rowIndex); } else { throw new IllegalArgumentException("Unrecognized state."); } } /** * This method chooses random probabilities for a row which add up to 1.0. * Random doubles are drawn from a random distribution, and the final row is * then normalized. * * @param size the length of the row. * @return an array with randomly distributed probabilities of this length. * @see #randomizeRow */ private static double[] getRandomWeights(int size) { assert size >= 0; double[] row = new double[size]; double sum = 0.0; // Renders rows more deterministic. double bias = 0; for (int i = 0; i < size; i++) { row[i] = RandomUtil.getInstance().nextDouble(); if (row[i] > 0.5) { row[i] += bias; } sum += row[i]; } for (int i = 0; i < size; i++) { row[i] /= sum; } return row; } private static double[] getRandomWeights2(int size) { assert size >= 0; double[] row = new double[size]; double sum = 0.0; // Renders rows more deterministic. double bias = 2; int index = -1; double max = 0.0; for (int i = 0; i < size; i++) { row[i] = RandomUtil.getInstance().nextDouble(); if (row[i] > max) { max = row[i]; index = i; } } row[index] += bias; for (int i = 0; i < size; i++) { sum += row[i]; } for (int i = 0; i < size; i++) { row[i] /= sum; } return row; } private static double[] getRandomWeights3(int size) { assert size >= 0; double[] row = new double[size]; double sum = 0.0; // Renders rows more deterministic. double bias = 0; for (int i = 0; i < size; i++) { row[i] = RandomUtil.getInstance().nextBeta(2, 5); if (row[i] > 0.5) { row[i] += bias; } sum += row[i]; } for (int i = 0; i < size; i++) { row[i] /= sum; } return row; } private void initializeRowAsUnknowns(int nodeIndex, int rowIndex) { final int size = getNumColumns(nodeIndex); double[] row = new double[size]; Arrays.fill(row, Double.NaN); probs[nodeIndex][rowIndex] = row; } /** * This method initializes the node indicated. */ private void retainOldRowIfPossible(int nodeIndex, int rowIndex, BayesIm oldBayesIm, int initializationMethod) { // Set<Node> newParents = new HashSet<Node>(getBayesPm().getDag().getParents(node)); // Set<Node> oldParents = new HashSet<Node>(oldBayesIm.getBayesPm().getDag().getParents(node)); // int method = newParents == oldParents ? initializationMethod : MlBayesIm.MANUAL; int oldNodeIndex = getCorrespondingNodeIndex(nodeIndex, oldBayesIm); if (oldNodeIndex == -1) { overwriteRow(nodeIndex, rowIndex, initializationMethod); } else if (getNumColumns(nodeIndex) != oldBayesIm.getNumColumns(oldNodeIndex)) { overwriteRow(nodeIndex, rowIndex, initializationMethod); // } else if (parentsChanged(nodeIndex, this, oldBayesIm)) { // overwriteRow(nodeIndex, rowIndex, initializationMethod); } else { int oldRowIndex = getUniqueCompatibleOldRow(nodeIndex, rowIndex, oldBayesIm); if (oldRowIndex >= 0) { copyValuesFromOldToNew(oldNodeIndex, oldRowIndex, nodeIndex, rowIndex, oldBayesIm); } else { overwriteRow(nodeIndex, rowIndex, initializationMethod); } } } // private boolean parentsChanged(int nodeIndex, BayesIm bayesIm, BayesIm oldBayesIm) { // int[] dims = bayesIm.getParents(nodeIndex); // int[] oldDims = oldBayesIm.getParents(nodeIndex); // // if (dims.length != oldDims.length) { // return false; // } // // for (int i = 0; i < dims.length; i++) { // if (dims[i] != oldDims[i]) { // return false; // } // } // // return true; // } /** * Returns the unique rowIndex in the old BayesIm for the given node that is * compatible with the given rowIndex in the new BayesIm for that node, if * one exists. Otherwise, returns -1. A compatible rowIndex is one in which * all the parents that the given node has in common between the old BayesIm * and the new BayesIm are assigned the values they have in the new * rowIndex. If a parent node is removed in the new BayesIm, there may be * more than one such compatible rowIndex in the old BayesIm, in which case * -1 is returned. Likewise, there may be no compatible rows, in which case * -1 is returned. */ private int getUniqueCompatibleOldRow(int nodeIndex, int rowIndex, BayesIm oldBayesIm) { int oldNodeIndex = getCorrespondingNodeIndex(nodeIndex, oldBayesIm); int oldNumParents = oldBayesIm.getNumParents(oldNodeIndex); int[] oldParentValues = new int[oldNumParents]; Arrays.fill(oldParentValues, -1); int[] parentValues = getParentValues(nodeIndex, rowIndex); // Go through each parent of the node in the new BayesIm. for (int i = 0; i < getNumParents(nodeIndex); i++) { // Get the index of the parent in the new graph and in the old // graph. If it's no longer in the new graph, skip to the next // parent. int parentNodeIndex = getParent(nodeIndex, i); int oldParentNodeIndex = getCorrespondingNodeIndex(parentNodeIndex, oldBayesIm); int oldParentIndex = -1; for (int j = 0; j < oldBayesIm.getNumParents(oldNodeIndex); j++) { if (oldParentNodeIndex == oldBayesIm.getParent(oldNodeIndex, j)) { oldParentIndex = j; break; } } if (oldParentIndex == -1 || oldParentIndex >= oldBayesIm.getNumParents(oldNodeIndex)) { return -1; } // Look up that value index for the new BayesIm for that parent. // If it was a valid value index in the old BayesIm, record // that value in oldParentValues. Otherwise return -1. int newParentValue = parentValues[i]; int oldParentDim = oldBayesIm.getParentDim(oldNodeIndex, oldParentIndex); if (newParentValue < oldParentDim) { oldParentValues[oldParentIndex] = newParentValue; } else { return -1; } } // // Go through each parent of the node in the new BayesIm. // for (int i = 0; i < oldBayesIm.getNumParents(oldNodeIndex); i++) { // // // Get the index of the parent in the new graph and in the old // // graph. If it's no longer in the new graph, skip to the next // // parent. // int oldParentNodeIndex = oldBayesIm.getParent(oldNodeIndex, i); // int parentNodeIndex = // oldBayesIm.getCorrespondingNodeIndex(oldParentNodeIndex, this); // int parentIndex = -1; // // for (int j = 0; j < this.getNumParents(nodeIndex); j++) { // if (parentNodeIndex == this.getParent(nodeIndex, j)) { // parentIndex = j; // break; // } // } // // if (parentIndex == -1 || // parentIndex >= this.getNumParents(nodeIndex)) { // continue; // } // // // Look up that value index for the new BayesIm for that parent. // // If it was a valid value index in the old BayesIm, record // // that value in oldParentValues. Otherwise return -1. // int parentValue = oldParentValues[i]; // int parentDim = // this.getParentDim(nodeIndex, parentIndex); // // if (parentValue < parentDim) { // oldParentValues[parentIndex] = oldParentValue; // } else { // return -1; // } // } // If there are any -1's in the combination at this point, return -1. for (int oldParentValue : oldParentValues) { if (oldParentValue == -1) { return -1; } } // Otherwise, return the combination, which will be a row in the // old BayesIm. return oldBayesIm.getRowIndex(oldNodeIndex, oldParentValues); } private void copyValuesFromOldToNew(int oldNodeIndex, int oldRowIndex, int nodeIndex, int rowIndex, BayesIm oldBayesIm) { if (getNumColumns(nodeIndex) != oldBayesIm.getNumColumns(oldNodeIndex)) { throw new IllegalArgumentException( "It's only possible to copy " + "one row of probability values to another in a Bayes IM " + "if the number of columns in the table are the same."); } for (int colIndex = 0; colIndex < getNumColumns(nodeIndex); colIndex++) { double prob = oldBayesIm.getProbability(oldNodeIndex, oldRowIndex, colIndex); setProbability(nodeIndex, rowIndex, colIndex, prob); } } /** * Adds semantic checks to the default deserialization method. This method * must have the standard signature for a readObject method, and the body of * the method must begin with "s.defaultReadObject();". Other than that, any * semantic checks can be specified and do not need to stay the same from * version to version. A readObject method of this form may be added to any * class, even if Tetrad sessions were previously saved out using a version * of the class that didn't include it. (That's what the * "s.defaultReadObject();" is for. See J. Bloch, Effective Java, for help. * * @throws java.io.IOException * @throws ClassNotFoundException */ private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException { s.defaultReadObject(); if (bayesPm == null) { throw new NullPointerException(); } if (nodes == null) { throw new NullPointerException(); } if (parents == null) { throw new NullPointerException(); } if (parentDims == null) { throw new NullPointerException(); } if (probs == null) { throw new NullPointerException(); } } }