org.apache.ctakes.temporal.eval.EvaluationOfTimeSpans.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.ctakes.temporal.eval.EvaluationOfTimeSpans.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.ctakes.temporal.eval;

import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;

import org.apache.ctakes.temporal.ae.BackwardsTimeAnnotator;
import org.apache.ctakes.temporal.ae.CRFTimeAnnotator;
import org.apache.ctakes.temporal.ae.ConstituencyBasedTimeAnnotator;
import org.apache.ctakes.temporal.ae.MetaTimeAnnotator;
import org.apache.ctakes.temporal.ae.TimeAnnotator;
import org.apache.ctakes.temporal.ae.feature.selection.FeatureSelection;
import org.apache.ctakes.typesystem.type.textsem.TimeMention;
import org.apache.ctakes.typesystem.type.textspan.Segment;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.eval.AnnotationStatistics;
import org.cleartk.ml.CleartkAnnotator;
import org.cleartk.ml.CleartkSequenceAnnotator;
import org.cleartk.ml.Instance;
import org.cleartk.ml.crfsuite.CrfSuiteStringOutcomeDataWriter;
import org.cleartk.ml.feature.transform.InstanceDataWriter;
import org.cleartk.ml.feature.transform.InstanceStream;
import org.cleartk.ml.jar.DefaultDataWriterFactory;
import org.cleartk.ml.jar.DefaultSequenceDataWriterFactory;
import org.cleartk.ml.jar.DirectoryDataWriterFactory;
import org.cleartk.ml.jar.GenericJarClassifierFactory;
import org.cleartk.ml.jar.JarClassifierBuilder;
import org.cleartk.ml.liblinear.LibLinearStringOutcomeDataWriter;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.lexicalscope.jewel.cli.CliFactory;
import com.lexicalscope.jewel.cli.Option;

public class EvaluationOfTimeSpans extends EvaluationOfAnnotationSpans_ImplBase {

    static interface Options extends Evaluation_ImplBase.Options {

        @Option(longName = "featureSelectionThreshold", defaultValue = "1")
        public float getFeatureSelectionThreshold();

        @Option(longName = "SMOTENeighborNumber", defaultValue = "0")
        public float getSMOTENeighborNumber();

        @Option(shortName = "b")
        public boolean getRunBackwards();

        @Option(shortName = "f")
        public boolean getRunForwards();

        @Option(shortName = "p")
        public boolean getRunParserBased();

        @Option(shortName = "c")
        public boolean getRunCrfBased();

        @Option
        public boolean getSkipTrain();
    }

    public static void main(String[] args) throws Exception {
        Options options = CliFactory.parseArguments(Options.class, args);
        List<Integer> trainItems = null;
        List<Integer> devItems = null;
        List<Integer> testItems = null;

        List<Integer> patientSets = options.getPatients().getList();
        if (options.getXMLFormat() == XMLFormat.I2B2) {
            trainItems = I2B2Data.getTrainPatientSets(options.getXMLDirectory());
            devItems = I2B2Data.getDevPatientSets(options.getXMLDirectory());
            testItems = I2B2Data.getTestPatientSets(options.getXMLDirectory());
        } else {
            trainItems = THYMEData.getPatientSets(patientSets, options.getTrainRemainders().getList());
            devItems = THYMEData.getPatientSets(patientSets, options.getDevRemainders().getList());
            testItems = THYMEData.getPatientSets(patientSets, options.getTestRemainders().getList());
        }

        List<Integer> allTrain = new ArrayList<>(trainItems);
        List<Integer> allTest = null;

        if (options.getTest()) {
            allTrain.addAll(devItems);
            allTest = new ArrayList<>(testItems);
        } else {
            allTest = new ArrayList<>(devItems);
        }

        // specify the annotator classes to use
        List<Class<? extends JCasAnnotator_ImplBase>> annotatorClasses = Lists.newArrayList();
        if (options.getRunBackwards())
            annotatorClasses.add(BackwardsTimeAnnotator.class);
        if (options.getRunForwards())
            annotatorClasses.add(TimeAnnotator.class);
        if (options.getRunParserBased())
            annotatorClasses.add(ConstituencyBasedTimeAnnotator.class);
        if (options.getRunCrfBased())
            annotatorClasses.add(CRFTimeAnnotator.class);
        if (annotatorClasses.size() == 0) {
            // run all
            annotatorClasses.add(BackwardsTimeAnnotator.class);
            annotatorClasses.add(TimeAnnotator.class);
            annotatorClasses.add(ConstituencyBasedTimeAnnotator.class);
            annotatorClasses.add(CRFTimeAnnotator.class);
        }
        Map<Class<? extends JCasAnnotator_ImplBase>, String[]> annotatorTrainingArguments = Maps.newHashMap();

        // THYME best params: Backwards: 0.1, CRF 0.3, Time 0.1, Constituency 0.3
        // i2b2 best params: Backwards 0.1, CRF 3.0, Time 0.1, Constituency 0.3
        //      String gridParam = "0.01";
        annotatorTrainingArguments.put(BackwardsTimeAnnotator.class, new String[] { "-c", "0.1" });
        annotatorTrainingArguments.put(TimeAnnotator.class, new String[] { "-c", "0.1" });
        annotatorTrainingArguments.put(ConstituencyBasedTimeAnnotator.class, new String[] { "-c", "0.3" });
        annotatorTrainingArguments.put(CRFTimeAnnotator.class, new String[] { "-p", "c2=" + "0.3" });

        // run one evaluation per annotator class
        final Map<Class<?>, AnnotationStatistics<?>> annotatorStats = Maps.newHashMap();
        for (Class<? extends JCasAnnotator_ImplBase> annotatorClass : annotatorClasses) {
            EvaluationOfTimeSpans evaluation = new EvaluationOfTimeSpans(new File("target/eval/time-spans"),
                    options.getRawTextDirectory(), options.getXMLDirectory(), options.getXMLFormat(),
                    options.getSubcorpus(), options.getXMIDirectory(), options.getTreebankDirectory(),
                    options.getFeatureSelectionThreshold(), options.getSMOTENeighborNumber(), annotatorClass,
                    options.getPrintOverlappingSpans(), annotatorTrainingArguments.get(annotatorClass));
            evaluation.prepareXMIsFor(patientSets);
            evaluation.setSkipTrain(options.getSkipTrain());
            evaluation.printErrors = options.getPrintErrors();
            if (options.getI2B2Output() != null)
                evaluation.setI2B2Output(options.getI2B2Output() + "/" + annotatorClass.getSimpleName());
            String name = String.format("%s.errors", annotatorClass.getSimpleName());
            evaluation.setLogging(Level.FINE, new File("target/eval", name));
            AnnotationStatistics<String> stats = evaluation.trainAndTest(allTrain, allTest);
            annotatorStats.put(annotatorClass, stats);
        }

        // allow ordering of models by F1
        Ordering<Class<? extends JCasAnnotator_ImplBase>> byF1 = Ordering.natural()
                .onResultOf(new Function<Class<? extends JCasAnnotator_ImplBase>, Double>() {
                    @Override
                    public Double apply(Class<? extends JCasAnnotator_ImplBase> annotatorClass) {
                        return annotatorStats.get(annotatorClass).f1();
                    }
                });

        // print out models, ordered by F1
        for (Class<?> annotatorClass : byF1.sortedCopy(annotatorClasses)) {
            System.err.printf("===== %s =====\n", annotatorClass.getSimpleName());
            System.err.println(annotatorStats.get(annotatorClass));
        }
    }

    private Class<? extends JCasAnnotator_ImplBase> annotatorClass;

    private String[] trainingArguments;

    private float featureSelectionThreshold;

    private float smoteNeighborNumber;

    private boolean skipTrain = false;

    public EvaluationOfTimeSpans(File baseDirectory, File rawTextDirectory, File xmlDirectory, XMLFormat xmlFormat,
            Subcorpus subcorpus, File xmiDirectory, File treebankDirectory, float featureSelectionThreshold,
            float numOfSmoteNeighbors, Class<? extends JCasAnnotator_ImplBase> annotatorClass,
            boolean printOverlapping, String[] trainingArguments) {
        super(baseDirectory, rawTextDirectory, xmlDirectory, xmlFormat, subcorpus, xmiDirectory, treebankDirectory,
                TimeMention.class);
        this.annotatorClass = annotatorClass;
        this.featureSelectionThreshold = featureSelectionThreshold;
        this.trainingArguments = trainingArguments;
        this.printOverlapping = printOverlapping;
        this.smoteNeighborNumber = numOfSmoteNeighbors;
    }

    public void setSkipTrain(boolean val) {
        this.skipTrain = val;
    }

    @Override
    public void train(CollectionReader reader, File directory) throws Exception {
        if (!skipTrain) {
            super.train(reader, directory);
        }
    }

    @Override
    protected AnalysisEngineDescription getDataWriterDescription(File directory)
            throws ResourceInitializationException {
        if (MetaTimeAnnotator.class.isAssignableFrom(this.annotatorClass)) {
            return MetaTimeAnnotator.getDataWriterDescription(CrfSuiteStringOutcomeDataWriter.class, directory);
        } else if (CleartkAnnotator.class.isAssignableFrom(this.annotatorClass)) {
            //limit feature selection only to TimeAnnotator
            if ("org.apache.ctakes.temporal.ae.TimeAnnotator".equals(this.annotatorClass.getName())) {
                Class<?> dataWriterClass = this.featureSelectionThreshold > 0f ? InstanceDataWriter.class
                        : LibLinearStringOutcomeDataWriter.class;
                return TimeAnnotator.createDataWriterDescription(dataWriterClass, this.getModelDirectory(directory),
                        this.featureSelectionThreshold, this.smoteNeighborNumber);
            }
            return AnalysisEngineFactory.createEngineDescription(this.annotatorClass,
                    CleartkAnnotator.PARAM_IS_TRAINING, true, DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
                    LibLinearStringOutcomeDataWriter.class, DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
                    this.getModelDirectory(directory));

        } else if (CleartkSequenceAnnotator.class.isAssignableFrom(this.annotatorClass)) {
            return AnalysisEngineFactory.createEngineDescription(this.annotatorClass,
                    CleartkSequenceAnnotator.PARAM_IS_TRAINING, true,
                    DefaultSequenceDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
                    CrfSuiteStringOutcomeDataWriter.class, DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
                    this.getModelDirectory(directory));
        } else {
            throw new ResourceInitializationException("Annotator class was not recognized as an acceptable class!",
                    new Object[] {});
        }
    }

    @Override
    protected void trainAndPackage(File directory) throws Exception {
        if (this.featureSelectionThreshold > 0
                && "org.apache.ctakes.temporal.ae.TimeAnnotator".equals(this.annotatorClass.getName())) {
            // Extracting features and writing instances
            Iterable<Instance<String>> instances = InstanceStream
                    .loadFromDirectory(this.getModelDirectory(directory));
            // Collect MinMax stats for feature normalization
            FeatureSelection<String> featureSelection = TimeAnnotator
                    .createFeatureSelection(this.featureSelectionThreshold);
            featureSelection.train(instances);
            featureSelection.save(TimeAnnotator.createFeatureSelectionURI(this.getModelDirectory(directory)));
            // now write in the libsvm format
            LibLinearStringOutcomeDataWriter dataWriter = new LibLinearStringOutcomeDataWriter(
                    this.getModelDirectory(directory));
            for (Instance<String> instance : instances) {
                dataWriter.write(featureSelection.transform(instance));
            }
            dataWriter.finish();
        }
        JarClassifierBuilder.trainAndPackage(this.getModelDirectory(directory), this.trainingArguments);
    }

    @Override
    protected AnalysisEngineDescription getAnnotatorDescription(File directory)
            throws ResourceInitializationException {
        if (MetaTimeAnnotator.class.isAssignableFrom(this.annotatorClass)) {
            return MetaTimeAnnotator.getAnnotatorDescription(directory);
        } else if ("org.apache.ctakes.temporal.ae.TimeAnnotator".equals(this.annotatorClass.getName())) {
            return TimeAnnotator.createAnnotatorDescription(this.getModelDirectory(directory));
        }
        return AnalysisEngineFactory.createEngineDescription(this.annotatorClass,
                CleartkAnnotator.PARAM_IS_TRAINING, false, GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
                new File(this.getModelDirectory(directory), "model.jar"));
    }

    @Override
    protected Collection<? extends Annotation> getGoldAnnotations(JCas jCas, Segment segment) {
        return selectExact(jCas, TimeMention.class, segment);
    }

    @Override
    protected Collection<? extends Annotation> getSystemAnnotations(JCas jCas, Segment segment) {
        return selectExact(jCas, TimeMention.class, segment);
    }

    private File getModelDirectory(File directory) {
        return new File(directory, this.annotatorClass.getSimpleName());
    }
}