Model.MultiPlatformLDA.java Source code

Java tutorial

Introduction

Here is the source code for Model.MultiPlatformLDA.java

Source

package Model;

import hoang.larc.tooler.RankingTool;
import hoang.larc.tooler.SystemTool;
import hoang.larc.tooler.WeightedElement;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.HashMap;
import java.util.Random;
import java.util.Scanner;

import org.apache.commons.io.FilenameUtils;

/* Copyright (c) 2016 Roy Ka-Wei LEE and Tuan-Anh HOANG
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 * 
 */

public class MultiPlatformLDA {

    public String dataPath;
    public String outputPath;
    public int nTopics;
    public int nPlatforms;
    public int modelType;

    public int burningPeriod;
    public int maxIteration;
    public int samplingGap;
    public int testBatch;

    public Random rand;

    public boolean toOutputLikelihoodPerplexity;// option to output likelihood
    // of train data and perplexity
    // of test data
    public boolean toOutputTopicTopPosts;// option to output top posts of each
    // topic
    public boolean toOutputInferedTopics;// option to output (all) posts'
    // inferred topics
    public boolean toOutputInferedPlatforms;// option to output inferred
    // platform for posts in test batch

    // hyperparameters
    private double alpha;
    private double sum_alpha;
    private double beta;
    private double sum_beta;
    private double[] gamma;
    private double sum_gamma;
    private double mu;
    private double sum_mu;

    // data
    private User[] users;
    private String[] vocabulary;

    // parameters
    private double[][] topics;
    private double[] backgroundTopic;
    private double[] coinBias;
    private double[][] globalTopicPlatformDistribution;// only used for model 2

    // Gibbs sampling variables
    // user-topic count
    private int[][] n_zu; // n_zu[z,u]: number of times topic z is observed in
    // posts by user u
    private int[] sum_nzu; // sum_nzu[u]: total number of topics that are
    // observed in posts by user u

    // topic-word count
    private int[][] n_wz; // n_wz[w,z]: number of times word w is generated by a
    // topic z in all tweets
    private int[] sum_nwz; // sum_nw[z]: total number of words that are
    // generated by a topic z in tweets
    private int[] n_wb; // n_wz[w]: number of times word w is generated by a
    // background topic
    private int sum_nwb; // sum_nw[z]: total number of words that are generated
                         // by background topic

    // topic-coin count
    private int[] n_c; // sum_nw[c]: total number of words that are associated
    // with coin c
    private int sum_nc;

    // user-topic-platform count: use by model-1
    private int[][][] n_pu; // n_p[p,u,z]: number of times platform p is
    // selected for a topic z that is observed in posts
    // by user u
    private int[][] sum_npu; // sum_npu[u,z]: number of times topic z is
    // observed in posts by user u (This is same as
    // sum_nz)

    // global-topic-platform count: use by model-2
    private int[][] n_p; // n_p[p,z]: number of times platform p is selected for
                         // a topic z that is observed in all posts
    private int[] sum_np; // sum_npu[z]: number of times topic z is observed in
    // all posts

    private int[][] final_n_zu;
    private int[] final_sum_nzu;
    private int[][] final_n_wz;
    private int[] final_sum_nwz;
    private int[] final_n_wb;
    private int final_sum_nwb;
    private int[] final_n_c;
    private int final_sum_nc;
    private int[][][] final_n_pu;
    private int[][] final_sum_npu;
    private int[][] final_n_p;
    private int[] final_sum_np;

    private double postLogLikelidhood;
    private double postContentLogLikelidhood;

    private double postLogPerplexity;
    private double postContentPerplexity;

    public void readData() {
        Scanner sc = null;
        BufferedReader br = null;
        String line = null;
        HashMap<String, Integer> userId2Index = null;
        HashMap<Integer, String> userIndex2Id = null;

        try {
            String folderName = dataPath + "/users";
            File postFolder = new File(folderName);

            // Read number of users
            int nUser = postFolder.listFiles().length;
            users = new User[nUser];
            userId2Index = new HashMap<String, Integer>(nUser);
            userIndex2Id = new HashMap<Integer, String>(nUser);
            int u = -1;

            // Read the posts from each user file
            for (File postFile : postFolder.listFiles()) {
                u++;
                users[u] = new User();

                // Read index of the user
                String userId = FilenameUtils.removeExtension(postFile.getName());
                userId2Index.put(userId, u);
                userIndex2Id.put(u, userId);
                users[u].userID = userId;

                // Read the number of posts from user
                int nPost = 0;
                br = new BufferedReader(new FileReader(postFile.getAbsolutePath()));
                while (br.readLine() != null) {
                    nPost++;
                }
                br.close();

                // Declare the number of posts from user
                users[u].posts = new Post[nPost];

                // Read each of the post
                br = new BufferedReader(new FileReader(postFile.getAbsolutePath()));
                int j = -1;
                while ((line = br.readLine()) != null) {
                    j++;
                    users[u].posts[j] = new Post();

                    sc = new Scanner(line.toString());
                    sc.useDelimiter(",");
                    while (sc.hasNext()) {
                        users[u].posts[j].postID = sc.next();
                        users[u].posts[j].platform = sc.nextInt();
                        users[u].posts[j].batch = sc.nextInt();

                        // Read the words in each post
                        String[] tokens = sc.next().toString().split(" ");
                        users[u].posts[j].words = new int[tokens.length];
                        for (int i = 0; i < tokens.length; i++) {
                            users[u].posts[j].words[i] = Integer.parseInt(tokens[i]);
                        }
                    }
                }
                br.close();
            }

            // Read post vocabulary
            String vocabularyFileName = dataPath + "/vocabulary.csv";

            br = new BufferedReader(new FileReader(vocabularyFileName));
            int nPostWord = 0;
            while (br.readLine() != null) {
                nPostWord++;
            }
            br.close();
            vocabulary = new String[nPostWord];

            br = new BufferedReader(new FileReader(vocabularyFileName));
            while ((line = br.readLine()) != null) {
                String[] tokens = line.split(",");
                int index = Integer.parseInt(tokens[0]);
                vocabulary[index] = tokens[1];
            }
            br.close();

        } catch (Exception e) {
            System.out.println("Error in reading post from file!");
            e.printStackTrace();
            System.exit(0);
        }
    }

    private void declareFinalCounts() {

        // Final array of user-topics
        final_n_zu = new int[nTopics][users.length];
        final_sum_nzu = new int[users.length];
        for (int u = 0; u < users.length; u++) {
            for (int z = 0; z < nTopics; z++) {
                final_n_zu[z][u] = 0;
            }
            final_sum_nzu[u] = 0;
        }

        // Final array of topic-words which are generated by topic z
        final_n_wz = new int[vocabulary.length][nTopics];
        final_sum_nwz = new int[nTopics];
        for (int z = 0; z < nTopics; z++) {
            for (int w = 0; w < vocabulary.length; w++) {
                final_n_wz[w][z] = 0;
            }
            final_sum_nwz[z] = 0;
        }

        // Final array of topic-words which are generated by background topic
        final_n_wb = new int[vocabulary.length];
        for (int w = 0; w < vocabulary.length; w++) {
            final_n_wb[w] = 0;
        }
        final_sum_nwb = 0;

        // Final array of coins. Default array size is 2, representing 2 sides
        // of coin
        final_n_c = new int[2];
        for (int c = 0; c < 2; c++) {
            final_n_c[c] = 0;
        }
        final_sum_nc = 0;

        // Implementation for MultiPlatformLDA-UserSpecific
        if (modelType == ModelType.USER_SPECIFIC) {
            // Final array of user-topic-platforms
            final_n_pu = new int[users.length][nTopics][nPlatforms];
            final_sum_npu = new int[users.length][nTopics];
            for (int u = 0; u < users.length; u++) {
                for (int z = 0; z < nTopics; z++) {
                    for (int p = 0; p < nPlatforms; p++) {
                        final_n_pu[u][z][p] = 0;
                    }
                    final_sum_npu[u][z] = 0;
                }
            }
        } else {
            // Implementation for MultiPlatformLDA-Global
            // Final array of global-topic-platforms
            final_n_p = new int[nTopics][nPlatforms];
            final_sum_np = new int[nTopics];
            for (int z = 0; z < nTopics; z++) {
                for (int p = 0; p < nPlatforms; p++) {
                    final_n_p[z][p] = 0;
                }
                final_sum_np[z] = 0;
            }
        }
    }

    private void initilize() {
        // Init coin, topic and platform for each post
        for (int u = 0; u < users.length; u++) {
            for (int j = 0; j < users[u].posts.length; j++) {
                // Check that this is only for training
                if (users[u].posts[j].batch != testBatch) {
                    // Randomly assign topic to the post
                    users[u].posts[j].topic = rand.nextInt(nTopics);

                    // Declare size of coins
                    int nWords = users[u].posts[j].words.length;
                    users[u].posts[j].coins = new int[nWords];

                    // Randomly assign coin to word
                    for (int i = 0; i < users[u].posts[j].coins.length; i++)
                        users[u].posts[j].coins[i] = rand.nextInt(2);
                }
            }
        }

        // Declare and initiate counting tables
        // Init for user-topics
        n_zu = new int[nTopics][users.length];
        sum_nzu = new int[users.length];
        for (int u = 0; u < users.length; u++) {
            for (int z = 0; z < nTopics; z++) {
                n_zu[z][u] = 0;
            }
            sum_nzu[u] = 0;
        }
        // Init for topic-word for topic z
        n_wz = new int[vocabulary.length][nTopics];
        sum_nwz = new int[nTopics];
        for (int z = 0; z < nTopics; z++) {
            for (int w = 0; w < vocabulary.length; w++) {
                n_wz[w][z] = 0;
            }
            sum_nwz[z] = 0;
        }
        // Init for topic-word for background topic
        n_wb = new int[vocabulary.length];
        for (int w = 0; w < vocabulary.length; w++) {
            n_wb[w] = 0;
        }
        sum_nwb = 0;
        // Init for topic-coin
        n_c = new int[2];
        n_c[0] = 0;
        n_c[1] = 0;
        sum_nc = 0;
        // Implementation for MultiPlatformLDA-UserSpecific
        if (modelType == ModelType.USER_SPECIFIC) {
            // Init for user-topic-platform
            n_pu = new int[users.length][nTopics][nPlatforms];
            sum_npu = new int[users.length][nTopics];
            for (int u = 0; u < users.length; u++) {
                for (int z = 0; z < nTopics; z++) {
                    for (int p = 0; p < nPlatforms; p++) {
                        n_pu[u][z][p] = 0;
                    }
                    sum_npu[u][z] = 0;
                }
            }
        } else {
            // Implementation for MultiPlatformLDA-Global
            // Init for global-topic-platform
            n_p = new int[nTopics][nPlatforms];
            sum_np = new int[nTopics];
            for (int z = 0; z < nTopics; z++) {
                for (int p = 0; p < nPlatforms; p++) {
                    n_p[z][p] = 0;
                }
                sum_np[z] = 0;
            }

        }

        // Update counting tables
        for (int u = 0; u < users.length; u++) {
            for (int j = 0; j < users[u].posts.length; j++) {
                // Training batch
                if (users[u].posts[j].batch != testBatch) {
                    int z = users[u].posts[j].topic;
                    // Update user-topic counts
                    n_zu[z][u]++;
                    sum_nzu[u]++;

                    // Implementation for MultiPlatformLDA-UserSpecific
                    if (modelType == ModelType.USER_SPECIFIC) {
                        // Update user-topic-platform
                        int p = users[u].posts[j].platform;
                        n_pu[u][z][p]++;
                        sum_npu[u][z]++;
                    } else {
                        // Implementation for MultiPlatformLDA-Global
                        // Update global-topic-platform
                        int p = users[u].posts[j].platform;
                        n_p[z][p]++;
                        sum_np[z]++;
                    }

                    for (int i = 0; i < users[u].posts[j].words.length; i++) {
                        int w = users[u].posts[j].words[i];
                        int c = users[u].posts[j].coins[i];

                        // Update coin count
                        n_c[c]++;
                        sum_nc++;

                        if (c == 0) {
                            // Update background topic word count
                            n_wb[w]++;
                            sum_nwb++;
                        } else {
                            // Update background topic z word count
                            n_wz[w][z]++;
                            sum_nwz[z]++;
                        }

                    }
                }

            }
        }
    }

    private void setPriors() {
        // User topic prior
        alpha = 100.0 / nTopics;
        sum_alpha = 100;

        // topic platform prior
        //      mu = 0.1;
        //      sum_mu = 0.1 * nPlatforms;

        mu = 30.0 / nPlatforms;
        sum_mu = 30;

        // Topic tweet word prior
        beta = 0.01;
        sum_beta = 0.01 * vocabulary.length;

        // Biased coin prior
        gamma = new double[2];
        gamma[0] = 2;
        gamma[1] = 2;
        sum_gamma = gamma[0] + gamma[1];

    }

    // Sample the topic for post number j of user number u
    private void samplePostTopic(int u, int j) {

        // Implementation for MultiPlatformLDA-UserSpecific
        if (modelType == ModelType.USER_SPECIFIC) {
            // Get current topic
            int currz = users[u].posts[j].topic;

            // Decrement user-topic count and sum
            n_zu[currz][u]--;
            sum_nzu[u]--;
            // Decrement user-topic-platform count and sum
            int currp = users[u].posts[j].platform;
            n_pu[u][currz][currp]--;
            sum_npu[u][currz]--;
            for (int i = 0; i < users[u].posts[j].words.length; i++) {
                // Only consider the word belong to the topic not background
                if (users[u].posts[j].coins[i] == 1) {
                    // Decrement topic-word count and sum
                    int w = users[u].posts[j].words[i];
                    n_wz[w][currz]--;
                    sum_nwz[currz]--;
                }
            }

            double sump = 0;
            // p: p(z_u,s = z| rest)
            double[] p = new double[nTopics];
            for (int z = 0; z < nTopics; z++) {
                // User-topic
                p[z] = (n_zu[z][u] + alpha) / (sum_nzu[u] + sum_alpha);

                // User-topic-platform
                p[z] = p[z] * (n_pu[u][z][currp] + mu) / (sum_npu[u][z] + sum_mu);

                // topic-word
                for (int i = 0; i < users[u].posts[j].words.length; i++) {
                    // Only consider the word belong to the topic not background
                    if (users[u].posts[j].coins[i] == 1) {
                        int w = users[u].posts[j].words[i];
                        p[z] = p[z] * (n_wz[w][z] + beta) / (sum_nwz[z] + sum_beta);
                    }
                }
                // cumulative
                p[z] = sump + p[z];
                sump = p[z];
            }

            sump = rand.nextDouble() * sump;
            for (int z = 0; z < nTopics; z++) {
                if (sump > p[z])
                    continue;
                // Sample topic
                users[u].posts[j].topic = z;

                // Increment user-topic count and sum
                n_zu[z][u]++;
                sum_nzu[u]++;

                // Increment user-topic-platform count and sum
                n_pu[u][z][currp]++;
                sum_npu[u][z]++;

                // Increment topic-word count and sum
                for (int i = 0; i < users[u].posts[j].words.length; i++) {
                    // Only consider the word belong to the topic not background
                    if (users[u].posts[j].coins[i] == 1) {
                        int w = users[u].posts[j].words[i];
                        n_wz[w][z]++;
                        sum_nwz[z]++;
                    }
                }
                return;
            }
            System.out.println("bug in samplePostTopic");
            for (int z = 0; z < nTopics; z++) {
                System.out.print(p[z] + " ");
            }
            System.exit(-1);
        } else {
            // Implementation for MultiPlatformLDA-Global
            // Get current topic
            int currz = users[u].posts[j].topic;

            // Decrement user-topic count and sum
            n_zu[currz][u]--;
            sum_nzu[u]--;
            // Decrement global topic-platform count and sum
            int currp = users[u].posts[j].platform;
            n_p[currz][currp]--;
            sum_np[currz]--;
            for (int i = 0; i < users[u].posts[j].words.length; i++) {
                // Only consider the word belong to the topic not background
                if (users[u].posts[j].coins[i] == 1) {
                    // Decrement topic-word count and sum
                    int w = users[u].posts[j].words[i];
                    n_wz[w][currz]--;
                    sum_nwz[currz]--;
                }
            }

            double sump = 0;
            // p: p(z_u,s = z| rest)
            double[] p = new double[nTopics];
            for (int z = 0; z < nTopics; z++) {
                // User-topic
                p[z] = (n_zu[z][u] + alpha) / (sum_nzu[u] + sum_alpha);

                // global-topic-platform
                p[z] = p[z] * (n_p[z][currp] + mu) / (sum_np[z] + sum_mu);

                // topic-word
                for (int i = 0; i < users[u].posts[j].words.length; i++) {
                    // Only consider the word belong to the topic not background
                    if (users[u].posts[j].coins[i] == 1) {
                        int w = users[u].posts[j].words[i];
                        p[z] = p[z] * (n_wz[w][z] + beta) / (sum_nwz[z] + sum_beta);
                    }
                }
                // cumulative
                p[z] = sump + p[z];
                sump = p[z];
            }

            sump = rand.nextDouble() * sump;
            for (int z = 0; z < nTopics; z++) {
                if (sump > p[z])
                    continue;
                // Sample topic
                users[u].posts[j].topic = z;

                // Increment user-topic count and sum
                n_zu[z][u]++;
                sum_nzu[u]++;

                // Increment global-topic-platform count and sum
                n_p[z][currp]++;
                sum_np[z]++;

                // Increment topic-word count and sum
                for (int i = 0; i < users[u].posts[j].words.length; i++) {
                    // Only consider the word belong to the topic not background
                    if (users[u].posts[j].coins[i] == 1) {
                        int w = users[u].posts[j].words[i];
                        n_wz[w][z]++;
                        sum_nwz[z]++;
                    }
                }
                return;
            }
            System.out.println("bug in samplePostTopic");
            for (int z = 0; z < nTopics; z++) {
                System.out.print(p[z] + " ");
            }
            System.exit(-1);

        }

    }

    // Sample the coin for the word number i of the post number j of user number
    // u
    private void sampleWordCoin(int u, int j, int i) {
        // Get current coin
        int currc = users[u].posts[j].coins[i];
        // Get current word
        int w = users[u].posts[j].words[i];
        // Get current topic
        int z = users[u].posts[j].topic;

        // Decrement topic-coin count count and sum
        n_c[currc]--;
        sum_nc--;

        if (currc == 0) {
            // Decrement topic-word for background topic count and sum
            n_wb[w]--;
            sum_nwb--;
        } else {
            // Decrement topic-word for topic z count and sum
            n_wz[w][z]--;
            sum_nwz[z]--;
        }

        // Probability of coin 0 given priors and recent counts
        double p_0 = (n_c[0] + gamma[0]) / (sum_nc + sum_gamma);
        // Probability of w given coin 0
        p_0 = p_0 * (n_wb[w] + beta) / (sum_nwb + sum_beta);

        // Probability of coin 1 given priors and recent counts
        double p_1 = (n_c[1] + gamma[1]) / (sum_nc + sum_gamma);
        // Probability of w given coin 1 and topic z
        p_1 = p_1 * (n_wz[w][z] + beta) / (sum_nwz[z] + sum_beta);

        double sump = p_0 + p_1;
        sump = rand.nextDouble() * sump;
        int c = 0;
        if (sump > p_0)
            c = 1;

        // Increment topic-coin count and sum
        users[u].posts[j].coins[i] = c;
        n_c[c]++;
        sum_nc++;
        if (c == 0) {
            // Increment topic-word for background topic count and sum
            n_wb[w]++;
            sum_nwb++;
        } else {
            // Increment topic-word for topic z count and sum
            n_wz[w][z]++;
            sum_nwz[z]++;
        }
    }

    private void updateFinalCounts() {
        for (int u = 0; u < users.length; u++) {
            for (int z = 0; z < nTopics; z++) {
                final_n_zu[z][u] += n_zu[z][u];
            }
            final_sum_nzu[u] += sum_nzu[u];
        }

        // Implementation for MultiPlatformLDA-UserSpecific
        if (modelType == ModelType.USER_SPECIFIC) {
            for (int u = 0; u < users.length; u++) {
                for (int z = 0; z < nTopics; z++) {
                    for (int p = 0; p < nPlatforms; p++) {
                        final_n_pu[u][z][p] += n_pu[u][z][p];
                    }
                    final_sum_npu[u][z] += sum_npu[u][z];
                }
            }
        } else {
            // Implementation for MultiPlatformLDA-Global
            for (int z = 0; z < nTopics; z++) {
                for (int p = 0; p < nPlatforms; p++) {
                    final_n_p[z][p] += n_p[z][p];
                }
                final_sum_np[z] += sum_np[z];
            }
        }

        for (int u = 0; u < users.length; u++) {
            for (int z = 0; z < nTopics; z++) {
                final_n_zu[z][u] += n_zu[z][u];
            }
            final_sum_nzu[u] += sum_nzu[u];
        }

        for (int z = 0; z < nTopics; z++) {
            for (int w = 0; w < vocabulary.length; w++) {
                final_n_wz[w][z] += n_wz[w][z];
            }
            final_sum_nwz[z] += sum_nwz[z];
        }

        for (int w = 0; w < vocabulary.length; w++) {
            final_n_wb[w] += n_wb[w];
        }
        final_sum_nwb += sum_nwb;

        for (int c = 0; c < 2; c++) {
            final_n_c[c] += n_c[c];
        }
        final_sum_nc += sum_nc;
    }

    private void gibbsSampling() {
        System.out.println("Runing Gibbs sampling");
        System.out.print("Setting priors ...");
        setPriors();
        System.out.println("Done!");
        declareFinalCounts();
        System.out.print("Initializing ... ");
        initilize();
        System.out.println("Done!");
        for (int iter = 0; iter < burningPeriod + maxIteration; iter++) {
            System.out.print("iteration " + iter);
            // topic
            for (int u = 0; u < users.length; u++) {
                for (int t = 0; t < users[u].posts.length; t++) {
                    if (users[u].posts[t].batch != testBatch) {
                        samplePostTopic(u, t);
                    }
                }
            }
            // coin
            for (int u = 0; u < users.length; u++) {
                for (int t = 0; t < users[u].posts.length; t++) {
                    if (users[u].posts[t].batch != testBatch) {
                        for (int i = 0; i < users[u].posts[t].words.length; i++)
                            sampleWordCoin(u, t, i);
                    }
                }
            }

            System.out.println(" done!");
            if (samplingGap <= 0)
                continue;
            if (iter < burningPeriod)
                continue;
            if ((iter - burningPeriod) % samplingGap == 0) {
                updateFinalCounts();
            }
        }
        if (samplingGap <= 0)
            updateFinalCounts();
    }

    private void inferingModelParameters() {
        // User-topic distribution
        for (int u = 0; u < users.length; u++) {
            users[u].topicDistribution = new double[nTopics];
            for (int z = 0; z < nTopics; z++) {
                users[u].topicDistribution[z] = (final_n_zu[z][u] + alpha) / (final_sum_nzu[u] + sum_alpha);
            }
        }

        // Implementation for MultiPlatformLDA-UserSpecific
        if (modelType == ModelType.USER_SPECIFIC) {
            // User-topic-platform distribution
            for (int u = 0; u < users.length; u++) {
                users[u].topicPlatformDistribution = new double[nTopics][nPlatforms];
                for (int z = 0; z < nTopics; z++) {
                    for (int p = 0; p < nPlatforms; p++) {
                        users[u].topicPlatformDistribution[z][p] = (final_n_pu[u][z][p] + mu)
                                / (final_sum_npu[u][z] + sum_mu);
                    }
                }
            }
        } else {
            // Implementation for MultiPlatformLDA-Global
            // Global-topic-platform distribution
            globalTopicPlatformDistribution = new double[nTopics][nPlatforms];
            for (int z = 0; z < nTopics; z++) {
                for (int p = 0; p < nPlatforms; p++) {
                    globalTopicPlatformDistribution[z][p] = (final_n_p[z][p] + mu) / (final_sum_np[z] + sum_mu);
                }
            }

        }

        // Topic-word distribution
        topics = new double[nTopics][vocabulary.length];
        for (int z = 0; z < nTopics; z++) {
            for (int w = 0; w < vocabulary.length; w++)
                topics[z][w] = (final_n_wz[w][z] + beta) / (final_sum_nwz[z] + sum_beta);
        }

        // Background-word distribution
        backgroundTopic = new double[vocabulary.length];
        for (int w = 0; w < vocabulary.length; w++)
            backgroundTopic[w] = (final_n_wb[w] + beta) / (final_sum_nwb + sum_beta);

        // Topic-coin
        coinBias = new double[2];
        coinBias[0] = (final_n_c[0] + gamma[0]) / (final_sum_nc + sum_gamma);
        coinBias[1] = (final_n_c[1] + gamma[1]) / (final_sum_nc + sum_gamma);

    }

    public void learnModel() {
        gibbsSampling();
        inferingModelParameters();
        inferPostTopic();
        getLikelihoodPerplexity();
    }

    private double getPostLikelihood(int u, int j) {
        // compute likelihood of post number j of user number u

        // content
        double content_LogLikelihood = 0;
        for (int i = 0; i < users[u].posts[j].words.length; i++) {
            int w = users[u].posts[j].words[i];
            // probability that word i is generated by background topic
            double p_0 = backgroundTopic[w] * coinBias[0];
            // probability that word i is generated by other topics
            double p_1 = 0;
            for (int z = 0; z < nTopics; z++) {
                double p_z = topics[z][w] * users[u].topicDistribution[z];
                p_1 = p_1 + p_z;
            }
            p_1 = p_1 * coinBias[1];
            content_LogLikelihood += Math.log10(p_0 + p_1);
        }
        // platform
        int p = users[u].posts[j].platform;
        double p_Platform = 0;
        for (int z = 0; z < nTopics; z++) {
            if (modelType == ModelType.USER_SPECIFIC) {
                p_Platform += users[u].topicDistribution[z] * users[u].topicPlatformDistribution[z][p];
            } else {
                p_Platform += users[u].topicDistribution[z] * globalTopicPlatformDistribution[z][p];
            }
        }

        return content_LogLikelihood + Math.log10(p_Platform);
    }

    private double getPostLikelihood(int u, int j, int z) {
        // Compute likelihood of post number j of user number u given the topic
        // z
        if (z >= 0) {
            double content_logLikelihood = 0;
            for (int i = 0; i < users[u].posts[j].words.length; i++) {
                int w = users[u].posts[j].words[i];
                // Probability that word i is generated by background topic
                double p_0 = backgroundTopic[w] * coinBias[0];
                // Probability that word i is generated by topic z
                double p_1 = topics[z][w] * coinBias[1];
                content_logLikelihood += Math.log10(p_0 + p_1);
            }
            double platform_logLikelihood = 0;
            int p = users[u].posts[j].platform;
            if (modelType == ModelType.USER_SPECIFIC) {
                platform_logLikelihood = Math.log10(users[u].topicPlatformDistribution[z][p]);
            } else {
                platform_logLikelihood = Math.log10(globalTopicPlatformDistribution[z][p]);
            }

            return (content_logLikelihood + platform_logLikelihood);
        } else {// background topic only
            double content_logLikelihood = 0;
            for (int i = 0; i < users[u].posts[j].words.length; i++) {
                int w = users[u].posts[j].words[i];
                // Probability that word i is generated by background topic
                double p_0 = backgroundTopic[w];
                content_logLikelihood = content_logLikelihood + Math.log10(p_0);
            }
            double platform_logLikelihood = Math.log10(1.0 / nPlatforms);// random
            return (content_logLikelihood + platform_logLikelihood);
        }
    }

    private double getPostContentLikelihood(int u, int j) {
        // compute likelihood of content of post number j of user number u

        // content
        double content_LogLikelihood = 0;
        for (int i = 0; i < users[u].posts[j].words.length; i++) {
            int w = users[u].posts[j].words[i];
            // probability that word i is generated by background topic
            double p_0 = backgroundTopic[w] * coinBias[0];
            // probability that word i is generated by other topics
            double p_1 = 0;
            for (int z = 0; z < nTopics; z++) {
                double p_z = topics[z][w] * users[u].topicDistribution[z];
                p_1 = p_1 + p_z;
            }
            p_1 = p_1 * coinBias[1];
            content_LogLikelihood += Math.log10(p_0 + p_1);
        }
        return content_LogLikelihood;
    }

    private double getPostLikelihood(int u, int j, int p, int z) {
        // Compute likelihood of post number j of user number u given the
        // platform is p and the topic is z
        if (z >= 0) {
            double content_logLikelihood = 0;
            for (int i = 0; i < users[u].posts[j].words.length; i++) {
                int w = users[u].posts[j].words[i];
                // Probability that word i is generated by background topic
                double p_0 = backgroundTopic[w] * coinBias[0];
                // Probability that word i is generated by topic z
                double p_1 = topics[z][w] * coinBias[1];
                content_logLikelihood += Math.log10(p_0 + p_1);
            }
            double platform_logLikelihood = 0;
            if (modelType == ModelType.USER_SPECIFIC) {
                platform_logLikelihood = Math.log10(users[u].topicPlatformDistribution[z][p]);
            } else {
                platform_logLikelihood = Math.log10(globalTopicPlatformDistribution[z][p]);
            }
            return (content_logLikelihood + platform_logLikelihood);
        } else {// background topic only
            double content_logLikelihood = 0;
            for (int i = 0; i < users[u].posts[j].words.length; i++) {
                int w = users[u].posts[j].words[i];
                // Probability that word i is generated by background topic
                double p_0 = backgroundTopic[w];
                content_logLikelihood = content_logLikelihood + Math.log10(p_0);
            }
            double platform_logLikelihood = Math.log10(1.0 / nPlatforms);// random
            return (content_logLikelihood + platform_logLikelihood);
        }
    }

    private void getLikelihoodPerplexity() {
        postLogLikelidhood = 0;
        postContentLogLikelidhood = 0;
        postLogPerplexity = 0;
        postContentPerplexity = 0;
        int nTestPost = 0;
        for (int u = 0; u < users.length; u++) {
            for (int j = 0; j < users[u].posts.length; j++) {
                double logLikelihood = getPostLikelihood(u, j);
                double logContentLikelihood = getPostContentLikelihood(u, j);
                if (users[u].posts[j].batch != testBatch) {
                    postLogLikelidhood += logLikelihood;
                    postContentLogLikelidhood += logContentLikelihood;
                } else {
                    postLogPerplexity += (-logLikelihood);
                    postContentPerplexity += (-logContentLikelihood);
                    nTestPost++;
                }
            }
        }
        postLogPerplexity /= nTestPost;
        postContentPerplexity /= nTestPost;
    }

    private void inferPostTopic() {
        for (int u = 0; u < users.length; u++) {
            for (int j = 0; j < users[u].posts.length; j++) {
                users[u].posts[j].inferedTopic = -1;// background topic only
                users[u].posts[j].inferedLikelihood = users[u].posts.length * Math.log10(coinBias[0])
                        + getPostLikelihood(u, j, -1);

                for (int z = 0; z < nTopics; z++) {
                    double p_z = getPostLikelihood(u, j, z);
                    p_z += Math.log10(users[u].topicDistribution[z]);

                    if (users[u].posts[j].inferedLikelihood < p_z) {
                        users[u].posts[j].inferedLikelihood = p_z;
                        users[u].posts[j].inferedTopic = z;
                    }
                }
            }
        }
    }

    private void inferPostPlatform() {
        for (int u = 0; u < users.length; u++) {
            for (int j = 0; j < users[u].posts.length; j++) {
                if (users[u].posts[j].batch != testBatch)
                    continue;
                double maxLikelihood = getPostLikelihood(u, j, 0, 0);
                users[u].posts[j].inferedPlatform = 0;
                for (int p = 0; p < nPlatforms; p++) {
                    for (int z = 0; z < nTopics; z++) {
                        double likeLihood = getPostLikelihood(u, j, p, z);
                        if (maxLikelihood < likeLihood) {
                            maxLikelihood = likeLihood;
                            users[u].posts[j].inferedPlatform = p;
                        }
                    }
                }
            }
        }
    }

    private void outputPostTopicWordDisributions() {
        try {
            String filename = outputPath + "/topicWordDistributions.csv";
            BufferedWriter bw = new BufferedWriter(new FileWriter(filename));
            for (int z = 0; z < nTopics; z++) {
                bw.write("" + z);
                for (int w = 0; w < vocabulary.length; w++)
                    bw.write("," + topics[z][w]);
                bw.write("\n");
            }
            bw.close();

            filename = outputPath + "/backgroundTopicWordDistribution.csv";
            bw = new BufferedWriter(new FileWriter(filename));
            bw.write(backgroundTopic[0] + "");
            for (int w = 1; w < vocabulary.length; w++)
                bw.write("," + backgroundTopic[w]);
            bw.close();
        } catch (Exception e) {
            System.out.println("Error in writing out topics to file!");
            e.printStackTrace();
            System.exit(0);
        }
    }

    private void outputCoinBias() {
        try {
            String fileName = outputPath + "/coinBias.csv";
            File file = new File(fileName);
            if (!file.exists()) {
                file.createNewFile();
            }
            BufferedWriter bw = new BufferedWriter(new FileWriter(file.getAbsoluteFile()));
            bw.write(coinBias[0] + "," + coinBias[1]);
            bw.close();
        } catch (Exception e) {
            System.out.println("Error in writing out coin bias to file!");
            e.printStackTrace();
            System.exit(0);
        }
    }

    private void outputPostTopicTopWords(int k) {
        try {
            String fileName = outputPath + "/postTopicTopWords.csv";
            File file = new File(fileName);
            if (!file.exists()) {
                file.createNewFile();
            }
            BufferedWriter bw = new BufferedWriter(new FileWriter(file.getAbsoluteFile()));
            RankingTool rankTool = new RankingTool();
            WeightedElement[] topWords = null;
            for (int z = 0; z < nTopics; z++) {
                bw.write(z + "\n");
                topWords = rankTool.getTopKbyWeight(vocabulary, topics[z], k);
                for (int j = 0; j < k; j++)
                    bw.write("," + topWords[j].name + "," + topWords[j].weight + "\n");
            }

            bw.write("background\n");
            topWords = rankTool.getTopKbyWeight(vocabulary, backgroundTopic, 2 * k);
            for (int j = 0; j < 2 * k; j++)
                bw.write("," + topWords[j].name + "," + topWords[j].weight + "\n");

            bw.close();
        } catch (Exception e) {
            System.out.println("Error in writing out post topic top words to file!");
            e.printStackTrace();
            System.exit(0);
        }
    }

    private void outputTopicTopPosts(int k) {
        int[] nTopicPosts = new int[nTopics];
        int nBackgroundTopicPosts = 0;
        for (int z = 0; z < nTopics; z++)
            nTopicPosts[z] = 0;
        for (int u = 0; u < users.length; u++) {
            for (int j = 0; j < users[u].posts.length; j++) {
                if (users[u].posts[j].batch == testBatch)
                    continue;
                if (users[u].posts[j].inferedTopic >= 0)
                    nTopicPosts[users[u].posts[j].inferedTopic]++;
                else
                    nBackgroundTopicPosts++;
            }
        }

        String[][] postID = new String[nTopics][];
        double[][] postPerplexity = new double[nTopics][];
        for (int z = 0; z < nTopics; z++) {
            postID[z] = new String[nTopicPosts[z]];
            postPerplexity[z] = new double[nTopicPosts[z]];
            nTopicPosts[z] = 0;
        }
        String[] backgroundTopicPostID = new String[nBackgroundTopicPosts];
        double[] backgroundTopicPostPerplexity = new double[nBackgroundTopicPosts];
        nBackgroundTopicPosts = 0;

        for (int u = 0; u < users.length; u++) {
            for (int j = 0; j < users[u].posts.length; j++) {
                if (users[u].posts[j].batch == testBatch)
                    continue;
                int z = users[u].posts[j].inferedTopic;
                if (z >= 0) {
                    postID[z][nTopicPosts[z]] = users[u].posts[j].postID;
                    postPerplexity[z][nTopicPosts[z]] = users[u].posts[j].inferedLikelihood
                            / users[u].posts[j].words.length;
                    nTopicPosts[z]++;
                } else {
                    backgroundTopicPostID[nBackgroundTopicPosts] = users[u].posts[j].postID;
                    backgroundTopicPostPerplexity[nBackgroundTopicPosts] = users[u].posts[j].inferedLikelihood;
                    nBackgroundTopicPosts++;
                }

            }
        }

        try {
            String filename = outputPath + "/topicTopPosts.csv";
            BufferedWriter bw = new BufferedWriter(new FileWriter(filename));
            RankingTool rankTool = new RankingTool();
            WeightedElement[] topPosts = null;
            for (int z = 0; z < nTopics; z++) {
                bw.write(z + "\n");
                topPosts = rankTool.getTopKbyWeight(postID[z], postPerplexity[z], Math.min(k, nTopicPosts[z]));
                for (int j = 0; j < Math.min(k, nTopicPosts[z]); j++)
                    bw.write("," + topPosts[j].name + "," + topPosts[j].weight + "\n");
            }

            if (nBackgroundTopicPosts > 0) {
                bw.write("background\n");
                topPosts = rankTool.getTopKbyWeight(backgroundTopicPostID, backgroundTopicPostPerplexity,
                        Math.min(k, nBackgroundTopicPosts));
                for (int j = 0; j < Math.min(k, nBackgroundTopicPosts); j++)
                    bw.write("," + topPosts[j].name + "," + topPosts[j].weight + "\n");
            }
            bw.close();
        } catch (Exception e) {
            System.out.println("Error in writing out post topic top posts to file!");
            e.printStackTrace();
            System.exit(0);
        }
    }

    private void outputUserTopicDistribution() {
        try {
            String fileName = outputPath + "/userTopicDistributions.csv";
            File file = new File(fileName);
            if (!file.exists()) {
                file.createNewFile();
            }
            BufferedWriter bw = new BufferedWriter(new FileWriter(file.getAbsoluteFile()));
            for (int u = 0; u < users.length; u++) {
                bw.write("" + users[u].userID);
                for (int z = 0; z < nTopics; z++)
                    bw.write("," + users[u].topicDistribution[z]);
                bw.write("\n");
            }
            bw.close();
        } catch (Exception e) {
            System.out.println("Error in writing out users' topic distribution to file!");
            e.printStackTrace();
            System.exit(0);
        }
    }

    private void outputUserSpecificTopicPlatformDistribution() {
        try {
            String fileName = outputPath + "/userTopicPlatformDistributions.csv";
            File file = new File(fileName);
            if (!file.exists()) {
                file.createNewFile();
            }
            BufferedWriter bw = new BufferedWriter(new FileWriter(file.getAbsoluteFile()));
            for (int u = 0; u < users.length; u++) {
                bw.write("" + users[u].userID);
                for (int z = 0; z < nTopics; z++) {
                    for (int p = 0; p < nPlatforms; p++) {
                        bw.write("," + users[u].topicPlatformDistribution[z][p]);
                        System.out.println(u + "," + z + "," + p + "," + users[u].topicPlatformDistribution[z][p]);
                    }
                }
                bw.write("\n");
            }
            bw.close();
        } catch (Exception e) {
            System.out.println("Error in writing out user topic platform distributions to file!");
            e.printStackTrace();
            System.exit(0);
        }
    }

    private void outputGlobalTopicPlatformDistribution() {
        try {
            String filename = outputPath + "/globalTopicPlatformDistributions.csv";
            BufferedWriter bw = new BufferedWriter(new FileWriter(filename));
            for (int z = 0; z < nTopics; z++) {
                bw.write("" + z);
                for (int p = 0; p < nPlatforms; p++) {
                    bw.write("," + globalTopicPlatformDistribution[z][p]);
                }
                bw.write("\n");
            }
            bw.close();
        } catch (Exception e) {
            System.out.println("Error in writing out user topic distributions to file!");
            e.printStackTrace();
            System.exit(0);
        }
    }

    private void outputLikelihoodPerplexity() {
        try {
            String fileName = outputPath + "/likelihood-perplexity.csv";
            File file = new File(fileName);
            if (!file.exists()) {
                file.createNewFile();
            }
            BufferedWriter bw = new BufferedWriter(new FileWriter(file.getAbsoluteFile()));
            bw.write("postLogLikelihood,postLogPerplexity\n");
            bw.write("" + postLogLikelidhood + "," + postLogPerplexity);
            bw.write("," + postContentLogLikelidhood + "," + postContentPerplexity);
            bw.close();
        } catch (Exception e) {
            System.out.println("Error in writing out posts to file!");
            e.printStackTrace();
            System.exit(0);
        }
    }

    private void outputInferedTopic() {
        try {
            SystemTool.createFolder(outputPath, "inferedTopics");
            for (int u = 0; u < users.length; u++) {
                String filename = outputPath + SystemTool.pathSeparator + "inferedTopics" + SystemTool.pathSeparator
                        + users[u].userID + ".txt";
                BufferedWriter bw = new BufferedWriter(new FileWriter(filename));
                for (int j = 0; j < users[u].posts.length; j++)
                    bw.write(users[u].posts[j].postID + "\t" + users[u].posts[j].inferedTopic + "\n");
                bw.close();
            }
        } catch (Exception e) {
            System.out.println("Error in writing out post topics to file!");
            e.printStackTrace();
            System.exit(0);
        }
    }

    private void outputInferedPlatform() {
        try {
            SystemTool.createFolder(outputPath, "inferedPlatforms");
            for (int u = 0; u < users.length; u++) {
                String filename = outputPath + SystemTool.pathSeparator + "inferedPlatforms"
                        + SystemTool.pathSeparator + users[u].userID + ".csv";
                BufferedWriter bw = new BufferedWriter(new FileWriter(filename));
                for (int j = 0; j < users[u].posts.length; j++) {
                    if (users[u].posts[j].batch != testBatch)
                        continue;
                    bw.write(users[u].posts[j].postID + "," + users[u].posts[j].inferedPlatform + ","
                            + users[u].posts[j].platform + "\n");
                }
                bw.close();
            }
        } catch (Exception e) {
            System.out.println("Error in writing out post topics to file!");
            e.printStackTrace();
            System.exit(0);
        }
    }

    public void outputAll() {
        outputPostTopicWordDisributions();
        outputPostTopicTopWords(20);
        outputCoinBias();
        outputUserTopicDistribution();
        if (modelType == ModelType.USER_SPECIFIC) {
            outputUserSpecificTopicPlatformDistribution();
        } else {
            outputGlobalTopicPlatformDistribution();
        }
        if (toOutputTopicTopPosts || toOutputInferedTopics) {
            inferPostTopic();
            if (toOutputTopicTopPosts) {
                outputTopicTopPosts(100);
            }
            if (toOutputInferedTopics) {
                outputInferedTopic();
            }
        }
        if (toOutputInferedPlatforms) {
            inferPostPlatform();
            outputInferedPlatform();
        }
        if (toOutputLikelihoodPerplexity) {
            getLikelihoodPerplexity();
            outputLikelihoodPerplexity();
        }
    }
}