naivebayesimplementation.NaiveBayesClassifier.java Source code

Java tutorial

Introduction

Here is the source code for naivebayesimplementation.NaiveBayesClassifier.java

Source

package naivebayesimplementation;

/* 
 * Copyright (C) 2014 Vasilis Vryniotis <bbriniotis at datumbox.com>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

import classifiers.NaiveBayes;
import dataobjects.NaiveBayesKnowledgeBase;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.Map.Entry;

import org.apache.poi.ss.usermodel.Sheet;
import org.apache.poi.ss.usermodel.Workbook;
import org.apache.poi.xssf.usermodel.XSSFWorkbook;

public class NaiveBayesClassifier {

    StopWords sWord = new StopWords();
    PorterStemmer pWord = new PorterStemmer();
    public static Connection conn = null;
    private Map<String, String[]> allTrainingTopics;

    public static void main(String[] args) throws IOException {

        conn = null;
        NaiveBayesClassifier nbc = new NaiveBayesClassifier();

        Map<String, String[]> trainingExamples = new HashMap<>();
        Map<String, Integer> assignedTopicCount = new TreeMap<>();
        Map<String, List<String>> assignedTopics = new TreeMap<String, List<String>>();

        //Read Training Examples
        nbc.ReadTrainingExamples(trainingExamples);

        //train classifier
        NaiveBayes nb = new NaiveBayes();
        nb.setChisquareCriticalValue(6.63); //0.01 pvalue
        nb.train(trainingExamples);

        //get trained classifier knowledgeBase
        NaiveBayesKnowledgeBase knowledgeBase = nb.getKnowledgeBase();

        nb = null;

        //Use classifier
        nb = new NaiveBayes(knowledgeBase);

        //Classify a query to a topic queryAssignedTopic

        //        String query = "Missionary Service. I want to go on a Mission.";
        //        String queryAssignedTopic = nb.predict(nbc.GetCleanQuery(query));

        //----Give score to videos on a topic queryAssignedTopic (topic assigned to query)
        //        Map<String, Double> videoScoresOnTopic = nbc.GetVideoScoresForTopic(nb, queryAssignedTopic);
        //        nbc.PrintTopVideosForTopic(videoScoresOnTopic);

        //----Classify videos from database
        nbc.ClassifyVideosFromDB(nb, assignedTopics, assignedTopicCount);

        //----Creates a file that shows each assigned topic to video. Shows the data of the video
        //nbc.CreateFileWithRawDataAndTopicAssignment(assignedTopics, assignedTopicCount);

        //nbc.CreateFilesOfVideoData(assignedTopics);
        //nbc.WordFrequencyCounter();
    }

    private NaiveBayes GetNaiveBayesWithOneTopic(String topic) {

        Map<String, String[]> trainingExamples = new HashMap<>();
        trainingExamples.put(topic, allTrainingTopics.get(topic));

        //train classifier
        NaiveBayes nb = new NaiveBayes();
        nb.setChisquareCriticalValue(6.63); //0.01 pvalue
        nb.train(trainingExamples);

        //get trained classifier knowledgeBase
        NaiveBayesKnowledgeBase knowledgeBase = nb.getKnowledgeBase();

        //Use classifier
        nb = new NaiveBayes(knowledgeBase);

        return nb;
    }

    public String GetCleanQuery(String query) {
        return GetCleanTitle(query);
    }

    private void PrintTopVideosForTopic(Map<String, Double> videoScores) {

        videoScores = sortByComparator(videoScores, false);

        Iterator it = videoScores.entrySet().iterator();

        int i = 0;

        while (i < 100) {
            i++;

            Map.Entry pair = (Map.Entry) it.next();
            System.out.println(pair.getValue() + " = " + pair.getKey());
            it.remove();
        }
    }

    public Double GetVideoScoreForTopic(NaiveBayes nb, String topic, String videoId) {

        Map<String, Double> videoScores = new TreeMap<String, Double>();
        //NaiveBayes nbForOneTopic = GetNaiveBayesWithOneTopic(topic);
        PreparedStatement stmt = null;
        ResultSet rs = null;

        try {
            //         Class.forName("com.mysql.jdbc.Driver");
            //         conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/cs653","root", "103191");

            String title = null;
            String description = null;
            String vidId = null;

            double max = Double.NEGATIVE_INFINITY;
            double min = Double.POSITIVE_INFINITY;

            String query = "select vid, title, description from videos where vid = '" + videoId + "';";
            ;
            stmt = conn.prepareStatement(query);
            rs = stmt.executeQuery();

            while (rs.next()) {

                StringBuilder cleanCombinedInfo = new StringBuilder();
                StringBuilder rawCombinedInfo = new StringBuilder();

                vidId = rs.getString(1);
                title = rs.getString(2);
                description = rs.getString(3);

                // Append Raw Title
                rawCombinedInfo.append(title);
                rawCombinedInfo.append(" ");

                // Get Clean Title Info
                String cleanTitleInfo = GetCleanTitle(title);

                // Get Clean and Raw Tag Info for Video
                String[] tagInfo = new String[2];
                GetTagInfo(vidId, tagInfo);
                String rawTagInfo = tagInfo[0];
                String cleanTagInfo = tagInfo[1];

                // Append Raw Tag Info
                rawCombinedInfo.append(rawTagInfo);
                rawCombinedInfo.append(" ");

                // Append Raw Description
                rawCombinedInfo.append(description);

                //Get Clean Description Info
                String cleanDescriptionInfo = GetDescriptionInfo(description);

                // Append Clean Title
                cleanCombinedInfo.append(cleanTitleInfo);
                cleanCombinedInfo.append(" ");

                // Append Clean Tags
                cleanCombinedInfo.append(cleanTagInfo);
                cleanCombinedInfo.append(" ");

                // Append Clean Description
                cleanCombinedInfo.append(cleanDescriptionInfo);

                // Clean extra spaces
                String combinedInfoStr = cleanCombinedInfo.toString().replaceAll("\\s+", " ").trim();
                String rawCombinedInfoStr = rawCombinedInfo.toString();

                Map<String, Double> map = nb.getScoresForEveryTopicNormWithin(combinedInfoStr);
                Double score = map.get(topic);
                return score;
            }

            return null;

        } catch (Exception e) {
            System.out.println("Error: Could not select title and description 1");
            //System.out.println(e.);
            e.printStackTrace();
        } finally {
            safeClose(rs);
            safeClose(stmt);
        }

        return null;
    }

    /* GiveVideosFromDBTopicScore
     * 
     * Returns Map of Video Info and Scores for a Particular Topic
     * 
     */
    public Map<String, Double> GetVideoScoresForTopic(NaiveBayes nb, String topic) {

        Map<String, Double> videoScores = new TreeMap<String, Double>();
        //NaiveBayes nbForOneTopic = GetNaiveBayesWithOneTopic(topic);

        try {
            Class.forName("com.mysql.jdbc.Driver");
            conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/cs653", "root", "");

            PreparedStatement stmt = null;
            ResultSet rs = null;
            String title = null;
            String description = null;
            String vidId = null;

            double max = Double.NEGATIVE_INFINITY;
            double min = Double.POSITIVE_INFINITY;

            try {

                String query = "select vid, title, description from videos";
                stmt = conn.prepareStatement(query);
                rs = stmt.executeQuery();

                while (rs.next()) {

                    StringBuilder cleanCombinedInfo = new StringBuilder();
                    StringBuilder rawCombinedInfo = new StringBuilder();

                    vidId = rs.getString(1);
                    title = rs.getString(2);
                    description = rs.getString(3);

                    // Append Raw Title
                    rawCombinedInfo.append(title);
                    rawCombinedInfo.append(" ");

                    // Get Clean Title Info
                    String cleanTitleInfo = GetCleanTitle(title);

                    // Get Clean and Raw Tag Info for Video
                    String[] tagInfo = new String[2];
                    GetTagInfo(vidId, tagInfo);
                    String rawTagInfo = tagInfo[0];
                    String cleanTagInfo = tagInfo[1];

                    // Append Raw Tag Info
                    rawCombinedInfo.append(rawTagInfo);
                    rawCombinedInfo.append(" ");

                    // Append Raw Description
                    rawCombinedInfo.append(description);

                    //Get Clean Description Info
                    String cleanDescriptionInfo = GetDescriptionInfo(description);

                    // Append Clean Title
                    cleanCombinedInfo.append(cleanTitleInfo);
                    cleanCombinedInfo.append(" ");

                    // Append Clean Tags
                    cleanCombinedInfo.append(cleanTagInfo);
                    cleanCombinedInfo.append(" ");

                    // Append Clean Description
                    cleanCombinedInfo.append(cleanDescriptionInfo);

                    // Clean extra spaces
                    String combinedInfoStr = cleanCombinedInfo.toString().replaceAll("\\s+", " ").trim();
                    String rawCombinedInfoStr = rawCombinedInfo.toString();

                    //               if (combinedInfoStr.contains("missionari")) {
                    //                  String hello = "hello";
                    //                  hello += " my good friend";
                    //               }
                    //               System.out.println("\n" + rawCombinedInfoStr);
                    // This score is relative to the ranking of the topics. 1 is best, 0 is last
                    Map<String, Double> map = nb.getScoresForEveryTopicNormWithin(combinedInfoStr);
                    double score = map.get(topic);

                    // This score is relative to all the scores given by the classifier
                    //double score = nb.getRawScoresForEveryTopic(combinedInfoStr).get(topic);
                    //               for (Map.Entry<String, Double> entry : map.entrySet())
                    //               {
                    //                   System.out.println(entry.getKey() + ": " + entry.getValue());
                    //               }

                    videoScores.put(vidId, score);

                }

                //videoScores = nb.NormalizeScoresThoughout(videoScores);

                return videoScores;

            } catch (Exception e) {
                System.out.println("Error: Could not select title and description 1");
                //System.out.println(e.);
                e.printStackTrace();
            } finally {
                safeClose(rs);
                safeClose(stmt);
            }

        } catch (ClassNotFoundException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } catch (SQLException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        return null;
    }

    private void ClassifyVideosFromDB(NaiveBayes nb, Map<String, List<String>> assignedTopics,
            Map<String, Integer> assignedTopicCount) {

        try {
            Class.forName("com.mysql.jdbc.Driver");
            conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/cs653", "root", "");

            PreparedStatement stmt = null;
            ResultSet rs = null;
            String title = null;
            String description = null;
            String vidId = null;

            try {

                String query = "select vid, title, description from videos";
                stmt = conn.prepareStatement(query);
                rs = stmt.executeQuery();

                while (rs.next()) {

                    StringBuilder cleanCombinedInfo = new StringBuilder();
                    StringBuilder rawCombinedInfo = new StringBuilder();

                    vidId = rs.getString(1);
                    title = rs.getString(2);
                    description = rs.getString(3);

                    // Append Raw Title
                    rawCombinedInfo.append(title);
                    rawCombinedInfo.append(" ");

                    // Get Clean Title Info
                    String cleanTitleInfo = GetCleanTitle(title);

                    // Get Clean and Raw Tag Info for Video
                    String[] tagInfo = new String[2];
                    GetTagInfo(vidId, tagInfo);
                    String rawTagInfo = tagInfo[0];
                    String cleanTagInfo = tagInfo[1];

                    // Append Raw Tag Info
                    rawCombinedInfo.append(rawTagInfo);
                    rawCombinedInfo.append(" ");

                    // Append Raw Description
                    rawCombinedInfo.append(description);

                    //Get Clean Description Info
                    String cleanDescriptionInfo = GetDescriptionInfo(description);

                    // Append Clean Title
                    cleanCombinedInfo.append(cleanTitleInfo);
                    cleanCombinedInfo.append(" ");

                    // Append Clean Tags
                    cleanCombinedInfo.append(cleanTagInfo);
                    cleanCombinedInfo.append(" ");

                    // Append Clean Description
                    cleanCombinedInfo.append(cleanDescriptionInfo);

                    // Clean extra spaces
                    String combinedInfoStr = cleanCombinedInfo.toString().replaceAll("\\s+", " ").trim();
                    String rawCombinedInfoStr = rawCombinedInfo.toString();

                    boolean raw = true;

                    // Does not assign topics to the Church Auditing Department/Committee Report Videos
                    if (!rawCombinedInfoStr.contains("Church Audit")
                            && !rawCombinedInfoStr.contains("Statistical Report")
                            && !rawCombinedInfoStr.contains("Committee Report")) {

                        // Predict Topics
                        String output = nb.predict(combinedInfoStr);
                        if (vidId.equals("FYVvE4tr2BI")) {
                            System.out.println(output + ": " + rawCombinedInfoStr);
                        }
                        if (assignedTopics.containsKey(output)) {
                            List<String> allInfo = assignedTopics.get(output);

                            if (!raw) {
                                allInfo.add(combinedInfoStr);
                            } else {
                                allInfo.add(rawCombinedInfoStr);
                            }

                            assignedTopics.put(output, allInfo);
                        } else {
                            List<String> allInfo = new ArrayList<String>();

                            if (!raw) {
                                allInfo.add(combinedInfoStr);
                            } else {
                                allInfo.add(rawCombinedInfoStr);
                            }

                            assignedTopics.put(output, allInfo);
                        }

                        //System.out.println(output + ": " + combinedInfoStr);

                        if (!assignedTopicCount.containsKey(output)) {
                            assignedTopicCount.put(output, 1);
                        } else {
                            int count = assignedTopicCount.get(output);
                            assignedTopicCount.put(output, ++count);
                        }
                    }
                }
            } catch (Exception e) {
                System.out.println("Error: Could not select title and description 1");
                //System.out.println(e.);
                e.printStackTrace();
            } finally {
                safeClose(rs);
                safeClose(stmt);
            }

        } catch (ClassNotFoundException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } catch (SQLException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    /* GetTagInfo()
     * 
     * Saves Tag Info into rawTagInfo
     * Saves Clean Info into cleanTagInfo
     */
    private void GetTagInfo(String vidId, String[] tagInfo) {

        try {
            StringBuilder rawCombinedInfo = new StringBuilder();
            StringBuilder cleanCombinedInfo = new StringBuilder();

            // Get Tags from Video
            String tagQuery = "select t.vid, t.tag from videos v, tags t where t.vid = v.vid and t.vid = ?;";
            PreparedStatement tagStmt = conn.prepareStatement(tagQuery);
            tagStmt.setString(1, vidId);
            ResultSet tagRs = tagStmt.executeQuery();

            // Append Tags
            while (tagRs.next()) {
                String tag = tagRs.getString(2);
                //System.out.print(tag);
                if (!isNumeric(tag) && !"notag".equals(tag)) {

                    // raw info
                    rawCombinedInfo.append(tag);
                    rawCombinedInfo.append(" ");

                    //tag = nbc.RemoveHighFrequencyWordsFromSentence(tag);
                    //System.out.print( " " + tag);
                    if (!tag.equals("") && tag != null) {

                        String stemmedTag = pWord.stem(tag);
                        cleanCombinedInfo.append(stemmedTag);
                        //System.out.println(" " + stemmedTag);
                        cleanCombinedInfo.append(" ");

                    }
                }

            }

            tagInfo[0] = rawCombinedInfo.toString();
            tagInfo[1] = cleanCombinedInfo.toString();

        } catch (SQLException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    /* GetCleanTitle()
     * 
     * returns a cleaned up string of the title
     */
    private String GetCleanTitle(String title) {

        //Take away all punctuation and numbers
        title = title.replaceAll("[^0-9a-zA-Z ]", " ");

        //Get rid of Extra spaces
        title = title.replaceAll("\\s+", " ").trim();
        title = RemoveStopWordsFromSentenceAndStem(title);

        return title;

    }

    /* GetDescriptionInfo()
     * 
     * returns a cleaned up string of the description
     */
    private String GetDescriptionInfo(String description) {

        if (description != null) {

            //Take away all punctuation and numbers
            description = description.replaceAll("[^0-9a-zA-Z ]", " ");

            //Get rid of Extra spaces
            description = description.replaceAll("\\s+", " ").trim();
            description = RemoveStopWordsFromSentenceAndStem(description);

            // Append Description
            if (description != null && !description.equals("null") && !description.equals("")) {
                return description;
            }
        }

        return "";

    }

    /* Reads from an excel sheet
     * 
     * Topic, Words
     * 
     */
    public void ReadTrainingExamples(Map<String, String[]> trainingExamples) {
        String excelFilePath = "resources/topics/AssignedTopics.xlsx";
        FileInputStream inputStream;
        try {

            inputStream = new FileInputStream(new File(excelFilePath));

            Workbook workbook = new XSSFWorkbook(inputStream);

            Sheet firstSheet = workbook.getSheetAt(0);

            String[] emtpyArray = new String[1];
            emtpyArray[0] = " ";

            // Using for Topics that don't fit any of the topics
            //trainingExamples.put("Another Topic", emtpyArray );

            for (int row = 0; row < 60; row++) {
                String topicName = firstSheet.getRow(row).getCell(1).toString();
                String topicWords = firstSheet.getRow(row).getCell(2).toString();

                String[] topicWordsArray = topicWords.split("\\s");

                trainingExamples.put(topicName, topicWordsArray);

            }

            // Saving this topics for scoring
            allTrainingTopics = trainingExamples;

            workbook.close();
            inputStream.close();

        } catch (FileNotFoundException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    private void CreateFileWithRawDataAndTopicAssignment(Map<String, List<String>> assignedTopics,
            Map<String, Integer> assignedTopicCount) {
        PrintWriter writer;
        try {

            writer = new PrintWriter("Raw-Data-With-Topic-Assignment.txt", "UTF-8");

            // Print out Results about output
            Iterator<Map.Entry<String, List<String>>> iter = assignedTopics.entrySet().iterator();

            //loop through all the categories and training examples
            while (iter.hasNext()) {
                Map.Entry<String, List<String>> entry = iter.next();
                String assignedTopic = entry.getKey();
                List<String> allInfo = entry.getValue();

                for (int i = 0; i < allInfo.size(); i++) {
                    writer.println("\n" + assignedTopic + " : " + allInfo.get(i));
                }
            }

            // Print out Results about output
            Iterator<Map.Entry<String, Integer>> it = assignedTopicCount.entrySet().iterator();

            //loop through all the categories and training examples
            while (it.hasNext()) {
                Map.Entry<String, Integer> entry = it.next();
                String assignedTopic = entry.getKey();
                Integer count = entry.getValue();

                writer.println("\n" + assignedTopic + " " + count);
            }

            it = assignedTopicCount.entrySet().iterator();

            while (it.hasNext()) {
                Map.Entry<String, Integer> entry = it.next();
                String assignedTopic = entry.getKey();
                Integer count = entry.getValue();

                writer.println(assignedTopic);
            }

            it = assignedTopicCount.entrySet().iterator();
            while (it.hasNext()) {
                Map.Entry<String, Integer> entry = it.next();
                String assignedTopic = entry.getKey();
                Integer count = entry.getValue();

                writer.println(count);
            }

            writer.close();

        } catch (FileNotFoundException | UnsupportedEncodingException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

    }

    private String RemoveStopWordsFromSentenceAndStem(String sent) {
        List<String> nonStopWordsStemmed = new ArrayList<String>();
        String allWords[] = sent.split(" ");
        for (int i = 0; i < allWords.length; i++) {
            String cleanString = rStemStop(allWords[i]);
            if (cleanString != null) {
                if (cleanString.length() >= 4 && !cleanString.substring(0, 4).equals("http")) {
                    nonStopWordsStemmed.add(cleanString);
                } else if (cleanString.length() < 4 && !cleanString.equals("www")) {
                    nonStopWordsStemmed.add(cleanString);
                }
            }
        }
        StringBuilder stringToBuild = new StringBuilder();

        for (int i = 0; i < nonStopWordsStemmed.size(); i++) {
            stringToBuild.append(nonStopWordsStemmed.get(i));
            if (i != nonStopWordsStemmed.size() - 1) {
                stringToBuild.append(" ");
            }
        }

        return stringToBuild.toString();
    }

    private String rStemStop(String getString) {

        String pStr = getString.toLowerCase();
        if (!sWord.contains(pStr)) {
            pStr = pWord.stem(pStr);
            if (!pStr.contentEquals("Invalid term") && !pStr.contentEquals("No term entered")) {
                //System.out.println(pStr);
                return pStr;
            }
        }

        return null;
    }

    private void safeClose(Connection con) {
        if (con != null) {
            try {
                con.close();
            } catch (SQLException e) {
                // ...
            }
        }
    }

    private void safeClose(Statement st) {
        if (st != null) {
            try {
                st.close();
            } catch (SQLException e) {
                // ...
            }
        }
    }

    private void safeClose(PreparedStatement cps) {
        if (cps != null) {
            try {
                cps.close();
            } catch (SQLException e) {
                // ...
            }
        }
    }

    private void safeClose(ResultSet crs) {
        if (crs != null) {
            try {
                crs.close();
            } catch (SQLException e) {
                // ...
            }
        }
    }

    private boolean isNumeric(String str) {
        try {
            double d = Double.parseDouble(str);
        } catch (NumberFormatException nfe) {
            return false;
        }
        return true;
    }

    private Map<String, Double> sortByComparator(Map<String, Double> unsortMap, final boolean order) {

        List<Entry<String, Double>> list = new LinkedList<Entry<String, Double>>(unsortMap.entrySet());

        // Sorting the list based on values
        Collections.sort(list, new Comparator<Entry<String, Double>>() {
            public int compare(Entry<String, Double> o1, Entry<String, Double> o2) {
                if (order) {
                    return o1.getValue().compareTo(o2.getValue());
                } else {
                    return o2.getValue().compareTo(o1.getValue());

                }
            }
        });

        // Maintaining insertion order with the help of LinkedList
        Map<String, Double> sortedMap = new LinkedHashMap<String, Double>();
        for (Entry<String, Double> entry : list) {
            sortedMap.put(entry.getKey(), entry.getValue());
        }

        return sortedMap;
    }

}