edu.cmu.cs.lti.ark.fn.Semafor.java Source code

Java tutorial

Introduction

Here is the source code for edu.cmu.cs.lti.ark.fn.Semafor.java

Source

/*******************************************************************************
 * Copyright (c) 2012 Dipanjan Das
 * Language Technologies Institute,
 * Carnegie Mellon University,
 * All Rights Reserved.
 *
 * CommandLineOptions.java is part of SEMAFOR 2.1.
 *
 * SEMAFOR 2.1 is free software: you can redistribute it and/or modify  it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * SEMAFOR 2.1 is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with SEMAFOR 2.1.  If not, see <http://www.gnu.org/licenses/>.
 ******************************************************************************/
package edu.cmu.cs.lti.ark.fn;

import com.google.common.base.Charsets;
import com.google.common.base.Joiner;
import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import com.google.common.collect.Queues;
import com.google.common.io.Files;
import com.google.common.io.InputSupplier;
import com.google.common.io.OutputSupplier;
import com.google.common.primitives.Ints;
import edu.cmu.cs.lti.ark.fn.data.prep.formats.AllLemmaTags;
import edu.cmu.cs.lti.ark.fn.data.prep.formats.Sentence;
import edu.cmu.cs.lti.ark.fn.data.prep.formats.SentenceCodec;
import edu.cmu.cs.lti.ark.fn.data.prep.formats.Token;
import edu.cmu.cs.lti.ark.fn.evaluation.PrepareFullAnnotationJson;
import edu.cmu.cs.lti.ark.fn.identification.GraphBasedFrameIdentifier;
import edu.cmu.cs.lti.ark.fn.identification.RequiredDataForFrameIdentification;
import edu.cmu.cs.lti.ark.fn.parsing.*;
import edu.cmu.cs.lti.ark.fn.segmentation.RoteSegmenter;
import edu.cmu.cs.lti.ark.fn.utils.DataPointWithFrameElements;
import edu.cmu.cs.lti.ark.fn.utils.FNModelOptions;
import edu.cmu.cs.lti.ark.util.ds.Pair;
import edu.cmu.cs.lti.ark.util.ds.Range0Based;
import edu.cmu.cs.lti.ark.util.nlp.Lemmatizer;
import edu.cmu.cs.lti.ark.util.nlp.MorphaLemmatizer;
import edu.cmu.cs.lti.ark.util.nlp.parse.DependencyParse;
import edu.cmu.cs.lti.ark.util.nlp.parse.DependencyParses;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Writer;
import java.net.URISyntaxException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.*;

import static com.google.common.collect.ImmutableList.copyOf;
import static com.google.common.collect.Lists.transform;
import static edu.cmu.cs.lti.ark.fn.data.prep.formats.SentenceCodec.ConllCodec;
import static edu.cmu.cs.lti.ark.fn.evaluation.PrepareFullAnnotationJson.processPredictionLine;
import static edu.cmu.cs.lti.ark.fn.identification.FrameIdentificationRelease.getTokenRepresentation;
import static edu.cmu.cs.lti.ark.fn.parsing.DataPrep.SpanAndParseIdx;
import static edu.cmu.cs.lti.ark.util.SerializedObjects.readObject;
import static java.util.concurrent.Executors.newFixedThreadPool;
import static org.apache.commons.io.IOUtils.closeQuietly;

public class Semafor {
    public static final String REQUIRED_DATA_FILENAME = "reqData.jobj";
    private static final String ALPHABET_FILENAME = "parser.conf";
    private static final String FRAME_ELEMENT_MAP_FILENAME = "framenet.frame.element.map";
    private static final String ARG_MODEL_FILENAME = "argmodel.dat";

    private static final Joiner TAB = Joiner.on("\t");

    protected final Set<String> allRelatedWords;
    protected final FEDict frameElementsForFrame;
    protected final RoteSegmenter segmenter;
    protected final GraphBasedFrameIdentifier idModel;
    protected final Decoding decoder;
    protected final Map<String, Integer> argIdFeatureIndex;
    protected final Lemmatizer lemmatizer = new MorphaLemmatizer();

    /**
     * required flags:
     * model-dir
     * input-file
     * output-file
     */
    public static void main(String[] args) throws Exception {
        final FNModelOptions options = new FNModelOptions(args);
        final File inputFile = new File(options.inputFile.get());
        final File outputFile = new File(options.outputFile.get());
        final String modelDirectory = options.modelDirectory.get();
        final int numThreads = options.numThreads.present() ? options.numThreads.get() : 1;
        final Semafor semafor = getSemaforInstance(modelDirectory);
        semafor.runParser(Files.newReaderSupplier(inputFile, Charsets.UTF_8),
                Files.newWriterSupplier(outputFile, Charsets.UTF_8), numThreads);
    }

    public Semafor(Set<String> allRelatedWords, FEDict frameElementsForFrame, RoteSegmenter segmenter,
            GraphBasedFrameIdentifier idModel, Decoding decoder, Map<String, Integer> argIdFeatureIndex) {
        this.allRelatedWords = allRelatedWords;
        this.frameElementsForFrame = frameElementsForFrame;
        this.segmenter = segmenter;
        this.idModel = idModel;
        this.decoder = decoder;
        this.argIdFeatureIndex = argIdFeatureIndex;
    }

    public static Semafor getSemaforInstance(String modelDirectory)
            throws IOException, ClassNotFoundException, URISyntaxException {
        final String requiredDataFilename = new File(modelDirectory, REQUIRED_DATA_FILENAME).getAbsolutePath();
        final String alphabetFilename = new File(modelDirectory, ALPHABET_FILENAME).getAbsolutePath();
        final String frameElementMapFilename = new File(modelDirectory, FRAME_ELEMENT_MAP_FILENAME)
                .getAbsolutePath();
        final String argModelFilename = new File(modelDirectory, ARG_MODEL_FILENAME).getAbsolutePath();
        // unpack required data
        final RequiredDataForFrameIdentification r = readObject(requiredDataFilename);
        final Set<String> allRelatedWords = r.getAllRelatedWords();
        final GraphBasedFrameIdentifier idModel = GraphBasedFrameIdentifier.getInstance(modelDirectory);
        final RoteSegmenter segmenter = new RoteSegmenter(allRelatedWords);
        System.err.println("Initializing alphabet for argument identification..");
        final Map<String, Integer> argIdFeatureIndex = DataPrep.readFeatureIndex(new File(alphabetFilename));
        final FEDict frameElementsForFrame = FEDict.fromFile(frameElementMapFilename);
        final Decoding decoder = Decoding.fromFile(argModelFilename, alphabetFilename);
        return new Semafor(allRelatedWords, frameElementsForFrame, segmenter, idModel, decoder, argIdFeatureIndex);
    }

    /**
     * Reads conll sentences, parses them, and writes the json-serialized results.
     *
     * @param inputSupplier where to read conll sentences from
     * @param outputSupplier where to write the results to
     * @param numThreads the number of threads to use
     * @throws IOException
     * @throws InterruptedException
     */
    public void runParser(final InputSupplier<? extends Readable> inputSupplier,
            final OutputSupplier<? extends Writer> outputSupplier, final int numThreads)
            throws IOException, InterruptedException {
        // use the producer-worker-consumer pattern to parse all sentences in multiple threads, while keeping
        // output in order.
        final BlockingQueue<Future<Optional<SemaforParseResult>>> results = Queues
                .newLinkedBlockingDeque(5 * numThreads);
        final ExecutorService workerThreadPool = newFixedThreadPool(numThreads);
        // try to shutdown gracefully. don't worry too much if it doesn't work
        Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
            @Override
            public void run() {
                try {
                    workerThreadPool.shutdown();
                    workerThreadPool.awaitTermination(5, TimeUnit.SECONDS);
                } catch (InterruptedException ignored) {
                }
            }
        }));

        final PrintWriter output = new PrintWriter(outputSupplier.getOutput());
        try {
            // Start thread to fetch computed results and write to file
            final Thread consumer = new Thread(new Runnable() {
                @Override
                public void run() {
                    while (!Thread.currentThread().isInterrupted()) {
                        try {
                            final Optional<SemaforParseResult> oResult = results.take().get();
                            if (!oResult.isPresent())
                                break; // got poison pill. we're done
                            output.println(oResult.get().toJson());
                            output.flush();
                        } catch (Exception e) {
                            e.printStackTrace();
                            throw new RuntimeException(e);
                        }
                    }
                }
            });
            consumer.start();
            // in main thread, put placeholders on results queue (so results stay in order), then
            // tell a worker thread to fill up the placeholder
            final SentenceCodec.SentenceIterator sentences = ConllCodec.readInput(inputSupplier.getInput());
            try {
                int i = 0;
                while (sentences.hasNext()) {
                    final Sentence sentence = sentences.next();
                    final int sentenceId = i;
                    results.put(workerThreadPool.submit(new Callable<Optional<SemaforParseResult>>() {
                        @Override
                        public Optional<SemaforParseResult> call() throws Exception {
                            final long start = System.currentTimeMillis();
                            try {
                                final SemaforParseResult result = parseSentence(sentence);
                                final long end = System.currentTimeMillis();
                                System.err.printf("parsed sentence %d in %d millis.%n", sentenceId, end - start);
                                return Optional.of(result);
                            } catch (Exception e) {
                                e.printStackTrace();
                                throw e;
                            }
                        }
                    }));
                    i++;
                }
                // put a poison pill on the queue to signal that we're done
                results.put(workerThreadPool.submit(new Callable<Optional<SemaforParseResult>>() {
                    @Override
                    public Optional<SemaforParseResult> call() throws Exception {
                        return Optional.absent();
                    }
                }));
                workerThreadPool.shutdown();
            } finally {
                closeQuietly(sentences);
            }
            // wait for consumer to finish
            consumer.join();
        } finally {
            closeQuietly(output);
        }
        System.err.println("Done.");
    }

    public SemaforParseResult parseSentence(Sentence unLemmatizedSentence) throws IOException {
        // look up lemmas
        final Sentence sentence = addLemmas(unLemmatizedSentence);
        // find targets
        final List<List<Integer>> segments = predictTargets(sentence);
        // frame identification
        final List<Pair<List<Integer>, String>> idResult = predictFrames(sentence, segments);
        // argument identification
        return predictArguments(sentence, idResult);
    }

    public List<List<Integer>> predictTargets(Sentence sentence) {
        return segmenter.getSegmentation(sentence);
    }

    public List<Pair<List<Integer>, String>> predictFrames(Sentence sentence, List<List<Integer>> targets) {
        final List<Pair<List<Integer>, String>> idResult = Lists.newArrayList();
        for (List<Integer> targetTokenIdxs : targets) {
            final String frame = idModel.getBestFrame(targetTokenIdxs, sentence);
            idResult.add(Pair.of(targetTokenIdxs, frame));
        }
        return idResult;
    }

    public SemaforParseResult predictArguments(Sentence sentence, List<Pair<List<Integer>, String>> idResults)
            throws IOException {
        final List<String> idResultLines = getArgumentIdInput(sentence, idResults);
        final List<String> argResult = predictArgumentLines(sentence, idResultLines, 1);
        return getSemaforParseResult(sentence, argResult);
    }

    /**
     * Convert to the weird format that {@link #predictArgumentLines} expects.
     * @param sentence the input sentence
     * @param idResults a list of (target, frame) pairs
     * @return a list of strings in the format that {@link #predictArgumentLines} expects.
     */
    public List<String> getArgumentIdInput(Sentence sentence, List<Pair<List<Integer>, String>> idResults) {
        final List<String> idResultLines = Lists.newArrayList();
        final String parseLine = AllLemmaTags.makeLine(sentence.toAllLemmaTagsArray());
        for (Pair<List<Integer>, String> targetAndFrame : idResults) {
            final List<Integer> targetTokenIdxs = targetAndFrame.first;
            final String frame = targetAndFrame.second;
            final String tokenIdxsStr = Joiner.on("_").join(targetTokenIdxs);
            final Pair<String, String> tokenRepresentation = getTokenRepresentation(tokenIdxsStr, parseLine);
            final String lexicalUnit = tokenRepresentation.first;
            final String tokenStrs = tokenRepresentation.second;
            idResultLines.add(TAB.join(0, 1.0, 1, frame, lexicalUnit, tokenIdxsStr, tokenStrs, 0));
        }
        return idResultLines;
    }

    public List<String> predictArgumentLines(Sentence sentence, List<String> idResult, int kBest)
            throws IOException {
        final List<FrameFeatures> frameFeaturesList = Lists.newArrayList();
        final FeatureExtractor featureExtractor = new FeatureExtractor();
        for (String feLine : idResult) {
            final DataPointWithFrameElements dataPoint = new DataPointWithFrameElements(sentence, feLine);
            final String frame = dataPoint.getFrameName();
            final DependencyParses parses = dataPoint.getParses();
            final int targetStartTokenIdx = dataPoint.getTargetTokenIdxs()[0];
            final int targetEndTokenIdx = dataPoint.getTargetTokenIdxs()[dataPoint.getTargetTokenIdxs().length - 1];
            final List<SpanAndParseIdx> spans = DataPrep.findSpans(dataPoint, 1);
            final List<String> frameElements = Lists.newArrayList(frameElementsForFrame.lookupFrameElements(frame));
            final List<SpanAndCorrespondingFeatures[]> featuresAndSpanByArgument = Lists.newArrayList();
            for (String frameElement : frameElements) {
                final List<SpanAndCorrespondingFeatures> spansAndFeatures = Lists.newArrayList();
                for (SpanAndParseIdx candidateSpanAndParseIdx : spans) {
                    final Range0Based span = candidateSpanAndParseIdx.span;
                    final DependencyParse parse = parses.get(candidateSpanAndParseIdx.parseIdx);
                    final Set<String> featureSet = featureExtractor
                            .extractFeatures(dataPoint, frame, frameElement, span, parse).elementSet();
                    final int[] featArray = convertToIdxs(featureSet);
                    spansAndFeatures
                            .add(new SpanAndCorrespondingFeatures(new int[] { span.start, span.end }, featArray));
                }
                featuresAndSpanByArgument
                        .add(spansAndFeatures.toArray(new SpanAndCorrespondingFeatures[spansAndFeatures.size()]));
            }
            frameFeaturesList.add(new FrameFeatures(frame, targetStartTokenIdx, targetEndTokenIdx, frameElements,
                    featuresAndSpanByArgument));
        }
        return decoder.decodeAll(frameFeaturesList, idResult, 0, kBest);
    }

    private int[] convertToIdxs(Iterable<String> featureSet) {
        // convert feature names to feature indexes
        final List<Integer> featureList = Lists.newArrayList();
        for (String feature : featureSet) {
            final Integer idx = argIdFeatureIndex.get(feature);
            if (idx != null) {
                featureList.add(idx);
            }
        }
        return Ints.toArray(featureList);
    }

    public SemaforParseResult getSemaforParseResult(Sentence sentence, List<String> results) {
        final List<RankedScoredRoleAssignment> roleAssignments = copyOf(transform(results, processPredictionLine));
        List<String> tokens = Lists.newArrayListWithExpectedSize(sentence.size());
        for (Token token : sentence.getTokens()) {
            tokens.add(token.getForm());
        }
        return PrepareFullAnnotationJson.getSemaforParse(roleAssignments, tokens);
    }

    public Sentence addLemmas(Sentence sentence) {
        return lemmatizer.addLemmas(sentence);
    }
}