de.tudarmstadt.ukp.csniper.webapp.evaluation.MlPipeline.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.csniper.webapp.evaluation.MlPipeline.java

Source

/*******************************************************************************
 * Copyright 2013
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ******************************************************************************/
package de.tudarmstadt.ukp.csniper.webapp.evaluation;

import static java.util.Collections.singleton;
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createPrimitive;
import static org.apache.uima.fit.factory.TypeSystemDescriptionFactory.createTypeSystemDescription;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EmptyStackException;
import java.util.List;
import java.util.Set;

import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.SystemUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.uima.UIMAException;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CAS;
import org.apache.uima.cas.CASException;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.util.CasCreationUtils;
import org.cleartk.classifier.CleartkProcessingException;
import org.cleartk.classifier.DataWriter;
import org.cleartk.classifier.Feature;
import org.cleartk.classifier.Instance;
import org.cleartk.classifier.jar.DirectoryDataWriterFactory;
import org.cleartk.classifier.jar.Train;

import com.google.common.io.Files;

import de.tudarmstadt.ukp.csniper.ml.DummySentenceSplitter;
import de.tudarmstadt.ukp.csniper.ml.GoldFromMetadataAnnotator;
import de.tudarmstadt.ukp.csniper.ml.TKSVMlightFeatureExtractor;
import de.tudarmstadt.ukp.csniper.ml.tksvm.DefaultTKSVMlightDataWriterFactory;
import de.tudarmstadt.ukp.csniper.ml.tksvm.TKSVMlightDataWriter;
import de.tudarmstadt.ukp.csniper.ml.tksvm.TKSVMlightSequenceClassifier;
import de.tudarmstadt.ukp.csniper.ml.tksvm.TKSVMlightSequenceClassifierBuilder;
import de.tudarmstadt.ukp.csniper.ml.tksvm.TreeFeatureVector;
import de.tudarmstadt.ukp.csniper.webapp.evaluation.model.CachedParse;
import de.tudarmstadt.ukp.csniper.webapp.evaluation.model.EvaluationItem;
import de.tudarmstadt.ukp.csniper.webapp.evaluation.model.EvaluationResult;
import de.tudarmstadt.ukp.csniper.webapp.evaluation.model.Mark;
import de.tudarmstadt.ukp.csniper.webapp.project.model.AnnotationType;
import de.tudarmstadt.ukp.csniper.webapp.search.tgrep.PennTreeUtils;
import de.tudarmstadt.ukp.csniper.webapp.statistics.SortableAggregatedEvaluationResultDataProvider.ResultFilter;
import de.tudarmstadt.ukp.csniper.webapp.statistics.model.AggregatedEvaluationResult;
import de.tudarmstadt.ukp.csniper.webapp.support.task.Task;
import de.tudarmstadt.ukp.csniper.webapp.support.uima.AnalysisEngineFactory;
import de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData;
import de.tudarmstadt.ukp.dkpro.core.api.syntax.type.PennTree;

public class MlPipeline {
    private static Log LOG = LogFactory.getLog(MlPipeline.class);

    // private static final String LANGUAGE = "en";
    private static final Double THRESHOLD = 0.0;

    private String language;

    private AnalysisEngine gold;
    private AnalysisEngine sent;
    private AnalysisEngine tok;
    private AnalysisEngine parser;

    private EvaluationRepository repository;

    private Task task;

    public MlPipeline(String aLanguage) throws ResourceInitializationException {
        language = aLanguage;
        gold = createPrimitive(GoldFromMetadataAnnotator.class);
        sent = createPrimitive(DummySentenceSplitter.class);
        tok = AnalysisEngineFactory.createAnalysisEngine(AnalysisEngineFactory.SEGMENTER, "language", aLanguage,
                "createSentences", false);
        parser = AnalysisEngineFactory.createAnalysisEngine(AnalysisEngineFactory.PARSER, "language", aLanguage);
    }

    public void setRepostitory(EvaluationRepository aRepostitory) {
        repository = aRepostitory;
    }

    public void setTask(Task aTask) {
        task = aTask;
    }

    public String parse(EvaluationResult result, CAS cas) throws UIMAException {
        // get parse from db, or parse now
        String pennTree = "";
        CachedParse cp = repository.getCachedParse(result.getItem());
        if (cp != null && !cp.getPennTree().isEmpty()) {
            if ("ERROR".equals(cp.getPennTree())) {
                System.out.println("Unable to parse: [" + result.getItem().getCoveredText() + "] (cached)");
                return "";
            }
            // write existing parse to cas for extraction
            pennTree = cp.getPennTree();
            addPennTree(cas, cp.getPennTree());
        } else {
            parser.process(cas);
            try {
                pennTree = StringUtils
                        .normalizeSpace(JCasUtil.selectSingle(cas.getJCas(), PennTree.class).getPennTree());
                repository.writeCachedParse(new CachedParse(result.getItem(), pennTree));
            } catch (IllegalArgumentException e) {
                System.out.println("Unable to parse: [" + result.getItem().getCoveredText() + "]");
                repository.writeCachedParse(new CachedParse(result.getItem(), "ERROR"));
            }
        }

        return pennTree;
    }

    public void createTrainingData(File aModelDir, List<EvaluationResult> aTrainingList)
            throws UIMAException, IOException {
        AnalysisEngine extract = createPrimitive(TKSVMlightFeatureExtractor.class,
                DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY, aModelDir.getAbsolutePath(),
                TKSVMlightFeatureExtractor.PARAM_DATA_WRITER_FACTORY_CLASS_NAME,
                DefaultTKSVMlightDataWriterFactory.class.getName());

        ProgressMeter progress = new ProgressMeter(aTrainingList.size());
        // extract features
        CAS cas = CasCreationUtils.createCas(createTypeSystemDescription(), null, null);
        for (EvaluationResult result : aTrainingList) {
            // add gold annotation
            DocumentMetaData.create(cas).setDocumentTitle(result.getResult());
            // set doc text
            cas.setDocumentText(result.getItem().getCoveredText());
            // set language
            cas.setDocumentLanguage(language);

            // convert gold annotations
            gold.process(cas);
            // preprocessing
            sent.process(cas);
            tok.process(cas);
            // get parse from db, or parse now
            parse(result, cas);
            // extract features
            extract.process(cas);
            cas.reset();
            progress.next();
            LOG.info(progress);
            if (task != null) {
                task.increment();
                task.checkCanceled();
            }
        }
        extract.collectionProcessComplete();
    }

    public void classify(File aModelDir, List<EvaluationResult> aToPredictList) throws IOException, UIMAException {
        TKSVMlightSequenceClassifierBuilder builder = new TKSVMlightSequenceClassifierBuilder();
        TKSVMlightSequenceClassifier classifier = builder.loadClassifierFromTrainingDirectory(aModelDir);
        File cFile = File.createTempFile("tkclassify", ".txt");

        BufferedWriter bw = null;
        try {
            bw = new BufferedWriter(new FileWriter(cFile));

            // predict unclassified
            CAS cas = CasCreationUtils.createCas(createTypeSystemDescription(), null, null);
            ProgressMeter progress = new ProgressMeter(aToPredictList.size());
            for (EvaluationResult result : aToPredictList) {
                cas.setDocumentText(result.getItem().getCoveredText());
                cas.setDocumentLanguage(language);

                // dummy sentence split
                sent.process(cas);

                // tokenize
                tok.process(cas);

                // get parse from db, or parse now
                String pennTree = parse(result, cas);

                // write tree to file
                Feature tree = new Feature("TK_tree", StringUtils.normalizeSpace(pennTree));
                TreeFeatureVector tfv = classifier.getFeaturesEncoder().encodeAll(Arrays.asList(tree));
                try {
                    bw.write("0");
                    bw.write(TKSVMlightDataWriter.createString(tfv));
                    bw.write(SystemUtils.LINE_SEPARATOR);
                } catch (IOException e) {
                    throw new AnalysisEngineProcessException(e);
                }
                cas.reset();
                progress.next();
                LOG.info(progress);
                if (task != null) {
                    task.increment();
                    task.checkCanceled();
                }
            }
        } finally {
            IOUtils.closeQuietly(bw);
        }

        // classify all
        List<Double> predictions = classifier.tkSvmLightPredict2(cFile);

        if (predictions.size() != aToPredictList.size()) {
            // TODO throw different exception instead
            throw new IOException("there are [" + predictions.size() + "] predictions, but ["
                    + aToPredictList.size() + "] were expected.");
        }

        for (int i = 0; i < aToPredictList.size(); i++) {
            Mark m = (predictions.get(i) > THRESHOLD) ? Mark.PRED_CORRECT : Mark.PRED_WRONG;
            aToPredictList.get(i).setResult(m.getTitle());
        }
    }

    public void predict(List<EvaluationResult> aTrainingList, List<EvaluationResult> aToPredictList)
            throws UIMAException, IOException {
        if (aTrainingList.size() == 0) {
            return;
        }

        if (task != null) {
            task.setTotal(aTrainingList.size() + aToPredictList.size());
        }

        // create temp dir for model files
        File modelDir = Files.createTempDir();
        createTrainingData(modelDir, aTrainingList);

        // train model
        try {
            Train.main(modelDir.getPath(), "-t", "5", "-c", "1.0", "-C", "+");
        } catch (Exception e) {
            throw new UIMAException(e);
        }

        // classify
        classify(modelDir, aToPredictList);
    }

    public boolean predict(List<EvaluationResult> aResults, int aMinItemsAnnotated)
            throws UIMAException, IOException {
        // split results in annotated and empty
        List<EvaluationResult> annotated = new ArrayList<EvaluationResult>();
        List<EvaluationResult> empty = new ArrayList<EvaluationResult>();

        for (EvaluationResult result : aResults) {
            Mark m = Mark.fromString(result.getResult());
            switch (m) {
            case CORRECT:
            case WRONG:
                annotated.add(result);
                break;
            case NA:
            case PRED_CORRECT:
            case PRED_WRONG:
                empty.add(result);
                break;
            default:
                // CHECK
                break;
            }
        }

        // exit, if not enough items have been annotated
        // TODO differentiate between correct/wrong?
        // i.e. ensure the user to at least have X correct and X wrong items before predicting?
        // a classifier trained only on "correct"s will not issue "wrong"s for anything, etc.
        if (annotated.size() < aMinItemsAnnotated) {
            return false;
        }
        predict(annotated, empty);

        return true;
    }

    public boolean predictAggregated(List<EvaluationResult> aResults, String aCollectionId, AnnotationType aType,
            Set<String> aUsers, double aUserThreshold, double aConfidenceThreshold)
            throws UIMAException, IOException {
        // get aggregated results
        List<AggregatedEvaluationResult> aggregatedResults = repository.listAggregatedResults(
                singleton(aCollectionId), singleton(aType), aUsers, aUserThreshold, aConfidenceThreshold);

        if (aggregatedResults.isEmpty()) {
            return false;
        }

        // create training list
        List<EvaluationResult> trainingList = convertToSimple(aggregatedResults);

        // create toPredict list
        List<EvaluationResult> toPredict = new ArrayList<EvaluationResult>();
        for (EvaluationResult er : aResults) {
            Mark result = Mark.fromString(er.getResult());
            if (result != Mark.CORRECT && result != Mark.WRONG) {
                toPredict.add(er);
            }
        }

        predict(trainingList, toPredict);

        return true;
    }

    private void addPennTree(CAS aCas, String aPennTree) throws CASException {
        PennTree tree = new PennTree(aCas.getJCas(), 0, aCas.getDocumentText().length());
        tree.setPennTree(aPennTree);
        tree.addToIndexes();
    }

    public static List<EvaluationResult> convertToSimple(List<AggregatedEvaluationResult> aAgg) {
        // create training list
        List<EvaluationResult> trainingList = new ArrayList<EvaluationResult>();
        for (AggregatedEvaluationResult aer : aAgg) {
            ResultFilter aggregated = aer.getClassification();
            if (aggregated == ResultFilter.CORRECT || aggregated == ResultFilter.WRONG) {
                trainingList.add(new EvaluationResult(aer.getItem(), "__dummy__", aggregated.getLabel()));
            }
        }

        return trainingList;
    }

    public static File train(List<EvaluationResult> aTrainingList, EvaluationRepository aRepository)
            throws IOException, CleartkProcessingException {
        File modelDir = Files.createTempDir();
        DefaultTKSVMlightDataWriterFactory dataWriterFactory = new DefaultTKSVMlightDataWriterFactory();
        dataWriterFactory.setOutputDirectory(modelDir);
        DataWriter<Boolean> dataWriter = dataWriterFactory.createDataWriter();

        for (EvaluationResult result : aTrainingList) {
            CachedParse cp = aRepository.getCachedParse(result.getItem());
            if (cp == null || cp.getPennTree().isEmpty() || "ERROR".equals(cp.getPennTree())) {
                System.out.println("Unable to parse: [" + result.getItem().getCoveredText() + "] (cached)");
                continue;
            }

            Instance<Boolean> instance = new Instance<Boolean>();
            instance.add(new Feature("TK_tree", StringUtils.normalizeSpace(cp.getPennTree())));
            instance.setOutcome(Mark.fromString(result.getResult()) == Mark.CORRECT);
            dataWriter.write(instance);
        }

        dataWriter.finish();

        // train model
        try {
            Train.main(modelDir.getPath(), "-t", "5", "-c", "1.0", "-C", "+");
        } catch (Exception e) {
            throw new CleartkProcessingException(e);
        }

        return modelDir;
    }

    /**
     * Mind this method may return less results than parses were passed to it, e.g. because a 
     * cached parse may be empty or "ERROR" in which case no result for it is generated!
     */
    public static List<EvaluationResult> classifyPreParsed(File aModelDir, List<CachedParse> aParses, String aType,
            String aUser) throws IOException, UIMAException {
        TKSVMlightSequenceClassifierBuilder builder = new TKSVMlightSequenceClassifierBuilder();
        TKSVMlightSequenceClassifier classifier = builder.loadClassifierFromTrainingDirectory(aModelDir);
        File cFile = File.createTempFile("tkclassify", ".txt");

        List<EvaluationItem> items = new ArrayList<EvaluationItem>();
        BufferedWriter bw = null;
        try {
            bw = new BufferedWriter(new FileWriter(cFile));

            for (CachedParse parse : aParses) {
                if (parse.getPennTree().isEmpty() || "ERROR".equals(parse.getPennTree())) {
                    continue;
                }

                String coveredText;
                try {
                    coveredText = PennTreeUtils.toText(parse.getPennTree());
                } catch (EmptyStackException e) {
                    LOG.error("Invalid Penn Tree: [" + parse.getPennTree() + "]", e);
                    continue;
                }

                // Prepare evaluation item to return
                EvaluationItem item = new EvaluationItem();
                item.setType(aType);
                item.setBeginOffset(parse.getBeginOffset());
                item.setEndOffset(parse.getEndOffset());
                item.setDocumentId(parse.getDocumentId());
                item.setCollectionId(parse.getCollectionId());
                item.setCoveredText(coveredText);
                items.add(item);

                // write tree to file
                Feature tree = new Feature("TK_tree", StringUtils.normalizeSpace(parse.getPennTree()));
                TreeFeatureVector tfv = classifier.getFeaturesEncoder().encodeAll(Arrays.asList(tree));

                bw.write("0");
                bw.write(TKSVMlightDataWriter.createString(tfv));
                bw.write(SystemUtils.LINE_SEPARATOR);
            }
        } catch (IOException e) {
            throw new AnalysisEngineProcessException(e);
        } finally {
            IOUtils.closeQuietly(bw);
        }

        // classify all
        List<Double> predictions = classifier.tkSvmLightPredict2(cFile);

        if (predictions.size() != items.size()) {
            // TODO throw different exception instead
            throw new IOException("there are [" + predictions.size() + "] predictions, but [" + items.size()
                    + "] were expected.");
        }

        List<EvaluationResult> results = new ArrayList<EvaluationResult>();
        for (EvaluationItem item : items) {
            results.add(new EvaluationResult(item, aUser, ""));
        }

        for (int i = 0; i < results.size(); i++) {
            Mark m = (predictions.get(i) > THRESHOLD) ? Mark.PRED_CORRECT : Mark.PRED_WRONG;
            results.get(i).setResult(m.getTitle());
        }

        return results;
    }
}