edu.pitt.csb.stability.StabilityUtils.java Source code

Java tutorial

Introduction

Here is the source code for edu.pitt.csb.stability.StabilityUtils.java

Source

///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015 by Peter Spirtes, Richard Scheines, Joseph   //
// Ramsey, and Clark Glymour.                                                //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////

package edu.pitt.csb.stability;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.RandomSampler;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.GraphSearch;
import edu.cmu.tetrad.util.ForkJoinPoolInstance;
import edu.cmu.tetrad.util.TetradMatrix;
import edu.pitt.csb.mgm.MGM;
import edu.pitt.csb.mgm.MixedUtils;
import org.apache.commons.math3.util.CombinatoricsUtils;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.RecursiveTask;

/**
 * Runs a search algorithm over a N subsamples of size b to asses stability
 * as in "Stability Selection" and "Stability Approach to Regularization Selection"
 * Created by ajsedgewick on 9/4/15.
 */
public class StabilityUtils {

    //returns an adjacency matrix containing the edgewise instability as defined in Liu et al
    public static DoubleMatrix2D StabilitySearch(DataSet data, DataGraphSearch gs, int N, int b) {
        int numVars = data.getNumColumns();
        DoubleMatrix2D thetaMat = DoubleFactory2D.dense.make(numVars, numVars, 0.0);

        int[][] samps = subSampleNoReplacement(data.getNumRows(), b, N);

        for (int s = 0; s < N; s++) {
            DataSet dataSubSamp = data.subsetRows(samps[s]);
            Graph g = gs.search(dataSubSamp);

            //TODO update graphToMatrix method
            DoubleMatrix2D curAdj = MixedUtils.skeletonToMatrix(g);
            thetaMat.assign(curAdj, Functions.plus);
        }
        thetaMat.assign(Functions.mult(1.0 / N));
        //thetaMat.assign(thetaMat.copy().assign(Functions.minus(1.0)), Functions.mult).assign(Functions.mult(-2.0));
        return thetaMat;
    }

    //returns an adjacency matrix containing the edgewise instability as defined in Liu et al
    public static DoubleMatrix2D StabilitySearchPar(final DataSet data, final DataGraphSearch gs, int N, int b) {

        final int numVars = data.getNumColumns();
        final DoubleMatrix2D thetaMat = DoubleFactory2D.dense.make(numVars, numVars, 0.0);

        final int[][] samps = subSampleNoReplacement(data.getNumRows(), b, N);

        final ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool();

        class StabilityAction extends RecursiveAction {
            private int chunk;
            private int from;
            private int to;

            public StabilityAction(int chunk, int from, int to) {
                this.chunk = chunk;
                this.from = from;
                this.to = to;
            }

            //could avoid using syncronized if we keep track of array of mats and add at end, but that needs lots of
            //memory
            private synchronized void addToMat(DoubleMatrix2D matSum, DoubleMatrix2D curMat) {
                matSum.assign(curMat, Functions.plus);
            }

            @Override
            protected void compute() {
                if (to - from <= chunk) {
                    for (int s = from; s < to; s++) {
                        DataSet dataSubSamp = data.subsetRows(samps[s]).copy();
                        DataGraphSearch curGs = gs.copy();
                        Graph g = curGs.search(dataSubSamp);

                        //TODO update graphToMatrix method
                        DoubleMatrix2D curAdj = MixedUtils.skeletonToMatrix(g); //set weights so that undirected stability works
                        addToMat(thetaMat, curAdj);
                    }

                    return;
                } else {
                    List<StabilityAction> tasks = new ArrayList<StabilityAction>();

                    final int mid = (to - from) / 2;

                    tasks.add(new StabilityAction(chunk, from, from + mid));
                    tasks.add(new StabilityAction(chunk, from + mid, to));

                    invokeAll(tasks);

                    return;
                }
            }

        }

        final int chunk = 2;

        pool.invoke(new StabilityAction(chunk, 0, N));

        thetaMat.assign(Functions.mult(1.0 / N));

        //do this elsewhere
        //thetaMat.assign(thetaMat.copy().assign(Functions.minus(1.0)), Functions.mult).assign(Functions.mult(-2.0));
        return thetaMat;
    }

    //needs a symmetric matrix
    //array of averages of instability matrix over [all, cc, cd, dd] edges
    //TODO directed version
    public static double[] totalInstabilityUndir(DoubleMatrix2D xi, List<Node> vars) {
        if (vars.size() != xi.columns() || vars.size() != xi.rows()) {
            throw new IllegalArgumentException(
                    "stability mat must have same number of rows and columns as there are vars");
        }

        Algebra al = new Algebra();
        //DoubleMatrix2D xiu = MGM.upperTri(xi.copy().assign(al.transpose(xi)),1);

        DoubleMatrix2D xiu = xi.copy().assign(xi.copy().assign(Functions.minus(1.0)), Functions.mult)
                .assign(Functions.mult(-2.0));

        double[] D = new double[4];
        int[] discInds = MixedUtils.getDiscreteInds(vars);
        int[] contInds = MixedUtils.getContinuousInds(vars);
        int p = contInds.length;
        int q = discInds.length;
        double temp = MGM.upperTri(xiu.copy(), 1).zSum();
        D[0] = temp / ((p + q - 1.0) * (p + q) / 2.0);
        temp = MGM.upperTri(xiu.viewSelection(contInds, contInds).copy(), 1).zSum();
        D[1] = temp / (p * (p - 1.0) / 2.0);
        temp = xiu.viewSelection(contInds, discInds).zSum();
        D[2] = temp / (p * q);
        temp = MGM.upperTri(xiu.viewSelection(discInds, discInds).copy(), 1).zSum();
        D[3] = temp / (q * (q - 1.0) / 2.0);

        return D;
    }

    //array of averages of instability matrix over [all, cc, cd, dd] edges
    //TODO directed version
    public static double[] totalInstabilityDir(DoubleMatrix2D xi, List<Node> vars) {
        if (vars.size() != xi.columns() || vars.size() != xi.rows()) {
            throw new IllegalArgumentException(
                    "stability mat must have same number of rows and columns as there are vars");
        }

        double[] D = new double[4];
        int[] discInds = MixedUtils.getDiscreteInds(vars);
        int[] contInds = MixedUtils.getContinuousInds(vars);
        int p = contInds.length;
        int q = discInds.length;
        D[0] = xi.zSum() / ((p + q - 1) * (p + q) / 2);

        D[1] = xi.viewSelection(contInds, contInds).zSum() / (p * (p - 1));
        D[2] = xi.viewSelection(contInds, discInds).zSum() / (p * q);
        D[3] = xi.viewSelection(discInds, discInds).zSum() / (q * (q - 1));

        return D;
    }

    //returns an numSub by subSize matrix of subsamples of the sequence 1:sampSize
    public static int[][] subSampleNoReplacement(int sampSize, int subSize, int numSub) {

        if (subSize < 1) {
            throw new IllegalArgumentException("Sample size must be > 0.");
        }

        List<Integer> indices = new ArrayList<Integer>(sampSize);
        for (int i = 0; i < sampSize; i++) {
            indices.add(i);
        }

        int[][] sampMat = new int[numSub][subSize];

        for (int i = 0; i < numSub; i++) {
            Collections.shuffle(indices);
            int[] curSamp;
            SAMP: while (true) {
                curSamp = subSampleIndices(sampSize, subSize);
                for (int j = 0; j < i; j++) {
                    if (Arrays.equals(curSamp, sampMat[j])) {
                        continue SAMP;
                    }
                }
                break;
            }
            sampMat[i] = curSamp;
        }

        return sampMat;
    }

    private static int[] subSampleIndices(int N, int subSize) {
        List<Integer> indices = new ArrayList<Integer>(N);
        for (int i = 0; i < N; i++) {
            indices.add(i);
        }

        Collections.shuffle(indices);
        int[] samp = new int[subSize];
        for (int i = 0; i < subSize; i++) {
            samp[i] = indices.get(i);
        }
        return samp;
    }

    //some tests...
    public static void main(String[] args) {
        String fn = "/Users/ajsedgewick/tetrad_mgm_runs/run2/networks/DAG_0_graph.txt";
        Graph trueGraph = GraphUtils.loadGraphTxt(new File(fn));
        DataSet ds = null;
        try {
            ds = MixedUtils.loadData("/Users/ajsedgewick/tetrad_mgm_runs/run2/data/", "DAG_0_data.txt");
        } catch (Throwable t) {
            t.printStackTrace();
        }

        double lambda = .1;
        SearchWrappers.MGMWrapper mgm = new SearchWrappers.MGMWrapper(new double[] { lambda, lambda, lambda });
        long start = System.currentTimeMillis();
        DoubleMatrix2D xi = StabilitySearch(ds, mgm, 8, 200);
        long end = System.currentTimeMillis();
        System.out.println("Not parallel: " + ((end - start) / 1000.0));

        start = System.currentTimeMillis();
        DoubleMatrix2D xi2 = StabilitySearchPar(ds, mgm, 8, 200);
        end = System.currentTimeMillis();
        System.out.println("Parallel: " + ((end - start) / 1000.0));

        System.out.println(xi);
        System.out.println(xi2);
    }
}