gov.llnl.iscr.iris.LDAHandler.java Source code

Java tutorial

Introduction

Here is the source code for gov.llnl.iscr.iris.LDAHandler.java

Source

/**
* Copyright (c) 2011, Lawrence Livermore National Security, LLC. 
* Produced at the Lawrence Livermore National Laboratory. 
* Written by Kevin Lawrence, lawrence22@llnl.gov
* Under the guidance of: 
* David Andrzejewski, andrzejewski1@llnl.gov
* David Buttler, buttler1@llnl.gov 
* LLNL-CODE-521811 All rights reserved. This file is part of IRIS
*
* This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public
* License (as published by the Free Software Foundation) version 2, dated June 1991. This program is distributed in the
* hope that it will be useful, but WITHOUT ANY WARRANTY; without even the IMPLIED WARRANTY OF MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the terms and conditions of the GNU General Public License for more details.
* You should have received a copy of the GNU General Public License along with this program; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA For full text see license.txt
*
*
*/
package gov.llnl.iscr.iris;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import com.mongodb.BasicDBObject;
import com.mongodb.DBCursor;
import com.mongodb.DBObject;

import gov.llnl.iscr.iris.LDAModel;

/**
 * 
 * This class provides the methods for processing the data retrieved by the {@link LDAModel}. 
 * It contains an LDAModel as one of its members and makes the appropriate calls for data via this object.
 * It then processes the data according to the topic selection procedure for the latent topic feedback system.
 * <p>A typical invocation sequence is:
 * <blockquote><pre>
 * MongoInstance mongo = new MongoInstance("127.0.0.1", "topicModel");
 * LDAModel model = new LDAModel(mongo);
 * LDAHandler lda = new LDAHandler(model);
 * </pre></blockquote>
 * 
 */
public class LDAHandler {
    private final LDAModel model;
    private double topicThreshold = -100.0;
    private List<Integer> enrichedSet = null;
    private List<Integer> relatedSet = null;
    private List<BasicDBObject> selectedNgrams = null;
    private List<BasicDBObject> selectedUnigrams = null;
    private Map<Integer, List<String>> expansionWords = new LinkedHashMap<Integer, List<String>>();

    public static enum TopicType {
        ENRICHED, RELATED
    }

    /**
     * creates an instance of LDAHandler based on the LDAModel provided
     * @param model
     */
    public LDAHandler(LDAModel model) {
        this.model = model;
    }

    /**
     * sets the threshold value to be used for filtering "junk" topics
     * to a value obtained by thresholding the topic semantic coherence scores
     * at the percentile value given.
     * @param thresholdPercentile
     * @return
     */
    public LDAHandler setTopicThreshold(float thresholdPercentile) {

        DBCursor semcoCur = model.getSemcoValues();
        int semcoCount = semcoCur.count();
        int limit = (int) (thresholdPercentile * semcoCount) - 1;

        this.topicThreshold = (Double) semcoCur.toArray().get(limit).get("semco");
        return this;
    }

    /**
     * set the threshold value to be used for filtering "junk" topics to the absolute value given 
     * @param threshold
     * @return
     */
    public LDAHandler setTopicThreshold(double threshold) {
        this.topicThreshold = threshold;
        return this;
    }

    /**
     * sets the list of enriched topics for the user query using the results list; 
     * call {@link LDAHandler#getRelatedTopicSet()} on the returned
     * LDAHandler to retrieve the list.
     * @param results
     * @return
     */
    @SuppressWarnings("unchecked")
    public LDAHandler setEnrichedTopicSet(List<Object> docIDs) {
        enrichedSet = new ArrayList<Integer>();
        int iterateCount = 0;
        for (int i = 0; i < 2; i++) {
            //-|===============================================
            //-|1. Gets associated topics for given document
            //-|2. Sorts topics by probability, descending
            //-|3. Filters topics according to threshold value
            //-|===============================================

            List<DBObject> topics = (List<DBObject>) model.getTopics(docIDs.get(i)).get("topics");
            Collections.sort(topics, new TopicSortByProb());
            List<Integer> temp = filterTopics(topics, topicThreshold, TopicType.ENRICHED);

            //-|===============================================
            //-|Iterates temp to extract enriched topics
            //-|===============================================
            int index = 0;
            int numOfTopicAdded = 0;
            while (numOfTopicAdded < 2 && index < temp.size() && enrichedSet.size() < 4) {
                if (!enrichedSet.contains(temp.get(index))) {
                    enrichedSet.add(temp.get(index));
                    ++index;
                    ++numOfTopicAdded;
                } else {
                    ++index;
                }
            }
            //If the enriched set does not contain four (4) topics after searching both topic lists
            if (i == 1 && enrichedSet.size() < 4) {
                i = -1;
                ++iterateCount;
            }
            //If the topic list has been iterated twice: Exit
            if (iterateCount == 2)
                break;
        }

        return this;
    }

    /**
     * sets the list of topics that are related to the enriched topic list; 
     * call {@link LDAHandler#getRelatedTopicSet()} on the returned
     * LDAHandler to retrieve the list.
     * @return
     */
    public LDAHandler setRelatedTopicSet() {

        relatedSet = new ArrayList<Integer>();
        if (enrichedSet != null) {
            List<DBCursor> curList = model.getRelatedTopics(enrichedSet);
            for (DBCursor cur : curList) {
                List<Integer> temp = filterTopics(cur.toArray(), topicThreshold, TopicType.RELATED);
                relatedSet.add(temp.get(0));
                relatedSet.add(temp.get(1));
            }
        } else {
            System.err.println("Related Topics Set cannot be populated!");
            System.err.println("Ensure enriched topic set has been established.");
        }
        return this;
    }

    /**
     * sets the ngrams for the given topic; 
     * call {@link LDAHandler#getSelectedNgrams()} on the returned
     * LDAHandler to retrieve the list.
     * @param selectedTopic
     * @return
     */
    public LDAHandler setNgrams(Object selectedTopic) {
        //-|==================================================
        //-|1. Instantiate multicomparator for sorting ngrams
        //-|   by size and score
        //-|2. Get ngrams from model and sort them
        //-|==================================================
        List<Comparator<BasicDBObject>> comps = new ArrayList<Comparator<BasicDBObject>>();
        comps.add(new NgramSortBySize());
        comps.add(new NgramSortByScore());
        MultiComparator<BasicDBObject> multiComp = new MultiComparator<BasicDBObject>(comps);

        @SuppressWarnings("unchecked")
        List<BasicDBObject> topicNgrams = (List<BasicDBObject>) model.getNgrams(selectedTopic).get("ngrams");
        Collections.sort(topicNgrams, multiComp);
        List<BasicDBObject> allNgrams = topicNgrams;

        //-|=====================================
        //-|Extract top trigram and
        //-|top two (2) bigrams from sort ngrams
        //-|=====================================
        boolean tri = true;
        int biCount = 0;
        selectedNgrams = new ArrayList<BasicDBObject>();
        for (int i = 0; i < allNgrams.size(); i++) {
            if (tri && (Integer) allNgrams.get(i).get("size") == 3) {
                tri = false;
                selectedNgrams.add(allNgrams.get(i));
            } else if (biCount < 2) {
                if ((Integer) allNgrams.get(i).get("size") == 2) {
                    selectedNgrams.add(allNgrams.get(i));
                    ++biCount;
                }
            } else
                break;

        }
        return this;
    }

    /**
     * sets the unigrams for the given topic; 
     * call {@link LDAHandler#getSelectedUnigrams()} on the returned
     * LDAHandler to retrieve the list.
     * @param selectedTopic
     * @return
     */
    public LDAHandler setUnigrams(Object selectedTopic) {
        //-|===================================================
        //-|1. Get unigrams (probable words) for selected topic
        //-|2. Sort unigrams by probability
        //-|===================================================
        @SuppressWarnings("unchecked")
        List<DBObject> topicUnigrams = (List<DBObject>) model.getUnigrams(selectedTopic).get("words");
        Collections.sort(topicUnigrams, new TopicSortByProb());
        List<DBObject> allUnigrams = topicUnigrams;

        //-|==============================
        //-|Extract the best unigrams for
        //-|display and query expansion
        //-|==============================
        List<String> words = new ArrayList<String>(); //Store words for expansion
        DBObject obj;
        selectedUnigrams = new ArrayList<BasicDBObject>();
        for (int i = 0; i < 4; i++) {
            obj = allUnigrams.get(i);
            selectedUnigrams.add((BasicDBObject) obj);
            words.add(obj.get("word").toString());
        }
        try {
            obj = allUnigrams.get(4);
            words.add(obj.get("word").toString());
        } catch (Exception e) {
            System.err.println("Could NOT add the last (5th) term to the query expansion word list:");
            System.err.println("Only four (4) words will be used for query expansion.");
        }
        expansionWords.put((Integer) selectedTopic, words);
        return this;
    }

    /**
     * returns the list of enriched topics
     * @return
     */
    public List<Integer> getEnrichedTopicSet() {
        return enrichedSet;
    }

    /**
     * returns the a map of topics and associated expansion words
     * @return
     */
    public Map<Integer, List<String>> getAllExpansionWords() {
        return expansionWords;
    }

    /**
     * returns a list of expansion words for the given topic
     * @param topicID
     * @return
     */
    public List<String> getTopicExpansionWords(Integer topicID) {
        return expansionWords.get(topicID);
    }

    /**
     * returns a list of related topics
     * @return
     */
    public List<Integer> getRelatedTopicSet() {
        return relatedSet;
    }

    /**
     * returns a list of ngrams
     * @return
     */
    public List<BasicDBObject> getSelectedNgrams() {
        return selectedNgrams;
    }

    /**
     * returns a list of unigrams
     * @return
     */
    public List<BasicDBObject> getSelectedUnigrams() {
        return selectedUnigrams;
    }

    /**
     * returns the topic threshold value
     * @return
     */
    public double getTopicThreshold() {
        return topicThreshold;
    }

    /**
     * returns the model associated with the LDAHandler
     * @return
     */
    public LDAModel getModel() {
        return model;
    }

    /**
     * filter the list of given topics by removing those that are less than the provided threshold
     * @param topics the list of topic objects which consist of a key-value map (topic: id, prob: value)
     * @param threshold value use to remove topics that fall short
     * @param topicType enum value to use for distinguishing what type of topics are to be filtered (ENRICHED or RELATED)
     * @return
     */
    public List<Integer> filterTopics(List<DBObject> topics, double threshold, TopicType topicType) {
        List<Integer> semcoArgs = new ArrayList<Integer>(); //Store topic IDs to retrieve the semco values

        //Getting topic ID's from topics list
        Iterator<DBObject> topicsIter = topics.iterator();
        switch (topicType) {
        case ENRICHED:
            while (topicsIter.hasNext()) {
                semcoArgs.add((Integer) topicsIter.next().get("topic"));
            }
        case RELATED:
            while (topicsIter.hasNext()) {
                semcoArgs.add((Integer) topicsIter.next().get("cotopic"));
            }
        }
        //-|==========================================================
        //-|Retrieve the list of topicIDs in given list of topics 
        //-|   less than threshold (junk topics)
        //-|Remove junk topics from given list and return filter list
        //-|==========================================================
        DBCursor semcoCur = model.getTopicsLessThan(semcoArgs, threshold);
        while (semcoCur.hasNext()) {
            DBObject obj = semcoCur.next();
            semcoArgs.remove(obj.get("topic"));
        }

        return semcoArgs;

    }

    //-|=============================================================
    //-|Private classes used for sorting data retrieved from model
    //-|=============================================================
    /**
     * Comparator use to sort topic objects by probability
     */
    private static class TopicSortByProb implements Comparator<DBObject>, Serializable {

        private static final long serialVersionUID = 1L;

        public int compare(DBObject o1, DBObject o2) {
            Double d1 = (Double) o1.get("prob");
            Double d2 = (Double) o2.get("prob");

            return (d1.compareTo(d2)) * -1;
        }
    }

    /**
     * Comparator use to sort ngram objects by score 
     */
    private static class NgramSortByScore implements Comparator<BasicDBObject>, Serializable {

        private static final long serialVersionUID = 2L;

        public int compare(BasicDBObject o1, BasicDBObject o2) {
            Double d1 = o1.getDouble("score");
            Double d2 = o2.getDouble("score");

            return (d1.compareTo(d2)) * -1;
        }
    }

    /**
     * Comparator use to sort ngram objects by size
     */
    private static class NgramSortBySize implements Comparator<BasicDBObject>, Serializable {

        private static final long serialVersionUID = 3L;

        public int compare(BasicDBObject o1, BasicDBObject o2) {
            Integer i1 = o1.getInt("size");
            Integer i2 = o2.getInt("size");

            return (i1.compareTo(i2)) * -1;

        }
    }

    /**
     * Generic comparator list use to enable sorting using multiple comparators 
     */
    private static class MultiComparator<T> implements Comparator<T>, Serializable {

        private static final long serialVersionUID = 4L;

        private List<Comparator<T>> comparators;

        public MultiComparator(List<Comparator<T>> comparators) {
            this.comparators = comparators;
        }

        public int compare(T o1, T o2) {
            for (Comparator<T> comparator : comparators) {
                int comparison = comparator.compare(o1, o2);
                if (comparison != 0)
                    return comparison;
            }
            return 0;
        }
    }
}