org.mitre.ccv.weka.mapred.ClassifyInstances.java Source code

Java tutorial

Introduction

Here is the source code for org.mitre.ccv.weka.mapred.ClassifyInstances.java

Source

/**
 * Created on APril 13, 2009.
 *
 * Copyright 2010- The MITRE Corporation. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you  may not 
 * use this file except in compliance with the License. You may obtain a copy of 
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software 
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 
 * License for the specific language governing permissions andlimitations under
 * the License.
 *
 * $Id$
 */
package org.mitre.ccv.weka.mapred;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.net.URI;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.hadoop.mapred.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

import org.mitre.ccv.index.IndexedCompleteCompositionVector;
import org.mitre.ccv.index.IndexedCompositionDistribution;
import org.mitre.mapred.io.StringDoublePairWritable;
import org.mitre.ccv.weka.AbstractWekaCompleteCompositionVector;
import org.mitre.ccv.weka.AbstractWekaCompleteCompositionVector.LabeledInstance;

import org.mitre.mapred.fs.FileUtils;

import weka.core.Attribute;
import weka.core.Instance;

/**
 * A Map-Reduce class that uses a Complete Composition Vector Weka classifiers 
 * {@link org.mitre.ccv.weka.AbstractWekaCompleteCompositionVector}, J48 decision tree
 * {@link org.mitre.ccv.weka.CompleteCompositionVectorJ48} and an SMO SVM
 * {@link org.mitre.ccv.weka.CompleteCompositionVectorSMO} to
 * classify complete composition vectors generated from sequeces.
 * <P>
 * This can only classify (i.e., doesn't do the learning). A pre-built trained model is required.
 * Also, the confidence measurements from the SMO seem to be very low (not very meaningful).
 * <P>
 * This currently is a memory hog and can trigger {@link OutOfMemoryError}
 * exceptions (many of which are not from this class).
 *
 * <p>Some properties (<code>-D property=value</code>) supported are:
 * <ul>
 * <li>weka.output.sortby which determines the sort order of the output.</li>
 * </ul>
 * 
 * @see {@link org.mitre.ccv.weka.CompleteCompositionVectorJ48}
 * @author Marc Colosimo
 */
public class ClassifyInstances extends Configured implements Tool {

    /**
     * This uncovers several bugs and limitations of hadoop and JDK6
     * - lots of OutOfMemoryExceptions all over the place because this is highly threaded
     *
     * GC issues caused by this mapping, since we create a new DistributionIndex each time.
     *
     * Several options exist such as
     *  - Use parallel collector (also known as the throughput collector) with "-XX:+UseParallelGC"
     *    and (optionally) enable parallel compaction with "-XX:+UseParallelOldGC".
     *  - Use concurrent collector with "-XX:+UseConcMarkSweepGC"
     *    and (optionally) enable incremental mode with "-XX:+CMSIncrementalMode"
     *  - play with the heap options
     *
     * See <http://java.sun.com/javase/technologies/hotspot/gc/gc_tuning_6.html> for massive list of options.
     *
     * See <http://hadoop.apache.org/core/docs/current/mapred_tutorial.html#Task+Execution+%26+Environment>
     * mapred.child.java.opts
     *
     * Also if we run this with many maps on the same machine (-m 50 out of 56 on 13 mixed machines)
     * we run out of memory and expose this bug
     * <http://bugs.sun.com/bugdatabase/view_bug.do?bug_id=6521677>
     * 
     */
    private static final Log LOG = LogFactory.getLog(ClassifyInstances.class);
    public static final String MODEL_PATH = "weka.classifier.model.path";

    /** Property for setting classifier class name. */
    public static final String CLASSIFIER = "weka.classifier.class";

    /** Default classifier class name. */
    public static final String DEFAULT_CLASSIFIER = "org.mitre.ccv.weka.CompleteCompositionVectorJ48";

    /** Property name for how to format (tab-delimited) and sort the output. */
    public static final String SORT_OUTPUT_BY = "weka.output.sortby";
    /** Sort the output by the sample name (default). Format: sample name, class name, confidence. */
    public static final String SORTBY_SAMPLE = "sample";
    /** Sort the output by the class instance (name). Format: class name, confidence, sample name. */
    public static final String SORTBY_CLASS = "class";
    /** Sort the output by the classifiers confidence (highest to lowest). Format: class name, confidence, sample name. */
    public static final String SORTBY_CONFIDENCE = "confidence";
    /** Sort the output by the class instance then confidence in that class. Format: class name, confidence, sample name. */
    public static final String SORTBY_CLASS_CONFIDENCE = "class.confidence";

    public static class CompositionVectorJ48Map extends MapReduceBase
            implements Mapper<Text, Text, WritableComparable, Text> {

        private static final Log MAP_LOG = LogFactory.getLog(CompositionVectorJ48Map.class);
        private AbstractWekaCompleteCompositionVector classifier = null;
        private FileSystem localFs = null;
        private Attribute classAttribute = null;
        //private Text classText = new Text();
        private String sortBy = null;
        private Text outText = new Text();
        //private Text outSampleName = new Text();
        private StringDoublePairWritable outClassConfidence = new StringDoublePairWritable();

        @SuppressWarnings("unchecked")
        public <T> T newClassifierInstance(Class<T> theClassifier, String inputFile) {
            T result;
            try {
                Constructor<T> meth = theClassifier.getDeclaredConstructor(new Class[] { String.class });
                //meth.setAccessible(true);
                result = meth.newInstance(inputFile);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
            return result;
        }

        @Override
        public void configure(JobConf conf) {
            sortBy = conf.get(SORT_OUTPUT_BY, null);
            try {
                this.localFs = FileSystem.getLocal(new Configuration());
                String listInput = conf.get(MODEL_PATH, null);
                Path[] localFiles = DistributedCache.getLocalCacheFiles(conf);

                if (listInput != null && localFiles.length != 0) {

                    for (int cv = 0; cv < localFiles.length; cv++) {
                        if (!localFiles[cv].getName().equals(listInput)) {
                            continue;
                        }
                        Class<? extends AbstractWekaCompleteCompositionVector> classifierClass = (Class<? extends AbstractWekaCompleteCompositionVector>) conf
                                .getClassByName(conf.get(CLASSIFIER, DEFAULT_CLASSIFIER));

                        // We might choke loading the model into memory!
                        // It would be nice to use nio classes especially MappedByteBuffer
                        //AbstractWekaCompleteCompositionVector classifierClass =
                        //        conf.getClass(CLASSIFIER, CompleteCompositionVectorJ48.class);
                        this.classifier = newClassifierInstance(classifierClass, localFiles[cv].toString());
                        //Constructor<CompleteCompositionVectorJ48> meth =
                        //        CompleteCompositionVectorJ48.class.getDeclaredConstructor(new Class[]{String.class});

                        //this.classifier = meth.newInstance(localFiles[cv].toString());

                        //this.classifier = ReflectionUtils.newInstance(CompleteCompositionVectorJ48.class, conf);
                        //this.classifier = new CompleteCompositionVectorJ48(localFiles[cv].toString());
                        this.classAttribute = (Attribute) this.classifier.getAttributes()
                                .elementAt(this.classifier.getClassIndex());
                        break;
                    }
                }
                // hopefully throws exception when map is called, just log it here
            } catch (FileNotFoundException ex) {
                MAP_LOG.fatal("Unable to get cached file", ex);
                //  } catch (ClassNotFoundException ex) {
                //      MAP_LOG.fatal("Unable to get cached file", ex);
            } catch (IOException ex) {
                MAP_LOG.fatal("Unable to get cached file", ex);
            } catch (ClassNotFoundException ex) {
                throw new RuntimeException(ex);
            }
            if (this.classifier == null) {
                MAP_LOG.warn("Classifier was not loaded!");
            }
        }

        public void formatOuput(OutputCollector<WritableComparable, Text> output, Text sampleName, String clsValue,
                Double clsDist) throws IOException {

            this.outClassConfidence.set(clsValue, clsDist);
            //this.outSampleName.set(sampleName);
            if (this.sortBy != null && !this.sortBy.equals(SORTBY_SAMPLE)) {
                if (this.sortBy.equals(SORTBY_CONFIDENCE)) {
                    // reverse sort order
                    this.outClassConfidence.compareValues(true);
                } else if (this.sortBy.equals(SORTBY_CLASS_CONFIDENCE)) {
                    // keep natural order for the key, reverse for values
                    this.outClassConfidence.compareKeyValues(false, true);
                } // SORTBY_CLASS (key) is default
                output.collect(this.outClassConfidence, sampleName);
            } else {

                this.outText.set(String.format("%s\t%f", clsValue, clsDist));
                output.collect(sampleName, this.outText);
            }
        }

        @Override
        public void map(Text key, Text value, OutputCollector<WritableComparable, Text> output, Reporter reporter)
                throws IOException {

            if (this.classifier == null) {
                throw new IOException("No classifier!");
            }
            if (this.classAttribute == null) {
                throw new IOException("Have classifier, but no classes!");
            }
            int start = this.classifier.getBegin();
            int end = this.classifier.getEnd();

            /**
             * IndexedCompositionDistribution will throw IllegalArgumentException
             */
            String seq = value.toString();
            if (end > seq.length()) {
                LOG.info(String.format("%s length (%d) is smaller than the end window size (%d). Skipping..",
                        key.toString(), seq.length(), end));
                return;
            }
            // DEBUG
            if (seq.length() > 1000000) {
                LOG.warn(String.format("Skipping %s because it is to long (%d)", key.toString(), seq.length()));
                System.out.printf("Skipping %s because it is to long (%d)", key.toString(), seq.length());
                return;
            }

            /**
             * Generate the complete composition vector for this sample.
             * IndexedCompositionDistribution might throw IllegalArgumentException.
             */
            reporter.setStatus(
                    String.format("Classifying %s using window sizes of %d to %d", key.toString(), start, end));
            IndexedCompleteCompositionVector ccv;
            try {
                IndexedCompositionDistribution cd = new IndexedCompositionDistribution(null, 1, seq, start, end);
                ccv = new IndexedCompleteCompositionVector(key.toString(), 1, start, end, cd);
            } catch (OutOfMemoryError me) {
                System.gc();
                LOG.warn(String.format(
                        "Ran out of memory while generating complete composition vector for %s of %d length",
                        key.toString(), seq.length()), me);
                System.err.printf(
                        "Ran out of memory while generating complete composition vector for %s of %d length",
                        key.toString(), seq.length());

                output.collect(new Text("[ERROR: Out of Memory generating CCV]"), key);
                return;
            }

            /** Generate an instance for this sample. */
            try {
                /** This can throw an exception for something */
                Instance inst = classifier.getInstanceSparse(ccv);
                /** To classify, we need to be part of a DataSet/Instance */
                LabeledInstance li = classifier.runClassifier(inst, classAttribute);

                int clsValue = (int) li.inst.classValue();
                //System.out.printf("%s\t%s\t%f\n", key.toString(),
                //        classAttribute.value(clsValue), li.clsDist[clsValue]);
                //this.classText.set(String.format("%s\t%f", classAttribute.value(clsValue), li.clsDist[clsValue]));
                //output.collect(key, this.classText);
                this.formatOuput(output, key, classAttribute.value(clsValue), li.clsDist[clsValue]);
            } catch (OutOfMemoryError me) {
                System.gc();
                LOG.warn(String.format("Ran out of memory while classifying vector for %s of %d length",
                        key.toString(), seq.length()), me);
                System.err.printf("Ran out of memory while classifying vector for %s of %d length", key.toString(),
                        seq.length());

            } catch (Exception ex) {
                MAP_LOG.warn("Exception when classifing!", ex);
            }
        }
    }

    public int initJob(JobConf jobConf, String modelInput, String input, String output) throws Exception {
        JobConf conf = new JobConf(jobConf, ClassifyInstances.class);
        conf.setJobName("ClassifyInstances");

        Path listPath = new Path(modelInput);
        FileSystem fs = listPath.getFileSystem(conf);
        if (modelInput != null) {
            Path qPath = fs.makeQualified(listPath);
            LOG.info(String.format("Caching model file %s", qPath.toString()));
            URI listURI = new URI(qPath.toString());
            DistributedCache.addCacheFile(listURI, conf);
            conf.set(MODEL_PATH, listPath.getName());
        } else {
            throw new Exception("ClassifyInstances requires a model!");
        }

        // Set up mapper
        SequenceFileInputFormat.setInputPaths(conf, new Path(input));
        conf.setInputFormat(SequenceFileInputFormat.class);
        conf.setMapperClass(CompositionVectorJ48Map.class);
        // Painful way to set job output key class because we can't use WritableComparable
        String sortBy = conf.get(SORT_OUTPUT_BY, null);
        if (sortBy != null && !sortBy.equals(SORTBY_SAMPLE)) {
            LOG.info("Sorting output by class name and/or confidence.");
            conf.setOutputKeyClass(StringDoublePairWritable.class);
        } else {
            LOG.info("Sorting output by sample name.");
            conf.setOutputKeyClass(Text.class);
        }
        conf.setOutputValueClass(Text.class); // job output value class

        // Uses default reducer (IdentityReducer) and save it to a plain text file
        conf.setOutputFormat(TextOutputFormat.class);
        TextOutputFormat.setOutputPath(conf, new Path(output));

        JobClient.runJob(conf);
        return 0;
    }

    static int printUsage() {
        System.out.println(
                "ClassifyInstances [-m <maps>] [-r <reduces>] [-D <property=value>] [-C <classifier class name>]<model> <input> <output>");
        System.out.println("Sorting of output is controlled by the 'weka.output.sortby' property");
        System.out.println(
                "Known classifiers are:\n\torg.mitre.ccv.weka.CompleteCompositionVectorJ48 (default)\n\torg.mitre.ccv.weka.CompleteCompositionVectorSMO\n"
                        + "Models must be generated by the same classifier used here.");
        return -1;
    }

    @Override
    public int run(String[] args) throws Exception {
        JobConf conf = new JobConf(getConf());

        ArrayList<String> other_args = new ArrayList<String>();
        for (int i = 0; i < args.length; ++i) {
            try {
                if ("-m".equals(args[i])) {
                    conf.setNumMapTasks(Integer.parseInt(args[++i]));
                } else if ("-r".equals(args[i])) {
                    conf.setNumReduceTasks(Integer.parseInt(args[++i]));
                } else if ("-D".equals(args[i])) {
                    String[] props = args[++i].split("=");
                    conf.set(props[0], props[1]);
                } else if ("-libjars".equals(args[i])) {
                    conf.set("tmpjars", FileUtils.validateFiles(args[++i], conf));

                    URL[] libjars = FileUtils.getLibJars(conf);
                    if (libjars != null && libjars.length > 0) {
                        // Add libjars to client/tasks classpath
                        conf.setClassLoader(new URLClassLoader(libjars, conf.getClassLoader()));
                        // Adds libjars to our classpath
                        Thread.currentThread().setContextClassLoader(
                                new URLClassLoader(libjars, Thread.currentThread().getContextClassLoader()));
                    }
                } else if ("-C".equals(args[i])) {
                    conf.set(CLASSIFIER, args[++i]);
                } else {
                    other_args.add(args[i]);
                }
            } catch (NumberFormatException except) {
                System.out.println("ERROR: Integer expected instead of " + args[i]);
                return printUsage();
            } catch (ArrayIndexOutOfBoundsException except) {
                System.out.println("ERROR: Required parameter missing from " + args[i - 1]);
                return printUsage();
            }
        }
        // Make sure there are exactly 3 parameters left.
        if (other_args.size() != 3) {
            System.out.println("ERROR: Wrong number of parameters: " + other_args.size() + " instead of 3.");

            return printUsage();
        }

        return initJob(conf, other_args.get(0), other_args.get(1), other_args.get(2));
    }

    static public void main(String[] args) throws Exception {
        int res = ToolRunner.run(new Configuration(), new ClassifyInstances(), args);
        System.exit(res);
    }
}