org.apache.ctakes.ytex.kernel.ImputedFeatureEvaluatorImpl.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.ctakes.ytex.kernel.ImputedFeatureEvaluatorImpl.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.ctakes.ytex.kernel;

import java.io.IOException;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Properties;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;

import javax.sql.DataSource;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.ctakes.ytex.kernel.dao.ClassifierEvaluationDao;
import org.apache.ctakes.ytex.kernel.dao.ConceptDao;
import org.apache.ctakes.ytex.kernel.model.ConcRel;
import org.apache.ctakes.ytex.kernel.model.ConceptGraph;
import org.apache.ctakes.ytex.kernel.model.CrossValidationFold;
import org.apache.ctakes.ytex.kernel.model.FeatureEvaluation;
import org.apache.ctakes.ytex.kernel.model.FeatureParentChild;
import org.apache.ctakes.ytex.kernel.model.FeatureRank;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowCallbackHandler;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.transaction.PlatformTransactionManager;

import weka.core.ContingencyTables;

/**
 * Calculate the mutual information of each concept of a corpus wrt a concept
 * graph and classification task (label) and possibly a fold. We calculate the
 * following:
 * <ul>
 * <li>raw mutual information of each concept (infogain). We calculate the joint
 * distribution of concepts (X) and document classes (Y), and compute the mutual
 * information for each concept.
 * <li>mutual information inherited by parents (infogain-parent). For each
 * concept in the concept graph, we merge the joint distribution of child
 * concepts. This is done recursively.
 * <li>mutual information inherited by children from parents (infogain-child).
 * We take the top n concepts and assign their children (entire subgraph) the
 * mutual info of the parent.
 * </ul>
 * <p>
 * The mutual information of each concept is stored in the feature_rank table.
 * The related records in the feature_eval table have the following values:
 * <ul>
 * <li>type = infogain, infogain-parent, infogain-imputed, infogain-imputed-filt
 * <li>feature_set_name = conceptSetName
 * <li>param1 = conceptGraphName
 * </ul>
 * 
 * How this works in broad strokes:
 * <ul>
 * <li> {@link #evaluateCorpus(Parameters)} load instances, iterate through
 * labels
 * <li>
 * {@link #evaluateCorpusLabel(Parameters, ConceptGraph, InstanceData, String)}
 * load concept - set[document] map for the specified label, iterate through
 * folds
 * <li>
 * {@link #evaluateCorpusFold(Parameters, Map, ConceptGraph, InstanceData, String, Map, int)}
 * create raw joint distribution of each concept, compute parent joint
 * distributions, assign children mutual info of parents
 * <li> {@link #completeJointDistroForFold(Map, Map, Set, Set, String)} computes
 * raw joint distribution of each concept
 * <li>
 * {@link #propagateJointDistribution(Map, Parameters, String, int, ConceptGraph, Map)}
 * recursively compute parent joint distribution by merging joint distro of
 * children.
 * <li>{@link #storeChildConcepts(Parameters, String, int, ConceptGraph)} take
 * top ranked parent concepts, assign concepts in subtrees the mutual info of
 * parents. Only concepts that exist in the corpus are added (depends on
 * computing the infocontent of concepts with CorpusEvaluator)
 * </ul>
 * 
 * 
 * @author vijay
 * 
 */
public class ImputedFeatureEvaluatorImpl implements ImputedFeatureEvaluator {

    /**
     * fill in map of Concept Id - bin - instance ids
     * 
     * @author vijay
     * 
     */
    public class ConceptInstanceMapExtractor implements RowCallbackHandler {
        ConceptGraph cg;
        Map<String, Map<String, Set<Long>>> conceptInstanceMap;

        ConceptInstanceMapExtractor(Map<String, Map<String, Set<Long>>> conceptInstanceMap, ConceptGraph cg) {
            this.cg = cg;
            this.conceptInstanceMap = conceptInstanceMap;
        }

        public void processRow(ResultSet rs) throws SQLException {
            String conceptId = rs.getString(1);
            long instanceId = rs.getLong(2);
            String x = rs.getString(3);
            Map<String, Set<Long>> binInstanceMap = conceptInstanceMap.get(conceptId);
            if (binInstanceMap == null) {
                // use the conceptId from the concept to save memory
                binInstanceMap = new HashMap<String, Set<Long>>(2);
                conceptInstanceMap.put(conceptId, binInstanceMap);
            }
            Set<Long> instanceIds = binInstanceMap.get(x);
            if (instanceIds == null) {
                instanceIds = new HashSet<Long>();
                binInstanceMap.put(x, instanceIds);
            }
            instanceIds.add(instanceId);
        }
    }

    /**
     * joint distribution of concept (x) and class (y). The bins for x and y are
     * predetermined. Typical levels for x are 0/1 (absent/present) and -1/0/1
     * (negated/not present/affirmed).
     * 
     * @author vijay
     * 
     */
    public static class JointDistribution {
        /**
         * merge joint distributions into a single distribution. For each value
         * of Y, the cells for each X bin, except for the xMerge bin, are the
         * intersection of all the instances in each of the corresponding bins.
         * The xMerge bin gets everything that is leftover.
         * 
         * @param jointDistros
         *            list of joint distribution tables to merge
         * @param yMargin
         *            map of y val - instance id. this could be calculated on
         *            the fly, but we have this information already.
         * @param xMerge
         *            the x val that contains everything that doesn't land in
         *            any of the other bins.
         * @return
         */
        public static JointDistribution merge(List<JointDistribution> jointDistros, Map<String, Set<Long>> yMargin,
                String xMerge) {
            Set<String> xVals = jointDistros.get(0).xVals;
            Set<String> yVals = jointDistros.get(0).yVals;
            JointDistribution mergedDistro = new JointDistribution(xVals, yVals);
            for (String y : yVals) {
                // intersect all bins besides the merge bin
                Set<Long> xMergedInst = mergedDistro.getInstances(xMerge, y);
                // everything comes into the merge bin
                // we take out things that land in other bins
                xMergedInst.addAll(yMargin.get(y));
                // iterate over other bins
                for (String x : xVals) {
                    if (!x.equals(xMerge)) {
                        Set<Long> intersectIds = mergedDistro.getInstances(x, y);
                        boolean bFirstIter = true;
                        // iterate over all joint distribution tables
                        for (JointDistribution distro : jointDistros) {
                            if (bFirstIter) {
                                // 1st iter - add all
                                intersectIds.addAll(distro.getInstances(x, y));
                                bFirstIter = false;
                            } else {
                                // subsequent iteration - intersect
                                intersectIds.retainAll(distro.getInstances(x, y));
                            }
                        }
                        // remove from the merge bin
                        xMergedInst.removeAll(intersectIds);
                    }
                }
            }
            return mergedDistro;
        }

        protected double[][] contingencyTable;
        /**
         * the entropy of X. Calculated once and returned as needed.
         */
        protected Double entropyX = null;
        /**
         * the entropy of X*Y. Calculated once and returned as needed.
         */
        protected Double entropyXY = null;
        /**
         * A y*x table where the cells hold the instance ids. We use the
         * instance ids instead of counts so we can merge the tables.
         */
        protected SortedMap<String, SortedMap<String, Set<Long>>> jointDistroTable;
        /**
         * the possible values of X (e.g. concept)
         */
        protected Set<String> xVals;

        /**
         * the possible values of Y (e.g. text)
         */
        protected Set<String> yVals;

        /**
         * set up the joint distribution table.
         * 
         * @param xVals
         *            the possible x values (bins)
         * @param yVals
         *            the possible y values (bins)
         */
        public JointDistribution(Set<String> xVals, Set<String> yVals) {
            this.xVals = xVals;
            this.yVals = yVals;
            jointDistroTable = new TreeMap<String, SortedMap<String, Set<Long>>>();
            for (String yVal : yVals) {
                SortedMap<String, Set<Long>> yMap = new TreeMap<String, Set<Long>>();
                jointDistroTable.put(yVal, yMap);
                for (String xVal : xVals) {
                    yMap.put(xVal, new HashSet<Long>());
                }
            }
        }

        public JointDistribution(Set<String> xVals, Set<String> yVals, Map<String, Set<Long>> xMargin,
                Map<String, Set<Long>> yMargin, String xLeftover) {
            this.xVals = xVals;
            this.yVals = yVals;
            jointDistroTable = new TreeMap<String, SortedMap<String, Set<Long>>>();
            for (String yVal : yVals) {
                SortedMap<String, Set<Long>> yMap = new TreeMap<String, Set<Long>>();
                jointDistroTable.put(yVal, yMap);
                for (String xVal : xVals) {
                    yMap.put(xVal, new HashSet<Long>());
                }
            }
            for (Map.Entry<String, Set<Long>> yEntry : yMargin.entrySet()) {
                // iterate over 'rows' i.e. the class names
                String yName = yEntry.getKey();
                Set<Long> yInst = new HashSet<Long>(yEntry.getValue());
                // iterate over 'columns' i.e. the values of x
                for (Map.Entry<String, Set<Long>> xEntry : xMargin.entrySet()) {
                    // copy the instances
                    Set<Long> foldXInst = jointDistroTable.get(yName).get(xEntry.getKey());
                    foldXInst.addAll(xEntry.getValue());
                    // keep only the ones that are in this fold
                    foldXInst.retainAll(yInst);
                    // remove the instances for this value of x from the set of
                    // all instances
                    yInst.removeAll(foldXInst);
                }
                if (yInst.size() > 0) {
                    // add the leftovers to the leftover bin
                    jointDistroTable.get(yEntry.getKey()).get(xLeftover).addAll(yInst);
                }
            }

        }

        // /**
        // * add an instance to the joint probability table
        // *
        // * @param x
        // * @param y
        // * @param instanceId
        // */
        // public void addInstance(String x, String y, int instanceId) {
        // // add the current row to the bin matrix
        // SortedMap<String, Set<Integer>> xMap = jointDistroTable.get(y);
        // if (xMap == null) {
        // xMap = new TreeMap<String, Set<Integer>>();
        // jointDistroTable.put(y, xMap);
        // }
        // Set<Integer> instanceSet = xMap.get(x);
        // if (instanceSet == null) {
        // instanceSet = new HashSet<Integer>();
        // xMap.put(x, instanceSet);
        // }
        // instanceSet.add(instanceId);
        // }

        // /**
        // * finalize the joint probability table wrt the specified instances.
        // If
        // * we are doing this per fold, then not all instances are going to be
        // in
        // * each fold. Limit to the instances in the specified fold.
        // * <p>
        // * Also, we might not have filled in all the cells. if necessary, put
        // * instances in the 'leftover' cell, fill it in based on the marginal
        // * distribution of the instances wrt classes.
        // *
        // * @param yMargin
        // * map of values of y to the instances with that value
        // * @param xLeftover
        // * the value of x to assign the the leftover instances
        // */
        // public JointDistribution complete(Map<String, Set<Integer>> xMargin,
        // Map<String, Set<Integer>> yMargin, String xLeftover) {
        // JointDistribution foldDistro = new JointDistribution(this.xVals,
        // this.yVals);
        // for (Map.Entry<String, Set<Integer>> yEntry : yMargin.entrySet()) {
        // // iterate over 'rows' i.e. the class names
        // String yName = yEntry.getKey();
        // Set<Integer> yInst = new HashSet<Integer>(yEntry.getValue());
        // // iterate over 'columns' i.e. the values of x
        // for (Map.Entry<String, Set<Integer>> xEntry : this.jointDistroTable
        // .get(yName).entrySet()) {
        // // copy the instances
        // Set<Integer> foldXInst = foldDistro.jointDistroTable.get(
        // yName).get(xEntry.getKey());
        // foldXInst.addAll(xEntry.getValue());
        // // keep only the ones that are in this fold
        // foldXInst.retainAll(yInst);
        // // remove the instances for this value of x from the set of
        // // all instances
        // yInst.removeAll(foldXInst);
        // }
        // if (yInst.size() > 0) {
        // // add the leftovers to the leftover bin
        // foldDistro.jointDistroTable.get(yEntry.getKey())
        // .get(xLeftover).addAll(yInst);
        // }
        // }
        // return foldDistro;
        // }

        public double[][] getContingencyTable() {
            if (contingencyTable == null) {
                contingencyTable = new double[this.yVals.size()][this.xVals.size()];
                int i = 0;
                for (String yVal : yVals) {
                    int j = 0;
                    for (String xVal : xVals) {
                        contingencyTable[i][j] = jointDistroTable.get(yVal).get(xVal).size();
                        j++;
                    }
                    i++;
                }
            }
            return contingencyTable;
        }

        public double getEntropyX() {
            double probs[] = new double[xVals.size()];
            Arrays.fill(probs, 0d);
            if (entropyX == null) {
                double nTotal = 0;
                for (Map<String, Set<Long>> xInstance : this.jointDistroTable.values()) {
                    int i = 0;
                    for (Set<Long> instances : xInstance.values()) {
                        double nCell = (double) instances.size();
                        nTotal += nCell;
                        probs[i] += nCell;
                        i++;
                    }
                }
                for (int i = 0; i < probs.length; i++)
                    probs[i] /= nTotal;
                entropyX = entropy(probs);
            }
            return entropyX;
        }

        public double getEntropyXY() {
            double probs[] = new double[xVals.size() * yVals.size()];
            Arrays.fill(probs, 0d);
            if (entropyXY == null) {
                double nTotal = 0;
                int i = 0;
                for (Map<String, Set<Long>> xInstance : this.jointDistroTable.values()) {
                    for (Set<Long> instances : xInstance.values()) {
                        probs[i] = (double) instances.size();
                        nTotal += probs[i];
                        i++;
                    }
                }
                for (int j = 0; j < probs.length; j++)
                    probs[j] /= nTotal;
                entropyXY = entropy(probs);
            }
            return entropyXY;
        }

        public double getInfoGain() {
            return ContingencyTables.entropyOverColumns(getContingencyTable())
                    - ContingencyTables.entropyConditionedOnRows(getContingencyTable());
        }

        public Set<Long> getInstances(String x, String y) {
            return jointDistroTable.get(y).get(x);
        }

        public double getMutualInformation(double entropyY) {
            return entropyY + this.getEntropyX() - this.getEntropyXY();
        }

        /**
         * print out joint distribution table
         */
        public String toString() {
            StringBuilder b = new StringBuilder();
            b.append(this.getClass().getCanonicalName());
            b.append(" [jointDistro=(");
            Iterator<Entry<String, SortedMap<String, Set<Long>>>> yIter = this.jointDistroTable.entrySet()
                    .iterator();
            while (yIter.hasNext()) {
                Entry<String, SortedMap<String, Set<Long>>> yEntry = yIter.next();
                Iterator<Entry<String, Set<Long>>> xIter = yEntry.getValue().entrySet().iterator();
                while (xIter.hasNext()) {
                    Entry<String, Set<Long>> xEntry = xIter.next();
                    b.append(xEntry.getValue().size());
                    if (xIter.hasNext())
                        b.append(", ");
                }
                if (yIter.hasNext())
                    b.append("| ");
            }
            b.append(")]");
            return b.toString();
        }
    }

    /**
     * We are passing around quite a few parameters. It gets to be a pain, so
     * put everything in an object.
     * 
     * @author vijay
     * 
     */
    public static class Parameters {
        String classFeatureQuery;
        String conceptGraphName;
        String conceptSetName;
        String corpusName;
        String freqQuery;
        double imputeWeight;

        String labelQuery;
        MeasureType measure;
        double minInfo;
        Double parentConceptEvalThreshold;
        Integer parentConceptTopThreshold;
        String splitName;
        String xLeftover;
        String xMerge;
        Set<String> xVals;

        public Parameters() {

        }

        public Parameters(Properties props) {
            corpusName = props.getProperty("org.apache.ctakes.ytex.corpusName");
            conceptGraphName = props.getProperty("org.apache.ctakes.ytex.conceptGraphName");
            conceptSetName = props.getProperty("org.apache.ctakes.ytex.conceptSetName");
            splitName = props.getProperty("org.apache.ctakes.ytex.splitName");
            labelQuery = props.getProperty("instanceClassQuery");
            classFeatureQuery = props.getProperty("org.apache.ctakes.ytex.conceptInstanceQuery");
            freqQuery = props.getProperty("org.apache.ctakes.ytex.freqQuery");
            minInfo = Double.parseDouble(props.getProperty("min.info", "1e-4"));
            String xValStr = props.getProperty("org.apache.ctakes.ytex.xVals", "0,1");
            xVals = new HashSet<String>();
            xVals.addAll(Arrays.asList(xValStr.split(",")));
            xLeftover = props.getProperty("org.apache.ctakes.ytex.xLeftover", "0");
            xMerge = props.getProperty("org.apache.ctakes.ytex.xMerge", "1");
            this.measure = MeasureType.valueOf(props.getProperty("org.apache.ctakes.ytex.measure", "INFOGAIN"));
            parentConceptEvalThreshold = FileUtil.getDoubleProperty(props,
                    "org.apache.ctakes.ytex.parentConceptEvalThreshold", null);
            parentConceptTopThreshold = parentConceptEvalThreshold == null
                    ? FileUtil.getIntegerProperty(props, "org.apache.ctakes.ytex.parentConceptTopThreshold", 25)
                    : null;
            imputeWeight = FileUtil.getDoubleProperty(props, "org.apache.ctakes.ytex.imputeWeight", 1d);
        }

        public String getClassFeatureQuery() {
            return classFeatureQuery;
        }

        public String getConceptGraphName() {
            return conceptGraphName;
        }

        public String getConceptSetName() {
            return conceptSetName;
        }

        public String getCorpusName() {
            return corpusName;
        }

        public String getFreqQuery() {
            return freqQuery;
        }

        public double getImputeWeight() {
            return imputeWeight;
        }

        public String getLabelQuery() {
            return labelQuery;
        }

        public MeasureType getMeasure() {
            return measure;
        }

        public double getMinInfo() {
            return minInfo;
        }

        public Double getParentConceptEvalThreshold() {
            return parentConceptEvalThreshold;
        }

        public Integer getParentConceptTopThreshold() {
            return parentConceptTopThreshold;
        }

        public String getSplitName() {
            return splitName;
        }

        public String getxLeftover() {
            return xLeftover;
        }

        public String getxMerge() {
            return xMerge;
        }

        public Set<String> getxVals() {
            return xVals;
        }
    }

    // /**
    // * iterates through query results and computes infogain
    // *
    // * @author vijay
    // *
    // */
    // public class JointDistroExtractor implements RowCallbackHandler {
    // /**
    // * key - fold
    // * <p/>
    // * value - map of concept id - joint distribution
    // */
    // private Map<String, JointDistribution> jointDistroMap;
    // private Set<String> xVals;
    // private Set<String> yVals;
    // private Map<Integer, String> instanceClassMap;
    //
    // public JointDistroExtractor(
    // Map<String, JointDistribution> jointDistroMap,
    // Set<String> xVals, Set<String> yVals,
    // Map<Integer, String> instanceClassMap) {
    // super();
    // this.xVals = xVals;
    // this.yVals = yVals;
    // this.jointDistroMap = jointDistroMap;
    // this.instanceClassMap = instanceClassMap;
    // }
    //
    // public void processRow(ResultSet rs) throws SQLException {
    // int instanceId = rs.getInt(1);
    // String conceptId = rs.getString(2);
    // String x = rs.getString(3);
    // String y = instanceClassMap.get(instanceId);
    // JointDistribution distro = jointDistroMap.get(conceptId);
    // if (distro == null) {
    // distro = new JointDistribution(xVals, yVals);
    // jointDistroMap.put(conceptId, distro);
    // }
    // distro.addInstance(x, y, instanceId);
    // }
    // }

    private static final Log log = LogFactory.getLog(ImputedFeatureEvaluatorImpl.class);

    protected static double entropy(double[] classProbs) {
        double entropy = 0;
        double log2 = Math.log(2);
        for (double prob : classProbs) {
            if (prob > 0)
                entropy += prob * Math.log(prob) / log2;
        }
        return entropy * -1;
    }

    /**
     * calculate entropy from a list/array of probabilities
     * 
     * @param classProbs
     * @return
     */
    protected static double entropy(Iterable<Double> classProbs) {
        double entropy = 0;
        double log2 = Math.log(2);
        for (double prob : classProbs) {
            if (prob > 0)
                entropy += prob * Math.log(prob) / log2;
        }
        return entropy * -1;
    }

    @SuppressWarnings("static-access")
    public static void main(String args[]) throws ParseException, IOException {
        Options options = new Options();
        options.addOption(OptionBuilder.withArgName("prop").hasArg().isRequired()
                .withDescription("property file with queries and other parameters. todo desc").create("prop"));
        try {
            CommandLineParser parser = new GnuParser();
            CommandLine line = parser.parse(options, args);
            if (!KernelContextHolder.getApplicationContext().getBean(ImputedFeatureEvaluator.class)
                    .evaluateCorpus(line.getOptionValue("prop"))) {
                printHelp(options);
            }
        } catch (ParseException pe) {
            printHelp(options);
        }
    }

    private static void printHelp(Options options) {
        HelpFormatter formatter = new HelpFormatter();
        formatter.printHelp("java " + ImputedFeatureEvaluatorImpl.class.getName()
                + " calculate raw, propagated, and imputed infogain for each feature", options);
    }

    protected ClassifierEvaluationDao classifierEvaluationDao;

    protected ConceptDao conceptDao;

    private InfoContentEvaluator infoContentEvaluator;

    protected JdbcTemplate jdbcTemplate;

    protected KernelUtil kernelUtil;
    protected NamedParameterJdbcTemplate namedParamJdbcTemplate;

    protected PlatformTransactionManager transactionManager;

    private Properties ytexProperties = null;

    /**
     * recursively add children of cr to childConcepts
     * 
     * @param childConcepts
     * @param cr
     */
    private void addSubtree(Set<String> childConcepts, ConcRel cr) {
        childConcepts.add(cr.getConceptID());
        for (ConcRel crc : cr.getChildren()) {
            addSubtree(childConcepts, crc);
        }
    }

    private JointDistribution calcMergedJointDistribution(Map<String, JointDistribution> conceptJointDistroMap,
            Map<String, Integer> conceptDistMap, ConcRel cr, Map<String, JointDistribution> rawJointDistroMap,
            Map<String, Set<Long>> yMargin, String xMerge, double minInfo, List<String> path) {
        if (conceptJointDistroMap.containsKey(cr.getConceptID())) {
            return conceptJointDistroMap.get(cr.getConceptID());
        } else {
            List<JointDistribution> distroList = new ArrayList<JointDistribution>(cr.getChildren().size() + 1);
            int distance = -1;
            // if this concept is in the raw joint distro map, add it to the
            // list of joint distributions to merge
            if (rawJointDistroMap.containsKey(cr.getConceptID())) {
                JointDistribution rawJointDistro = rawJointDistroMap.get(cr.getConceptID());
                distroList.add(rawJointDistro);
                distance = 0;
            }
            // get the joint distributions of children
            for (ConcRel crc : cr.getChildren()) {
                List<String> pathChild = new ArrayList<String>(path.size() + 1);
                pathChild.addAll(path);
                pathChild.add(crc.getConceptID());
                // recurse - get joint distribution of children
                JointDistribution jdChild = calcMergedJointDistribution(conceptJointDistroMap, conceptDistMap, crc,
                        rawJointDistroMap, yMargin, xMerge, minInfo, pathChild);
                if (jdChild != null) {
                    distroList.add(jdChild);
                    if (distance != 0) {
                        // look at children's distance from raw data, add 1
                        int distChild = conceptDistMap.get(crc.getConceptID());
                        if (distance == -1 || (distChild + 1) < distance) {
                            distance = distChild + 1;
                        }
                    }
                }
            }
            // merge the joint distributions
            JointDistribution mergedDistro;
            if (distroList.size() > 0) {
                if (distroList.size() == 1) {
                    // only one joint distro - trivial merge
                    mergedDistro = distroList.get(0);
                } else {
                    // multiple joint distros - merge them into a new one
                    mergedDistro = JointDistribution.merge(distroList, yMargin, xMerge);
                }
                // if (log.isDebugEnabled()) {
                // log.debug("path = " + path + ", distroList = " + distroList
                // + ", distro = " + mergedDistro);
                // }
            } else {
                // no joint distros to merge - null
                mergedDistro = null;
            }
            // save this in the map
            conceptJointDistroMap.put(cr.getConceptID(), mergedDistro);
            if (distance > -1)
                conceptDistMap.put(cr.getConceptID(), distance);
            return mergedDistro;
        }
    }

    /**
     * 
     */
    private double calculateFoldEntropy(Map<String, Set<Long>> classCountMap) {
        int total = 0;
        List<Double> classProbs = new ArrayList<Double>(classCountMap.size());
        // calculate total number of instances in this fold
        for (Set<Long> instances : classCountMap.values()) {
            total += instances.size();
        }
        // calculate per-class probability in this fold
        for (Set<Long> instances : classCountMap.values()) {
            classProbs.add((double) instances.size() / (double) total);
        }
        return entropy(classProbs);
    }

    /**
     * finalize the joint distribution tables wrt a fold.
     * 
     * @param jointDistroMap
     * @param yMargin
     * @param yVals
     * @param xVals
     * @param xLeftover
     */
    private Map<String, JointDistribution> completeJointDistroForFold(
            Map<String, Map<String, Set<Long>>> conceptInstanceMap, Map<String, Set<Long>> yMargin,
            Set<String> xVals, Set<String> yVals, String xLeftover) {
        //
        Map<String, JointDistribution> foldJointDistroMap = new HashMap<String, JointDistribution>(
                conceptInstanceMap.size());
        for (Map.Entry<String, Map<String, Set<Long>>> conceptInstance : conceptInstanceMap.entrySet()) {
            foldJointDistroMap.put(conceptInstance.getKey(),
                    new JointDistribution(xVals, yVals, conceptInstance.getValue(), yMargin, xLeftover));
        }
        return foldJointDistroMap;
    }

    /**
     * delete the feature evaluations before we insert them
     * 
     * @param params
     * @param label
     * @param foldId
     */
    private void deleteFeatureEval(Parameters params, String label, int foldId) {

        for (String type : new String[] { params.getMeasure().getName(),
                params.getMeasure().getName() + SUFFIX_PROP, params.getMeasure().getName() + SUFFIX_IMPUTED,
                params.getMeasure().getName() + SUFFIX_IMPUTED_FILTERED })
            this.classifierEvaluationDao.deleteFeatureEvaluation(params.getCorpusName(), params.getConceptSetName(),
                    label, type, foldId, 0d, params.getConceptGraphName());
    }

    /*
     * (non-Javadoc)
     * 
     * @see org.apache.ctakes.ytex.kernel.CorpusLabelEvaluator#evaluateCorpus(java.lang.String,
     * java.lang.String, java.lang.String, java.lang.String, java.lang.String,
     * java.lang.Double, java.util.Set, java.lang.String, java.lang.String)
     */
    @Override
    public boolean evaluateCorpus(final Parameters params) {
        if (!(params.getCorpusName() != null && params.getConceptGraphName() != null
                && params.getLabelQuery() != null && params.getClassFeatureQuery() != null))
            return false;
        ConceptGraph cg = conceptDao.getConceptGraph(params.getConceptGraphName());
        InstanceData instanceData = kernelUtil.loadInstances(params.getLabelQuery());
        for (String label : instanceData.getLabelToInstanceMap().keySet()) {
            evaluateCorpusLabel(params, cg, instanceData, label);
        }
        return true;
    }

    @Override
    public boolean evaluateCorpus(String propFile) throws IOException {
        Properties props = new Properties();
        // put org.apache.ctakes.ytex properties in props
        props.putAll(this.getYtexProperties());
        // override org.apache.ctakes.ytex properties with propfile
        props.putAll(FileUtil.loadProperties(propFile, true));
        return this.evaluateCorpus(new Parameters(props));
    }

    private void evaluateCorpusFold(Parameters params, Map<String, Set<Long>> yMargin, ConceptGraph cg,
            InstanceData instanceData, String label, Map<String, Map<String, Set<Long>>> conceptInstanceMap,
            int foldId) {
        if (log.isInfoEnabled())
            log.info("evaluateCorpusFold() label = " + label + ", fold = " + foldId);
        deleteFeatureEval(params, label, foldId);

        // get the entropy of Y for this fold
        double yEntropy = this.calculateFoldEntropy(yMargin);
        // get the joint distribution of concepts and instances
        Map<String, JointDistribution> rawJointDistro = this.completeJointDistroForFold(conceptInstanceMap, yMargin,
                params.getxVals(), instanceData.getLabelToClassMap().get(label), params.getxLeftover());
        List<FeatureRank> listRawRanks = new ArrayList<FeatureRank>(rawJointDistro.size());
        FeatureEvaluation feRaw = saveFeatureEvaluation(rawJointDistro, params, label, foldId, yEntropy, "",
                listRawRanks);
        // propagate across graph and save
        propagateJointDistribution(rawJointDistro, params, label, foldId, cg, yMargin);
        // store children of top concepts
        storeChildConcepts(listRawRanks, params, label, foldId, cg, true);
        storeChildConcepts(listRawRanks, params, label, foldId, cg, false);
    }

    /**
     * evaluate corpus on label
     * 
     * @param classFeatureQuery
     * @param minInfo
     * @param xVals
     * @param xLeftover
     * @param xMerge
     * @param eval
     * @param cg
     * @param instanceData
     * @param label
     * @param parentConceptTopThreshold
     * @param parentConceptEvalThreshold
     */
    private void evaluateCorpusLabel(Parameters params, ConceptGraph cg, InstanceData instanceData, String label) {
        if (log.isInfoEnabled())
            log.info("evaluateCorpusLabel() label = " + label);
        Map<String, Map<String, Set<Long>>> conceptInstanceMap = loadConceptInstanceMap(
                params.getClassFeatureQuery(), cg, label);
        for (int run : instanceData.getLabelToInstanceMap().get(label).keySet()) {
            for (int fold : instanceData.getLabelToInstanceMap().get(label).get(run).keySet()) {
                int foldId = this.getFoldId(params, label, run, fold);
                // evaluate for the specified fold training set
                // construct map of class - [instance ids]
                Map<String, Set<Long>> yMargin = getFoldYMargin(instanceData, label, run, fold);
                evaluateCorpusFold(params, yMargin, cg, instanceData, label, conceptInstanceMap, foldId);
            }
        }
    }

    public ClassifierEvaluationDao getClassifierEvaluationDao() {
        return classifierEvaluationDao;
    }

    public ConceptDao getConceptDao() {
        return conceptDao;
    }

    public DataSource getDataSource(DataSource ds) {
        return this.jdbcTemplate.getDataSource();
    }

    private int getFoldId(Parameters params, String label, int run, int fold) {
        // figure out fold id
        int foldId = 0;
        if (run > 0 && fold > 0) {
            CrossValidationFold cvFold = this.classifierEvaluationDao.getCrossValidationFold(params.getCorpusName(),
                    params.getSplitName(), label, run, fold);
            if (cvFold != null) {
                foldId = cvFold.getCrossValidationFoldId();
            } else {
                log.warn("could not find cv fold, name=" + params.getCorpusName() + ", run=" + run + ", fold="
                        + fold);
            }
        }
        return foldId;
    }

    private Map<String, Set<Long>> getFoldYMargin(InstanceData instanceData, String label, int run, int fold) {
        Map<Long, String> instanceClassMap = instanceData.getLabelToInstanceMap().get(label).get(run).get(fold)
                .get(true);
        Map<String, Set<Long>> yMargin = new HashMap<String, Set<Long>>();
        for (Map.Entry<Long, String> instanceClass : instanceClassMap.entrySet()) {
            Set<Long> instanceIds = yMargin.get(instanceClass.getValue());
            if (instanceIds == null) {
                instanceIds = new HashSet<Long>();
                yMargin.put(instanceClass.getValue(), instanceIds);
            }
            instanceIds.add(instanceClass.getKey());
        }
        return yMargin;
    }

    public InfoContentEvaluator getInfoContentEvaluator() {
        return infoContentEvaluator;
    }

    public KernelUtil getKernelUtil() {
        return kernelUtil;
    }

    public Properties getYtexProperties() {
        return ytexProperties;
    }

    private FeatureEvaluation initFeatureEval(Parameters params, String label, int foldId, String type) {
        FeatureEvaluation feval = new FeatureEvaluation();
        feval.setCorpusName(params.getCorpusName());
        feval.setLabel(label);
        feval.setCrossValidationFoldId(foldId);
        feval.setParam2(params.getConceptGraphName());
        feval.setEvaluationType(type);
        feval.setFeatureSetName(params.getConceptSetName());
        return feval;
    }

    /**
     * load the map of concept - instances
     * 
     * @param classFeatureQuery
     * @param cg
     * @param label
     * @return
     */
    private Map<String, Map<String, Set<Long>>> loadConceptInstanceMap(String classFeatureQuery, ConceptGraph cg,
            String label) {
        Map<String, Map<String, Set<Long>>> conceptInstanceMap = new HashMap<String, Map<String, Set<Long>>>();
        Map<String, Object> args = new HashMap<String, Object>(1);
        if (label != null && label.length() > 0) {
            args.put("label", label);
        }
        ConceptInstanceMapExtractor ex = new ConceptInstanceMapExtractor(conceptInstanceMap, cg);
        this.namedParamJdbcTemplate.query(classFeatureQuery, args, ex);
        return conceptInstanceMap;
    }

    /**
     * 'complete' the joint distribution tables wrt a fold (yMargin). propagate
     * the joint distribution of all concepts recursively.
     * 
     * @param rawJointDistroMap
     * @param labelEval
     * @param cg
     * @param yMargin
     * @param xMerge
     * @param minInfo
     */
    private FeatureEvaluation propagateJointDistribution(Map<String, JointDistribution> rawJointDistroMap,
            Parameters params, String label, int foldId, ConceptGraph cg, Map<String, Set<Long>> yMargin) {
        // get the entropy of Y for this fold
        double yEntropy = this.calculateFoldEntropy(yMargin);
        // allocate a map to hold the results of the propagation across the
        // concept graph
        Map<String, JointDistribution> conceptJointDistroMap = new HashMap<String, JointDistribution>(
                cg.getConceptMap().size());
        Map<String, Integer> conceptDistMap = new HashMap<String, Integer>();
        // recurse
        calcMergedJointDistribution(conceptJointDistroMap, conceptDistMap, cg.getConceptMap().get(cg.getRoot()),
                rawJointDistroMap, yMargin, params.getxMerge(), params.getMinInfo(),
                Arrays.asList(new String[] { cg.getRoot() }));
        List<FeatureRank> listPropRanks = new ArrayList<FeatureRank>(conceptJointDistroMap.size());
        return this.saveFeatureEvaluation(conceptJointDistroMap, params, label, foldId, yEntropy, SUFFIX_PROP,
                listPropRanks);
    }

    private List<FeatureRank> rank(MeasureType measureType, FeatureEvaluation fe,
            Map<String, JointDistribution> rawJointDistro, double yEntropy, List<FeatureRank> featureRankList) {
        for (Map.Entry<String, JointDistribution> conceptJointDistro : rawJointDistro.entrySet()) {
            JointDistribution d = conceptJointDistro.getValue();
            if (d != null) {
                double evaluation;
                if (MeasureType.MUTUALINFO.equals(measureType)) {
                    evaluation = d.getMutualInformation(yEntropy);
                } else {
                    evaluation = d.getInfoGain();
                }
                if (evaluation > 1e-3) {
                    FeatureRank r = new FeatureRank(fe, conceptJointDistro.getKey(), evaluation);
                    featureRankList.add(r);
                }
            }
        }
        return FeatureRank.sortFeatureRankList(featureRankList, new FeatureRank.FeatureRankDesc());
    }

    private FeatureEvaluation saveFeatureEvaluation(Map<String, JointDistribution> rawJointDistro,
            Parameters params, String label, int foldId, double yEntropy, String suffix,
            List<FeatureRank> listRawRanks) {
        FeatureEvaluation fe = initFeatureEval(params, label, foldId, params.getMeasure().getName() + suffix);
        this.classifierEvaluationDao.saveFeatureEvaluation(fe,
                rank(params.getMeasure(), fe, rawJointDistro, yEntropy, listRawRanks));
        return fe;
    }

    public void setClassifierEvaluationDao(ClassifierEvaluationDao classifierEvaluationDao) {
        this.classifierEvaluationDao = classifierEvaluationDao;
    }

    public void setConceptDao(ConceptDao conceptDao) {
        this.conceptDao = conceptDao;
    }

    public void setDataSource(DataSource ds) {
        this.jdbcTemplate = new JdbcTemplate(ds);
        this.namedParamJdbcTemplate = new NamedParameterJdbcTemplate(ds);
    }

    // private CorpusLabelEvaluation initCorpusLabelEval(CorpusEvaluation eval,
    // String label, String splitName, int run, int fold) {
    // Integer foldId = getFoldId(eval, label, splitName, run, fold);
    // // see if the labelEval is already there
    // CorpusLabelEvaluation labelEval = corpusDao.getCorpusLabelEvaluation(
    // eval.getCorpusName(), eval.getConceptGraphName(),
    // eval.getConceptSetName(), label, foldId);
    // if (labelEval == null) {
    // // not there - add it
    // labelEval = new CorpusLabelEvaluation();
    // labelEval.setCorpus(eval);
    // labelEval.setFoldId(foldId);
    // labelEval.setLabel(label);
    // corpusDao.addCorpusLabelEval(labelEval);
    // }
    // return labelEval;
    // }

    public void setInfoContentEvaluator(InfoContentEvaluator infoContentEvaluator) {
        this.infoContentEvaluator = infoContentEvaluator;
    }

    // /**
    // * create the corpusEvaluation if it doesn't exist
    // *
    // * @param corpusName
    // * @param conceptGraphName
    // * @param conceptSetName
    // * @return
    // */
    // private CorpusEvaluation initEval(String corpusName,
    // String conceptGraphName, String conceptSetName) {
    // CorpusEvaluation eval = this.corpusDao.getCorpus(corpusName,
    // conceptGraphName, conceptSetName);
    // if (eval == null) {
    // eval = new CorpusEvaluation();
    // eval.setConceptGraphName(conceptGraphName);
    // eval.setConceptSetName(conceptSetName);
    // eval.setCorpusName(corpusName);
    // this.corpusDao.addCorpus(eval);
    // }
    // return eval;
    // }

    public void setKernelUtil(KernelUtil kernelUtil) {
        this.kernelUtil = kernelUtil;
    }

    //
    // private void saveLabelStatistic(String conceptID,
    // JointDistribution distroMerged, JointDistribution distroRaw,
    // CorpusLabelEvaluation labelEval, double yEntropy, double minInfo,
    // int distance) {
    // double miMerged = distroMerged.getMutualInformation(yEntropy);
    // double miRaw = distroRaw != null ? distroRaw
    // .getMutualInformation(yEntropy) : 0;
    // if (miMerged > minInfo || miRaw > minInfo) {
    // ConceptLabelStatistic stat = new ConceptLabelStatistic();
    // stat.setCorpusLabel(labelEval);
    // stat.setMutualInfo(miMerged);
    // if (distroRaw != null)
    // stat.setMutualInfoRaw(miRaw);
    // stat.setConceptId(conceptID);
    // stat.setDistance(distance);
    // this.corpusDao.addLabelStatistic(stat);
    // }
    // }

    public void setYtexProperties(Properties ytexProperties) {
        this.ytexProperties = ytexProperties;
    }

    /**
     * save the children of the 'top' parent concepts.
     * 
     * @param labelEval
     * @param parentConceptTopThreshold
     * @param parentConceptEvalThreshold
     * @param cg
     * @param bAll
     *            impute to all concepts/concepts actually in corpus. if we are
     *            imputing to all concepts, filter by infocontent (this includes
     *            hypernyms of concepts in the corpus). else only impute to
     *            conrete concepts in the corpus
     */
    public void storeChildConcepts(List<FeatureRank> listRawRanks, Parameters params, String label, int foldId,
            ConceptGraph cg, boolean bAll) {
        // only include concepts that actually occur in the corpus
        Map<String, Double> conceptICMap = bAll
                ? classifierEvaluationDao.getInfoContent(params.getCorpusName(), params.getConceptGraphName(),
                        params.getConceptSetName())
                : this.infoContentEvaluator.getFrequencies(params.getFreqQuery());
        // get the raw feature evaluations. The imputed feature evaluation is a
        // mixture of the parent feature eval and the raw feature eval.
        Map<String, Double> conceptRawEvalMap = new HashMap<String, Double>(listRawRanks.size());
        for (FeatureRank r : listRawRanks) {
            conceptRawEvalMap.put(r.getFeatureName(), r.getEvaluation());
        }
        // this map will get filled with the links between parent and child
        // concepts for imputation
        Map<FeatureRank, Set<FeatureRank>> childParentMap = bAll ? null
                : new HashMap<FeatureRank, Set<FeatureRank>>();
        // .getFeatureRankEvaluations(params.getCorpusName(),
        // params.getConceptSetName(), null,
        // InfoContentEvaluator.INFOCONTENT, 0,
        // params.getConceptGraphName());
        // get the top parent concepts - use either top N, or those with a
        // cutoff greater than the specified threshold
        // List<ConceptLabelStatistic> listConceptStat =
        // parentConceptTopThreshold != null ? this.corpusDao
        // .getTopCorpusLabelStat(labelEval, parentConceptTopThreshold)
        // : this.corpusDao.getThresholdCorpusLabelStat(labelEval,
        // parentConceptMutualInfoThreshold);
        String propagatedType = params.getMeasure().getName() + SUFFIX_PROP;
        List<FeatureRank> listConceptStat = params.getParentConceptTopThreshold() != null
                ? this.classifierEvaluationDao.getTopFeatures(params.getCorpusName(), params.getConceptSetName(),
                        label, propagatedType, foldId, 0, params.getConceptGraphName(),
                        params.getParentConceptTopThreshold())
                : this.classifierEvaluationDao.getThresholdFeatures(params.getCorpusName(),
                        params.getConceptSetName(), label, propagatedType, foldId, 0, params.getConceptGraphName(),
                        params.getParentConceptEvalThreshold());
        FeatureEvaluation fe = this.initFeatureEval(params, label, foldId,
                params.getMeasure().getName() + (bAll ? SUFFIX_IMPUTED : SUFFIX_IMPUTED_FILTERED));
        // map of concept id to children and the 'best' statistic
        Map<String, FeatureRank> mapChildConcept = new HashMap<String, FeatureRank>();
        // get all the children of the parent concepts
        for (FeatureRank parentConcept : listConceptStat) {
            updateChildren(parentConcept, mapChildConcept, fe, cg, conceptICMap, conceptRawEvalMap, childParentMap,
                    params.getImputeWeight(), params.getMinInfo());
        }
        // save the imputed feature ranks
        List<FeatureRank> features = new ArrayList<FeatureRank>(mapChildConcept.values());
        FeatureRank.sortFeatureRankList(features, new FeatureRank.FeatureRankDesc());
        this.classifierEvaluationDao.saveFeatureEvaluation(fe, features);
        if (!bAll) {
            // save the parent-child links
            for (Map.Entry<FeatureRank, Set<FeatureRank>> childParentEntry : childParentMap.entrySet()) {
                FeatureRank child = childParentEntry.getKey();
                for (FeatureRank parent : childParentEntry.getValue()) {
                    FeatureParentChild parchd = new FeatureParentChild();
                    parchd.setFeatureRankParent(parent);
                    parchd.setFeatureRankChild(child);
                    this.classifierEvaluationDao.saveFeatureParentChild(parchd);
                }
            }
        }
    }

    /**
     * add the children of parentConcept to mapChildConcept. Assign the child
     * the best mutual information value of the parent.
     * 
     * @param parentConcept
     * @param mapChildConcept
     * @param labelEval
     * @param cg
     * @param parentChildMap
     * @param conceptRawEvalMap
     */
    private void updateChildren(FeatureRank parentConcept, Map<String, FeatureRank> mapChildConcept,
            FeatureEvaluation fe, ConceptGraph cg, Map<String, Double> conceptICMap,
            Map<String, Double> conceptRawEvalMap, Map<FeatureRank, Set<FeatureRank>> childParentMap,
            double imputeWeight, double minInfo) {
        ConcRel cr = cg.getConceptMap().get(parentConcept.getFeatureName());
        Set<String> childConcepts = new HashSet<String>();
        addSubtree(childConcepts, cr);
        for (String childConceptId : childConcepts) {
            // only add the child to the map if it exists in the corpus
            if (conceptICMap.containsKey(childConceptId)) {
                FeatureRank chd = mapChildConcept.get(childConceptId);
                // create the child if it does not already exist
                if (chd == null) {
                    chd = new FeatureRank(fe, childConceptId, 0d);
                    mapChildConcept.put(childConceptId, chd);
                }
                // give the child the mutual info of the parent with the highest
                // score
                double rawEvaluation = conceptRawEvalMap.containsKey(childConceptId)
                        ? conceptRawEvalMap.get(childConceptId)
                        : minInfo;
                double imputedEvaluation = (imputeWeight * parentConcept.getEvaluation())
                        + ((1 - imputeWeight) * rawEvaluation);
                if (chd.getEvaluation() < imputedEvaluation) {
                    chd.setEvaluation(imputedEvaluation);
                }
                // add the relationship to the parentChildMap
                // do this only if the childParentMap is not null
                if (childParentMap != null) {
                    Set<FeatureRank> parents = childParentMap.get(chd);
                    if (parents == null) {
                        parents = new HashSet<FeatureRank>(10);
                        childParentMap.put(chd, parents);
                    }
                    parents.add(parentConcept);
                }
            }
        }
    }
}