de.tudarmstadt.ukp.dkpro.core.mallet.lda.MalletLdaTopicModelInferencer.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dkpro.core.mallet.lda.MalletLdaTopicModelInferencer.java

Source

/*
 * Copyright 2014
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package de.tudarmstadt.ukp.dkpro.core.mallet.lda;

import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.TokenSequence2FeatureSequence;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.topics.TopicInferencer;
import cc.mallet.types.Instance;
import cc.mallet.types.TokenSequence;
import de.tudarmstadt.ukp.dkpro.core.api.featurepath.FeaturePathException;
import de.tudarmstadt.ukp.dkpro.core.api.io.sequencegenerator.PhraseSequenceGenerator;
import de.tudarmstadt.ukp.dkpro.core.api.io.sequencegenerator.StringSequenceGenerator;
import de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData;
import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;
import de.tudarmstadt.ukp.dkpro.core.mallet.MalletModelTrainer;
import de.tudarmstadt.ukp.dkpro.core.mallet.type.TopicDistribution;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.TypeCapability;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.DoubleArray;
import org.apache.uima.jcas.cas.IntegerArray;
import org.apache.uima.resource.ResourceInitializationException;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
 * Infers the topic distribution over documents using a Mallet {@link ParallelTopicModel}.
 */
@TypeCapability(inputs = { "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token" }, outputs = {
        "de.tudarmstadt.ukp.dkpro.core.mallet.type.TopicDistribution" })

public class MalletLdaTopicModelInferencer extends JCasAnnotator_ImplBase {
    private static final String NONE_LABEL = "X";

    public final static String PARAM_MODEL_LOCATION = ComponentParameters.PARAM_MODEL_LOCATION;
    @ConfigurationParameter(name = PARAM_MODEL_LOCATION, mandatory = true)
    private File modelLocation;

    /**
     * The annotation type to use as tokens. Default: {@link Token}
     */
    public final static String PARAM_TYPE_NAME = "typeName";
    @ConfigurationParameter(name = PARAM_TYPE_NAME, mandatory = true, defaultValue = "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token")
    private String typeName;

    /**
     * The number of iterations during inference. Default: 100.
     */
    public final static String PARAM_N_ITERATIONS = "nIterations";
    @ConfigurationParameter(name = PARAM_N_ITERATIONS, mandatory = true, defaultValue = "100")
    private int nIterations;

    /**
     * The number of iterations before hyperparameter optimization begins. Default: 1
     */
    public final static String PARAM_BURN_IN = "burnIn";
    @ConfigurationParameter(name = PARAM_BURN_IN, mandatory = true, defaultValue = "1")
    private int burnIn;

    public final static String PARAM_THINNING = "thinning";
    @ConfigurationParameter(name = PARAM_THINNING, mandatory = true, defaultValue = "5")
    private int thinning;

    /**
     * Minimum topic proportion for the document-topic assignment.
     */
    public final static String PARAM_MIN_TOPIC_PROB = "minTopicProb";
    @ConfigurationParameter(name = PARAM_MIN_TOPIC_PROB, mandatory = true, defaultValue = "0.2")
    private double minTopicProb;

    /**
     * Maximum number of topics to assign. If not set (or &lt;= 0), the number of topics in the
     * model divided by 10 is set.
     */
    public final static String PARAM_MAX_TOPIC_ASSIGNMENTS = "maxTopicAssignments";
    @ConfigurationParameter(name = PARAM_MAX_TOPIC_ASSIGNMENTS, mandatory = true, defaultValue = "0")
    private int maxTopicAssignments;

    /**
     * The annotation type to use for the model. Default: {@code de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token}.
     * For lemmas, use {@code de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token/lemma/value}
     */
    public static final String PARAM_TOKEN_FEATURE_PATH = MalletModelTrainer.PARAM_TOKEN_FEATURE_PATH;
    @ConfigurationParameter(name = PARAM_TOKEN_FEATURE_PATH, mandatory = true, defaultValue = "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token")
    private String tokenFeaturePath;
    /**
     * Ignore tokens (or lemmas, respectively) that are shorter than the given value. Default: 3.
     */
    public static final String PARAM_MIN_TOKEN_LENGTH = "minTokenLength";
    @ConfigurationParameter(name = PARAM_MIN_TOKEN_LENGTH, mandatory = true, defaultValue = "3")
    private int minTokenLength;

    /**
     * If set to true (default: false), all tokens are lowercased.
     */
    public static final String PARAM_LOWERCASE = "lowercase";
    @ConfigurationParameter(name = PARAM_LOWERCASE, mandatory = true, defaultValue = "false")
    private boolean lowercase;

    private TopicInferencer inferencer;
    private Pipe malletPipe;
    private StringSequenceGenerator sequenceGenerator;

    @Override
    public void initialize(UimaContext context) throws ResourceInitializationException {
        super.initialize(context);

        ParallelTopicModel model;
        try {
            getLogger().info("Loading model file " + modelLocation);
            model = ParallelTopicModel.read(modelLocation);

            if (maxTopicAssignments <= 0) {
                maxTopicAssignments = model.getNumTopics() / 10;
            }
        } catch (Exception e) {
            throw new ResourceInitializationException(e);
        }
        getLogger().info("Model loaded.");

        inferencer = model.getInferencer();
        malletPipe = new TokenSequence2FeatureSequence(model.getAlphabet());
        try {
            sequenceGenerator = new PhraseSequenceGenerator.Builder().featurePath(tokenFeaturePath)
                    .minTokenLength(minTokenLength).lowercase(lowercase).buildStringSequenceGenerator();
        } catch (IOException e) {
            throw new ResourceInitializationException(e);
        }
    }

    @Override
    public void process(JCas aJCas) throws AnalysisEngineProcessException {
        try {
            List<String[]> tokenSequences = sequenceGenerator.tokenSequences(aJCas);

            if (tokenSequences.isEmpty()) {
                getLogger().warn("Empty document.");
            } else {
                DocumentMetaData metadata = DocumentMetaData.get(aJCas);

                /* create Mallet Instance */
                TokenSequence ts = new TokenSequence(tokenSequences.get(0));
                Instance instance = new Instance(ts, NONE_LABEL, metadata.getDocumentId(),
                        metadata.getDocumentUri());

                /* infer topic distribution across document */
                TopicDistribution topicDistributionAnnotation = new TopicDistribution(aJCas);
                double[] topicDistribution = inferencer.getSampledDistribution(malletPipe.instanceFrom(instance),
                        nIterations, thinning, burnIn);

                /* convert data type (Mallet output -> UIMA double array) */
                DoubleArray da = new DoubleArray(aJCas, topicDistribution.length);
                da.copyFromArray(topicDistribution, 0, 0, topicDistribution.length);
                topicDistributionAnnotation.setTopicProportions(da);

                /* assign topics to document according to topic distribution */
                int[] assignedTopicIndexes = assignTopics(topicDistribution);
                IntegerArray topicIndexes = new IntegerArray(aJCas, assignedTopicIndexes.length);
                topicIndexes.copyFromArray(assignedTopicIndexes, 0, 0, assignedTopicIndexes.length);
                topicDistributionAnnotation.setTopicAssignment(topicIndexes);

                aJCas.addFsToIndexes(topicDistributionAnnotation);
            }
        } catch (FeaturePathException e) {
            throw new AnalysisEngineProcessException(e);
        }
    }

    /**
     * Assign topics according to the following formula:
     * <p>
     * Topic proportion must be at least the maximum topic's proportion divided by the maximum
     * number of topics to be assigned. In addition, the topic proportion must not lie under the
     * minTopicProb. If more topics comply with these criteria, only retain the n
     * (maxTopicAssignments) largest values.
     *
     * @param topicDistribution a double array containing the document's topic proportions
     * @return an array of integers pointing to the topics assigned to the document
     * @deprecated this method should be removed at some point because assignment / topic tagging
     * should be done in a dedicated step (module).
     */
    // TODO: should return a boolean[] of the same size as topicDistribution
    // TODO: should probably be moved to a dedicated module because assignments (topic tagging)
    // should not be done at inference level
    @Deprecated
    private int[] assignTopics(final double[] topicDistribution) {
        /*
         * threshold is the largest value divided by the maximum number of topics or the fixed
         * number set as minTopicProb parameter.
         */
        double threshold = Math
                .max(Collections.max(Arrays.asList(ArrayUtils.toObject(topicDistribution))).doubleValue()
                        / maxTopicAssignments, minTopicProb);

        /*
         * assign indexes for values that are above threshold
         */
        List<Integer> indexes = new ArrayList<>(topicDistribution.length);
        for (int i = 0; i < topicDistribution.length; i++) {
            if (topicDistribution[i] >= threshold) {
                indexes.add(i);
            }
        }

        /*
         * Reduce assignments to maximum number of allowed assignments.
         */
        if (indexes.size() > maxTopicAssignments) {

            /* sort index list by corresponding values */
            Collections.sort(indexes, (aO1, aO2) -> Double.compare(topicDistribution[aO1], topicDistribution[aO2]));

            while (indexes.size() > maxTopicAssignments) {
                indexes.remove(0);
            }
        }

        return ArrayUtils.toPrimitive(indexes.toArray(new Integer[indexes.size()]));
    }
}