de.tudarmstadt.ukp.dkpro.argumentation.sequence.annotator.EmbeddingsAnnotator.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dkpro.argumentation.sequence.annotator.EmbeddingsAnnotator.java

Source

/*
 * Copyright 2015
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.dkpro.argumentation.sequence.annotator;

import de.tudarmstadt.ukp.dkpro.argumentation.sequence.dl.Embedding;
import de.tudarmstadt.ukp.dkpro.argumentation.sequence.dl.Word2VecReader;
import de.tudarmstadt.ukp.dkpro.argumentation.type.Embeddings;
import de.tudarmstadt.ukp.dkpro.core.api.frequency.tfidf.type.Tfidf;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import org.apache.commons.io.IOUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.DoubleArray;
import org.apache.uima.jcas.tcas.Annotation;
import org.apache.uima.resource.ResourceInitializationException;

import java.io.*;
import java.net.URL;
import java.util.*;

import static de.tudarmstadt.ukp.dkpro.core.api.resources.ResourceUtils.resolveLocation;

/**
 * @author Ivan Habernal
 */
public class EmbeddingsAnnotator extends JCasAnnotator_ImplBase {
    public static final String PARAM_WORD_2_VEC_FILE = "word2VecFile";

    @ConfigurationParameter(name = PARAM_WORD_2_VEC_FILE, mandatory = true)
    protected File word2VecFile;

    public static final String PARAM_CACHE_FILE = "cacheFile";
    @ConfigurationParameter(name = PARAM_CACHE_FILE, mandatory = false)
    protected String cacheFile;

    public static final String PARAM_TO_LOWERCASE = "toLowerCase";
    @ConfigurationParameter(name = PARAM_TO_LOWERCASE, mandatory = true, defaultValue = "false")
    boolean toLowerCase;

    /**
     * If true, the resulting sentence embedding vector will be averaged (default); otherwise
     * it will be just a sum of word vectors
     */
    public static final String PARAM_VECTOR_AVERAGING = "vectorAveraging";
    @ConfigurationParameter(name = PARAM_VECTOR_AVERAGING, mandatory = true, defaultValue = "true")
    boolean vectorAveraging;

    public static final String PARAM_TFIDF_WEIGHTING = "tfIdfWeighting";
    @ConfigurationParameter(name = PARAM_TFIDF_WEIGHTING, mandatory = true, defaultValue = "true")
    boolean tfIdfWeighting;

    protected Word2VecReader reader;

    protected Map<String, Vector> cache = new HashMap<>();

    // fixme make dynamical
    public static final int VECTOR_SIZE = 300;

    @Override
    public void initialize(UimaContext context) throws ResourceInitializationException {
        super.initialize(context);

        try {
            if (cacheFile != null) {
                loadCache();
            }
        } catch (IOException e) {
            throw new ResourceInitializationException(e);
        }
    }

    protected void initReader() throws IOException {
        if (reader == null) {
            reader = new Word2VecReader(word2VecFile, true);
        }
    }

    @SuppressWarnings("unchecked")
    protected void loadCache() throws IOException {
        URL source = resolveLocation(cacheFile);
        InputStream stream = source.openStream();

        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        IOUtils.copy(stream, baos);
        ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());

        ObjectInputStream os = new ObjectInputStream(bais);

        try {
            this.cache = (Map<String, Vector>) os.readObject();
        } catch (ClassNotFoundException e) {
            throw new IOException(e);
        }

        IOUtils.closeQuietly(os);
    }

    protected Embedding getEmbeddings(String t) {
        String token = preprocessToken(t);
        if (!cache.containsKey(token)) {
            System.err.println("Word " + token + " not cached; maybe you forgot to run "
                    + EmbeddingsCachePreprocessor.class + " to prepare the cache?");

            if (token.length() < 50) {
                Embedding[] embeddings;
                try {
                    initReader();

                    embeddings = reader.getEmbeddings(token);
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }

                return embeddings[0];
            }

            return new Embedding(token, null);
        } else {
            return new Embedding(token, cache.get(token));
        }
    }

    /**
     * Normalizes the token, e.g. lower casing (no change by default)
     *
     * @param token token
     * @return token
     */
    protected String preprocessToken(String token) {
        if (toLowerCase) {
            return token.toLowerCase();
        }

        return token;
    }

    protected Collection<? extends Annotation> selectAnnotationsForEmbeddings(JCas aJCas) {
        return JCasUtil.select(aJCas, Sentence.class);
    }

    /**
     * Combines the list of vectors into a single on, by summation and averaging (by default)
     *
     * @param vectors vectors
     * @return vector
     */
    protected DenseVector createFinalVector(List<Vector> vectors) {
        DenseVector result = new DenseVector(VECTOR_SIZE);

        for (Vector v : vectors) {
            result.add(v);
        }

        // averaging
        if (vectorAveraging) {
            result.scale(1.0 / (double) vectors.size());
        }

        return result;
    }

    @Override
    public void process(JCas aJCas) throws AnalysisEngineProcessException {
        // process each annotation (sentence, etc.)
        for (Annotation annotation : selectAnnotationsForEmbeddings(aJCas)) {
            // get tfidf values for all tokens
            LinkedHashMap<String, Double> tokenTfIdf = new LinkedHashMap<>();

            for (Token t : JCasUtil.selectCovered(Token.class, annotation)) {
                String coveredText = t.getCoveredText();
                if (coveredText.length() < 50) {
                    // retieve tfidf value
                    List<Tfidf> tfidfs = JCasUtil.selectCovered(Tfidf.class, t);

                    if (tfidfs.isEmpty()) {
                        throw new AnalysisEngineProcessException(
                                new IllegalStateException("Word embeddings annotations require TFIDF annotations"));
                    }

                    double tfidfValue = tfidfs.iterator().next().getTfidfValue();
                    tokenTfIdf.put(coveredText, tfidfValue);
                }
            }

            List<Vector> vectors = new ArrayList<>();

            // get list of embeddings for each token
            for (Map.Entry<String, Double> entry : tokenTfIdf.entrySet()) {
                Embedding embedding = getEmbeddings(entry.getKey());

                if (embedding != null && embedding.getVector() != null) {
                    // create a deep copy!!!!!!
                    DenseVector vector = new DenseVector(embedding.getVector());

                    // multiply by tfidf
                    double tfidf = entry.getValue();

                    if (this.tfIdfWeighting) {
                        vector.scale(tfidf);
                    }

                    vectors.add(vector);
                }
            }

            // create the final vector
            DenseVector finalVector = createFinalVector(vectors);

            // make new annotation
            Embeddings embeddings = new Embeddings(aJCas, annotation.getBegin(), annotation.getEnd());

            // copy double values
            DoubleArray doubleArray = new DoubleArray(aJCas, VECTOR_SIZE);
            doubleArray.copyFromArray(finalVector.getData(), 0, 0, VECTOR_SIZE);
            embeddings.setVector(doubleArray);

            embeddings.addToIndexes();
        }
    }
}