de.tudarmstadt.ukp.dkpro.core.mallet.lda.MalletLdaTopicModelUtils.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dkpro.core.mallet.lda.MalletLdaTopicModelUtils.java

Source

/*
 * Copyright 2015
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.dkpro.core.mallet.lda;

import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.IDSorter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.io.File;
import java.io.IOException;
import java.util.*;

public class MalletLdaTopicModelUtils {
    private static final Log LOGGER = LogFactory.getLog(MalletLdaTopicModelUtils.class);

    /**
     * Retrieve the top n topic words for each topic in the given model.
     * 
     * @param modelFile
     *            the model file
     * @param nWords
     *            the maximum number of words to retrieve
     * @param normalize
     *            normalize the word weights ?
     * 
     * @return a list of maps where each map represents a topic, mapping words to weights
     * @throws IOException
     *             if the model cannot be read
     */
    public static List<Map<String, Double>> getTopWords(File modelFile, int nWords, boolean normalize)
            throws IOException {
        LOGGER.info("Reading model file " + modelFile + "...");
        ParallelTopicModel model;
        try {
            model = ParallelTopicModel.read(modelFile);
        } catch (Exception e) {
            throw new IOException(e);
        }
        Alphabet alphabet = model.getAlphabet();

        List<Map<String, Double>> topics = new ArrayList<>(model.getNumTopics());

        /* iterate over topics */
        for (TreeSet<IDSorter> topic : model.getSortedWords()) {
            Map<String, Double> topicWords = new HashMap<>(nWords);

            /* iterate over word IDs in topic (sorted by weight) */
            for (IDSorter id : topic) {
                double weight = normalize ? id.getWeight() / alphabet.size() : id.getWeight(); // normalize
                String word = (String) alphabet.lookupObject(id.getID());

                topicWords.put(word, weight);

                if (topicWords.size() >= nWords) {
                    break; // go to next topic
                }
            }
            topics.add(topicWords);
        }
        return topics;
    }

    /**
     * Print the top n words of each topic into a file.
     * 
     * @param modelFile
     *            the model file
     * @param targetFile
     *            the file in which the topic words are written
     * @param nWords
     *            the number of words to extract
     * @throws IOException
     *             if the model file cannot be read or if the target file cannot be written
     */
    public static void printTopicWords(File modelFile, File targetFile, int nWords) throws IOException {
        boolean newLineAfterEachWord = false;

        ParallelTopicModel model;
        try {
            model = ParallelTopicModel.read(modelFile);
        } catch (Exception e) {
            throw new IOException(e);
        }
        model.printTopWords(targetFile, nWords, newLineAfterEachWord);
    }

}