smile.wide.Network.java Source code

Java tutorial

Introduction

Here is the source code for smile.wide.Network.java

Source

/*
         Licensed to the DARPA XDATA project.
   DARPA XDATA licenses this file to you under the 
     Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
       You may obtain a copy of the License at
    
     http://www.apache.org/licenses/LICENSE-2.0
    
  Unless required by applicable law or agreed to in writing,
 software distributed under the License is distributed on an
 "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
             either express or implied.                    
   See the License for the specific language governing
 permissions and limitations under the License.
 */

package smile.wide;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.util.ToolRunner;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

import smile.wide.BNQuery;
import smile.wide.InferenceDriver;
import smile.wide.data.DataSet;
import smile.wide.data.DataSetReader;
import smile.wide.hadoop.io.DoubleArrayWritable;

/**
 * The main class of SMILE-WIDE. Forwards most (construction) calls to an embedded
 * instance of smile.Network, but adds the infer() method which executes
 * on Hadoop infrastructure.
 * 
 * @author robert.e.cranfill@boeing.com
 * @author tomas.singliar@boeing.com
 * 
 * @see  <a href="http://genie.sis.pitt.edu/wiki/SMILE.NET:_Network" target="blank">SMILE online documentation</a>
 * 
 */
public class Network implements NetworkIF {
    private smile.Network hiddenNetwork_;
    private Logger logger_;

    // Hadoop execution
    private Configuration conf_ = null;
    private Path jobHDFSPath_ = null; // the HDFS path of the last Hadoop job
    private Path tempDir_ = null;

    // intermediate result
    private int[] instanceIDs_ = null;
    private float[][] posteriors_ = null; // each row corresponds to an instance

    // each column to a query variable's possible value

    // ========================================================================
    // Setting the network up - calls forwarded to jsmile.Network

    public Network() {
        logger_ = Logger.getLogger(Network.class);
        logger_.setLevel(Level.DEBUG);

        hiddenNetwork_ = new smile.Network();
    }

    public Network(Logger logger) {
        logger_ = logger;
        hiddenNetwork_ = new smile.Network();
    }

    public Network(String networkFileName) {
        logger_ = Logger.getLogger(Network.class);
        logger_.setLevel(Level.DEBUG);
        hiddenNetwork_ = new smile.Network();
        hiddenNetwork_.readFile(networkFileName);
    }

    // ========================================================================
    // Setting the network up - calls forwarded to jsmile.Network

    @Override
    // smilewide.NetworkIF
    public void addOutcome(int nodeHandle, String outcomeId) {
        logger_.debug("smilewide.addOutcome(int nodeHandle, String outcomeId)");
        hiddenNetwork_.addOutcome(nodeHandle, outcomeId);
    }

    public void addOutcome(String nodeId, String outcomeId) {
        logger_.debug("smilewide.addOutcome(String nodeId, String outcomeId)");
        hiddenNetwork_.addOutcome(nodeId, outcomeId);
    }

    @Override
    // smilewide.NetworkIF
    public int addNode(int cpt, String string) {
        logger_.debug("smilewide.addNode");
        return hiddenNetwork_.addNode(cpt, string);
    }

    @Override
    // smilewide.NetworkIF
    public void deleteOutcome(int nodeHandle, int outcomeIndex) {
        logger_.debug("smilewide.deleteOutcome(int nodeHandle, int outcomeIndex)");
        hiddenNetwork_.deleteOutcome(nodeHandle, outcomeIndex);
    }

    @Override
    // smilewide.NetworkIF
    public void deleteOutcome(String nodeId, String outcomeId) {
        logger_.debug("smilewide.deleteOutcome(String nodeId, String outcomeId)");
        hiddenNetwork_.deleteOutcome(nodeId, outcomeId);
    }

    @Override
    // smilewide.NetworkIF
    public void setNodeDefinition(int nodeHandle, double[] definition) {
        logger_.debug("smilewide.setNodeDefinition(int nodeHandle, double[] definition)");
        hiddenNetwork_.setNodeDefinition(nodeHandle, definition);
    }

    @Override
    // smilewide.NetworkIF
    public void setNodeDefinition(String nodeId, double[] definition) {
        logger_.debug("smilewide.setNodeDefinition(String nodeId, double[] definition)");
        hiddenNetwork_.setNodeDefinition(nodeId, definition);
    }

    @Override
    // smilewide.NetworkIF
    public void setOutcomeId(int nodeHandle, int outcomeIndex, String id) {
        logger_.debug("smilewide.setOutcomeId");
        hiddenNetwork_.setOutcomeId(nodeHandle, outcomeIndex, id);
    }

    @Override
    // smilewide.NetworkIF
    public void updateBeliefs() {
        logger_.debug("smilewide.updateBeliefs");
        hiddenNetwork_.updateBeliefs();
    }

    @Override
    // smilewide.NetworkIF
    public int getNode(String nodeId) {
        logger_.debug("smilewide.getNode");
        return hiddenNetwork_.getNode(nodeId);
    }

    @Override
    // smilewide.NetworkIF
    public void addArc(int parentHandle, int childHandle) {
        logger_.debug("smilewide.addArc");
        hiddenNetwork_.addArc(parentHandle, childHandle);
    }

    @Override
    // smilewide.NetworkIF
    public void setEvidence(int nodeHandle, int outcomeIndex) {
        logger_.debug("smilewide.setEvidence");
        hiddenNetwork_.setEvidence(nodeHandle, outcomeIndex);
    }

    @Override
    // smilewide.NetworkIF
    public double[] getNodeDefinition(String nodeId) {
        logger_.debug("getNodeDefinition");
        return hiddenNetwork_.getNodeDefinition(nodeId);
    }

    @Override
    // smilewide.NetworkIF
    public double[] getNodeValue(int nodeHandle) {
        logger_.debug("smilewide.getNodeValue");
        return hiddenNetwork_.getNodeValue(nodeHandle);
    }

    @Override
    // smilewide.NetworkIF
    public double[] getNodeValue(String nodeID) {
        logger_.debug("smilewide.getNodeValue");
        return hiddenNetwork_.getNodeValue(nodeID);
    }

    @Override
    // smilewide.NetworkIF
    public void clearAllEvidence() {
        logger_.debug("smilewide.clearAllEvidence");
        hiddenNetwork_.clearAllEvidence();
    }

    @Override
    // smilewide.NetworkIF
    public int getNodeCount() {
        logger_.debug("smilewide.getNodeCount");
        return hiddenNetwork_.getNodeCount();
    }

    @Override
    // smilewide.NetworkIF
    public int getOutcomeCount(String nodeId) {
        logger_.debug("getOutcomeCount");
        return hiddenNetwork_.getOutcomeCount(nodeId);
    }

    @Override
    // smilewide.NetworkIF
    public String[] getOutcomeIds(int nodeHandle) {
        logger_.debug("smilewide.getOutcomeIds");
        return hiddenNetwork_.getOutcomeIds(nodeHandle);
    }

    @Override
    // smilewide.NetworkIF
    public String[] getOutcomeIds(String nodeId) {
        logger_.debug("smilewide.getOutcomeIds");
        return hiddenNetwork_.getOutcomeIds(nodeId);
    }

    @Override
    // smilewide.NetworkIF
    public void readFile(String fileName) {
        logger_.debug("smilewide.readFile");
        hiddenNetwork_.readFile(fileName);
    }

    @Override
    // smilewide.NetworkIF
    public void setEvidence(int nodeHandle, String outcomeId) {
        logger_.debug("smilewide.setEvidence");
        hiddenNetwork_.setEvidence(nodeHandle, outcomeId);
    }

    @Override
    // smilewide.NetworkIF
    public void writeFile(String fileName) {
        logger_.debug("smilewide.writeFile");
        hiddenNetwork_.writeFile(fileName);
    }

    // ========================================================================================
    // Accessors

    /**
     * Set the SMILE network using an XSDL-formatted string.
     * 
     * @param xdsl   The string representation of the network.
     */
    public void setNetwork(String xdsl) {
        logger_.debug("smilewide.setNetwork");
        hiddenNetwork_.readString(xdsl);
    }

    /**
     * Get the SMILE network as an XDSL-formatted string.
     * 
     * @return net The string representation of the network.
     */
    public String getNetworkString() {
        return hiddenNetwork_.writeString();
    }

    /**
     * Set the underlying smile.Network.
     * 
     * @param network   The network.
     */
    public void setNetwork(smile.Network network) {
        hiddenNetwork_ = network;
    }

    /**
     * Get the underlying smile.Network.
     * 
     * @return A smile.Network object.
     */
    public smile.Network getNetwork() {
        return hiddenNetwork_;
    }

    // ========================================================================================
    // BigData calls

    /**
     * Runs inference in parallel on a large dataset. The result is stored as a two dimensional array -
     * major dimension corresponding to the instance, and minor to the query variable value.
     * 
     * Result can be retrieved by calling inferenceResult().
     * 
     * Preconditions for correct use:
     * - all variables in the evidence set are present in the BN
     * - the query variable is present in the BN
     * - the names of all variables in the evidence set are names of columns in the DataSet
     * (if the variable is not in the evidence, why condition on it?)
     * 
     * 
     * @param dataset   The dataset to run inference on
     * @param q         The query to ask of each instance
     */
    public void infer(DataSet dataset, DataSetReader<?, ?> reader, BNQuery q) {
        InferenceDriver id = new InferenceDriver();
        conf_ = id.getConf(); // get the inference driver's config and set it up

        // find out the working location
        try {
            Job j = new Job(conf_);
            jobHDFSPath_ = j.getWorkingDirectory();
        } catch (IOException e1) {
            e1.printStackTrace();
            return;
        }

        int r = (new Random()).nextInt();
        tempDir_ = new Path(jobHDFSPath_ + "/tmp/infresult_" + r);

        // write the network somewhere into HDFS - relies on two subsequent jobs
        // starting in the same directory, usually user home
        // TODO: communicate the actual location to the subsequent job?
        String name = hiddenNetwork_.getName() + ".xdsl";
        try {

            hiddenNetwork_.writeFile("/tmp/" + name);
            FileSystem fs = FileSystem.get(conf_);
            fs.mkdirs(new Path(jobHDFSPath_ + "/tmp/"));
            fs.moveFromLocalFile(new Path(name), new Path(jobHDFSPath_ + "/tmp/" + name));
        } catch (IOException e) {
            logger_.error("I/O Error recording the Bayes network " + name + " to " + jobHDFSPath_ + "/tmp/" + name);
            e.printStackTrace();
        }
        conf_.set("xdata.bayesnets.networkfile", jobHDFSPath_ + "/tmp/" + name);

        // tell the driver the reader class
        conf_.set("xdata.bayesnets.datasetreader.class", reader.getClass().getName());

        // pull out the column indices that correspond to the query variables
        // DataSet should know which columns are which

        ArrayList<String> evVars = q.getEvidenceVars();
        String qvar = q.getQueryVar();
        int[] colIndices = new int[evVars.size() + 1];
        String[] colNames = new String[evVars.size() + 1];

        colIndices[0] = dataset.indexOfColumn(qvar);

        for (int i = 1; i <= evVars.size(); ++i) {
            colIndices[i] = dataset.indexOfColumn(evVars.get(i - 1));
            colNames[i] = evVars.get(i - 1);
        }
        Arrays.sort(colIndices);

        conf_.set("xdata.bayesnets.datasetreader.filter", concat(colIndices, ","));

        // name the variables to which the dataset columns map
        conf_.set("xdata.bayesnets.datasetreader.variablenames", concat(colNames, ","));

        int instID = dataset.instanceIDColumnIndex();
        if (instID == -1) {
            logger_.error("No instance ID column index in dataset " + dataset.getName()
                    + ". Dataset must provide one for inference.");
            return;
        }
        conf_.set("xdata.bayesnets.datasetreader.instid", "" + instID);

        conf_.set("xdata.bayesnets.queryvariable", q.getQueryVar());

        // arguments for the inference driver are the location of the dataset and where
        // to write
        String[] args = new String[2]; // 2 arguments
        args[0] = dataset.location().toString();
        args[1] = tempDir_.toString();

        try {
            id.setConf(conf_); // make sure we are not setting up a side copy of the conf...
            ToolRunner.run(id, args);
        } catch (Exception e) {
            logger_.error("Something went wrong in executing the inference job");
            e.printStackTrace();
        }

    }

    /**
     * Retrieve the result from file and be ready to return it as an array
     * from inferenceResult()
     */
    public void retrieveResult() {

        FileSystem fs;
        try {
            fs = FileSystem.get(conf_);

            FileStatus[] stats = fs.listStatus(tempDir_);

            SequenceFile.Reader reader = null;
            ArrayList<Integer> instids = new ArrayList<Integer>();
            ArrayList<float[]> posts = new ArrayList<float[]>();

            // read all output files
            for (FileStatus stat : stats) {
                if (stat.getPath().toUri().toString().contains("part-r-"))
                    try {
                        logger_.info("Reading results from " + stat.getPath());
                        reader = new SequenceFile.Reader(fs, stat.getPath(), conf_);
                        IntWritable key = new IntWritable();
                        DoubleArrayWritable value = new DoubleArrayWritable();

                        while (reader.next(key, value)) {
                            instids.add(key.get());
                            DoubleWritable[] unpack = (DoubleWritable[]) value.get();
                            float[] post = new float[unpack.length];
                            for (int i = 0; i < unpack.length; ++i) {
                                post[i] = (float) unpack[i].get();
                            }
                            posts.add(post);
                        }
                    } finally {
                        IOUtils.closeStream(reader);
                    }
            }

            // copy over the result
            assert (instids.size() == posts.size());
            instanceIDs_ = new int[instids.size()];
            posteriors_ = new float[posts.size()][];
            for (int i = 0; i < instids.size(); ++i) {
                instanceIDs_[i] = instids.get(i);
                posteriors_[i] = posts.get(i);
            }

        } catch (IOException e) {
            logger_.error("Could not read in the inference results");
            e.printStackTrace();
        }

    }

    public int[] getInstanceIDs() {
        return instanceIDs_;
    }

    public float[][] inferenceResult() {
        return posteriors_;
    }

    // ================================================================================
    // Little helpers

    /**
     * Concatenate values using the given delimiter
     * 
     * @param colIndices
     * @param delim
     * @return
     */
    private String concat(int[] colIndices, String delim) {

        String result = "";
        for (int i = 0; i < colIndices.length - 1; ++i) {
            result += colIndices[i] + delim;
        }

        if (colIndices.length > 0) {
            result += colIndices[colIndices.length - 1];
        }

        return result;
    }

    /**
     * Concatenate values using the given delimiter
     * 
     * @param colIndices
     * @param delim
     * @return
     */
    private <T> String concat(T[] colIndices, String delim) {

        String result = "";
        for (int i = 0; i < colIndices.length - 1; ++i) {
            result += colIndices[i] + delim;
        }

        if (colIndices.length > 0) {
            result += colIndices[colIndices.length - 1];
        }

        return result;
    }

}