models.TopicModel.java Source code

Java tutorial

Introduction

Here is the source code for models.TopicModel.java

Source

/*
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 */

package models;

import java.io.*;
import java.util.*;
import java.util.regex.Pattern;
import javax.persistence.*;
import javax.persistence.CascadeType;

import cc.mallet.pipe.iterator.JsonIterator;
import cc.mallet.topics.PancakeTopicInferencer;
import com.avaje.ebean.Ebean;
import org.apache.commons.lang.ArrayUtils;
import org.codehaus.jackson.JsonNode;
import org.codehaus.jackson.map.ObjectMapper;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;

import static org.elasticsearch.index.query.QueryBuilders.filteredQuery;
import static org.elasticsearch.index.query.QueryBuilders.fuzzyQuery;
import static org.elasticsearch.index.query.FilterBuilders.termFilter;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;

import play.Configuration;
import play.Play;
import play.db.ebean.Model;

import cc.mallet.topics.PersistentParallelTopicModel;
import cc.mallet.util.CharSequenceLexer;
import cc.mallet.types.InstanceList;
import cc.mallet.pipe.iterator.CsvIterator;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.CharSequence2TokenSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.pipe.TokenSequenceRemoveStopwords;
import cc.mallet.pipe.TokenSequence2FeatureSequence;

/**
 * Created with IntelliJ IDEA.
 * User: oyiptong
 * Date: 2012-08-03
 * Time: 5:59 PM
 */

@Entity
@Table(name = "smarts_topic_model")
public class TopicModel extends Model {

    @Id
    @GeneratedValue
    private long id;

    private double alpha;

    private double beta;

    @Column(name = "num_topics", nullable = false)
    private int numTopics;

    @Lob
    @Basic(fetch = FetchType.EAGER)
    private byte[] model;

    @Lob
    @Basic(fetch = FetchType.EAGER)
    private byte[] inferencer;

    @Lob
    @Column(name = "feature_sequence", nullable = false)
    private byte[] featureSequence;

    @Column(length = 255)
    private String name;

    @OneToMany(cascade = CascadeType.ALL, mappedBy = "topicModel")
    private List<Topic> topics;

    @OneToMany(cascade = CascadeType.ALL, mappedBy = "topicModel")
    private List<Document> documents;

    @Transient
    private PersistentParallelTopicModel malletTopicModel;

    @Transient
    private InstanceList currentInstanceList;

    // Getters & Setters

    public long getId() {
        return id;
    }

    public double getAlpha() {
        return alpha;
    }

    public double getBeta() {
        return beta;
    }

    public int getNumTopics() {
        return numTopics;
    }

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public List<Topic> getTopics() {
        return topics;
    }

    public List<Document> getDocuments() {
        return documents;
    }

    /*
     * Standard pipe configuration for LDA modelling
     */
    public static SerialPipes getStandardPipes() {
        ArrayList<Pipe> pipeList = new ArrayList<Pipe>();

        Pattern tokenPattern = Pattern.compile(CharSequenceLexer.LEX_ALPHA.toString());
        pipeList.add(new CharSequence2TokenSequence(tokenPattern));

        pipeList.add(new TokenSequenceLowercase());

        pipeList.add(new TokenSequenceRemoveStopwords(false, false));

        pipeList.add(new TokenSequence2FeatureSequence());

        return new SerialPipes(pipeList);
    }

    public TopicModel(String name, int numTopics, double alpha, double beta, Reader dataReader) throws Exception {
        this.name = name;
        this.alpha = alpha;
        this.beta = beta;
        this.numTopics = numTopics;

        TopicModel named_model = TopicModel.find.where().eq("name", name).findUnique();

        if (named_model != null) {
            // TODO: a better exception. Also, handle concurrency
            throw new Exception("A model of that name already exists");
        }

        // convert input to vectors
        Pipe instancePipe = getStandardPipes();
        InstanceList instances = new InstanceList(instancePipe);
        instances.addThruPipe(
                new CsvIterator(dataReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"), 3, 2, 1));

        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ObjectOutputStream oos = new ObjectOutputStream(baos);
        oos.writeObject(instances);
        oos.close();
        this.featureSequence = baos.toByteArray();

        Configuration config = Play.application().configuration();
        // train model
        malletTopicModel = new PersistentParallelTopicModel(this.numTopics, this.alpha, this.beta);
        malletTopicModel.addInstances(instances);
        malletTopicModel.setNumIterations(config.getInt("smarts.topicModel.numIterations"));
        malletTopicModel.setOptimizeInterval(config.getInt("smarts.topicModel.optimizeIntervals"));
        malletTopicModel.setBurninPeriod(config.getInt("smarts.topicModel.burnInPeriod"));
        malletTopicModel.setSymmetricAlpha(config.getBoolean("smarts.topicModel.symmetricAlpha"));
        malletTopicModel.setNumThreads(config.getInt("smarts.topicModel.numThreads"));

        malletTopicModel.estimate();
    }

    public void saveObjectGraph() throws Exception {
        Ebean.beginTransaction();
        Configuration config = Play.application().configuration();
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ObjectOutputStream oos = new ObjectOutputStream(baos);
            oos.writeObject(malletTopicModel);
            oos.close();
            model = baos.toByteArray();

            baos = new ByteArrayOutputStream();
            oos = new ObjectOutputStream(baos);
            oos.writeObject(malletTopicModel.getInferencer());
            oos.close();
            inferencer = baos.toByteArray();

            ArrayList<Topic> topicList = new ArrayList<Topic>();

            // create topics
            Object[][] topicWords = malletTopicModel.getTopWords(config.getInt("smarts.topicModel.numTopWords"));
            for (int topicNum = 0; topicNum < this.numTopics; topicNum++) {
                StringBuilder wordList = new StringBuilder();
                Object[] words = topicWords[topicNum];

                for (Object w : words) {
                    wordList.append(w);
                    wordList.append(" ");
                }
                if (wordList.length() > 0) {
                    // remove trailing space
                    wordList.deleteCharAt(wordList.length() - 1);
                }
                Topic topic = new Topic(topicNum, wordList.toString());
                topics.add(topic);

                topicList.add(topic);
            }

            // create documents
            PancakeTopicInferencer inferencer = malletTopicModel.getInferencer();
            InstanceList docVectors = getDocumentVectors();

            // Only record the n most significant topics
            List<List> orderedDistributions = inferencer.inferSortedDistributions(docVectors,
                    config.getInt("smarts.inference.numIterations"), config.getInt("smarts.inference.thinning"),
                    config.getInt("smarts.inference.burnInPeriod"),
                    Double.parseDouble(config.getString("smarts.inference.threshold")),
                    config.getInt("smarts.inference.numSignificantFeatures"));

            for (int docIndex = 0; docIndex < orderedDistributions.size(); docIndex++) {
                List docData = orderedDistributions.get(docIndex);
                String docName = (String) docData.get(0);

                double[] docTopWeights = generateTopTopicWeightVector(docIndex, orderedDistributions);
                Document doc = new Document(docName, docTopWeights);

                documents.add(doc);
                getDocuments().add(doc);
            }

            Ebean.save(this);
            Ebean.save(topics);

            // loop to save so the save hook is called
            for (Document doc : documents) {
                doc.save();
            }

            Ebean.commitTransaction();

        } finally {
            Ebean.endTransaction();
        }
    }

    public static Finder<Long, TopicModel> find = new Finder<Long, TopicModel>(Long.class, TopicModel.class);

    public static TopicModel fetch(String name) throws Exception {
        TopicModel model = TopicModel.find.where().eq("name", name).findUnique();

        model.malletTopicModel = PersistentParallelTopicModel.read(model.model);
        return model;
    }

    protected InstanceList getInferenceVectors(JsonNode docs) throws IOException, ClassNotFoundException {
        InstanceList docVectors = getDocumentVectors();

        Pipe instancePipe = docVectors.getPipe();

        InstanceList newInstances = new InstanceList(instancePipe);
        newInstances.addThruPipe(new JsonIterator(docs));

        return newInstances;

    }

    protected InstanceList getDocumentVectors() throws IOException, ClassNotFoundException {
        if (currentInstanceList == null) {
            ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(featureSequence));
            currentInstanceList = (InstanceList) ois.readObject();
            ois.close();
        }
        return currentInstanceList;
    }

    public Map<String, List<String>> inferString(JsonNode jsonData, int maxTopics)
            throws ClassNotFoundException, IOException {
        PancakeTopicInferencer inferencer = malletTopicModel.getInferencer();
        InstanceList instances = getInferenceVectors(jsonData);

        List<Topic> topics = Topic.find.where().eq("topic_model_id", getId()).orderBy("number ASC").findList();

        Configuration config = Play.application().configuration();

        List<List> distributions = inferencer.inferSortedDistributions(instances,
                config.getInt("smarts.inference.numIterations"), config.getInt("smarts.inference.thinning"),
                config.getInt("smarts.inference.burnInPeriod"),
                Double.parseDouble(config.getString("smarts.inference.threshold")), maxTopics);

        Map<String, List<String>> output = new HashMap<String, List<String>>();

        for (int docIndex = 0; docIndex < distributions.size(); docIndex++) {
            List docData = distributions.get(docIndex);
            List<List> topicDist = (List<List>) docData.get(1);

            List<String> docTopicWords = new ArrayList<String>();
            for (int topicIndex = 0; topicIndex < maxTopics; topicIndex++) {
                List topicData = topicDist.get(topicIndex);
                int topicId = ((Integer) topicData.get(0)).intValue();
                Topic topic = topics.get(topicId);
                docTopicWords.add(topic.getWordSample());
            }
            output.put((String) docData.get(0), docTopicWords);
        }

        return output;
    }

    public List recommend(JsonNode jsonData, int maxTopics, int maxRecommendations)
            throws ClassNotFoundException, IOException, InterruptedException {
        PancakeTopicInferencer inferencer = malletTopicModel.getInferencer();
        InstanceList inferenceVectors = getInferenceVectors(jsonData);

        Configuration config = Play.application().configuration();

        List<List> distributions = inferencer.inferDistributions(inferenceVectors,
                config.getInt("smarts.inference.numIterations"), config.getInt("smarts.inference.thinning"),
                config.getInt("smarts.inference.burnInPeriod"),
                Double.parseDouble(config.getString("smarts.inference.threshold")));

        List<List> inferenceOrderedDistribution = inferencer.inferSortedDistributions(inferenceVectors,
                config.getInt("smarts.inference.numIterations"), config.getInt("smarts.inference.thinning"),
                config.getInt("smarts.inference.burnInPeriod"),
                Double.parseDouble(config.getString("smarts.inference.threshold")),
                Math.max(maxTopics, config.getInt("smarts.inference.numSignificantFeatures")));

        // output containers
        List output = new ArrayList(2);
        List<String> inferredWords = new ArrayList<String>();
        List<String> distributionDesc = new ArrayList<String>();
        List<String> recommendations = new ArrayList<String>();
        Set<String> allRecommendations = new HashSet<String>();

        // for each document
        for (int distIndex = 0; distIndex < inferenceOrderedDistribution.size(); distIndex++) {
            List docTopTopics = inferenceOrderedDistribution.get(distIndex);
            List<List> topicDist = (List<List>) docTopTopics.get(1);

            // obtain textual topic distribution info
            List<String> docTopicWords = new ArrayList<String>();
            List<String> docTopicWeightDesc = new ArrayList<String>();
            for (int topicIndex = 0; topicIndex < maxTopics; topicIndex++) {
                // obtain textual topic distribution info
                List topicData = topicDist.get(topicIndex);
                int topicNumber = ((Integer) topicData.get(0)).intValue();
                double topicWeight = ((Double) topicData.get(1)).doubleValue();

                Topic topic = topics.get(topicNumber);
                docTopicWords.add(topic.getWordSample());
                docTopicWeightDesc.add(String.format("topic #%d match: %.2f%%", topicNumber, topicWeight));
            }
            inferredWords = docTopicWords;
            distributionDesc = docTopicWeightDesc;

            double[] docTopWeights = generateTopTopicWeightVector(distIndex, inferenceOrderedDistribution);

            ObjectMapper mapper = new ObjectMapper();

            // 100 dimensions to match projection indexing
            String signature = RandomProjection.projectString(docTopWeights, config.getInt("smarts.lsh.numBits"));

            ElasticSearch es = ElasticSearch.getElasticSearch();
            Client esClient = es.getClient();

            /*
            SearchResponse response = esClient.prepareSearch("pancake-smarts")
            .setTypes("document")
            .setQuery(
                    fuzzyQuery("features_bits", signature).minSimilarity((float) 0.6)
            )
            .setFrom(0).setSize(maxRecommendations)
            .execute()
            .actionGet();
            */
            SearchResponse response = esClient.prepareSearch("pancake-smarts").setTypes("document")
                    .setQuery(filteredQuery(fuzzyQuery("features_bits", signature).minSimilarity((float) 0.6),
                            termFilter("topic_model_id", this.getId())))
                    .setFrom(0).setSize(maxRecommendations).execute().actionGet();

            SearchHits hits = response.getHits();
            SearchHit[] hitArray = hits.getHits();

            long[] hitIds = new long[maxRecommendations];

            for (int hitIndex = 0; hitIndex < hitArray.length; hitIndex++) {
                SearchHit hit = hitArray[hitIndex];
                hitIds[hitIndex] = Long.parseLong(hit.getId());
            }

            List<Document> recommendedDocs = Ebean.find(Document.class).where()
                    .in("id", ArrayUtils.toObject(hitIds)).findList();
            for (Document doc : recommendedDocs) {
                if (!allRecommendations.contains(doc.getUrl())) {
                    recommendations.add(doc.getUrl());
                    allRecommendations.add(doc.getUrl());
                }
            }
        }
        output.add(inferredWords);
        output.add(recommendations);
        output.add(distributionDesc);
        return output;
    }

    public double[] generateTopTopicWeightVector(int docIndex, List<List> sortedDistribution) {
        double[] docWeights = new double[numTopics];

        List<List> topTopics = (List<List>) sortedDistribution.get(docIndex).get(1);
        for (List topicData : topTopics) {
            int topicNum = ((Integer) topicData.get(0)).intValue();
            double weight = ((Double) topicData.get(1)).doubleValue();
            docWeights[topicNum] = weight;
        }

        return docWeights;
    }
}