Java tutorial
/******************************************************************************* * Copyright 2014 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universitt Darmstadt * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ******************************************************************************/ package de.tudarmstadt.ukp.dkpro.core.mallet.topicmodel; import static org.apache.uima.fit.util.JCasUtil.selectCovered; import java.io.File; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Map.Entry; import org.apache.commons.lang.ArrayUtils; import org.apache.uima.UimaContext; import org.apache.uima.analysis_engine.AnalysisEngineProcessException; import org.apache.uima.cas.text.AnnotationFS; import org.apache.uima.fit.component.JCasAnnotator_ImplBase; import org.apache.uima.fit.descriptor.ConfigurationParameter; import org.apache.uima.jcas.JCas; import org.apache.uima.jcas.cas.DoubleArray; import org.apache.uima.jcas.cas.IntegerArray; import org.apache.uima.resource.ResourceInitializationException; import cc.mallet.pipe.Pipe; import cc.mallet.pipe.TokenSequence2FeatureSequence; import cc.mallet.topics.ParallelTopicModel; import cc.mallet.topics.TopicInferencer; import cc.mallet.types.Instance; import cc.mallet.types.TokenSequence; import de.tudarmstadt.ukp.dkpro.core.api.featurepath.FeaturePathException; import de.tudarmstadt.ukp.dkpro.core.api.featurepath.FeaturePathFactory; import de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData; import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters; import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Lemma; import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token; import de.tudarmstadt.ukp.dkpro.core.mallet.type.TopicDistribution; /** * Infers the topic distribution over documents using a Mallet {@link ParallelTopicModel}. * * @author Carsten Schnober */ public class MalletTopicModelInferencer extends JCasAnnotator_ImplBase { private static final String NONE_LABEL = "X"; public final static String PARAM_MODEL_LOCATION = ComponentParameters.PARAM_MODEL_LOCATION; @ConfigurationParameter(name = PARAM_MODEL_LOCATION, mandatory = true) private File modelLocation; /** * The annotation type to use as tokens. Default: {@link Token} */ public final static String PARAM_TYPE_NAME = "typeName"; @ConfigurationParameter(name = PARAM_TYPE_NAME, mandatory = true, defaultValue = "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token") private String typeName; /** * The number of iterations during inference. Default: 10. */ public final static String PARAM_N_ITERATIONS = "nIterations"; @ConfigurationParameter(name = PARAM_N_ITERATIONS, mandatory = true, defaultValue = "10") private int nIterations; /** * The number of iterations before hyperparameter optimization begins. Default: 1 */ public final static String PARAM_BURN_IN = "burnIn"; @ConfigurationParameter(name = PARAM_BURN_IN, mandatory = true, defaultValue = "1") private int burnIn; public final static String PARAM_THINNING = "thinning"; @ConfigurationParameter(name = PARAM_THINNING, mandatory = true, defaultValue = "5") private int thinning; /** * Minimum topic proportion for the document-topic assignment. */ public final static String PARAM_MIN_TOPIC_PROB = "minTopicProb"; @ConfigurationParameter(name = PARAM_MIN_TOPIC_PROB, mandatory = true, defaultValue = "0.2") private double minTopicProb; /** * Maximum number of topics to assign. If not set (or <= 0), the number of topics in the model * divided by 10 is set. */ public final static String PARAM_MAX_TOPIC_ASSIGNMENTS = "maxTopicAssignments"; @ConfigurationParameter(name = PARAM_MAX_TOPIC_ASSIGNMENTS, mandatory = true, defaultValue = "0") private int maxTopicAssignments; /** * If set, uses lemmas instead of original text as features. */ public static final String PARAM_USE_LEMMA = "useLemma"; @ConfigurationParameter(name = PARAM_USE_LEMMA, mandatory = true, defaultValue = "false") private boolean useLemma; /** * Ignore tokens (or lemmas, respectively) that are shorter than the given value. Default: 3. */ public static final String PARAM_MIN_TOKEN_LENGTH = "minTokenLength"; @ConfigurationParameter(name = PARAM_MIN_TOKEN_LENGTH, mandatory = true, defaultValue = "3") private int minTokenLength; private TopicInferencer inferencer; private Pipe malletPipe; @Override public void initialize(UimaContext context) throws ResourceInitializationException { super.initialize(context); try { ParallelTopicModel model = ParallelTopicModel.read(modelLocation); inferencer = model.getInferencer(); if (maxTopicAssignments <= 0) { maxTopicAssignments = model.getNumTopics() / 10; } } catch (Exception e) { throw new ResourceInitializationException(e); } malletPipe = new TokenSequence2FeatureSequence(); }; @Override public void process(JCas aJCas) throws AnalysisEngineProcessException { /* convert tokens (or other annotation type) into a Mallet TokenSequence */ TokenSequence tokenStream = new TokenSequence(); try { for (Entry<AnnotationFS, String> entry : FeaturePathFactory.select(aJCas.getCas(), typeName)) { if (useLemma) { for (Lemma lemma : selectCovered(Lemma.class, entry.getKey())) { String text = lemma.getValue(); if (text.length() >= minTokenLength) { tokenStream.add(text); } } } else { String text = entry.getValue(); if (text.length() >= minTokenLength) { tokenStream.add(text); } } } } catch (FeaturePathException e) { throw new AnalysisEngineProcessException(e); } /* create Mallet Instance */ DocumentMetaData metadata = DocumentMetaData.get(aJCas); Instance instance = new Instance(tokenStream, NONE_LABEL, metadata.getDocumentId(), metadata.getDocumentUri()); /* infer topic distribution across document */ TopicDistribution topicDistributionAnnotation = new TopicDistribution(aJCas); double[] topicDistribution = inferencer.getSampledDistribution(malletPipe.instanceFrom(instance), nIterations, thinning, burnIn); /* convert data type */ DoubleArray da = new DoubleArray(aJCas, topicDistribution.length); da.copyFromArray(topicDistribution, 0, 0, topicDistribution.length); topicDistributionAnnotation.setTopicProportions(da); /* assign topics to document according to topic distribution */ int[] assignedTopicIndexes = assignTopics(topicDistribution); IntegerArray topicIndexes = new IntegerArray(aJCas, assignedTopicIndexes.length); topicIndexes.copyFromArray(assignedTopicIndexes, 0, 0, assignedTopicIndexes.length); topicDistributionAnnotation.setTopicAssignment(topicIndexes); aJCas.addFsToIndexes(topicDistributionAnnotation); } /** * Assign topics according to the following formula: * <p> * Topic proportion must be at least the maximum topic's proportion divided by the maximum * number of topics to be assigned. In addition, the topic proportion must not lie under the * minTopicProb. If more topics comply with these criteria, only retain the n * (maxTopicAssignments) largest values. * * @param topicDistribution * a double array containing the document's topic proportions * @return an array of integers pointing to the topics assigned to the document */ private int[] assignTopics(final double[] topicDistribution) { /* * threshold is the largest value divided by the maximum number of topics or the fixed * number set as minTopicProb parameter. */ double threshold = Math.max( Collections.max(Arrays.asList(ArrayUtils.toObject(topicDistribution))) / maxTopicAssignments, minTopicProb); /* * assign indexes for values that are above threshold */ List<Integer> indexes = new ArrayList<>(topicDistribution.length); for (int i = 0; i < topicDistribution.length; i++) { if (topicDistribution[i] >= threshold) { indexes.add(i); } } /* * Reduce assignments to maximum number of allowed assignments. */ if (indexes.size() > maxTopicAssignments) { /* sort index list by corresponding values */ Collections.sort(indexes, new Comparator<Integer>() { @Override public int compare(Integer aO1, Integer aO2) { return Double.compare(topicDistribution[aO1], topicDistribution[aO2]); } }); while (indexes.size() > maxTopicAssignments) { indexes.remove(0); } } return ArrayUtils.toPrimitive(indexes.toArray(new Integer[indexes.size()])); } }