de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.lda.LDATopicsFeature.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.lda.LDATopicsFeature.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.lda;

import cc.mallet.pipe.*;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.topics.TopicInferencer;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.AbstractUnitSentenceFeatureGenerator;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;
import de.tudarmstadt.ukp.dkpro.tc.api.exception.TextClassificationException;
import de.tudarmstadt.ukp.dkpro.tc.api.features.Feature;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.resource.ResourceSpecifier;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import static de.tudarmstadt.ukp.dkpro.core.api.resources.ResourceUtils.resolveLocation;

/**
 * Use pre-trained LDA model with k topics created externally on a large corpus; extract k features,
 * each represent weight of k-th topic on the document. The topic distribution on document is
 * computed using Gibbs sampling (see
 * {@link TopicInferencer#getSampledDistribution(Instance, int, int, int)}).
 * <p/>
 * This extractor is based on lemma; thus both the pre-trained LDA and the annotated jCas must
 * contain lemmas.
 *
 * @author Ivan Habernal
 */
public class LDATopicsFeature extends AbstractUnitSentenceFeatureGenerator {
    public static final String PARAM_LDA_MODEL_FILE = "ldaModelFile";
    @ConfigurationParameter(name = PARAM_LDA_MODEL_FILE, mandatory = true)
    private String ldaModelFile;
    private Object[][] topWords;

    @Override
    public boolean initialize(ResourceSpecifier aSpecifier, Map<String, Object> aAdditionalParams)
            throws ResourceInitializationException {
        if (!super.initialize(aSpecifier, aAdditionalParams)) {
            return false;
        }

        try {
            URL source = resolveLocation(ldaModelFile);
            //            load(source.openStream());
            //            FileInputStream fis = new FileInputStream(ldaModelFile);
            InputStream stream = source.openStream();
            ObjectInputStream ois = new ObjectInputStream(stream);

            this.model = (ParallelTopicModel) ois.readObject();

            IOUtils.closeQuietly(stream);

            this.pipes = createPipes();
            this.pipes.setDataAlphabet(model.getAlphabet());

            for (Pipe pipe : this.pipes.pipes()) {
                pipe.setDataAlphabet(this.model.getAlphabet());
            }

            // extract top words for feature naming
            topWords = this.model.getTopWords(10);
        } catch (IOException | ClassNotFoundException ex) {
            throw new ResourceInitializationException(ex);
        }

        return true;
    }

    //    public static int repeatSimComputing = 50;
    public static int inferIteration = 20;
    //    public static int wordsPerTopic = 30;
    //    public static int numTrainIter = 500;
    //    public static int numTopics = 100;
    //    public static float alpha = 1.0f;
    //    public static float beta = 0.01f;
    protected ParallelTopicModel model;
    private SerialPipes pipes;

    protected static SerialPipes createPipes() {
        ArrayList<Pipe> pipeList = new ArrayList<>();

        // Pipes: lowercase, tokenize, remove stopwords, map to features
        pipeList.add(new CharSequenceLowercase());
        pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}[\\p{L}\\p{P}]+\\p{L}")));
        pipeList.add(new TokenSequenceRemoveStopwords());
        pipeList.add(new TokenSequence2FeatureSequence());
        return new SerialPipes(pipeList);
    }

    public double[] getVector(String text) {
        InstanceList instances = new InstanceList(this.pipes);
        instances.addThruPipe(new Instance(text, "X", "doc1", null));

        TopicInferencer inference = model.getInferencer();
        return inference.getSampledDistribution(instances.get(0), inferIteration, 1, 5);
    }

    @Override
    protected List<Feature> extract(JCas jCas, Sentence sentence, String sentencePrefix)
            throws TextClassificationException {

        StringBuilder sb = new StringBuilder();
        for (Token token : JCasUtil.selectCovered(Token.class, sentence)) {
            String lemma = token.getLemma().getValue().toLowerCase();

            sb.append(lemma);
            sb.append(" ");
        }

        double[] topicDistribution = getVector(sb.toString());

        List<Feature> features = new ArrayList<>(topicDistribution.length);
        for (int i = 0; i < topicDistribution.length; i++) {
            double value = topicDistribution[i];
            String name = getKthTopicDescription(i);

            features.add(new Feature(sentencePrefix + name, value));
        }

        return features;
    }

    private String getKthTopicDescription(int k) {
        Object[] words = this.topWords[k];
        return StringUtils.join(words, "_");
    }

}