org.apache.ctakes.ytex.kernel.FoldGeneratorImpl.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.ctakes.ytex.kernel.FoldGeneratorImpl.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.ctakes.ytex.kernel;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Random;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.ctakes.ytex.kernel.dao.ClassifierEvaluationDao;
import org.apache.ctakes.ytex.kernel.model.CrossValidationFold;
import org.apache.ctakes.ytex.kernel.model.CrossValidationFoldInstance;

/**
 * utility generates cv fold splits, stores in db. Takes as a command line
 * parameter -prop [property file]. Also reads properties from org.apache.ctakes.ytex.properties.
 * Required properties:
 * <ul>
 * <li>org.apache.ctakes.ytex.corpusName
 * <li>instanceClassQuery
 * </ul>
 * 
 * Optional properties:
 * <ul>
 * <li>minPerClass default 1 minimum number of instances per class/fold. if not
 * enough instances of a specific class, the instance will be repeated across
 * folds. E.g. if you have only one example of one class, and 2 folds, that one
 * example will be duplicated in both folds.
 * <li>rand random number seed, defaults to current time millis
 * <li>org.apache.ctakes.ytex.splitName default null cv_fold.split_name
 * <li>folds default 2
 * <li>runs default 5
 * </ul>
 * 
 * @author vijay
 */
public class FoldGeneratorImpl implements FoldGenerator {
    private static final Log log = LogFactory.getLog(FoldGeneratorImpl.class);

    /**
     * iterate through the labels, split instances into folds
     * 
     * @param mapClassToInstanceId
     * @param nFolds
     * @param nMinPerClass
     * @param nSeed
     * @return list with nFolds sets of instance ids corresponding to the folds
     */
    private static List<Set<Long>> createFolds(Map<String, List<Long>> mapClassToInstanceId, int nFolds,
            int nMinPerClass, Random r) {
        List<Set<Long>> folds = new ArrayList<Set<Long>>(nFolds);
        Map<String, List<Set<Long>>> mapLabelFolds = new HashMap<String, List<Set<Long>>>();
        for (Map.Entry<String, List<Long>> classToInstanceId : mapClassToInstanceId.entrySet()) {
            List<Long> instanceIds = classToInstanceId.getValue();
            Collections.shuffle(instanceIds, r);
            List<Set<Long>> classFolds = new ArrayList<Set<Long>>(nFolds);
            int blockSize = instanceIds.size() / nFolds;
            for (int i = 0; i < nFolds; i++) {
                Set<Long> foldInstanceIds = new HashSet<Long>(blockSize);
                if (instanceIds.size() <= nMinPerClass) {
                    // we don't have minPerClass for the given class
                    // just add all of them to each fold
                    foldInstanceIds.addAll(instanceIds);
                } else if (blockSize < nMinPerClass) {
                    // too few of the given class - just randomly select
                    // nMinPerClass
                    double fraction = (double) nMinPerClass / (double) instanceIds.size();
                    // iterate through the list, start somewhere in the middle
                    int instanceIdIndex = (int) (r.nextDouble() * instanceIds.size());
                    while (foldInstanceIds.size() < nMinPerClass) {
                        // go back to beginning of list if we hit the end
                        if (instanceIdIndex >= instanceIds.size()) {
                            instanceIdIndex = 0;
                        }
                        // randomly select this line
                        if (r.nextDouble() <= fraction) {
                            long instanceId = instanceIds.get(instanceIdIndex);
                            foldInstanceIds.add(instanceId);
                        }
                        // go to next line
                        instanceIdIndex++;
                    }
                } else {
                    int nStart = i * blockSize;
                    int nEnd = (i == nFolds - 1) ? instanceIds.size() : nStart + blockSize;
                    for (int instanceIdIndex = nStart; instanceIdIndex < nEnd; instanceIdIndex++) {
                        foldInstanceIds.add(instanceIds.get(instanceIdIndex));
                    }
                }
                classFolds.add(foldInstanceIds);
            }
            mapLabelFolds.put(classToInstanceId.getKey(), classFolds);
        }
        for (int i = 0; i < nFolds; i++) {
            Set<Long> foldInstanceIds = new HashSet<Long>();
            for (List<Set<Long>> labelFold : mapLabelFolds.values()) {
                foldInstanceIds.addAll(labelFold.get(i));
            }
            folds.add(foldInstanceIds);
        }
        return folds;
    }

    @SuppressWarnings("static-access")
    public static void main(String args[]) throws ParseException, IOException {
        Options options = new Options();
        options.addOption(OptionBuilder.withArgName("prop").hasArg()
                .withDescription("property file with query to retrieve instance id - label - class triples")
                .create("prop"));
        // OptionGroup group = new OptionGroup();
        // group
        // .addOption(OptionBuilder
        // .withArgName("query")
        // .hasArg()
        // .withDescription(
        // "query to retrieve instance id - label - class triples")
        // .create("query"));
        // group
        // .addOption(OptionBuilder
        // .withArgName("prop")
        // .hasArg()
        // .withDescription(
        // "property file with query to retrieve instance id - label - class triples")
        // .create("prop"));
        // group.isRequired();
        // options.addOptionGroup(group);
        // options.addOption(OptionBuilder.withArgName("name").hasArg()
        // .isRequired().withDescription("name. required").create("name"));
        // options.addOption(OptionBuilder.withArgName("runs").hasArg()
        // .withDescription("number of runs, default 1").create("runs"));
        // options.addOption(OptionBuilder.withArgName("folds").hasArg()
        // .withDescription("number of folds, default 4").create("folds"));
        // options.addOption(OptionBuilder.withArgName("minPerClass").hasArg()
        // .withDescription("minimum instances per class, default 1")
        // .create("minPerClass"));
        // options.addOption(OptionBuilder.withArgName("rand").hasArg()
        // .withDescription(
        // "random number seed; default current time in millis")
        // .create("rand"));
        try {
            if (args.length == 0)
                printHelp(options);
            else {
                CommandLineParser parser = new GnuParser();
                CommandLine line = parser.parse(options, args);
                String propFile = line.getOptionValue("prop");
                Properties props = FileUtil.loadProperties(propFile, true);
                // Integer rand = line.hasOption("rand") ? Integer.parseInt(line
                // .getOptionValue("rand")) : null;
                // int runs = Integer.parseInt(line.getOptionValue("runs",
                // "1"));
                // int minPerClass = Integer.parseInt(line.getOptionValue(
                // "minPerClass", "1"));
                // int folds = Integer.parseInt(line.getOptionValue("folds",
                // "4"));
                String corpusName = props.getProperty("org.apache.ctakes.ytex.corpusName");
                String splitName = props.getProperty("org.apache.ctakes.ytex.splitName");
                String query = props.getProperty("instanceClassQuery");
                int folds = Integer.parseInt(props.getProperty("folds", "2"));
                int runs = Integer.parseInt(props.getProperty("runs", "5"));
                int minPerClass = Integer.parseInt(props.getProperty("minPerClass", "1"));
                Integer rand = props.containsKey("rand") ? Integer.parseInt(props.getProperty("rand")) : null;
                boolean argsOk = true;
                if (corpusName == null) {
                    log.error("missing parameter: org.apache.ctakes.ytex.corpusName");
                    argsOk = false;
                }
                if (query == null) {
                    log.error("missing parameter: instanceClassQuery");
                    argsOk = false;
                }
                if (!argsOk) {
                    printHelp(options);
                    System.exit(1);
                } else {
                    KernelContextHolder.getApplicationContext().getBean(FoldGenerator.class)
                            .generateRuns(corpusName, splitName, query, folds, minPerClass, rand, runs);
                }
            }
        } catch (ParseException pe) {
            printHelp(options);
        }
    }

    private static void printHelp(Options options) {
        HelpFormatter formatter = new HelpFormatter();
        formatter.printHelp(
                "java org.apache.ctakes.ytex.kernel.FoldGeneratorImpl splits training data into mxn training/test sets for mxn-fold cross validation",
                options);
    }

    ClassifierEvaluationDao classifierEvaluationDao;

    KernelUtil kernelUtil;

    /**
     * generate folds for a run
     * 
     * @param labels
     * @param mapInstanceToClassLabel
     * @param name
     * @param splitName
     * @param run
     * @param query
     * @param nFolds
     * @param nMinPerClass
     * @param r
     */
    public void generateFolds(Set<String> labels, InstanceData instances, String corpusName, String splitName,
            int run, String query, int nFolds, int nMinPerClass, Random r) {
        for (String label : instances.getLabelToInstanceMap().keySet()) {
            // there should not be any runs/folds/train test split - just unpeel
            // until we get to the instance - class map
            SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>> runMap = instances
                    .getLabelToInstanceMap().get(label);
            SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>> foldMap = runMap.values().iterator()
                    .next();
            SortedMap<Boolean, SortedMap<Long, String>> trainMap = foldMap.values().iterator().next();
            SortedMap<Long, String> mapInstanceIdToClass = trainMap.values().iterator().next();
            List<Set<Long>> folds = createFolds(nFolds, nMinPerClass, r, mapInstanceIdToClass);
            // insert the folds
            insertFolds(folds, corpusName, splitName, label, run);
        }
    }

    /**
     * inver the map of instance id to class, call createFolds
     * 
     * @param nFolds
     * @param nMinPerClass
     * @param r
     * @param mapInstanceIdToClass
     * @return
     */
    private List<Set<Long>> createFolds(int nFolds, int nMinPerClass, Random r,
            SortedMap<Long, String> mapInstanceIdToClass) {
        // invert the mapInstanceIdToClass
        Map<String, List<Long>> mapClassToInstanceId = new TreeMap<String, List<Long>>();
        for (Map.Entry<Long, String> instance : mapInstanceIdToClass.entrySet()) {
            String className = instance.getValue();
            long instanceId = instance.getKey();
            List<Long> classInstanceIds = mapClassToInstanceId.get(className);
            if (classInstanceIds == null) {
                classInstanceIds = new ArrayList<Long>();
                mapClassToInstanceId.put(className, classInstanceIds);
            }
            classInstanceIds.add(instanceId);
        }
        // stratified split into folds
        List<Set<Long>> folds = createFolds(mapClassToInstanceId, nFolds, nMinPerClass, r);
        return folds;
    }

    /*
     * (non-Javadoc)
     * 
     * @see org.apache.ctakes.ytex.kernel.FoldGenerator#generateRuns(java.lang.String,
     * java.lang.String, int, int, java.lang.Integer, int)
     */
    @Override
    public void generateRuns(String corpusName, String splitName, String query, int nFolds, int nMinPerClass,
            Integer nSeed, int nRuns) {
        Random r = new Random(nSeed != null ? nSeed : System.currentTimeMillis());
        SortedSet<String> labels = new TreeSet<String>();
        InstanceData instances = kernelUtil.loadInstances(query);
        this.getClassifierEvaluationDao().deleteCrossValidationFoldByName(corpusName, splitName);
        for (int run = 1; run <= nRuns; run++) {
            generateFolds(labels, instances, corpusName, splitName, run, query, nFolds, nMinPerClass, r);
        }
    }

    public ClassifierEvaluationDao getClassifierEvaluationDao() {
        return classifierEvaluationDao;
    }

    public KernelUtil getKernelUtil() {
        return kernelUtil;
    }

    /**
     * insert the folds into the database
     * 
     * @param folds
     * @param corpusName
     * @param run
     */
    private void insertFolds(List<Set<Long>> folds, String corpusName, String splitName, String label, int run) {
        // iterate over fold numbers
        for (int foldNum = 1; foldNum <= folds.size(); foldNum++) {
            Set<CrossValidationFoldInstance> instanceIds = new HashSet<CrossValidationFoldInstance>();
            // iterate over instances in each fold
            for (int trainFoldNum = 1; trainFoldNum <= folds.size(); trainFoldNum++) {
                // add the instance, set the train flag
                for (long instanceId : folds.get(trainFoldNum - 1))
                    instanceIds.add(new CrossValidationFoldInstance(instanceId, trainFoldNum != foldNum));
            }
            classifierEvaluationDao
                    .saveFold(new CrossValidationFold(corpusName, splitName, label, run, foldNum, instanceIds));
            // insert test set
            // classifierEvaluationDao.saveFold(new CrossValidationFold(name,
            // label, run, foldNum, false, folds.get(foldNum - 1)));
            // insert training set
            // Set<Integer> trainInstances = new TreeSet<Integer>();
            // for (int trainFoldNum = 1; trainFoldNum <= folds.size();
            // trainFoldNum++) {
            // if (trainFoldNum != foldNum)
            // trainInstances.addAll(folds.get(trainFoldNum - 1));
            // }
            // classifierEvaluationDao.saveFold(new CrossValidationFold(name,
            // label, run, foldNum, true, trainInstances));
        }
    }

    public void setClassifierEvaluationDao(ClassifierEvaluationDao classifierEvaluationDao) {
        this.classifierEvaluationDao = classifierEvaluationDao;
    }

    public void setKernelUtil(KernelUtil kernelUtil) {
        this.kernelUtil = kernelUtil;
    }

    @Override
    public SortedMap<String, SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>>> generateRuns(
            SortedMap<String, SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>>> labelToInstanceMap,
            int nFolds, int nMinPerClass, Integer nSeed, int nRuns) {
        // allocate map to return
        SortedMap<String, SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>>> labelToInstanceFoldMap = new TreeMap<String, SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>>>();
        // initialize random seed
        Random r = new Random(nSeed != null ? nSeed : System.currentTimeMillis());
        // iterate over labels
        for (Map.Entry<String, SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>>> labelRun : labelToInstanceMap
                .entrySet()) {
            String label = labelRun.getKey();
            // extract the instance id - class map
            SortedMap<Long, String> instanceClassMap = labelRun.getValue().get(0).get(0).get(true);
            // allocate the run to fold map
            SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>> runMap = new TreeMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>>();
            labelToInstanceFoldMap.put(label, runMap);
            // iterate over runs
            for (int run = 1; run <= nRuns; run++) {
                // generate folds for run
                List<Set<Long>> folds = createFolds(nFolds, nMinPerClass, r, instanceClassMap);
                SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>> foldMap = new TreeMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>();
                // add the fold map to the run map
                runMap.put(run, foldMap);
                // iterate over folds
                for (int trainFoldNum = 1; trainFoldNum <= folds.size(); trainFoldNum++) {
                    // add train/test sets for the fold
                    SortedMap<Boolean, SortedMap<Long, String>> trainTestMap = new TreeMap<Boolean, SortedMap<Long, String>>();
                    foldMap.put(trainFoldNum, trainTestMap);
                    trainTestMap.put(true, new TreeMap<Long, String>());
                    trainTestMap.put(false, new TreeMap<Long, String>());
                    // populate the train/test sets
                    Set<Long> testIds = folds.get(trainFoldNum - 1);
                    // iterate over all instances
                    for (Map.Entry<Long, String> instanceClass : instanceClassMap.entrySet()) {
                        long instanceId = instanceClass.getKey();
                        String clazz = instanceClass.getValue();
                        // add the instance to the test set if it is in testIds,
                        // else to the train set
                        trainTestMap.get(!testIds.contains(instanceId)).put(instanceId, clazz);
                    }
                }
            }
        }
        return labelToInstanceFoldMap;
    }
}