de.tudarmstadt.ukp.dkpro.argumentation.sequence.annotator.EmbeddingsCachePreprocessor.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dkpro.argumentation.sequence.annotator.EmbeddingsCachePreprocessor.java

Source

/*
 * Copyright 2015
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.dkpro.argumentation.sequence.annotator;

import de.tudarmstadt.ukp.dkpro.argumentation.sequence.dl.Embedding;
import de.tudarmstadt.ukp.dkpro.argumentation.sequence.dl.Word2VecReader;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;
import org.apache.commons.io.IOUtils;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;

import java.io.File;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import java.util.*;

/**
 * @author Ivan Habernal
 */
public class EmbeddingsCachePreprocessor extends JCasAnnotator_ImplBase {
    public static final String PARAM_WORD_2_VEC_FILE = "word2VecFile";

    @ConfigurationParameter(name = PARAM_WORD_2_VEC_FILE, mandatory = true)
    private File word2VecFile;

    public static final String PARAM_CACHE_FILE = "cacheFile";
    @ConfigurationParameter(name = PARAM_CACHE_FILE, mandatory = true)
    private File cacheFile;

    Set<String> tokens = new HashSet<>();

    @Override
    public void process(JCas aJCas) throws AnalysisEngineProcessException {
        for (Token token : JCasUtil.select(aJCas, Token.class)) {
            String coveredText = token.getCoveredText();
            if (coveredText.length() < 50) {
                tokens.add(coveredText);
            }
        }
    }

    @Override
    public void collectionProcessComplete() throws AnalysisEngineProcessException {
        try {
            Word2VecReader reader = new Word2VecReader(word2VecFile, true);

            String[] tokenArray = new ArrayList<>(tokens).toArray(new String[tokens.size()]);
            System.out.println("Vocabulary size: " + tokenArray.length);
            Embedding[] embeddings = reader.getEmbeddings(tokenArray);

            Map<String, no.uib.cipr.matrix.Vector> cache = new HashMap<>();
            if (tokenArray.length != embeddings.length) {
                throw new IllegalStateException();
            }

            for (int i = 0; i < tokenArray.length; i++) {
                String token = tokenArray[i];
                Embedding embedding = embeddings[i];

                if (embedding != null) {
                    cache.put(token, embedding.getVector());
                } else {
                    cache.put(token, null);
                }
            }

            FileOutputStream fos = new FileOutputStream(cacheFile);
            ObjectOutputStream os = new ObjectOutputStream(fos);
            os.writeObject(cache);

            IOUtils.closeQuietly(fos);
        } catch (Exception e) {
            throw new AnalysisEngineProcessException(e);
        }
    }
}