de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.deeplearning.EmbeddingsCachePreprocessor.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.deeplearning.EmbeddingsCachePreprocessor.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.deeplearning;

import de.tudarmstadt.ukp.experiments.argumentation.clustering.dl.Embedding;
import de.tudarmstadt.ukp.experiments.argumentation.clustering.dl.Word2VecReader;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;
import de.tudarmstadt.ukp.dkpro.core.io.xmi.XmiReader;
import no.uib.cipr.matrix.Vector;
import org.apache.commons.io.IOUtils;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.collection.CollectionReaderDescription;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.fit.pipeline.SimplePipeline;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;

import java.io.File;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import java.util.*;

/**
 * @author Ivan Habernal
 */
public class EmbeddingsCachePreprocessor extends JCasAnnotator_ImplBase {
    public static final String corpusFilePathTrain = "TBD";

    public static final String EMBEDDINGS = "TBD";

    Set<String> tokens = new HashSet<>();

    public static void main(String[] args) throws Exception {
        CollectionReaderDescription readerDescription = CollectionReaderFactory.createReaderDescription(
                XmiReader.class, XmiReader.PARAM_SOURCE_LOCATION, corpusFilePathTrain, XmiReader.PARAM_PATTERNS,
                XmiReader.INCLUDE_PREFIX + "*.xmi", XmiReader.PARAM_LENIENT, false);

        SimplePipeline.runPipeline(readerDescription,
                AnalysisEngineFactory.createEngineDescription(EmbeddingsCachePreprocessor.class));
    }

    @Override
    public void process(JCas aJCas) throws AnalysisEngineProcessException {
        for (Token token : JCasUtil.select(aJCas, Token.class)) {
            String coveredText = token.getCoveredText();
            if (coveredText.length() < 50) {
                tokens.add(coveredText);
            }
        }
    }

    @Override
    public void collectionProcessComplete() throws AnalysisEngineProcessException {
        try {
            Word2VecReader reader = new Word2VecReader(new File(EMBEDDINGS), true);

            String[] tokenArray = new ArrayList<>(tokens).toArray(new String[tokens.size()]);
            System.out.println("Vocabulary size: " + tokenArray.length);
            Embedding[] embeddings = reader.getEmbeddings(tokenArray);

            Map<String, Vector> cache = new HashMap<>();
            if (tokenArray.length != embeddings.length) {
                throw new IllegalStateException();
            }

            for (int i = 0; i < tokenArray.length; i++) {
                String token = tokenArray[i];
                Embedding embedding = embeddings[i];

                if (embedding != null) {
                    cache.put(token, embedding.getVector());
                } else {
                    cache.put(token, null);
                }
            }

            File cacheFile = new File("TBD");
            FileOutputStream fos = new FileOutputStream(cacheFile);
            ObjectOutputStream os = new ObjectOutputStream(fos);
            os.writeObject(cache);

            IOUtils.closeQuietly(fos);
        } catch (Exception e) {
            throw new AnalysisEngineProcessException(e);
        }
    }
}