probcog.bayesnets.learning.DomainLearner.java Source code

Java tutorial

Introduction

Here is the source code for probcog.bayesnets.learning.DomainLearner.java

Source

/*******************************************************************************
 * Copyright (C) 2006-2012 Dominik Jain.
 * 
 * This file is part of ProbCog.
 * 
 * ProbCog is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * ProbCog is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
 ******************************************************************************/
package probcog.bayesnets.learning;

import edu.ksu.cis.bnj.ver3.core.*;

import java.sql.*;
import java.util.*;

import probcog.bayesnets.core.*;
import probcog.clustering.ClusterNamer;

import weka.core.*;
import weka.clusterers.*;
//import de.tum.in.fipm.base.data.QueryResult;

/**
 * learns domains for a certain set of nodes in a Bayesian network when given a
 * set of examples to learn from. The domains of discrete variables can be
 * learnt directly (i.e. the set of outcomes found in the examples becomes the
 * new domain); the domains of continuous variables can be learnt using
 * clustering. For clustering, WEKA's SimpleKMeans clustering algorithm is used,
 * which yields a Gaussian distribution for each cluster (i.e. expected value
 * (centroid) and standard deviation).
 * 
 * @author Dominik Jain
 */
public class DomainLearner extends Learner {

    /**
     * an array of ClusteredDomain objects, where each object contains
     * information on a node whose domain is to be learnt via clustering
     */
    ClusteredDomain[] clusteredDomains;

    /**
     * an array of instance collections for clustering; one collection is to
     * hold all the instances that were encountered for one of the nodes whose
     * domains are learnt via clustering. Each entry in the array corresponds to
     * an entry in clusteredDomains.
     */
    protected Instances[] clusterData;

    protected Attribute attrValue;

    /**
     * an array of clusterers, one for each node whose domain is to be learnt
     * via clustering. The clusterers are built from the instance data when the
     * learning process is ended. Each entry in the array corresponds to an
     * entry in clusteredDomains.
     */
    SimpleKMeans[] clusterers;

    /**
     * an object providing a function for naming clusters
     */
    protected ClusterNamer<SimpleKMeans> clusterNamer;

    /**
     * an array of nodes for which the domains
     * are to be learnt directly from the set of examples (i.e. every value that
     * occurs in the examples is also a possible outcome in the domain);
     */
    protected BeliefNode[] directDomains;

    /**
     * an array of hash sets, where each set contains the outcomes that were
     * encountered so far for one of the entries in array directDomains
     */
    protected Vector<HashSet<String>> directDomainData;

    /**
     * an array of arrays of strings specifying domains that can be transferred
     * from one node to another; may be null. If several nodes essentially share
     * the same domain, the domain need only be learnt for one of the nodes; the
     * learnt domain can then be transferred to the other nodes once the
     * learning is complete. Each item in the array is an array of strings,
     * where the first item is the name of the node for which the domain will be
     * learnt and all subsequent items are the names of nodes to which the
     * learnt domain is to be
     */
    String[][] duplicateDomains;

    protected boolean verbose = false;

    /**
     * holds information on a node whose domain is to be learnt by clustering
     * 
     * @author Dominik Jain
     */
    static public class ClusteredDomain {
        public String nodeName;

        public int numClusters;

        /**
         * @param nodeName
         *            the name of the node
         * @param numClusters
         *            the number of clusters to learn (or 0 if the number of
         *            clusters is to be determined automatically)
         */
        public ClusteredDomain(String nodeName, int numClusters) {
            this.nodeName = nodeName;
            this.numClusters = numClusters;
        }
    }

    /**
     * constructs a DomainLearner object from a BeliefNetworkEx object
     * 
     * @param bn
     *            the belief network
     * @param directDomains
     *            an array of strings containing the names of nodes for which
     *            the domains are to be learnt directly from the set of examples
     *            (i.e. every value that occurs in the examples is also a
     *            possible outcome in the domain); may be null
     * @param clusteredDomains
     *            an array of ClusteredDomain objects, where each object
     *            contains information on a node whose domain is to be learnt
     *            via clustering; may be null
     * @param namer
     *            the namer that is to be used for naming clusters (may be null
     *            if no clustered domains are specified)
     * @param duplicateDomains
     *            an array of arrays of strings specifying domains that can be
     *            transferred from one node to another; may be null. If several
     *            nodes essentially share the same domain, the domain need only
     *            be learnt for one of the nodes; the learnt domain can then be
     *            transferred to the other nodes once the learning is complete.
     *            Each item in the array is an array of strings, where the first
     *            item is the name of the node for which the domain will be
     *            learnt and all subsequent items are the names of nodes to
     *            which the learnt domain is to be transferred.
     * @throws Exception
     */
    public DomainLearner(BeliefNetworkEx bn, String[] directDomains, ClusteredDomain[] clusteredDomains,
            ClusterNamer<SimpleKMeans> namer, String[][] duplicateDomains) throws Exception {
        super(bn);
        init(getBeliefNodes(directDomains), clusteredDomains, namer, duplicateDomains);
    }

    /**
     * constructs a DomainLearner object from a BeliefNetwork object
     * 
     * @param bn
     *            the belief network
     * @param directDomains
     *            an array of strings containing the names of nodes for which
     *            the domains are to be learnt directly from the set of examples
     *            (i.e. every value that occurs in the examples is also a
     *            possible outcome in the domain)
     * @param clusteredDomains
     *            an array of ClusteredDomain objects, where each object
     *            contains information on a node whose domain is to be learnt
     *            via clustering
     * @param namer
     *            the namer that is to be used for naming clusters (may be null
     *            if no clustered domains are specified)
     * @param duplicateDomains
     *            an array of arrays of strings specifying domains that can be
     *            transferred from one node to another. If several nodes
     *            essentially share the same domain, the domain need only be
     *            learnt for one of the nodes; the learnt domain can then be
     *            transferred to the other nodes once the learning is complete.
     *            Each item in the array is an array of strings, where the first
     *            item is the name of the node for which the domain will be
     *            learnt and all subsequent items are the names of nodes to
     *            which the learnt domain is to be transferred.
     * @throws Exception
     */
    public DomainLearner(BeliefNetwork bn, String[] directDomains, ClusteredDomain[] clusteredDomains,
            ClusterNamer<SimpleKMeans> namer, String[][] duplicateDomains) {
        super(bn);
        init(getBeliefNodes(directDomains), clusteredDomains, namer, duplicateDomains);
    }

    /**
     * constructs a DomainLearner where the domains of all nodes are to be
     * learnt directly from the set of examples (i.e. every value that occurs in
     * the examples is also a possible outcome in the domain)
     * 
     * @param bn
     *            the belief network
     */
    public DomainLearner(BeliefNetwork bn) {
        this(new BeliefNetworkEx(bn));
    }

    /**
     * constructs a DomainLearner where the domains of all nodes are to be
     * learnt directly from the set of examples (i.e. every value that occurs in
     * the examples is also a possible outcome in the domain)
     * 
     * @param bn
     *            the belief network
     */
    public DomainLearner(BeliefNetworkEx bn) {
        super(bn);
        init(bn.bn.getNodes(), null, null, null);
    }

    protected BeliefNode[] getBeliefNodes(String[] names) {
        BeliefNode[] nodes = new BeliefNode[names.length];
        for (int i = 0; i < names.length; i++)
            nodes[i] = this.bn.getNode(names[i]);
        return nodes;
    }

    private void init(BeliefNode[] directDomains, ClusteredDomain[] clusteredDomains,
            ClusterNamer<SimpleKMeans> namer, String[][] duplicateDomains) {
        this.clusteredDomains = clusteredDomains;
        attrValue = new Attribute("value");
        if (clusteredDomains != null)
            clusterers = new SimpleKMeans[clusteredDomains.length];
        this.clusterNamer = namer;
        this.directDomains = directDomains;
        this.duplicateDomains = duplicateDomains;

        // create outcome sets for direct domain learning
        if (directDomains != null) {
            directDomainData = new Vector<HashSet<String>>();
            for (int i = 0; i < directDomains.length; i++)
                directDomainData.add(new HashSet<String>());
        }

        // create instance storage for learning of domains using clustering
        if (clusteredDomains != null) {
            clusterData = new Instances[clusteredDomains.length];
            for (int i = 0; i < clusteredDomains.length; i++) {
                FastVector attribs = new FastVector(1);
                attribs.addElement(attrValue);
                clusterData[i] = new Instances(clusteredDomains[i].nodeName, attribs, 100);
            }
        }
    }

    /**
     * learns all the examples in the result set. Each row in the result set
     * represents one example. All the random variables (nodes) that have been
     * scheduled for learning in the constructor need to be found in each result
     * row as columns that are named accordingly, i.e. for each random variable
     * for which the domain is to be learnt, there must be a column with a
     * matching name in the result set.
     * 
     * @param rs
     *            the result set
     * @throws Exception
     *             if the result set is empty
     * @throws SQLException
     *             particularly if there is no matching column for one of the
     *             node names
     */
    public void learn(ResultSet rs) throws Exception, SQLException {
        // if it's an empty result set, throw exception
        if (!rs.next())
            throw new Exception("empty result set!");

        // gather domain data
        int numDirectDomains = directDomains != null ? directDomains.length : 0;
        int numClusteredDomains = clusteredDomains != null ? clusteredDomains.length : 0;
        do {
            // for direct learning, add outcomes to the set of outcomes
            for (int i = 0; i < numDirectDomains; i++) {
                directDomainData.get(i).add(rs.getString(directDomains[i].getName()));
            }
            // for clustering, gather all instances
            for (int i = 0; i < numClusteredDomains; i++) {
                Instance inst = new Instance(1);
                inst.setValue(attrValue, rs.getDouble(clusteredDomains[i].nodeName));
                clusterData[i].add(inst);
            }
        } while (rs.next());
    }

    /**
     * learns all the examples in the result set. Each row in the result set
     * represents one example. All the random variables (nodes) that have been
     * scheduled for learning in the constructor need to be found in each result
     * row as columns that are named accordingly, i.e. for each random variable
     * for which the domain is to be learnt, there must be a column with a
     * matching name in the result set.
     * 
     * @param rs
     *            the result set
     * @throws Exception
     *             if the result set is empty
     * @throws SQLException
     *             particularly if there is no matching column for one of the
     *             node names
     */
    public void learn(Instances instances) throws Exception, SQLException {
        // if it's an empty result set, throw exception
        if (instances.numInstances() == 0)
            throw new Exception("empty result set!");

        // gather domain data
        int numDirectDomains = directDomains != null ? directDomains.length : 0;
        int numClusteredDomains = clusteredDomains != null ? clusteredDomains.length : 0;
        @SuppressWarnings("unchecked")
        Enumeration<Instance> instanceEnum = instances.enumerateInstances();
        while (instanceEnum.hasMoreElements()) {
            Instance instance = instanceEnum.nextElement();
            // for direct learning, add outcomes to the set of outcomes
            for (int i = 0; i < numDirectDomains; i++) {
                directDomainData.get(i).add(instance.stringValue(instances.attribute(directDomains[i].getName())));
            }
            // for clustering, gather all instances
            for (int i = 0; i < numClusteredDomains; i++) {
                Instance inst = new Instance(1);
                inst.setValue(attrValue, instance.value(instances.attribute(clusteredDomains[i].nodeName)));
                clusterData[i].add(inst);
            }
        }
    }

    /**
     * learns an example from a HashMap<String,String>.
     * 
     * @param data
     *            a HashMap containing the data for one example. The names of
     *            the random variables (nodes) for which the domains are to be
     *            learnt must be found in the set of keys of the hash map.
     * @throws Exception
     *             if required keys are missing from the HashMap
     */
    public void learn(Map<String, String> data) throws Exception {
        int numDirectDomains = directDomains != null ? directDomains.length : 0;
        int numClusteredDomains = clusteredDomains != null ? clusteredDomains.length : 0;

        // for direct learning, add outcomes to the set of outcomes
        for (int i = 0; i < numDirectDomains; i++) {
            String val = data.get(directDomains[i]);
            if (val == null)
                throw new Exception("Key " + clusteredDomains[i].nodeName + " not found in data!");
            directDomainData.get(i).add(val);
        }
        // for clustering, gather all instances
        for (int i = 0; i < numClusteredDomains; i++) {
            Instance inst = new Instance(1);
            String val = data.get(clusteredDomains[i].nodeName);
            if (val == null) {
                throw new Exception("Key " + clusteredDomains[i].nodeName + " not found in data!");
            }
            inst.setValue(attrValue, Double.parseDouble(val));
            clusterData[i].add(inst);
        }
    }

    /**
     * learns all the examples in a fipm.data.QueryResult (otherwise analogous
     * to learn(ResultSet))
     * 
     * @param res
     *            the query result containing the data for a set of examples
     * @throws Exception
     */
    /*
    public void learn(QueryResult res) throws Exception {
       int numDirectDomains = directDomains != null ? directDomains.length : 0;
       int numClusteredDomains = clusteredDomains != null ? clusteredDomains.length
        : 0;
        
       // get column indices
       Vector colnames = res.getColumnNames();
       int[] colIdx_cd = new int[numClusteredDomains];
       int[] colIdx_dd = new int[numDirectDomains];
       for (int i = 0; i < numClusteredDomains; i++) {
     colIdx_cd[i] = colnames.indexOf(clusteredDomains[i].nodeName);
     if (colIdx_cd[i] == -1)
        throw new Exception("Node/column "
              + clusteredDomains[i].nodeName
              + " was not found in result set");
       }
       for (int i = 0; i < numDirectDomains; i++) {
     colIdx_dd[i] = colnames.indexOf(directDomains[i]);
     if (colIdx_dd[i] == -1)
        throw new Exception("Node/column " + directDomains[i]
              + " was not found in result set");
       }
        
       // gather data
       for (int i = 0; i < res.getRowCount(); i++) {
     Vector<Object> row = new Vector<Object>();
     for(Object r:res.getRow(i))
           row.add(r);
     // for direct learning, add outcomes to the set of outcomes
     for (int j = 0; j < numDirectDomains; j++) {
        ((HashSet<String>) directDomainData[j]).add((String) row
              .get(colIdx_dd[j]));
     }
     // for clustering, gather all instances
     for (int j = 0; j < numClusteredDomains; j++) {
        Instance inst = new Instance(1);
        double value = Double.parseDouble((String) row
              .get(colIdx_cd[j]));
        inst.setValue(attrValue, value);
        clusterData[j].add(inst);
     }
       }
    }
    */

    /**
     * performs the clustering (if some domains are to be learnt by clustering)
     * and applies all the new domains. (This method is called by finish(),
     * which should be called when all the examples have been passed.)
     */
    protected void end_learning() throws Exception {
        if (directDomains != null)
            for (int i = 0; i < directDomains.length; i++) {
                if (verbose)
                    System.out.println(directDomains[i]);
                HashSet<String> hs = directDomainData.get(i);
                Discrete domain = new Discrete();
                for (Iterator<String> iter = hs.iterator(); iter.hasNext();)
                    domain.addName(iter.next());
                BeliefNode node = directDomains[i];
                if (node == null) {
                    System.out.println(
                            "No node with name '" + directDomains[i] + "' found to learn direct domain for.");
                }
                //System.out.println("DomainLearner: applying domain " + hs + " to " + node.getName());
                bn.bn.changeBeliefNodeDomain(node, domain);
            }
        if (clusteredDomains != null)
            for (int i = 0; i < clusteredDomains.length; i++) {
                if (verbose)
                    System.out.println(clusteredDomains[i].nodeName);
                try {
                    // perform clustering
                    clusterers[i] = new SimpleKMeans();
                    if (clusteredDomains[i].numClusters != 0)
                        clusterers[i].setNumClusters(clusteredDomains[i].numClusters);
                    clusterers[i].buildClusterer(clusterData[i]);
                    // update domain
                    bn.bn.changeBeliefNodeDomain(bn.getNode(clusteredDomains[i].nodeName),
                            new Discretized(clusterers[i], clusterNamer));
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        if (duplicateDomains != null) {
            for (int i = 0; i < duplicateDomains.length; i++) {
                Domain srcDomain = bn.getDomain(duplicateDomains[i][0]);
                for (int j = 1; j < duplicateDomains[i].length; j++) {
                    if (verbose)
                        System.out.println(duplicateDomains[i][j]);
                    bn.bn.changeBeliefNodeDomain(bn.getNode(duplicateDomains[i][j]), srcDomain);
                }
            }
        }
    }

    /**
     * returns the array clusterers for all the nodes whose domains were to be
     * learned by clustering
     * 
     * @return the array of clusterers. It is ordered according to the array of
     *         "clustered domains" that was passed at construction.
     * @throws Exception 
     */
    public SimpleKMeans[] getClusterers() throws Exception {
        finish(); // make sure learning is completed
        return clusterers;
    }

    /**
     * sorts the domains that were learned via clustering in ascending order of
     * cluster centroid. Attention: Do not use the clusterers returned by
     * getClusterers() after this function has been called because the indices
     * returned by clusterInstance will otherwise be wrong! In particular, never
     * conduct CPT-learning after calling this function. (You may, of course,
     * call this function after the CPT-learning has been completed.)
     */
    public void sortClusteredDomains() {
        // process all nodes whose domains were subject to clustering
        for (int i = 0; i < clusteredDomains.length; i++) {
            BeliefNode node = bn.getNode(clusteredDomains[i].nodeName);
            sortClusteredDomain(node, clusterers[i]);
        }
        // 
        if (duplicateDomains != null) {
            for (int i = 0; i < duplicateDomains.length; i++)
                for (int j = 0; j < clusteredDomains.length; j++)
                    if (duplicateDomains[i][0].equals(clusteredDomains[j].nodeName)) {
                        for (int k = 1; k < duplicateDomains[i].length; k++)
                            sortClusteredDomain(bn.getNode(duplicateDomains[i][k]), clusterers[j]);
                        break;
                    }
        }
    }

    /**
     * sorts the domain of the given node, for which the given clusterer has
     * been learnt, in ascending order of cluster centroid
     * 
     * @param node
     * @param clusterer
     */
    protected void sortClusteredDomain(BeliefNode node, SimpleKMeans clusterer) {
        // get domain sort order (sort by centroid, ascending),
        // i.e. get an unsorted and a sorted version of
        // the centroids array
        int numClusters = clusterer.getNumClusters();
        double[] values = clusterer.getClusterCentroids().attributeToDoubleArray(0);
        double[] sorted_values = (double[]) values.clone();
        Arrays.sort(sorted_values);
        // create new sorted domain
        Discrete domain = (Discrete) node.getDomain();
        Discrete sorted_domain = new Discrete();
        for (int new_idx = 0; new_idx < numClusters; new_idx++) {
            for (int old_idx = 0; old_idx < numClusters; old_idx++)
                if (values[old_idx] == sorted_values[new_idx])
                    sorted_domain.addName(domain.getName(old_idx));
        }
        // apply new, sorted domain
        bn.bn.changeBeliefNodeDomain(node, sorted_domain);
    }
}