de.tudarmstadt.ukp.experiments.argumentation.comments.pipeline.ExtendedMalletTopicModelEstimator.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.experiments.argumentation.comments.pipeline.ExtendedMalletTopicModelEstimator.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.experiments.argumentation.comments.pipeline;

import cc.mallet.types.TokenSequence;
import de.tudarmstadt.ukp.dkpro.core.api.featurepath.FeaturePathException;
import de.tudarmstadt.ukp.dkpro.core.api.featurepath.FeaturePathFactory;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Lemma;
import de.tudarmstadt.ukp.dkpro.core.mallet.topicmodel.MalletTopicModelEstimator;
import org.apache.commons.io.IOUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.text.AnnotationFS;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Map;
import java.util.Set;

import static org.apache.uima.fit.util.JCasUtil.selectCovered;

/**
 * @author Ivan Habernal
 */
public class ExtendedMalletTopicModelEstimator extends MalletTopicModelEstimator {
    public static final String PARAM_VOCABULARY_FILE = "vocabularyFile";
    @ConfigurationParameter(name = PARAM_VOCABULARY_FILE, mandatory = true)
    File vocabularyFile;

    Set<String> vocabulary;

    /**
     * Loads serialized vocabulary ({@code HashMap<String, Integer>}) from file
     *
     * @return vocabulary entries (key set)
     * @throws IOException
     * @throws ClassNotFoundException
     */
    @SuppressWarnings("unchecked")
    public static Set<String> readVocabulary(File vocabularyFile) throws IOException, ClassNotFoundException {
        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(vocabularyFile));
        Map<String, Integer> map = (Map<String, Integer>) ois.readObject();

        IOUtils.closeQuietly(ois);

        return map.keySet();
    }

    @Override
    public void initialize(UimaContext context) throws ResourceInitializationException {
        super.initialize(context);

        try {
            vocabulary = readVocabulary(vocabularyFile);
        } catch (IOException | ClassNotFoundException e) {
            throw new ResourceInitializationException(e);
        }
    }

    @Override
    protected TokenSequence generateTokenSequence(JCas aJCas) throws AnalysisEngineProcessException {
        TokenSequence tokenStream = new TokenSequence();
        try {
            for (Map.Entry<AnnotationFS, String> entry : FeaturePathFactory.select(aJCas.getCas(),
                    (String) this.getContext().getConfigParameterValue(PARAM_TYPE_NAME))) {
                String value = null;

                if ((boolean) this.getContext().getConfigParameterValue(PARAM_USE_LEMMA)) {
                    for (Lemma lemma : selectCovered(Lemma.class, entry.getKey())) {
                        String text = lemma.getValue();
                        if (text.length() >= ((int) this.getContext()
                                .getConfigParameterValue(PARAM_MIN_TOKEN_LENGTH))) {
                            value = text;
                        }
                    }
                } else {
                    String text = entry.getValue();
                    if (text.length() >= ((int) this.getContext()
                            .getConfigParameterValue(PARAM_MIN_TOKEN_LENGTH))) {
                        value = text;
                    }
                }

                // do the filtering; add only known words from the vocabulary
                if (value != null && vocabulary.contains(value)) {
                    tokenStream.add(value);
                }
            }
        } catch (FeaturePathException e) {
            throw new AnalysisEngineProcessException(e);
        }
        return tokenStream;
    }
}