Java tutorial
package Model; import hoang.larc.tooler.RankingTool; import hoang.larc.tooler.SystemTool; import hoang.larc.tooler.WeightedElement; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.FileReader; import java.io.FileWriter; import java.util.HashMap; import java.util.Random; import java.util.Scanner; import org.apache.commons.io.FilenameUtils; /* Copyright (c) 2016 Roy Ka-Wei LEE and Tuan-Anh HOANG * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * */ public class MultiPlatformLDA { public String dataPath; public String outputPath; public int nTopics; public int nPlatforms; public int modelType; public int burningPeriod; public int maxIteration; public int samplingGap; public int testBatch; public Random rand; public boolean toOutputLikelihoodPerplexity;// option to output likelihood // of train data and perplexity // of test data public boolean toOutputTopicTopPosts;// option to output top posts of each // topic public boolean toOutputInferedTopics;// option to output (all) posts' // inferred topics public boolean toOutputInferedPlatforms;// option to output inferred // platform for posts in test batch // hyperparameters private double alpha; private double sum_alpha; private double beta; private double sum_beta; private double[] gamma; private double sum_gamma; private double mu; private double sum_mu; // data private User[] users; private String[] vocabulary; // parameters private double[][] topics; private double[] backgroundTopic; private double[] coinBias; private double[][] globalTopicPlatformDistribution;// only used for model 2 // Gibbs sampling variables // user-topic count private int[][] n_zu; // n_zu[z,u]: number of times topic z is observed in // posts by user u private int[] sum_nzu; // sum_nzu[u]: total number of topics that are // observed in posts by user u // topic-word count private int[][] n_wz; // n_wz[w,z]: number of times word w is generated by a // topic z in all tweets private int[] sum_nwz; // sum_nw[z]: total number of words that are // generated by a topic z in tweets private int[] n_wb; // n_wz[w]: number of times word w is generated by a // background topic private int sum_nwb; // sum_nw[z]: total number of words that are generated // by background topic // topic-coin count private int[] n_c; // sum_nw[c]: total number of words that are associated // with coin c private int sum_nc; // user-topic-platform count: use by model-1 private int[][][] n_pu; // n_p[p,u,z]: number of times platform p is // selected for a topic z that is observed in posts // by user u private int[][] sum_npu; // sum_npu[u,z]: number of times topic z is // observed in posts by user u (This is same as // sum_nz) // global-topic-platform count: use by model-2 private int[][] n_p; // n_p[p,z]: number of times platform p is selected for // a topic z that is observed in all posts private int[] sum_np; // sum_npu[z]: number of times topic z is observed in // all posts private int[][] final_n_zu; private int[] final_sum_nzu; private int[][] final_n_wz; private int[] final_sum_nwz; private int[] final_n_wb; private int final_sum_nwb; private int[] final_n_c; private int final_sum_nc; private int[][][] final_n_pu; private int[][] final_sum_npu; private int[][] final_n_p; private int[] final_sum_np; private double postLogLikelidhood; private double postContentLogLikelidhood; private double postLogPerplexity; private double postContentPerplexity; public void readData() { Scanner sc = null; BufferedReader br = null; String line = null; HashMap<String, Integer> userId2Index = null; HashMap<Integer, String> userIndex2Id = null; try { String folderName = dataPath + "/users"; File postFolder = new File(folderName); // Read number of users int nUser = postFolder.listFiles().length; users = new User[nUser]; userId2Index = new HashMap<String, Integer>(nUser); userIndex2Id = new HashMap<Integer, String>(nUser); int u = -1; // Read the posts from each user file for (File postFile : postFolder.listFiles()) { u++; users[u] = new User(); // Read index of the user String userId = FilenameUtils.removeExtension(postFile.getName()); userId2Index.put(userId, u); userIndex2Id.put(u, userId); users[u].userID = userId; // Read the number of posts from user int nPost = 0; br = new BufferedReader(new FileReader(postFile.getAbsolutePath())); while (br.readLine() != null) { nPost++; } br.close(); // Declare the number of posts from user users[u].posts = new Post[nPost]; // Read each of the post br = new BufferedReader(new FileReader(postFile.getAbsolutePath())); int j = -1; while ((line = br.readLine()) != null) { j++; users[u].posts[j] = new Post(); sc = new Scanner(line.toString()); sc.useDelimiter(","); while (sc.hasNext()) { users[u].posts[j].postID = sc.next(); users[u].posts[j].platform = sc.nextInt(); users[u].posts[j].batch = sc.nextInt(); // Read the words in each post String[] tokens = sc.next().toString().split(" "); users[u].posts[j].words = new int[tokens.length]; for (int i = 0; i < tokens.length; i++) { users[u].posts[j].words[i] = Integer.parseInt(tokens[i]); } } } br.close(); } // Read post vocabulary String vocabularyFileName = dataPath + "/vocabulary.csv"; br = new BufferedReader(new FileReader(vocabularyFileName)); int nPostWord = 0; while (br.readLine() != null) { nPostWord++; } br.close(); vocabulary = new String[nPostWord]; br = new BufferedReader(new FileReader(vocabularyFileName)); while ((line = br.readLine()) != null) { String[] tokens = line.split(","); int index = Integer.parseInt(tokens[0]); vocabulary[index] = tokens[1]; } br.close(); } catch (Exception e) { System.out.println("Error in reading post from file!"); e.printStackTrace(); System.exit(0); } } private void declareFinalCounts() { // Final array of user-topics final_n_zu = new int[nTopics][users.length]; final_sum_nzu = new int[users.length]; for (int u = 0; u < users.length; u++) { for (int z = 0; z < nTopics; z++) { final_n_zu[z][u] = 0; } final_sum_nzu[u] = 0; } // Final array of topic-words which are generated by topic z final_n_wz = new int[vocabulary.length][nTopics]; final_sum_nwz = new int[nTopics]; for (int z = 0; z < nTopics; z++) { for (int w = 0; w < vocabulary.length; w++) { final_n_wz[w][z] = 0; } final_sum_nwz[z] = 0; } // Final array of topic-words which are generated by background topic final_n_wb = new int[vocabulary.length]; for (int w = 0; w < vocabulary.length; w++) { final_n_wb[w] = 0; } final_sum_nwb = 0; // Final array of coins. Default array size is 2, representing 2 sides // of coin final_n_c = new int[2]; for (int c = 0; c < 2; c++) { final_n_c[c] = 0; } final_sum_nc = 0; // Implementation for MultiPlatformLDA-UserSpecific if (modelType == ModelType.USER_SPECIFIC) { // Final array of user-topic-platforms final_n_pu = new int[users.length][nTopics][nPlatforms]; final_sum_npu = new int[users.length][nTopics]; for (int u = 0; u < users.length; u++) { for (int z = 0; z < nTopics; z++) { for (int p = 0; p < nPlatforms; p++) { final_n_pu[u][z][p] = 0; } final_sum_npu[u][z] = 0; } } } else { // Implementation for MultiPlatformLDA-Global // Final array of global-topic-platforms final_n_p = new int[nTopics][nPlatforms]; final_sum_np = new int[nTopics]; for (int z = 0; z < nTopics; z++) { for (int p = 0; p < nPlatforms; p++) { final_n_p[z][p] = 0; } final_sum_np[z] = 0; } } } private void initilize() { // Init coin, topic and platform for each post for (int u = 0; u < users.length; u++) { for (int j = 0; j < users[u].posts.length; j++) { // Check that this is only for training if (users[u].posts[j].batch != testBatch) { // Randomly assign topic to the post users[u].posts[j].topic = rand.nextInt(nTopics); // Declare size of coins int nWords = users[u].posts[j].words.length; users[u].posts[j].coins = new int[nWords]; // Randomly assign coin to word for (int i = 0; i < users[u].posts[j].coins.length; i++) users[u].posts[j].coins[i] = rand.nextInt(2); } } } // Declare and initiate counting tables // Init for user-topics n_zu = new int[nTopics][users.length]; sum_nzu = new int[users.length]; for (int u = 0; u < users.length; u++) { for (int z = 0; z < nTopics; z++) { n_zu[z][u] = 0; } sum_nzu[u] = 0; } // Init for topic-word for topic z n_wz = new int[vocabulary.length][nTopics]; sum_nwz = new int[nTopics]; for (int z = 0; z < nTopics; z++) { for (int w = 0; w < vocabulary.length; w++) { n_wz[w][z] = 0; } sum_nwz[z] = 0; } // Init for topic-word for background topic n_wb = new int[vocabulary.length]; for (int w = 0; w < vocabulary.length; w++) { n_wb[w] = 0; } sum_nwb = 0; // Init for topic-coin n_c = new int[2]; n_c[0] = 0; n_c[1] = 0; sum_nc = 0; // Implementation for MultiPlatformLDA-UserSpecific if (modelType == ModelType.USER_SPECIFIC) { // Init for user-topic-platform n_pu = new int[users.length][nTopics][nPlatforms]; sum_npu = new int[users.length][nTopics]; for (int u = 0; u < users.length; u++) { for (int z = 0; z < nTopics; z++) { for (int p = 0; p < nPlatforms; p++) { n_pu[u][z][p] = 0; } sum_npu[u][z] = 0; } } } else { // Implementation for MultiPlatformLDA-Global // Init for global-topic-platform n_p = new int[nTopics][nPlatforms]; sum_np = new int[nTopics]; for (int z = 0; z < nTopics; z++) { for (int p = 0; p < nPlatforms; p++) { n_p[z][p] = 0; } sum_np[z] = 0; } } // Update counting tables for (int u = 0; u < users.length; u++) { for (int j = 0; j < users[u].posts.length; j++) { // Training batch if (users[u].posts[j].batch != testBatch) { int z = users[u].posts[j].topic; // Update user-topic counts n_zu[z][u]++; sum_nzu[u]++; // Implementation for MultiPlatformLDA-UserSpecific if (modelType == ModelType.USER_SPECIFIC) { // Update user-topic-platform int p = users[u].posts[j].platform; n_pu[u][z][p]++; sum_npu[u][z]++; } else { // Implementation for MultiPlatformLDA-Global // Update global-topic-platform int p = users[u].posts[j].platform; n_p[z][p]++; sum_np[z]++; } for (int i = 0; i < users[u].posts[j].words.length; i++) { int w = users[u].posts[j].words[i]; int c = users[u].posts[j].coins[i]; // Update coin count n_c[c]++; sum_nc++; if (c == 0) { // Update background topic word count n_wb[w]++; sum_nwb++; } else { // Update background topic z word count n_wz[w][z]++; sum_nwz[z]++; } } } } } } private void setPriors() { // User topic prior alpha = 100.0 / nTopics; sum_alpha = 100; // topic platform prior // mu = 0.1; // sum_mu = 0.1 * nPlatforms; mu = 30.0 / nPlatforms; sum_mu = 30; // Topic tweet word prior beta = 0.01; sum_beta = 0.01 * vocabulary.length; // Biased coin prior gamma = new double[2]; gamma[0] = 2; gamma[1] = 2; sum_gamma = gamma[0] + gamma[1]; } // Sample the topic for post number j of user number u private void samplePostTopic(int u, int j) { // Implementation for MultiPlatformLDA-UserSpecific if (modelType == ModelType.USER_SPECIFIC) { // Get current topic int currz = users[u].posts[j].topic; // Decrement user-topic count and sum n_zu[currz][u]--; sum_nzu[u]--; // Decrement user-topic-platform count and sum int currp = users[u].posts[j].platform; n_pu[u][currz][currp]--; sum_npu[u][currz]--; for (int i = 0; i < users[u].posts[j].words.length; i++) { // Only consider the word belong to the topic not background if (users[u].posts[j].coins[i] == 1) { // Decrement topic-word count and sum int w = users[u].posts[j].words[i]; n_wz[w][currz]--; sum_nwz[currz]--; } } double sump = 0; // p: p(z_u,s = z| rest) double[] p = new double[nTopics]; for (int z = 0; z < nTopics; z++) { // User-topic p[z] = (n_zu[z][u] + alpha) / (sum_nzu[u] + sum_alpha); // User-topic-platform p[z] = p[z] * (n_pu[u][z][currp] + mu) / (sum_npu[u][z] + sum_mu); // topic-word for (int i = 0; i < users[u].posts[j].words.length; i++) { // Only consider the word belong to the topic not background if (users[u].posts[j].coins[i] == 1) { int w = users[u].posts[j].words[i]; p[z] = p[z] * (n_wz[w][z] + beta) / (sum_nwz[z] + sum_beta); } } // cumulative p[z] = sump + p[z]; sump = p[z]; } sump = rand.nextDouble() * sump; for (int z = 0; z < nTopics; z++) { if (sump > p[z]) continue; // Sample topic users[u].posts[j].topic = z; // Increment user-topic count and sum n_zu[z][u]++; sum_nzu[u]++; // Increment user-topic-platform count and sum n_pu[u][z][currp]++; sum_npu[u][z]++; // Increment topic-word count and sum for (int i = 0; i < users[u].posts[j].words.length; i++) { // Only consider the word belong to the topic not background if (users[u].posts[j].coins[i] == 1) { int w = users[u].posts[j].words[i]; n_wz[w][z]++; sum_nwz[z]++; } } return; } System.out.println("bug in samplePostTopic"); for (int z = 0; z < nTopics; z++) { System.out.print(p[z] + " "); } System.exit(-1); } else { // Implementation for MultiPlatformLDA-Global // Get current topic int currz = users[u].posts[j].topic; // Decrement user-topic count and sum n_zu[currz][u]--; sum_nzu[u]--; // Decrement global topic-platform count and sum int currp = users[u].posts[j].platform; n_p[currz][currp]--; sum_np[currz]--; for (int i = 0; i < users[u].posts[j].words.length; i++) { // Only consider the word belong to the topic not background if (users[u].posts[j].coins[i] == 1) { // Decrement topic-word count and sum int w = users[u].posts[j].words[i]; n_wz[w][currz]--; sum_nwz[currz]--; } } double sump = 0; // p: p(z_u,s = z| rest) double[] p = new double[nTopics]; for (int z = 0; z < nTopics; z++) { // User-topic p[z] = (n_zu[z][u] + alpha) / (sum_nzu[u] + sum_alpha); // global-topic-platform p[z] = p[z] * (n_p[z][currp] + mu) / (sum_np[z] + sum_mu); // topic-word for (int i = 0; i < users[u].posts[j].words.length; i++) { // Only consider the word belong to the topic not background if (users[u].posts[j].coins[i] == 1) { int w = users[u].posts[j].words[i]; p[z] = p[z] * (n_wz[w][z] + beta) / (sum_nwz[z] + sum_beta); } } // cumulative p[z] = sump + p[z]; sump = p[z]; } sump = rand.nextDouble() * sump; for (int z = 0; z < nTopics; z++) { if (sump > p[z]) continue; // Sample topic users[u].posts[j].topic = z; // Increment user-topic count and sum n_zu[z][u]++; sum_nzu[u]++; // Increment global-topic-platform count and sum n_p[z][currp]++; sum_np[z]++; // Increment topic-word count and sum for (int i = 0; i < users[u].posts[j].words.length; i++) { // Only consider the word belong to the topic not background if (users[u].posts[j].coins[i] == 1) { int w = users[u].posts[j].words[i]; n_wz[w][z]++; sum_nwz[z]++; } } return; } System.out.println("bug in samplePostTopic"); for (int z = 0; z < nTopics; z++) { System.out.print(p[z] + " "); } System.exit(-1); } } // Sample the coin for the word number i of the post number j of user number // u private void sampleWordCoin(int u, int j, int i) { // Get current coin int currc = users[u].posts[j].coins[i]; // Get current word int w = users[u].posts[j].words[i]; // Get current topic int z = users[u].posts[j].topic; // Decrement topic-coin count count and sum n_c[currc]--; sum_nc--; if (currc == 0) { // Decrement topic-word for background topic count and sum n_wb[w]--; sum_nwb--; } else { // Decrement topic-word for topic z count and sum n_wz[w][z]--; sum_nwz[z]--; } // Probability of coin 0 given priors and recent counts double p_0 = (n_c[0] + gamma[0]) / (sum_nc + sum_gamma); // Probability of w given coin 0 p_0 = p_0 * (n_wb[w] + beta) / (sum_nwb + sum_beta); // Probability of coin 1 given priors and recent counts double p_1 = (n_c[1] + gamma[1]) / (sum_nc + sum_gamma); // Probability of w given coin 1 and topic z p_1 = p_1 * (n_wz[w][z] + beta) / (sum_nwz[z] + sum_beta); double sump = p_0 + p_1; sump = rand.nextDouble() * sump; int c = 0; if (sump > p_0) c = 1; // Increment topic-coin count and sum users[u].posts[j].coins[i] = c; n_c[c]++; sum_nc++; if (c == 0) { // Increment topic-word for background topic count and sum n_wb[w]++; sum_nwb++; } else { // Increment topic-word for topic z count and sum n_wz[w][z]++; sum_nwz[z]++; } } private void updateFinalCounts() { for (int u = 0; u < users.length; u++) { for (int z = 0; z < nTopics; z++) { final_n_zu[z][u] += n_zu[z][u]; } final_sum_nzu[u] += sum_nzu[u]; } // Implementation for MultiPlatformLDA-UserSpecific if (modelType == ModelType.USER_SPECIFIC) { for (int u = 0; u < users.length; u++) { for (int z = 0; z < nTopics; z++) { for (int p = 0; p < nPlatforms; p++) { final_n_pu[u][z][p] += n_pu[u][z][p]; } final_sum_npu[u][z] += sum_npu[u][z]; } } } else { // Implementation for MultiPlatformLDA-Global for (int z = 0; z < nTopics; z++) { for (int p = 0; p < nPlatforms; p++) { final_n_p[z][p] += n_p[z][p]; } final_sum_np[z] += sum_np[z]; } } for (int u = 0; u < users.length; u++) { for (int z = 0; z < nTopics; z++) { final_n_zu[z][u] += n_zu[z][u]; } final_sum_nzu[u] += sum_nzu[u]; } for (int z = 0; z < nTopics; z++) { for (int w = 0; w < vocabulary.length; w++) { final_n_wz[w][z] += n_wz[w][z]; } final_sum_nwz[z] += sum_nwz[z]; } for (int w = 0; w < vocabulary.length; w++) { final_n_wb[w] += n_wb[w]; } final_sum_nwb += sum_nwb; for (int c = 0; c < 2; c++) { final_n_c[c] += n_c[c]; } final_sum_nc += sum_nc; } private void gibbsSampling() { System.out.println("Runing Gibbs sampling"); System.out.print("Setting priors ..."); setPriors(); System.out.println("Done!"); declareFinalCounts(); System.out.print("Initializing ... "); initilize(); System.out.println("Done!"); for (int iter = 0; iter < burningPeriod + maxIteration; iter++) { System.out.print("iteration " + iter); // topic for (int u = 0; u < users.length; u++) { for (int t = 0; t < users[u].posts.length; t++) { if (users[u].posts[t].batch != testBatch) { samplePostTopic(u, t); } } } // coin for (int u = 0; u < users.length; u++) { for (int t = 0; t < users[u].posts.length; t++) { if (users[u].posts[t].batch != testBatch) { for (int i = 0; i < users[u].posts[t].words.length; i++) sampleWordCoin(u, t, i); } } } System.out.println(" done!"); if (samplingGap <= 0) continue; if (iter < burningPeriod) continue; if ((iter - burningPeriod) % samplingGap == 0) { updateFinalCounts(); } } if (samplingGap <= 0) updateFinalCounts(); } private void inferingModelParameters() { // User-topic distribution for (int u = 0; u < users.length; u++) { users[u].topicDistribution = new double[nTopics]; for (int z = 0; z < nTopics; z++) { users[u].topicDistribution[z] = (final_n_zu[z][u] + alpha) / (final_sum_nzu[u] + sum_alpha); } } // Implementation for MultiPlatformLDA-UserSpecific if (modelType == ModelType.USER_SPECIFIC) { // User-topic-platform distribution for (int u = 0; u < users.length; u++) { users[u].topicPlatformDistribution = new double[nTopics][nPlatforms]; for (int z = 0; z < nTopics; z++) { for (int p = 0; p < nPlatforms; p++) { users[u].topicPlatformDistribution[z][p] = (final_n_pu[u][z][p] + mu) / (final_sum_npu[u][z] + sum_mu); } } } } else { // Implementation for MultiPlatformLDA-Global // Global-topic-platform distribution globalTopicPlatformDistribution = new double[nTopics][nPlatforms]; for (int z = 0; z < nTopics; z++) { for (int p = 0; p < nPlatforms; p++) { globalTopicPlatformDistribution[z][p] = (final_n_p[z][p] + mu) / (final_sum_np[z] + sum_mu); } } } // Topic-word distribution topics = new double[nTopics][vocabulary.length]; for (int z = 0; z < nTopics; z++) { for (int w = 0; w < vocabulary.length; w++) topics[z][w] = (final_n_wz[w][z] + beta) / (final_sum_nwz[z] + sum_beta); } // Background-word distribution backgroundTopic = new double[vocabulary.length]; for (int w = 0; w < vocabulary.length; w++) backgroundTopic[w] = (final_n_wb[w] + beta) / (final_sum_nwb + sum_beta); // Topic-coin coinBias = new double[2]; coinBias[0] = (final_n_c[0] + gamma[0]) / (final_sum_nc + sum_gamma); coinBias[1] = (final_n_c[1] + gamma[1]) / (final_sum_nc + sum_gamma); } public void learnModel() { gibbsSampling(); inferingModelParameters(); inferPostTopic(); getLikelihoodPerplexity(); } private double getPostLikelihood(int u, int j) { // compute likelihood of post number j of user number u // content double content_LogLikelihood = 0; for (int i = 0; i < users[u].posts[j].words.length; i++) { int w = users[u].posts[j].words[i]; // probability that word i is generated by background topic double p_0 = backgroundTopic[w] * coinBias[0]; // probability that word i is generated by other topics double p_1 = 0; for (int z = 0; z < nTopics; z++) { double p_z = topics[z][w] * users[u].topicDistribution[z]; p_1 = p_1 + p_z; } p_1 = p_1 * coinBias[1]; content_LogLikelihood += Math.log10(p_0 + p_1); } // platform int p = users[u].posts[j].platform; double p_Platform = 0; for (int z = 0; z < nTopics; z++) { if (modelType == ModelType.USER_SPECIFIC) { p_Platform += users[u].topicDistribution[z] * users[u].topicPlatformDistribution[z][p]; } else { p_Platform += users[u].topicDistribution[z] * globalTopicPlatformDistribution[z][p]; } } return content_LogLikelihood + Math.log10(p_Platform); } private double getPostLikelihood(int u, int j, int z) { // Compute likelihood of post number j of user number u given the topic // z if (z >= 0) { double content_logLikelihood = 0; for (int i = 0; i < users[u].posts[j].words.length; i++) { int w = users[u].posts[j].words[i]; // Probability that word i is generated by background topic double p_0 = backgroundTopic[w] * coinBias[0]; // Probability that word i is generated by topic z double p_1 = topics[z][w] * coinBias[1]; content_logLikelihood += Math.log10(p_0 + p_1); } double platform_logLikelihood = 0; int p = users[u].posts[j].platform; if (modelType == ModelType.USER_SPECIFIC) { platform_logLikelihood = Math.log10(users[u].topicPlatformDistribution[z][p]); } else { platform_logLikelihood = Math.log10(globalTopicPlatformDistribution[z][p]); } return (content_logLikelihood + platform_logLikelihood); } else {// background topic only double content_logLikelihood = 0; for (int i = 0; i < users[u].posts[j].words.length; i++) { int w = users[u].posts[j].words[i]; // Probability that word i is generated by background topic double p_0 = backgroundTopic[w]; content_logLikelihood = content_logLikelihood + Math.log10(p_0); } double platform_logLikelihood = Math.log10(1.0 / nPlatforms);// random return (content_logLikelihood + platform_logLikelihood); } } private double getPostContentLikelihood(int u, int j) { // compute likelihood of content of post number j of user number u // content double content_LogLikelihood = 0; for (int i = 0; i < users[u].posts[j].words.length; i++) { int w = users[u].posts[j].words[i]; // probability that word i is generated by background topic double p_0 = backgroundTopic[w] * coinBias[0]; // probability that word i is generated by other topics double p_1 = 0; for (int z = 0; z < nTopics; z++) { double p_z = topics[z][w] * users[u].topicDistribution[z]; p_1 = p_1 + p_z; } p_1 = p_1 * coinBias[1]; content_LogLikelihood += Math.log10(p_0 + p_1); } return content_LogLikelihood; } private double getPostLikelihood(int u, int j, int p, int z) { // Compute likelihood of post number j of user number u given the // platform is p and the topic is z if (z >= 0) { double content_logLikelihood = 0; for (int i = 0; i < users[u].posts[j].words.length; i++) { int w = users[u].posts[j].words[i]; // Probability that word i is generated by background topic double p_0 = backgroundTopic[w] * coinBias[0]; // Probability that word i is generated by topic z double p_1 = topics[z][w] * coinBias[1]; content_logLikelihood += Math.log10(p_0 + p_1); } double platform_logLikelihood = 0; if (modelType == ModelType.USER_SPECIFIC) { platform_logLikelihood = Math.log10(users[u].topicPlatformDistribution[z][p]); } else { platform_logLikelihood = Math.log10(globalTopicPlatformDistribution[z][p]); } return (content_logLikelihood + platform_logLikelihood); } else {// background topic only double content_logLikelihood = 0; for (int i = 0; i < users[u].posts[j].words.length; i++) { int w = users[u].posts[j].words[i]; // Probability that word i is generated by background topic double p_0 = backgroundTopic[w]; content_logLikelihood = content_logLikelihood + Math.log10(p_0); } double platform_logLikelihood = Math.log10(1.0 / nPlatforms);// random return (content_logLikelihood + platform_logLikelihood); } } private void getLikelihoodPerplexity() { postLogLikelidhood = 0; postContentLogLikelidhood = 0; postLogPerplexity = 0; postContentPerplexity = 0; int nTestPost = 0; for (int u = 0; u < users.length; u++) { for (int j = 0; j < users[u].posts.length; j++) { double logLikelihood = getPostLikelihood(u, j); double logContentLikelihood = getPostContentLikelihood(u, j); if (users[u].posts[j].batch != testBatch) { postLogLikelidhood += logLikelihood; postContentLogLikelidhood += logContentLikelihood; } else { postLogPerplexity += (-logLikelihood); postContentPerplexity += (-logContentLikelihood); nTestPost++; } } } postLogPerplexity /= nTestPost; postContentPerplexity /= nTestPost; } private void inferPostTopic() { for (int u = 0; u < users.length; u++) { for (int j = 0; j < users[u].posts.length; j++) { users[u].posts[j].inferedTopic = -1;// background topic only users[u].posts[j].inferedLikelihood = users[u].posts.length * Math.log10(coinBias[0]) + getPostLikelihood(u, j, -1); for (int z = 0; z < nTopics; z++) { double p_z = getPostLikelihood(u, j, z); p_z += Math.log10(users[u].topicDistribution[z]); if (users[u].posts[j].inferedLikelihood < p_z) { users[u].posts[j].inferedLikelihood = p_z; users[u].posts[j].inferedTopic = z; } } } } } private void inferPostPlatform() { for (int u = 0; u < users.length; u++) { for (int j = 0; j < users[u].posts.length; j++) { if (users[u].posts[j].batch != testBatch) continue; double maxLikelihood = getPostLikelihood(u, j, 0, 0); users[u].posts[j].inferedPlatform = 0; for (int p = 0; p < nPlatforms; p++) { for (int z = 0; z < nTopics; z++) { double likeLihood = getPostLikelihood(u, j, p, z); if (maxLikelihood < likeLihood) { maxLikelihood = likeLihood; users[u].posts[j].inferedPlatform = p; } } } } } } private void outputPostTopicWordDisributions() { try { String filename = outputPath + "/topicWordDistributions.csv"; BufferedWriter bw = new BufferedWriter(new FileWriter(filename)); for (int z = 0; z < nTopics; z++) { bw.write("" + z); for (int w = 0; w < vocabulary.length; w++) bw.write("," + topics[z][w]); bw.write("\n"); } bw.close(); filename = outputPath + "/backgroundTopicWordDistribution.csv"; bw = new BufferedWriter(new FileWriter(filename)); bw.write(backgroundTopic[0] + ""); for (int w = 1; w < vocabulary.length; w++) bw.write("," + backgroundTopic[w]); bw.close(); } catch (Exception e) { System.out.println("Error in writing out topics to file!"); e.printStackTrace(); System.exit(0); } } private void outputCoinBias() { try { String fileName = outputPath + "/coinBias.csv"; File file = new File(fileName); if (!file.exists()) { file.createNewFile(); } BufferedWriter bw = new BufferedWriter(new FileWriter(file.getAbsoluteFile())); bw.write(coinBias[0] + "," + coinBias[1]); bw.close(); } catch (Exception e) { System.out.println("Error in writing out coin bias to file!"); e.printStackTrace(); System.exit(0); } } private void outputPostTopicTopWords(int k) { try { String fileName = outputPath + "/postTopicTopWords.csv"; File file = new File(fileName); if (!file.exists()) { file.createNewFile(); } BufferedWriter bw = new BufferedWriter(new FileWriter(file.getAbsoluteFile())); RankingTool rankTool = new RankingTool(); WeightedElement[] topWords = null; for (int z = 0; z < nTopics; z++) { bw.write(z + "\n"); topWords = rankTool.getTopKbyWeight(vocabulary, topics[z], k); for (int j = 0; j < k; j++) bw.write("," + topWords[j].name + "," + topWords[j].weight + "\n"); } bw.write("background\n"); topWords = rankTool.getTopKbyWeight(vocabulary, backgroundTopic, 2 * k); for (int j = 0; j < 2 * k; j++) bw.write("," + topWords[j].name + "," + topWords[j].weight + "\n"); bw.close(); } catch (Exception e) { System.out.println("Error in writing out post topic top words to file!"); e.printStackTrace(); System.exit(0); } } private void outputTopicTopPosts(int k) { int[] nTopicPosts = new int[nTopics]; int nBackgroundTopicPosts = 0; for (int z = 0; z < nTopics; z++) nTopicPosts[z] = 0; for (int u = 0; u < users.length; u++) { for (int j = 0; j < users[u].posts.length; j++) { if (users[u].posts[j].batch == testBatch) continue; if (users[u].posts[j].inferedTopic >= 0) nTopicPosts[users[u].posts[j].inferedTopic]++; else nBackgroundTopicPosts++; } } String[][] postID = new String[nTopics][]; double[][] postPerplexity = new double[nTopics][]; for (int z = 0; z < nTopics; z++) { postID[z] = new String[nTopicPosts[z]]; postPerplexity[z] = new double[nTopicPosts[z]]; nTopicPosts[z] = 0; } String[] backgroundTopicPostID = new String[nBackgroundTopicPosts]; double[] backgroundTopicPostPerplexity = new double[nBackgroundTopicPosts]; nBackgroundTopicPosts = 0; for (int u = 0; u < users.length; u++) { for (int j = 0; j < users[u].posts.length; j++) { if (users[u].posts[j].batch == testBatch) continue; int z = users[u].posts[j].inferedTopic; if (z >= 0) { postID[z][nTopicPosts[z]] = users[u].posts[j].postID; postPerplexity[z][nTopicPosts[z]] = users[u].posts[j].inferedLikelihood / users[u].posts[j].words.length; nTopicPosts[z]++; } else { backgroundTopicPostID[nBackgroundTopicPosts] = users[u].posts[j].postID; backgroundTopicPostPerplexity[nBackgroundTopicPosts] = users[u].posts[j].inferedLikelihood; nBackgroundTopicPosts++; } } } try { String filename = outputPath + "/topicTopPosts.csv"; BufferedWriter bw = new BufferedWriter(new FileWriter(filename)); RankingTool rankTool = new RankingTool(); WeightedElement[] topPosts = null; for (int z = 0; z < nTopics; z++) { bw.write(z + "\n"); topPosts = rankTool.getTopKbyWeight(postID[z], postPerplexity[z], Math.min(k, nTopicPosts[z])); for (int j = 0; j < Math.min(k, nTopicPosts[z]); j++) bw.write("," + topPosts[j].name + "," + topPosts[j].weight + "\n"); } if (nBackgroundTopicPosts > 0) { bw.write("background\n"); topPosts = rankTool.getTopKbyWeight(backgroundTopicPostID, backgroundTopicPostPerplexity, Math.min(k, nBackgroundTopicPosts)); for (int j = 0; j < Math.min(k, nBackgroundTopicPosts); j++) bw.write("," + topPosts[j].name + "," + topPosts[j].weight + "\n"); } bw.close(); } catch (Exception e) { System.out.println("Error in writing out post topic top posts to file!"); e.printStackTrace(); System.exit(0); } } private void outputUserTopicDistribution() { try { String fileName = outputPath + "/userTopicDistributions.csv"; File file = new File(fileName); if (!file.exists()) { file.createNewFile(); } BufferedWriter bw = new BufferedWriter(new FileWriter(file.getAbsoluteFile())); for (int u = 0; u < users.length; u++) { bw.write("" + users[u].userID); for (int z = 0; z < nTopics; z++) bw.write("," + users[u].topicDistribution[z]); bw.write("\n"); } bw.close(); } catch (Exception e) { System.out.println("Error in writing out users' topic distribution to file!"); e.printStackTrace(); System.exit(0); } } private void outputUserSpecificTopicPlatformDistribution() { try { String fileName = outputPath + "/userTopicPlatformDistributions.csv"; File file = new File(fileName); if (!file.exists()) { file.createNewFile(); } BufferedWriter bw = new BufferedWriter(new FileWriter(file.getAbsoluteFile())); for (int u = 0; u < users.length; u++) { bw.write("" + users[u].userID); for (int z = 0; z < nTopics; z++) { for (int p = 0; p < nPlatforms; p++) { bw.write("," + users[u].topicPlatformDistribution[z][p]); System.out.println(u + "," + z + "," + p + "," + users[u].topicPlatformDistribution[z][p]); } } bw.write("\n"); } bw.close(); } catch (Exception e) { System.out.println("Error in writing out user topic platform distributions to file!"); e.printStackTrace(); System.exit(0); } } private void outputGlobalTopicPlatformDistribution() { try { String filename = outputPath + "/globalTopicPlatformDistributions.csv"; BufferedWriter bw = new BufferedWriter(new FileWriter(filename)); for (int z = 0; z < nTopics; z++) { bw.write("" + z); for (int p = 0; p < nPlatforms; p++) { bw.write("," + globalTopicPlatformDistribution[z][p]); } bw.write("\n"); } bw.close(); } catch (Exception e) { System.out.println("Error in writing out user topic distributions to file!"); e.printStackTrace(); System.exit(0); } } private void outputLikelihoodPerplexity() { try { String fileName = outputPath + "/likelihood-perplexity.csv"; File file = new File(fileName); if (!file.exists()) { file.createNewFile(); } BufferedWriter bw = new BufferedWriter(new FileWriter(file.getAbsoluteFile())); bw.write("postLogLikelihood,postLogPerplexity\n"); bw.write("" + postLogLikelidhood + "," + postLogPerplexity); bw.write("," + postContentLogLikelidhood + "," + postContentPerplexity); bw.close(); } catch (Exception e) { System.out.println("Error in writing out posts to file!"); e.printStackTrace(); System.exit(0); } } private void outputInferedTopic() { try { SystemTool.createFolder(outputPath, "inferedTopics"); for (int u = 0; u < users.length; u++) { String filename = outputPath + SystemTool.pathSeparator + "inferedTopics" + SystemTool.pathSeparator + users[u].userID + ".txt"; BufferedWriter bw = new BufferedWriter(new FileWriter(filename)); for (int j = 0; j < users[u].posts.length; j++) bw.write(users[u].posts[j].postID + "\t" + users[u].posts[j].inferedTopic + "\n"); bw.close(); } } catch (Exception e) { System.out.println("Error in writing out post topics to file!"); e.printStackTrace(); System.exit(0); } } private void outputInferedPlatform() { try { SystemTool.createFolder(outputPath, "inferedPlatforms"); for (int u = 0; u < users.length; u++) { String filename = outputPath + SystemTool.pathSeparator + "inferedPlatforms" + SystemTool.pathSeparator + users[u].userID + ".csv"; BufferedWriter bw = new BufferedWriter(new FileWriter(filename)); for (int j = 0; j < users[u].posts.length; j++) { if (users[u].posts[j].batch != testBatch) continue; bw.write(users[u].posts[j].postID + "," + users[u].posts[j].inferedPlatform + "," + users[u].posts[j].platform + "\n"); } bw.close(); } } catch (Exception e) { System.out.println("Error in writing out post topics to file!"); e.printStackTrace(); System.exit(0); } } public void outputAll() { outputPostTopicWordDisributions(); outputPostTopicTopWords(20); outputCoinBias(); outputUserTopicDistribution(); if (modelType == ModelType.USER_SPECIFIC) { outputUserSpecificTopicPlatformDistribution(); } else { outputGlobalTopicPlatformDistribution(); } if (toOutputTopicTopPosts || toOutputInferedTopics) { inferPostTopic(); if (toOutputTopicTopPosts) { outputTopicTopPosts(100); } if (toOutputInferedTopics) { outputInferedTopic(); } } if (toOutputInferedPlatforms) { inferPostPlatform(); outputInferedPlatform(); } if (toOutputLikelihoodPerplexity) { getLikelihoodPerplexity(); outputLikelihoodPerplexity(); } } }