Java tutorial
/* * Copyright 2015 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universitt Darmstadt * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.dkpro.argumentation.sequence.annotator; import de.tudarmstadt.ukp.dkpro.argumentation.sequence.dl.Embedding; import de.tudarmstadt.ukp.dkpro.argumentation.sequence.dl.Word2VecReader; import de.tudarmstadt.ukp.dkpro.argumentation.type.Embeddings; import de.tudarmstadt.ukp.dkpro.core.api.frequency.tfidf.type.Tfidf; import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence; import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token; import no.uib.cipr.matrix.DenseVector; import no.uib.cipr.matrix.Vector; import org.apache.commons.io.IOUtils; import org.apache.uima.UimaContext; import org.apache.uima.analysis_engine.AnalysisEngineProcessException; import org.apache.uima.fit.component.JCasAnnotator_ImplBase; import org.apache.uima.fit.descriptor.ConfigurationParameter; import org.apache.uima.fit.util.JCasUtil; import org.apache.uima.jcas.JCas; import org.apache.uima.jcas.cas.DoubleArray; import org.apache.uima.jcas.tcas.Annotation; import org.apache.uima.resource.ResourceInitializationException; import java.io.*; import java.net.URL; import java.util.*; import static de.tudarmstadt.ukp.dkpro.core.api.resources.ResourceUtils.resolveLocation; /** * @author Ivan Habernal */ public class EmbeddingsAnnotator extends JCasAnnotator_ImplBase { public static final String PARAM_WORD_2_VEC_FILE = "word2VecFile"; @ConfigurationParameter(name = PARAM_WORD_2_VEC_FILE, mandatory = true) protected File word2VecFile; public static final String PARAM_CACHE_FILE = "cacheFile"; @ConfigurationParameter(name = PARAM_CACHE_FILE, mandatory = false) protected String cacheFile; public static final String PARAM_TO_LOWERCASE = "toLowerCase"; @ConfigurationParameter(name = PARAM_TO_LOWERCASE, mandatory = true, defaultValue = "false") boolean toLowerCase; /** * If true, the resulting sentence embedding vector will be averaged (default); otherwise * it will be just a sum of word vectors */ public static final String PARAM_VECTOR_AVERAGING = "vectorAveraging"; @ConfigurationParameter(name = PARAM_VECTOR_AVERAGING, mandatory = true, defaultValue = "true") boolean vectorAveraging; public static final String PARAM_TFIDF_WEIGHTING = "tfIdfWeighting"; @ConfigurationParameter(name = PARAM_TFIDF_WEIGHTING, mandatory = true, defaultValue = "true") boolean tfIdfWeighting; protected Word2VecReader reader; protected Map<String, Vector> cache = new HashMap<>(); // fixme make dynamical public static final int VECTOR_SIZE = 300; @Override public void initialize(UimaContext context) throws ResourceInitializationException { super.initialize(context); try { if (cacheFile != null) { loadCache(); } } catch (IOException e) { throw new ResourceInitializationException(e); } } protected void initReader() throws IOException { if (reader == null) { reader = new Word2VecReader(word2VecFile, true); } } @SuppressWarnings("unchecked") protected void loadCache() throws IOException { URL source = resolveLocation(cacheFile); InputStream stream = source.openStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream(); IOUtils.copy(stream, baos); ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); ObjectInputStream os = new ObjectInputStream(bais); try { this.cache = (Map<String, Vector>) os.readObject(); } catch (ClassNotFoundException e) { throw new IOException(e); } IOUtils.closeQuietly(os); } protected Embedding getEmbeddings(String t) { String token = preprocessToken(t); if (!cache.containsKey(token)) { System.err.println("Word " + token + " not cached; maybe you forgot to run " + EmbeddingsCachePreprocessor.class + " to prepare the cache?"); if (token.length() < 50) { Embedding[] embeddings; try { initReader(); embeddings = reader.getEmbeddings(token); } catch (IOException e) { throw new RuntimeException(e); } return embeddings[0]; } return new Embedding(token, null); } else { return new Embedding(token, cache.get(token)); } } /** * Normalizes the token, e.g. lower casing (no change by default) * * @param token token * @return token */ protected String preprocessToken(String token) { if (toLowerCase) { return token.toLowerCase(); } return token; } protected Collection<? extends Annotation> selectAnnotationsForEmbeddings(JCas aJCas) { return JCasUtil.select(aJCas, Sentence.class); } /** * Combines the list of vectors into a single on, by summation and averaging (by default) * * @param vectors vectors * @return vector */ protected DenseVector createFinalVector(List<Vector> vectors) { DenseVector result = new DenseVector(VECTOR_SIZE); for (Vector v : vectors) { result.add(v); } // averaging if (vectorAveraging) { result.scale(1.0 / (double) vectors.size()); } return result; } @Override public void process(JCas aJCas) throws AnalysisEngineProcessException { // process each annotation (sentence, etc.) for (Annotation annotation : selectAnnotationsForEmbeddings(aJCas)) { // get tfidf values for all tokens LinkedHashMap<String, Double> tokenTfIdf = new LinkedHashMap<>(); for (Token t : JCasUtil.selectCovered(Token.class, annotation)) { String coveredText = t.getCoveredText(); if (coveredText.length() < 50) { // retieve tfidf value List<Tfidf> tfidfs = JCasUtil.selectCovered(Tfidf.class, t); if (tfidfs.isEmpty()) { throw new AnalysisEngineProcessException( new IllegalStateException("Word embeddings annotations require TFIDF annotations")); } double tfidfValue = tfidfs.iterator().next().getTfidfValue(); tokenTfIdf.put(coveredText, tfidfValue); } } List<Vector> vectors = new ArrayList<>(); // get list of embeddings for each token for (Map.Entry<String, Double> entry : tokenTfIdf.entrySet()) { Embedding embedding = getEmbeddings(entry.getKey()); if (embedding != null && embedding.getVector() != null) { // create a deep copy!!!!!! DenseVector vector = new DenseVector(embedding.getVector()); // multiply by tfidf double tfidf = entry.getValue(); if (this.tfIdfWeighting) { vector.scale(tfidf); } vectors.add(vector); } } // create the final vector DenseVector finalVector = createFinalVector(vectors); // make new annotation Embeddings embeddings = new Embeddings(aJCas, annotation.getBegin(), annotation.getEnd()); // copy double values DoubleArray doubleArray = new DoubleArray(aJCas, VECTOR_SIZE); doubleArray.copyFromArray(finalVector.getData(), 0, 0, VECTOR_SIZE); embeddings.setVector(doubleArray); embeddings.addToIndexes(); } } }