edu.cmu.cs.lti.ark.fn.identification.latentmodel.LatentAlphabetCreationThreaded.java Source code

Java tutorial

Introduction

Here is the source code for edu.cmu.cs.lti.ark.fn.identification.latentmodel.LatentAlphabetCreationThreaded.java

Source

/*******************************************************************************
 * Copyright (c) 2011 Dipanjan Das 
 * Language Technologies Institute, 
 * Carnegie Mellon University, 
 * All Rights Reserved.
 *
 * AlphabetCreationThreaded.java is part of SEMAFOR 2.0.
 *
 * SEMAFOR 2.0 is free software: you can redistribute it and/or modify  it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or 
 * (at your option) any later version.
 *
 * SEMAFOR 2.0 is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details. 
 *
 * You should have received a copy of the GNU General Public License along
 * with SEMAFOR 2.0.  If not, see <http://www.gnu.org/licenses/>.
 ******************************************************************************/
package edu.cmu.cs.lti.ark.fn.identification.latentmodel;

import com.google.common.base.Charsets;
import com.google.common.collect.Sets;
import com.google.common.io.Files;
import edu.cmu.cs.lti.ark.fn.data.prep.formats.AllLemmaTags;
import edu.cmu.cs.lti.ark.fn.identification.RequiredDataForFrameIdentification;
import edu.cmu.cs.lti.ark.fn.utils.FNModelOptions;
import edu.cmu.cs.lti.ark.fn.wordnet.CachedRelations;
import edu.cmu.cs.lti.ark.fn.wordnet.Relations;
import edu.cmu.cs.lti.ark.util.SerializedObjects;
import edu.cmu.cs.lti.ark.util.ds.map.IntCounter;
import edu.cmu.cs.lti.ark.util.nlp.CachedLemmatizer;
import edu.cmu.cs.lti.ark.util.nlp.Lemmatizer;
import edu.cmu.cs.lti.ark.util.nlp.parse.DependencyParse;
import gnu.trove.THashMap;
import gnu.trove.THashSet;
import org.apache.commons.io.FileUtils;

import java.io.*;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.logging.*;

import static java.util.concurrent.Executors.newFixedThreadPool;
import static org.apache.commons.io.IOUtils.closeQuietly;

public class LatentAlphabetCreationThreaded {
    private static final Logger logger = Logger.getLogger(LatentAlphabetCreationThreaded.class.getCanonicalName());

    private static final int BATCH_SIZE = 100;
    private static int EXPECTED_NUM_FEATURES = 3000000;
    public static final String ALPHABET_FILENAME = "alphabet.dat";
    final static FilenameFilter LOCAL_ALPHABET_FILENAME_FILTER = new FilenameFilter() {
        public boolean accept(File dir, String filename) {
            return filename.startsWith(ALPHABET_FILENAME + "_");
        }
    };
    private final THashMap<String, THashSet<String>> frameMap;
    private final String parseFile;
    private final File alphabetDir;
    private final String frameElementsFile;
    private final int startIndex;
    private final int endIndex;
    private final int numThreads;
    private LatentFeatureExtractor featureExtractor;

    public static void main(String[] args) throws IOException, ClassNotFoundException {
        final FNModelOptions options = new FNModelOptions(args);

        LogManager.getLogManager().reset();
        final FileHandler fileHandler = new FileHandler(options.logOutputFile.get(), true);
        fileHandler.setFormatter(new SimpleFormatter());
        logger.addHandler(fileHandler);

        final int startIndex = options.startIndex.get();
        final int endIndex = options.endIndex.get();
        logger.info("Start:" + startIndex + " end:" + endIndex);
        final RequiredDataForFrameIdentification r = SerializedObjects.readObject(options.fnIdReqDataFile.get());
        final Relations wnRelations = new CachedRelations(r.getRevisedRelMap(), r.getRelatedWordsForWord());
        final Lemmatizer lemmatizer = new CachedLemmatizer(r.getHvLemmaCache());
        final int numThreads = options.numThreads.present() ? options.numThreads.get()
                : Runtime.getRuntime().availableProcessors();
        final File alphabetDir = new File(options.modelFile.get());
        final LatentAlphabetCreationThreaded events = new LatentAlphabetCreationThreaded(alphabetDir,
                options.trainFrameElementFile.get(), options.trainParseFile.get(), r.getFrameMap(),
                new LatentFeatureExtractor(wnRelations, lemmatizer), startIndex, endIndex, numThreads);
        events.createLocalAlphabets();
        combineAlphabets(alphabetDir);
    }

    /**
     * Creates a new AlphabetCreationThreaded with the given arguments
     *
     * @param alphabetDir        the file prefix to which to write the resulting alphabet
     * @param frameElementsFile   path to file containing gold standard frame element
     *                            annotations
     * @param parseFile           path to file containing dependency parsed sentences (the same
    *                            ones that are in frameElementsFile
     * @param frameMap            a map from frames to a set of disambiguated predicates
    *                            (words along with part of speech tags, but in the style of FrameNet)
     * @param startIndex          the line of the frameElementsFile to start at
     * @param endIndex            the line of frameElementsFile to end at
     * @param numThreads          the number of threads to run
     */
    public LatentAlphabetCreationThreaded(File alphabetDir, String frameElementsFile, String parseFile,
            THashMap<String, THashSet<String>> frameMap, LatentFeatureExtractor featureExtractor, int startIndex,
            int endIndex, int numThreads) {
        this.alphabetDir = alphabetDir;
        this.frameElementsFile = frameElementsFile;
        this.parseFile = parseFile;
        this.frameMap = frameMap;
        this.startIndex = startIndex;
        this.endIndex = endIndex;
        this.numThreads = numThreads;
        this.featureExtractor = featureExtractor;
    }

    /**
     * Splits frameElementLines into numThreads equally-sized batches and creates an alphabet
     * file for each one.
     *
     * @throws java.io.IOException
     */
    public void createLocalAlphabets() throws IOException {
        final List<String> frameLines = Files.readLines(new File(frameElementsFile), Charsets.UTF_8)
                .subList(startIndex, endIndex);
        final List<String> parseLines = Files.readLines(new File(parseFile), Charsets.UTF_8);

        final ExecutorService threadPool = newFixedThreadPool(numThreads);
        int i = 0;
        for (int start = startIndex; start < endIndex; start += BATCH_SIZE) {
            final int threadId = i;
            final List<String> frameLineBatch = frameLines.subList(start,
                    Math.min(frameLines.size(), start + BATCH_SIZE));
            threadPool.execute(new Runnable() {
                public void run() {
                    logger.info("Thread " + threadId + " : start");
                    processBatch(threadId, frameLineBatch, parseLines);
                    logger.info("Thread " + threadId + " : end");
                }
            });
            i++;
        }
        threadPool.shutdown();
    }

    private void processBatch(int threadId, List<String> frameLines, List<String> parseLines) {
        final Set<String> alphabet = Sets.newHashSet();
        for (int i = 0; i < frameLines.size(); i++) {
            processLine(frameLines.get(i), parseLines, alphabet);
            if (i % 10 == 0) {
                logger.info("Thread " + threadId + "\n" + "Processed index:" + i + " of " + frameLines.size() + "\n"
                        + "Alphabet size:" + alphabet.size());
            }
        }
        try {
            writeAlphabetFile(alphabet, alphabetDir + "_" + threadId);
        } catch (IOException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
    }

    private void processLine(String frameLine, List<String> parseLines, Set<String> alphabet) {
        // Parse the frameLine
        final String[] toks = frameLine.split("\t");
        // throw out first two fields
        final List<String> tokens = Arrays.asList(toks).subList(2, toks.length);
        final String frameName = tokens.get(1);
        final String[] targetIdxsStr = tokens.get(3).split("_");
        final int sentNum = Integer.parseInt(tokens.get(5));

        final int[] targetTokenIdxs = new int[targetIdxsStr.length];
        for (int j = 0; j < targetIdxsStr.length; j++)
            targetTokenIdxs[j] = Integer.parseInt(targetIdxsStr[j]);
        Arrays.sort(targetTokenIdxs);

        // Parse the parse line
        final String parseLine = parseLines.get(sentNum);
        final String[][] data = AllLemmaTags.readLine(parseLine);

        // extract features for every frame
        for (String frame : frameMap.keySet()) {
            extractFeaturesAndAddToAlphabet(frame, targetTokenIdxs, data, alphabet);
        }
    }

    private void extractFeaturesAndAddToAlphabet(String frame, int[] targetTokenIdxs, String[][] allLemmaTags,
            Set<String> alphabet) {
        final THashSet<String> hiddenUnits = frameMap.get(frame);
        final DependencyParse parse = DependencyParse.processFN(allLemmaTags, 0.0);
        for (String unit : hiddenUnits) {
            final IntCounter<String> valMap = featureExtractor.extractFeatures(frame, targetTokenIdxs, unit,
                    allLemmaTags, parse, true);
            alphabet.addAll(valMap.keySet());
        }
    }

    /**
     * Combines the multiple alphabet files created by createLocalAlphabets into one
     * alphabet file
     */
    public static void combineAlphabets(File alphabetDir) throws IOException {
        final String[] files = alphabetDir.list(LOCAL_ALPHABET_FILENAME_FILTER);
        final Set<String> alphabet = Sets.newHashSet();
        for (String file : files) {
            final String path = alphabetDir.getAbsolutePath() + "/" + file;
            if (logger.isLoggable(Level.INFO))
                logger.info("reading path: " + path);
            final Map<String, Integer> localAlphabet = readAlphabetFile(path);
            alphabet.addAll(localAlphabet.keySet());
        }
        writeAlphabetFile(alphabet, alphabetDir.getAbsolutePath() + "/" + ALPHABET_FILENAME);
    }

    public static Map<String, Integer> readAlphabetFile(String filename) throws IOException {
        final BufferedReader bReader = new BufferedReader(new FileReader(filename));
        try {
            final Map<String, Integer> alphabet = new THashMap<String, Integer>(EXPECTED_NUM_FEATURES);
            String line;
            int i = 0;
            while ((line = bReader.readLine()) != null) {
                alphabet.put(line.trim(), i);
                i++;
            }
            return alphabet;
        } finally {
            closeQuietly(bReader);
        }
    }

    public static void writeAlphabetFile(Set<String> alphabet, String filename) throws IOException {
        FileUtils.writeLines(new File(filename), alphabet);
    }
}