Java tutorial
/////////////////////////////////////////////////////////////////////////////// // For information as to what this class does, see the Javadoc, below. // // Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // // 2007, 2008, 2009, 2010, 2014, 2015 by Peter Spirtes, Richard Scheines, Joseph // // Ramsey, and Clark Glymour. // // // // This program is free software; you can redistribute it and/or modify // // it under the terms of the GNU General Public License as published by // // the Free Software Foundation; either version 2 of the License, or // // (at your option) any later version. // // // // This program is distributed in the hope that it will be useful, // // but WITHOUT ANY WARRANTY; without even the implied warranty of // // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // // GNU General Public License for more details. // // // // You should have received a copy of the GNU General Public License // // along with this program; if not, write to the Free Software // // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // /////////////////////////////////////////////////////////////////////////////// package edu.pitt.csb.mgm; import cern.colt.matrix.DoubleFactory2D; import cern.colt.matrix.DoubleMatrix1D; import cern.colt.matrix.DoubleMatrix2D; import cern.colt.matrix.linalg.Algebra; import cern.jet.math.Functions; import edu.cmu.tetrad.data.ContinuousVariable; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.data.DiscreteVariable; import edu.cmu.tetrad.data.ICovarianceMatrix; import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.regression.LogisticRegression; import edu.cmu.tetrad.regression.RegressionDataset; import edu.cmu.tetrad.regression.RegressionResult; import edu.cmu.tetrad.search.IndependenceTest; import edu.cmu.tetrad.search.SearchLogUtils; import edu.cmu.tetrad.util.TetradLogger; import edu.cmu.tetrad.util.TetradMatrix; import org.apache.commons.math3.distribution.ChiSquaredDistribution; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.*; /** * Performs a test of conditional independence X _||_ Y | Z1...Zn where all searchVariables are either continuous or discrete. * This test is valid for both ordinal and non-ordinal discrete searchVariables. * <p> * This logisticRegression makes multiple assumptions: 1. IIA 2. Large sample size (multiple regressions needed on subsets of * sample) * * @author Joseph Ramsey * @author Augustus Mayo. */ public class IndTestMultinomialAJ implements IndependenceTest { private DataSet originalData; private List<Node> searchVariables; private DataSet internalData; private double alpha; private double lastP = -1.0; private Map<Node, List<Node>> variablesPerNode = new HashMap<Node, List<Node>>(); private LogisticRegression logisticRegression; private RegressionDataset regression; private boolean verbose = false; private DoubleFactory2D factory2D = DoubleFactory2D.dense; public IndTestMultinomialAJ(DataSet data, double alpha) { this.searchVariables = data.getVariables(); this.originalData = data.copy(); DataSet internalData = data.copy(); this.alpha = alpha; List<Node> variables = internalData.getVariables(); for (Node node : variables) { List<Node> nodes = expandVariable(internalData, node); variablesPerNode.put(node, nodes); } this.internalData = internalData; this.logisticRegression = new LogisticRegression(internalData); this.regression = new RegressionDataset(internalData); } /** * @return an Independence test for a subset of the searchVariables. */ public IndependenceTest indTestSubset(List<Node> vars) { throw new UnsupportedOperationException(); } /** * @return true if the given independence question is judged true, false if not. The independence question is of the * form x _||_ y | z, z = <z1,...,zn>, where x, y, z1,...,zn are searchVariables in the list returned by * getVariableNames(). */ public boolean isIndependent(Node x, Node y, List<Node> z) { if (x instanceof DiscreteVariable && y instanceof DiscreteVariable) { return isIndependentMultinomialLogisticRegression(x, y, z); } else if (x instanceof DiscreteVariable) { return isIndependentRegression(y, x, z); } else { return isIndependentRegression(x, y, z); } } private List<Node> expandVariable(DataSet dataSet, Node node) { if (node instanceof ContinuousVariable) { return Collections.singletonList(node); } if (node instanceof DiscreteVariable && ((DiscreteVariable) node).getNumCategories() < 3) { return Collections.singletonList(node); } if (!(node instanceof DiscreteVariable)) { throw new IllegalArgumentException(); } List<String> varCats = new ArrayList<String>(((DiscreteVariable) node).getCategories()); // first category is reference varCats.remove(0); List<Node> variables = new ArrayList<Node>(); for (String cat : varCats) { Node newVar; do { String newVarName = node.getName() + "MULTINOM" + "." + cat; newVar = new DiscreteVariable(newVarName, 2); } while (dataSet.getVariable(newVar.getName()) != null); variables.add(newVar); dataSet.addVariable(newVar); int newVarIndex = dataSet.getColumn(newVar); int numCases = dataSet.getNumRows(); for (int l = 0; l < numCases; l++) { Object dataCell = dataSet.getObject(l, dataSet.getColumn(node)); int dataCellIndex = ((DiscreteVariable) node).getIndex(dataCell.toString()); if (dataCellIndex == ((DiscreteVariable) node).getIndex(cat)) dataSet.setInt(l, newVarIndex, 1); else dataSet.setInt(l, newVarIndex, 0); } } return variables; } private boolean isIndependentMultinomialLogisticRegression(Node x, Node y, List<Node> z) { if (!variablesPerNode.containsKey(x)) { throw new IllegalArgumentException("Unrecogized node: " + x); } if (!variablesPerNode.containsKey(y)) { throw new IllegalArgumentException("Unrecogized node: " + y); } for (Node node : z) { if (!variablesPerNode.containsKey(x)) { throw new IllegalArgumentException("Unrecogized node: " + node); } } List<Double> pValues = new ArrayList<Double>(); int[] _rows = getNonMissingRows(x, y, z); logisticRegression.setRows(_rows); List<Node> yzList = new ArrayList<>(); List<Node> zList = new ArrayList<>(); yzList.addAll(variablesPerNode.get(y)); for (Node _z : z) { yzList.addAll(variablesPerNode.get(_z)); zList.addAll(variablesPerNode.get(_z)); } //double[][] coeffsDep = new double[variablesPerNode.get(x).size()][]; DoubleMatrix2D coeffsNull = DoubleFactory2D.dense.make(zList.size() + 1, variablesPerNode.get(x).size()); DoubleMatrix2D coeffsDep = DoubleFactory2D.dense.make(yzList.size() + 1, variablesPerNode.get(x).size()); for (int i = 0; i < variablesPerNode.get(x).size(); i++) { Node _x = variablesPerNode.get(x).get(i); // Without y //List<Node> regressors0 = new ArrayList<Node>(); //for (Node _z : z) { //regressors0.addAll(variablesPerNode.get(_z)); //} // With y. /*List<Node> regressors1 = new ArrayList<Node>(); regressors1.addAll(variablesPerNode.get(y)); for (Node _z : z) { regressors1.addAll(variablesPerNode.get(_z)); }*/ LogisticRegression.Result result0 = logisticRegression.regress((DiscreteVariable) _x, zList); LogisticRegression.Result result1 = logisticRegression.regress((DiscreteVariable) _x, yzList); coeffsNull.viewColumn(i).assign(result0.getCoefs()); coeffsDep.viewColumn(i).assign(result1.getCoefs()); // Returns -2 LL //double ll0 = result0.getLogLikelihood(); //double ll1 = result1.getLogLikelihood(); //double chisq = (ll0 - ll1); //int df = variablesPerNode.get(y).size(); //double p = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(chisq); //pValues.add(p); } double chisq = 2 * (multiLL(coeffsDep, x, yzList) - multiLL(coeffsNull, x, zList)); int df = variablesPerNode.get(y).size() * variablesPerNode.get(x).size(); double p = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(chisq); //double p = 1.0; // Choose the minimum of the p-values // This is only one method that can be used, this requires every coefficient to be significant //for (double val : pValues) { // if (val < p) p = val; //} boolean indep = p > alpha; this.lastP = p; if (verbose) { if (indep) { TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, p)); } else { TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, p)); } } return indep; } int[] _rows = null; // This takes an inordinate amount of time. -jdramsey 20150929 private int[] getNonMissingRows(Node x, Node y, List<Node> z) { // List<Integer> rows = new ArrayList<Integer>(); // // I: // for (int i = 0; i < internalData.getNumRows(); i++) { // for (Node node : variablesPerNode.get(x)) { // if (isMissing(node, i)) continue I; // } // // for (Node node : variablesPerNode.get(y)) { // if (isMissing(node, i)) continue I; // } // // for (Node _z : z) { // for (Node node : variablesPerNode.get(_z)) { // if (isMissing(node, i)) continue I; // } // } // // rows.add(i); // } // int[] _rows = new int[rows.size()]; // for (int k = 0; k < rows.size(); k++) _rows[k] = rows.get(k); if (_rows == null) { _rows = new int[internalData.getNumRows()]; for (int k = 0; k < _rows.length; k++) _rows[k] = k; } return _rows; } private boolean isMissing(Node x, int i) { int j = internalData.getColumn(x); if (x instanceof DiscreteVariable) { int v = internalData.getInt(i, j); if (v == -99) { return true; } } if (x instanceof ContinuousVariable) { double v = internalData.getDouble(i, j); if (Double.isNaN(v)) { return true; } } return false; } private double multiLL(DoubleMatrix2D coeffs, Node dep, List<Node> indep) { if (dep == null) throw new IllegalArgumentException("must have a dependent node to regress on!"); List<Node> depList = new ArrayList<>(); depList.add(dep); DoubleMatrix2D depData = factory2D.make(internalData.subsetColumns(depList).getDoubleData().toArray()); int N = depData.rows(); DoubleMatrix2D indepData; if (indep.size() == 0) indepData = factory2D.make(N, 1, 1.0); else { indepData = factory2D.make(internalData.subsetColumns(indep).getDoubleData().toArray()); indepData = factory2D.appendColumns(factory2D.make(N, 1, 1.0), indepData); } DoubleMatrix2D probs = Algebra.DEFAULT.mult(indepData, coeffs); probs = factory2D.appendColumns(factory2D.make(indepData.rows(), 1, 1.0), probs).assign(Functions.exp); double ll = 0; for (int i = 0; i < N; i++) { DoubleMatrix1D curRow = probs.viewRow(i); curRow.assign(Functions.div(curRow.zSum())); ll += Math.log(curRow.get((int) depData.get(i, 0))); } return ll; } private boolean isIndependentRegression(Node x, Node y, List<Node> z) { /*if (!variablesPerNode.containsKey(x)) { throw new IllegalArgumentException("Unrecogized node: " + x); } if (!variablesPerNode.containsKey(y)) { throw new IllegalArgumentException("Unrecogized node: " + y); } for (Node node : z) { if (!variablesPerNode.containsKey(x)) { throw new IllegalArgumentException("Unrecogized node: " + node); } } List<Node> regressors = new ArrayList<Node>(); regressors.add(internalData.getVariable(y.getName())); for (Node _z : z) { regressors.addAll(variablesPerNode.get(_z)); } int[] _rows = getNonMissingRows(x, y, z); regression.setRows(_rows); RegressionResult result; try { result = regression.regress(x, regressors); } catch (Exception e) { return false; } double p = result.getP()[1]; this.lastP = p; boolean indep = p > alpha; if (verbose) { if (indep) { TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, p)); } else { TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, p)); } } return indep;*/ if (!variablesPerNode.containsKey(x)) { throw new IllegalArgumentException("Unrecogized node: " + x); } if (!variablesPerNode.containsKey(y)) { throw new IllegalArgumentException("Unrecogized node: " + y); } for (Node node : z) { if (!variablesPerNode.containsKey(node)) { throw new IllegalArgumentException("Unrecogized node: " + node); } } List<Node> yzDumList = new ArrayList<>(); List<Node> yzList = new ArrayList<>(); yzList.add(y); yzList.addAll(z); //List<Node> zList = new ArrayList<>(); yzDumList.addAll(variablesPerNode.get(y)); for (Node _z : z) { yzDumList.addAll(variablesPerNode.get(_z)); //zList.addAll(variablesPerNode.get(_z)); } int[] _rows = getNonMissingRows(x, y, z); regression.setRows(_rows); RegressionResult result = null; try { result = regression.regress(x, yzDumList); } catch (Exception e) { e.printStackTrace(); } double[] pVec = new double[yzList.size()]; double[] pCoef = result.getP(); //skip intercept at 0 int coeffInd = 1; for (int i = 0; i < pVec.length; i++) { List<Node> curDummy = variablesPerNode.get(yzList.get(i)); if (curDummy.size() == 1) { pVec[i] = pCoef[coeffInd]; coeffInd++; continue; } else { pVec[i] = 0; } for (Node n : curDummy) { pVec[i] += Math.log(pCoef[coeffInd]); coeffInd++; } if (pVec[i] == Double.NEGATIVE_INFINITY) pVec[i] = 0.0; else pVec[i] = 1.0 - new ChiSquaredDistribution(2 * curDummy.size()).cumulativeProbability(-2 * pVec[i]); } double p = pVec[0]; this.lastP = p; boolean indep = p > alpha; if (verbose) { if (indep) { TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, p)); } else { TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, p)); } } return indep; } public boolean isIndependent(Node x, Node y, Node... z) { List<Node> zList = Arrays.asList(z); return isIndependent(x, y, zList); } /** * @return true if the given independence question is judged false, true if not. The independence question is of the * form x _||_ y | z, z = <z1,...,zn>, where x, y, z1,...,zn are searchVariables in the list returned by * getVariableNames(). */ public boolean isDependent(Node x, Node y, List<Node> z) { return !this.isIndependent(x, y, z); } public boolean isDependent(Node x, Node y, Node... z) { List<Node> zList = Arrays.asList(z); return isDependent(x, y, zList); } /** * @return the probability associated with the most recently executed independence test, of Double.NaN if p value is * not meaningful for tis test. */ public double getPValue() { return this.lastP; //STUB } /** * @return the list of searchVariables over which this independence checker is capable of determinining independence * relations. */ public List<Node> getVariables() { return searchVariables; // Make sure the variables from the ORIGINAL data set are returned, not the modified dataset! } /** * @return the list of variable varNames. */ public List<String> getVariableNames() { List<Node> variables = getVariables(); List<String> variableNames = new ArrayList<String>(); for (Node variable1 : variables) { variableNames.add(variable1.getName()); } return variableNames; } public Node getVariable(String name) { for (int i = 0; i < getVariables().size(); i++) { Node variable = getVariables().get(i); if (variable.getName().equals(name)) { return variable; } } return null; } /** * @return true if y is determined the variable in z. */ public boolean determines(List<Node> z, Node y) { return false; //stub } /** * @return the significance level of the independence test. * @throws UnsupportedOperationException if there is no significance level. */ public double getAlpha() { return this.alpha; //STUB } /** * Sets the significance level. */ public void setAlpha(double alpha) { this.alpha = alpha; } public DataSet getData() { return this.originalData; } @Override public ICovarianceMatrix getCov() { return null; } @Override public List<DataSet> getDataSets() { return null; } @Override public int getSampleSize() { return 0; } @Override public List<TetradMatrix> getCovMatrices() { return null; } @Override public double getScore() { return getPValue(); } /** * @return a string representation of this test. */ public String toString() { NumberFormat nf = new DecimalFormat("0.0000"); return "Multinomial Logistic Regression, alpha = " + nf.format(getAlpha()); } public boolean isVerbose() { return verbose; } public void setVerbose(boolean verbose) { this.verbose = verbose; } }