ubic.pubmedgate.mallet.SimpleMalletRunner.java Source code

Java tutorial

Introduction

Here is the source code for ubic.pubmedgate.mallet.SimpleMalletRunner.java

Source

/*
 * The WhiteText project
 * 
 * Copyright (c) 2012 University of British Columbia
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

package ubic.pubmedgate.mallet;

import gate.Corpus;
import gate.util.FMeasure;

import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.FileAppender;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.PatternLayout;

import ubic.basecode.dataStructure.CountingMap;
import ubic.basecode.dataStructure.DoubleAddingMap;
import ubic.basecode.dataStructure.params.ParameterGrabber;
import ubic.pubmedgate.Config;
import ubic.pubmedgate.ConnectionsDocument;
import ubic.pubmedgate.GateInterface;
import ubic.pubmedgate.mallet.features.TargetFromGATE;
import ubic.pubmedgate.ner.TokenTargetLabeller;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.SequencePairAlignment;

public class SimpleMalletRunner extends MalletRunner {
    protected static Log log = LogFactory.getLog(SimpleMalletRunner.class);

    String inputTargetFeature;
    boolean bio;
    boolean reverse;
    boolean induceFeatures;
    int windowBefore, windowAfter;
    FoldRunner[] foldRunners;
    BrainRegionPipes brainPipes;
    int seed;
    Map<String, String> otherResults;

    protected SimpleMalletRunner() {
        getParams();
    }

    static Logger resultLog;
    static {
        setupResultLogger();
    }

    public void writeOutResults() {
        writeOutResults(truthSet);
    }

    public void writeOutResults(String truthAnnSet) {
        Map<String, String> params = new HashMap<String, String>();
        params.putAll(getParams());
        params.put("size", documents.size() + "");
        params.put("pipeNamesSize", brainPipes.size() + "");
        params.put("pipes", brainPipes.toString());
        List<String> sortedKeys = new LinkedList<String>(params.keySet());

        // sort before adding results
        Collections.sort(sortedKeys);

        FMeasure f = evaluate(truthAnnSet);
        Map<String, String> results = ParameterGrabber.getParams(f.getClass(), f);
        log.info("Params for results size:" + results.size());
        params.putAll(results);

        params.putAll(otherResults);

        // add to end (unsorted)
        sortedKeys.addAll(results.keySet());
        sortedKeys.addAll(otherResults.keySet());

        String line = "";
        for (String key : sortedKeys) {
            line += key + "=" + params.get(key) + ", ";
        }
        line = line.substring(0, line.length() - 2);
        resultLog.info(line);
    }

    public Map<String, String> getParams() {
        Map<String, String> params = new HashMap<String, String>();
        params.putAll(getParams(this.getClass()));
        params.putAll(getParams(this.getClass().getSuperclass()));
        return params;
    }

    public Map<String, String> getParams(Class<?> c) {
        return ParameterGrabber.getParams(c, this);
    }

    private static void setupResultLogger() {
        resultLog = Logger.getLogger("results");
        PatternLayout layout = new PatternLayout("%d{dd MMM yyyy HH:mm:ss} %C{1} %p | %m%n");

        try {
            FileAppender appender = new FileAppender(layout, "results" + System.currentTimeMillis() + ".csv");
            resultLog.addAppender(appender);
            resultLog.setLevel((Level) Level.DEBUG);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /*
     * null annotationSet
     */
    public SimpleMalletRunner(String truthSet, String tokenSet, Corpus corpus, boolean reverse, int windowBefore,
            int windowAfter, int numFolds) {
        this(null, truthSet, tokenSet, corpus, reverse, windowBefore, windowAfter, numFolds);
    }

    public SimpleMalletRunner(String annotationSet, String truthSet, String tokenSet, Corpus corpus,
            boolean reverse, int windowBefore, int windowAfter, int numFolds) {
        this(null, truthSet, tokenSet, corpus, reverse, windowBefore, windowAfter, numFolds, null);
    }

    public SimpleMalletRunner(String annotationSet, String truthSet, String tokenSet, Corpus corpus,
            boolean reverse, int windowBefore, int windowAfter, int numFolds, BrainRegionPipes brainPipes) {

        // for loggin
        if (corpus.getDataStore() == null) {
            dataStoreLocation = "null";
        } else {
            dataStoreLocation = corpus.getDataStore().getStorageUrl();
        }

        otherResults = new HashMap<String, String>();
        this.corpus = corpus;
        documents = GateInterface.getDocuments(corpus);
        this.truthSet = truthSet;
        seed = 1;

        induceFeatures = false;

        foldRunners = null;
        sentenceLevel = false;
        setBio(false);
        // inputTokenSet = ConnectionsDocument.GATETOKENS;
        inputTokenSet = tokenSet;
        this.numFolds = numFolds;

        // false then 1 before and 0 after
        // true then 0 before and 1 after
        this.reverse = reverse;
        this.windowBefore = windowBefore;
        this.windowAfter = windowAfter;

        if (annotationSet == null) {
            if (reverse)
                annotationSet = "MalletReverse";
            else
                annotationSet = "Mallet";
        }
        this.annotationSet = annotationSet;

        // if we don't give brain pipes, then make a new one with the 'good pipes'
        if (brainPipes == null) {
            this.brainPipes = new BrainRegionPipes();
            try {
                this.brainPipes.addAllGoodPipes();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } else {
            // pipes provided by constructor
            this.brainPipes = brainPipes;
        }

        skipAbbrev = false;

        // getParams();
        // resultLog.info( "working" );

    }

    public HashMap<String, Double> weightsTest(String states) {
        DoubleAddingMap<String> weights = new DoubleAddingMap<String>();
        for (FoldRunner foldRunner : foldRunners) {
            // combine all the folds
            DoubleAddingMap<String> foldWeights = foldRunner.weightsTest(states);
            weights.addPutAll(foldWeights);
        }
        return weights;
    }

    public Set<String> getStatePairs() {
        if (!bio) {
            String[] states = { TokenTargetLabeller.BRAIN_TARGET, TokenTargetLabeller.OUTSIDE_TARGET };
            Set<String> result = new HashSet<String>();
            for (String s1 : states) {
                for (String s2 : states) {
                    String statePair = s1 + "," + s2;
                    result.add(statePair);
                }
            }
            return result;
        } else
            return null;
    }

    public void compareFeatures() throws Exception {
        float sum = 0;
        float sumTotal = 0;
        for (FoldRunner foldRunner : foldRunners) {
            // feature count on training
            FeatureCounter counter = new FeatureCounter(this, false);
            counter.run(foldRunner.training);
            Set<String> trainFeatures = counter.getMap().keySet();
            int trainSize = trainFeatures.size();

            // feature count on test
            counter = new FeatureCounter(this, false);
            counter.run(foldRunner.testing);
            CountingMap<String> testMap = counter.getMap();
            Set<String> testFeatures = testMap.keySet();
            int testSize = testFeatures.size();
            long testSizeTotal = counter.getMap().summation();

            // stupid shallow copy error here
            Set<String> intersect = new HashSet<String>(trainFeatures);
            intersect.retainAll(testFeatures);
            int intersectSize = intersect.size();
            for (String key : intersect) {
                testMap.remove(key);
            }
            long intersectTotal = testMap.summation();

            log.info("Train Features=" + trainSize + " Test Features=" + testSize + " Intersect:" + intersectSize
                    + " %unseen" + (float) intersect.size() / (float) testFeatures.size());
            sum += (float) intersectSize / (float) testSize;
            log.info("test summation:" + testSizeTotal + " intersect summation:" + intersectTotal + " %unseen"
                    + intersectTotal / (float) testSizeTotal);
            sumTotal += intersectTotal / (float) testSizeTotal;

            // for ( int i = 0; i < 20; i++ ) {
            // log.info( testMap.keySet().toArray()[i] );
            // }
        }

        log.info("Percent of features not seen in training set: " + (1 - sum / (float) numFolds));
        log.info("Percent of words not seen in training set: " + sumTotal / (float) numFolds);
        otherResults.put("%NewFeatures", (1 - sum / (float) numFolds) + "");
        otherResults.put("%NewWords", (sumTotal / (float) numFolds) + "");
    }

    public void printTopWeights(HashMap<String, Double> weights) {
        // convert to a list
        List<Map.Entry<String, Double>> features = new LinkedList<Map.Entry<String, Double>>(weights.entrySet());

        // sort by the value
        Comparator c = new Comparator<Map.Entry<String, Double>>() {
            public int compare(Map.Entry<String, Double> a, Map.Entry<String, Double> b) {
                return a.getValue().compareTo(b.getValue());
            }
        };
        Collections.sort(features, c);
        for (int i = 0; i < 50; i++) {
            String feature = features.get(i).getKey();
            log.info(i + " " + features.get(i));
            // if ( feature.matches( ".+[/][+-]\\d$" ) ) log.info( i + " " + features.get( i ) );
        }
        log.info("______________________ OTHER DIRECTION ____________________");
        Collections.reverse(features);
        for (int i = 0; i < 50; i++) {
            String feature = features.get(i).getKey();
            log.info(i + " " + features.get(i));
            // if ( feature.matches( ".+[/][+-]\\d$" ) ) log.info( i + " " + features.get( i ) );
        }
    }

    public Collection<Pipe> getAllFoldPipes() throws Exception {

        Collection<Pipe> pipeList = new LinkedList<Pipe>();

        // convert GATE document in to many sentence based instances
        if (!sentenceLevel) {
            pipeList.add(new DocumentToSentencesPipe(inputTokenSet));
        }

        // convert GATE sentence annotations to Mallet tokens
        // this also makes the targets (should be refactored to be a different pipe)
        pipeList.add(new Sentence2TokenSequence(skipAbbrev, inputTokenSet));

        pipeList.addAll(brainPipes.getCurrentPipes());

        return pipeList;
    }

    public Pipe getPipe() throws Exception {
        Collection<Pipe> pipeList = getAllFoldPipes();
        SerialPipes pipeSerial = new SerialPipes(pipeList);

        log.info("number of pipes for all folds:" + pipeSerial.size());
        Pipe pipes;
        if (sentenceLevel) {
            // the pipeSerial should be null
            pipes = pipeSerial;
        } else { // use the cache for abstracts
            // pipes = pipeSerial;
            pipes = new DocCachePipe(pipeSerial);
        }
        return pipes;
    }

    public InstanceList getInstanceList(List<ConnectionsDocument> documents) throws Exception {
        // are we dealing with sentences or abstracts?

        InstanceList instanceList = new InstanceList(new Noop());

        // get documents
        int count = 1;
        for (ConnectionsDocument doc : documents) {
            if (count > 1) {
                log.info("Stopping at " + (count - 1) + " documents");
                break;
            }
            Instance instance = new Instance(doc, null, null, null);
            instanceList.addThruPipe(instance);
            // count++;
        }

        if (sentenceLevel) {
            InstanceList instanceListx = new InstanceList(new DocumentToSentencesPipe(inputTokenSet));
            instanceListx.addThruPipe(instanceList.iterator());
            instanceList = instanceListx;
            log.info(instanceList.size() + " sentences");
        } else {
            log.info(instanceList.size() + " abstracts");
        }

        for (Instance i : instanceList) {
            // log.info( "Training data:" + i.getName() );
        }

        return instanceList;
    }

    public void setSeed(int seed) {
        this.seed = seed;
    }

    public void run() throws Exception {
        InstanceList instanceList = getInstanceList(documents);
        Pipe pipes = getPipe();
        // send the docs through the folds, then expand and classify
        Iterator<InstanceList[]> folds = instanceList.crossValidationIterator(numFolds, seed);

        Object notifyWhenDone = new Object();
        int fold = 0;

        int foldRunnersSize = Config.config.getInt("whitetext.max_threads");
        if (numFolds < foldRunnersSize)
            foldRunnersSize = numFolds;
        foldRunners = new FoldRunner[foldRunnersSize];
        boolean allDone = false;
        // while there is folds to do, do em
        synchronized (notifyWhenDone) {
            while (!allDone) {
                // init threads
                for (int i = 0; i < foldRunners.length; i++) {
                    FoldRunner foldRunner = foldRunners[i];
                    // if we have something to do, and this thread is empty or done, then do it
                    if (folds.hasNext() && (foldRunner == null || foldRunner.isDone())) {
                        log.info("Starting fold:" + ++fold + " of " + numFolds);

                        InstanceList[] foldInstances = folds.next();

                        log.info("Training size:" + foldInstances[0].size() + " Testing size:"
                                + foldInstances[1].size());

                        // TokenSequence tt;
                        InstanceList training = new InstanceList(pipes);
                        training.addThruPipe(foldInstances[0].iterator());

                        InstanceList testing = new InstanceList(pipes);
                        testing.addThruPipe(foldInstances[1].iterator());

                        foldRunners[i] = new FoldRunner(fold, training, testing, notifyWhenDone, this);
                        // start it going
                        new Thread(foldRunners[i], "Thread-" + i).start();
                    }
                }

                int stillRunning = 0;
                for (FoldRunner thread : foldRunners) {
                    if (thread != null && !thread.isDone())
                        stillRunning++;
                }

                log.info("Main thread is waiting on " + stillRunning + " threads");
                if (fold < numFolds)
                    log.info("Next fold is " + (fold + 1) + "of " + numFolds);

                notifyWhenDone.wait();

                // check for all done
                allDone = !folds.hasNext();
                // if we ran out of folds then check to see if everyone is done
                if (allDone) {
                    for (FoldRunner thread : foldRunners) {
                        allDone = allDone && thread.isDone();
                    }
                }

            }
            // log.info( "Abstract cache size:" + pipes.cacheSize() );
            log.info("Before window size = " + windowBefore + " after window size = " + windowAfter);
            log.info("Pipes(" + brainPipes.size() + "):" + brainPipes.toString());

        }

    }

    public List<SequencePairAlignment<java.lang.Object, java.lang.Object>> getTop10(Instance i) {
        if (foldRunners == null)
            throw new RuntimeException("Classifiers not trained yet");
        // find the fold it is a test for and get it's top ten alignments
        for (FoldRunner foldRunner : foldRunners) {
            if (foldRunner.isInTestSet(i)) {
                return foldRunner.getTop10(i);
            }
        }
        return null;
    }

    /*
     * Gets all instances, since features vary based on fold, it returns the test instances
     */
    public List<Instance> getAllInstances() {
        List<Instance> instances = new LinkedList<Instance>();
        for (FoldRunner foldRunner : foldRunners) {
            instances.addAll(foldRunner.getTestInstances());
        }
        return instances;
    }

    public static Log getLog() {
        return log;
    }

    public boolean isBio() {
        return bio;
    }

    public String getInputTargetFeature() {
        return inputTargetFeature;
    }

    public String getInputTokenSet() {
        return inputTokenSet;
    }

    public boolean isReverse() {
        return reverse;
    }

    public int getWindowAfter() {
        return windowAfter;
    }

    public int getWindowBefore() {
        return windowBefore;
    }

    public boolean isInduceFeatures() {
        return induceFeatures;
    }

    public void setBio(boolean bio) {
        if (bio) {
            inputTargetFeature = TargetFromGATE.GATE_BIO_TARGET_FEATURE;
        } else {
            inputTargetFeature = TargetFromGATE.GATE_SIMPLE_TARGET_FEATURE;
        }

        this.bio = bio;
    }

    public Collection<Pipe> getPipes() {
        return brainPipes.getCurrentPipes();
    }

    public void setInduceFeatures(boolean induceFeatures) {
        this.induceFeatures = induceFeatures;
    }

    public boolean getSkipAbbrev() {
        return skipAbbrev;
    }
}