Java tutorial
/* * This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ package models; import java.io.*; import java.util.*; import java.util.regex.Pattern; import javax.persistence.*; import javax.persistence.CascadeType; import cc.mallet.pipe.iterator.JsonIterator; import cc.mallet.topics.PancakeTopicInferencer; import com.avaje.ebean.Ebean; import org.apache.commons.lang.ArrayUtils; import org.codehaus.jackson.JsonNode; import org.codehaus.jackson.map.ObjectMapper; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.client.Client; import static org.elasticsearch.index.query.QueryBuilders.filteredQuery; import static org.elasticsearch.index.query.QueryBuilders.fuzzyQuery; import static org.elasticsearch.index.query.FilterBuilders.termFilter; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import play.Configuration; import play.Play; import play.db.ebean.Model; import cc.mallet.topics.PersistentParallelTopicModel; import cc.mallet.util.CharSequenceLexer; import cc.mallet.types.InstanceList; import cc.mallet.pipe.iterator.CsvIterator; import cc.mallet.pipe.Pipe; import cc.mallet.pipe.SerialPipes; import cc.mallet.pipe.CharSequence2TokenSequence; import cc.mallet.pipe.TokenSequenceLowercase; import cc.mallet.pipe.TokenSequenceRemoveStopwords; import cc.mallet.pipe.TokenSequence2FeatureSequence; /** * Created with IntelliJ IDEA. * User: oyiptong * Date: 2012-08-03 * Time: 5:59 PM */ @Entity @Table(name = "smarts_topic_model") public class TopicModel extends Model { @Id @GeneratedValue private long id; private double alpha; private double beta; @Column(name = "num_topics", nullable = false) private int numTopics; @Lob @Basic(fetch = FetchType.EAGER) private byte[] model; @Lob @Basic(fetch = FetchType.EAGER) private byte[] inferencer; @Lob @Column(name = "feature_sequence", nullable = false) private byte[] featureSequence; @Column(length = 255) private String name; @OneToMany(cascade = CascadeType.ALL, mappedBy = "topicModel") private List<Topic> topics; @OneToMany(cascade = CascadeType.ALL, mappedBy = "topicModel") private List<Document> documents; @Transient private PersistentParallelTopicModel malletTopicModel; @Transient private InstanceList currentInstanceList; // Getters & Setters public long getId() { return id; } public double getAlpha() { return alpha; } public double getBeta() { return beta; } public int getNumTopics() { return numTopics; } public String getName() { return name; } public void setName(String name) { this.name = name; } public List<Topic> getTopics() { return topics; } public List<Document> getDocuments() { return documents; } /* * Standard pipe configuration for LDA modelling */ public static SerialPipes getStandardPipes() { ArrayList<Pipe> pipeList = new ArrayList<Pipe>(); Pattern tokenPattern = Pattern.compile(CharSequenceLexer.LEX_ALPHA.toString()); pipeList.add(new CharSequence2TokenSequence(tokenPattern)); pipeList.add(new TokenSequenceLowercase()); pipeList.add(new TokenSequenceRemoveStopwords(false, false)); pipeList.add(new TokenSequence2FeatureSequence()); return new SerialPipes(pipeList); } public TopicModel(String name, int numTopics, double alpha, double beta, Reader dataReader) throws Exception { this.name = name; this.alpha = alpha; this.beta = beta; this.numTopics = numTopics; TopicModel named_model = TopicModel.find.where().eq("name", name).findUnique(); if (named_model != null) { // TODO: a better exception. Also, handle concurrency throw new Exception("A model of that name already exists"); } // convert input to vectors Pipe instancePipe = getStandardPipes(); InstanceList instances = new InstanceList(instancePipe); instances.addThruPipe( new CsvIterator(dataReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"), 3, 2, 1)); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos); oos.writeObject(instances); oos.close(); this.featureSequence = baos.toByteArray(); Configuration config = Play.application().configuration(); // train model malletTopicModel = new PersistentParallelTopicModel(this.numTopics, this.alpha, this.beta); malletTopicModel.addInstances(instances); malletTopicModel.setNumIterations(config.getInt("smarts.topicModel.numIterations")); malletTopicModel.setOptimizeInterval(config.getInt("smarts.topicModel.optimizeIntervals")); malletTopicModel.setBurninPeriod(config.getInt("smarts.topicModel.burnInPeriod")); malletTopicModel.setSymmetricAlpha(config.getBoolean("smarts.topicModel.symmetricAlpha")); malletTopicModel.setNumThreads(config.getInt("smarts.topicModel.numThreads")); malletTopicModel.estimate(); } public void saveObjectGraph() throws Exception { Ebean.beginTransaction(); Configuration config = Play.application().configuration(); try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos); oos.writeObject(malletTopicModel); oos.close(); model = baos.toByteArray(); baos = new ByteArrayOutputStream(); oos = new ObjectOutputStream(baos); oos.writeObject(malletTopicModel.getInferencer()); oos.close(); inferencer = baos.toByteArray(); ArrayList<Topic> topicList = new ArrayList<Topic>(); // create topics Object[][] topicWords = malletTopicModel.getTopWords(config.getInt("smarts.topicModel.numTopWords")); for (int topicNum = 0; topicNum < this.numTopics; topicNum++) { StringBuilder wordList = new StringBuilder(); Object[] words = topicWords[topicNum]; for (Object w : words) { wordList.append(w); wordList.append(" "); } if (wordList.length() > 0) { // remove trailing space wordList.deleteCharAt(wordList.length() - 1); } Topic topic = new Topic(topicNum, wordList.toString()); topics.add(topic); topicList.add(topic); } // create documents PancakeTopicInferencer inferencer = malletTopicModel.getInferencer(); InstanceList docVectors = getDocumentVectors(); // Only record the n most significant topics List<List> orderedDistributions = inferencer.inferSortedDistributions(docVectors, config.getInt("smarts.inference.numIterations"), config.getInt("smarts.inference.thinning"), config.getInt("smarts.inference.burnInPeriod"), Double.parseDouble(config.getString("smarts.inference.threshold")), config.getInt("smarts.inference.numSignificantFeatures")); for (int docIndex = 0; docIndex < orderedDistributions.size(); docIndex++) { List docData = orderedDistributions.get(docIndex); String docName = (String) docData.get(0); double[] docTopWeights = generateTopTopicWeightVector(docIndex, orderedDistributions); Document doc = new Document(docName, docTopWeights); documents.add(doc); getDocuments().add(doc); } Ebean.save(this); Ebean.save(topics); // loop to save so the save hook is called for (Document doc : documents) { doc.save(); } Ebean.commitTransaction(); } finally { Ebean.endTransaction(); } } public static Finder<Long, TopicModel> find = new Finder<Long, TopicModel>(Long.class, TopicModel.class); public static TopicModel fetch(String name) throws Exception { TopicModel model = TopicModel.find.where().eq("name", name).findUnique(); model.malletTopicModel = PersistentParallelTopicModel.read(model.model); return model; } protected InstanceList getInferenceVectors(JsonNode docs) throws IOException, ClassNotFoundException { InstanceList docVectors = getDocumentVectors(); Pipe instancePipe = docVectors.getPipe(); InstanceList newInstances = new InstanceList(instancePipe); newInstances.addThruPipe(new JsonIterator(docs)); return newInstances; } protected InstanceList getDocumentVectors() throws IOException, ClassNotFoundException { if (currentInstanceList == null) { ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(featureSequence)); currentInstanceList = (InstanceList) ois.readObject(); ois.close(); } return currentInstanceList; } public Map<String, List<String>> inferString(JsonNode jsonData, int maxTopics) throws ClassNotFoundException, IOException { PancakeTopicInferencer inferencer = malletTopicModel.getInferencer(); InstanceList instances = getInferenceVectors(jsonData); List<Topic> topics = Topic.find.where().eq("topic_model_id", getId()).orderBy("number ASC").findList(); Configuration config = Play.application().configuration(); List<List> distributions = inferencer.inferSortedDistributions(instances, config.getInt("smarts.inference.numIterations"), config.getInt("smarts.inference.thinning"), config.getInt("smarts.inference.burnInPeriod"), Double.parseDouble(config.getString("smarts.inference.threshold")), maxTopics); Map<String, List<String>> output = new HashMap<String, List<String>>(); for (int docIndex = 0; docIndex < distributions.size(); docIndex++) { List docData = distributions.get(docIndex); List<List> topicDist = (List<List>) docData.get(1); List<String> docTopicWords = new ArrayList<String>(); for (int topicIndex = 0; topicIndex < maxTopics; topicIndex++) { List topicData = topicDist.get(topicIndex); int topicId = ((Integer) topicData.get(0)).intValue(); Topic topic = topics.get(topicId); docTopicWords.add(topic.getWordSample()); } output.put((String) docData.get(0), docTopicWords); } return output; } public List recommend(JsonNode jsonData, int maxTopics, int maxRecommendations) throws ClassNotFoundException, IOException, InterruptedException { PancakeTopicInferencer inferencer = malletTopicModel.getInferencer(); InstanceList inferenceVectors = getInferenceVectors(jsonData); Configuration config = Play.application().configuration(); List<List> distributions = inferencer.inferDistributions(inferenceVectors, config.getInt("smarts.inference.numIterations"), config.getInt("smarts.inference.thinning"), config.getInt("smarts.inference.burnInPeriod"), Double.parseDouble(config.getString("smarts.inference.threshold"))); List<List> inferenceOrderedDistribution = inferencer.inferSortedDistributions(inferenceVectors, config.getInt("smarts.inference.numIterations"), config.getInt("smarts.inference.thinning"), config.getInt("smarts.inference.burnInPeriod"), Double.parseDouble(config.getString("smarts.inference.threshold")), Math.max(maxTopics, config.getInt("smarts.inference.numSignificantFeatures"))); // output containers List output = new ArrayList(2); List<String> inferredWords = new ArrayList<String>(); List<String> distributionDesc = new ArrayList<String>(); List<String> recommendations = new ArrayList<String>(); Set<String> allRecommendations = new HashSet<String>(); // for each document for (int distIndex = 0; distIndex < inferenceOrderedDistribution.size(); distIndex++) { List docTopTopics = inferenceOrderedDistribution.get(distIndex); List<List> topicDist = (List<List>) docTopTopics.get(1); // obtain textual topic distribution info List<String> docTopicWords = new ArrayList<String>(); List<String> docTopicWeightDesc = new ArrayList<String>(); for (int topicIndex = 0; topicIndex < maxTopics; topicIndex++) { // obtain textual topic distribution info List topicData = topicDist.get(topicIndex); int topicNumber = ((Integer) topicData.get(0)).intValue(); double topicWeight = ((Double) topicData.get(1)).doubleValue(); Topic topic = topics.get(topicNumber); docTopicWords.add(topic.getWordSample()); docTopicWeightDesc.add(String.format("topic #%d match: %.2f%%", topicNumber, topicWeight)); } inferredWords = docTopicWords; distributionDesc = docTopicWeightDesc; double[] docTopWeights = generateTopTopicWeightVector(distIndex, inferenceOrderedDistribution); ObjectMapper mapper = new ObjectMapper(); // 100 dimensions to match projection indexing String signature = RandomProjection.projectString(docTopWeights, config.getInt("smarts.lsh.numBits")); ElasticSearch es = ElasticSearch.getElasticSearch(); Client esClient = es.getClient(); /* SearchResponse response = esClient.prepareSearch("pancake-smarts") .setTypes("document") .setQuery( fuzzyQuery("features_bits", signature).minSimilarity((float) 0.6) ) .setFrom(0).setSize(maxRecommendations) .execute() .actionGet(); */ SearchResponse response = esClient.prepareSearch("pancake-smarts").setTypes("document") .setQuery(filteredQuery(fuzzyQuery("features_bits", signature).minSimilarity((float) 0.6), termFilter("topic_model_id", this.getId()))) .setFrom(0).setSize(maxRecommendations).execute().actionGet(); SearchHits hits = response.getHits(); SearchHit[] hitArray = hits.getHits(); long[] hitIds = new long[maxRecommendations]; for (int hitIndex = 0; hitIndex < hitArray.length; hitIndex++) { SearchHit hit = hitArray[hitIndex]; hitIds[hitIndex] = Long.parseLong(hit.getId()); } List<Document> recommendedDocs = Ebean.find(Document.class).where() .in("id", ArrayUtils.toObject(hitIds)).findList(); for (Document doc : recommendedDocs) { if (!allRecommendations.contains(doc.getUrl())) { recommendations.add(doc.getUrl()); allRecommendations.add(doc.getUrl()); } } } output.add(inferredWords); output.add(recommendations); output.add(distributionDesc); return output; } public double[] generateTopTopicWeightVector(int docIndex, List<List> sortedDistribution) { double[] docWeights = new double[numTopics]; List<List> topTopics = (List<List>) sortedDistribution.get(docIndex).get(1); for (List topicData : topTopics) { int topicNum = ((Integer) topicData.get(0)).intValue(); double weight = ((Double) topicData.get(1)).doubleValue(); docWeights[topicNum] = weight; } return docWeights; } }