Java tutorial
package naivebayesimplementation; /* * Copyright (C) 2014 Vasilis Vryniotis <bbriniotis at datumbox.com> * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ import classifiers.NaiveBayes; import dataobjects.NaiveBayesKnowledgeBase; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintWriter; import java.io.UnsupportedEncodingException; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.TreeMap; import java.util.Map.Entry; import org.apache.poi.ss.usermodel.Sheet; import org.apache.poi.ss.usermodel.Workbook; import org.apache.poi.xssf.usermodel.XSSFWorkbook; public class NaiveBayesClassifier { StopWords sWord = new StopWords(); PorterStemmer pWord = new PorterStemmer(); public static Connection conn = null; private Map<String, String[]> allTrainingTopics; public static void main(String[] args) throws IOException { conn = null; NaiveBayesClassifier nbc = new NaiveBayesClassifier(); Map<String, String[]> trainingExamples = new HashMap<>(); Map<String, Integer> assignedTopicCount = new TreeMap<>(); Map<String, List<String>> assignedTopics = new TreeMap<String, List<String>>(); //Read Training Examples nbc.ReadTrainingExamples(trainingExamples); //train classifier NaiveBayes nb = new NaiveBayes(); nb.setChisquareCriticalValue(6.63); //0.01 pvalue nb.train(trainingExamples); //get trained classifier knowledgeBase NaiveBayesKnowledgeBase knowledgeBase = nb.getKnowledgeBase(); nb = null; //Use classifier nb = new NaiveBayes(knowledgeBase); //Classify a query to a topic queryAssignedTopic // String query = "Missionary Service. I want to go on a Mission."; // String queryAssignedTopic = nb.predict(nbc.GetCleanQuery(query)); //----Give score to videos on a topic queryAssignedTopic (topic assigned to query) // Map<String, Double> videoScoresOnTopic = nbc.GetVideoScoresForTopic(nb, queryAssignedTopic); // nbc.PrintTopVideosForTopic(videoScoresOnTopic); //----Classify videos from database nbc.ClassifyVideosFromDB(nb, assignedTopics, assignedTopicCount); //----Creates a file that shows each assigned topic to video. Shows the data of the video //nbc.CreateFileWithRawDataAndTopicAssignment(assignedTopics, assignedTopicCount); //nbc.CreateFilesOfVideoData(assignedTopics); //nbc.WordFrequencyCounter(); } private NaiveBayes GetNaiveBayesWithOneTopic(String topic) { Map<String, String[]> trainingExamples = new HashMap<>(); trainingExamples.put(topic, allTrainingTopics.get(topic)); //train classifier NaiveBayes nb = new NaiveBayes(); nb.setChisquareCriticalValue(6.63); //0.01 pvalue nb.train(trainingExamples); //get trained classifier knowledgeBase NaiveBayesKnowledgeBase knowledgeBase = nb.getKnowledgeBase(); //Use classifier nb = new NaiveBayes(knowledgeBase); return nb; } public String GetCleanQuery(String query) { return GetCleanTitle(query); } private void PrintTopVideosForTopic(Map<String, Double> videoScores) { videoScores = sortByComparator(videoScores, false); Iterator it = videoScores.entrySet().iterator(); int i = 0; while (i < 100) { i++; Map.Entry pair = (Map.Entry) it.next(); System.out.println(pair.getValue() + " = " + pair.getKey()); it.remove(); } } public Double GetVideoScoreForTopic(NaiveBayes nb, String topic, String videoId) { Map<String, Double> videoScores = new TreeMap<String, Double>(); //NaiveBayes nbForOneTopic = GetNaiveBayesWithOneTopic(topic); PreparedStatement stmt = null; ResultSet rs = null; try { // Class.forName("com.mysql.jdbc.Driver"); // conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/cs653","root", "103191"); String title = null; String description = null; String vidId = null; double max = Double.NEGATIVE_INFINITY; double min = Double.POSITIVE_INFINITY; String query = "select vid, title, description from videos where vid = '" + videoId + "';"; ; stmt = conn.prepareStatement(query); rs = stmt.executeQuery(); while (rs.next()) { StringBuilder cleanCombinedInfo = new StringBuilder(); StringBuilder rawCombinedInfo = new StringBuilder(); vidId = rs.getString(1); title = rs.getString(2); description = rs.getString(3); // Append Raw Title rawCombinedInfo.append(title); rawCombinedInfo.append(" "); // Get Clean Title Info String cleanTitleInfo = GetCleanTitle(title); // Get Clean and Raw Tag Info for Video String[] tagInfo = new String[2]; GetTagInfo(vidId, tagInfo); String rawTagInfo = tagInfo[0]; String cleanTagInfo = tagInfo[1]; // Append Raw Tag Info rawCombinedInfo.append(rawTagInfo); rawCombinedInfo.append(" "); // Append Raw Description rawCombinedInfo.append(description); //Get Clean Description Info String cleanDescriptionInfo = GetDescriptionInfo(description); // Append Clean Title cleanCombinedInfo.append(cleanTitleInfo); cleanCombinedInfo.append(" "); // Append Clean Tags cleanCombinedInfo.append(cleanTagInfo); cleanCombinedInfo.append(" "); // Append Clean Description cleanCombinedInfo.append(cleanDescriptionInfo); // Clean extra spaces String combinedInfoStr = cleanCombinedInfo.toString().replaceAll("\\s+", " ").trim(); String rawCombinedInfoStr = rawCombinedInfo.toString(); Map<String, Double> map = nb.getScoresForEveryTopicNormWithin(combinedInfoStr); Double score = map.get(topic); return score; } return null; } catch (Exception e) { System.out.println("Error: Could not select title and description 1"); //System.out.println(e.); e.printStackTrace(); } finally { safeClose(rs); safeClose(stmt); } return null; } /* GiveVideosFromDBTopicScore * * Returns Map of Video Info and Scores for a Particular Topic * */ public Map<String, Double> GetVideoScoresForTopic(NaiveBayes nb, String topic) { Map<String, Double> videoScores = new TreeMap<String, Double>(); //NaiveBayes nbForOneTopic = GetNaiveBayesWithOneTopic(topic); try { Class.forName("com.mysql.jdbc.Driver"); conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/cs653", "root", ""); PreparedStatement stmt = null; ResultSet rs = null; String title = null; String description = null; String vidId = null; double max = Double.NEGATIVE_INFINITY; double min = Double.POSITIVE_INFINITY; try { String query = "select vid, title, description from videos"; stmt = conn.prepareStatement(query); rs = stmt.executeQuery(); while (rs.next()) { StringBuilder cleanCombinedInfo = new StringBuilder(); StringBuilder rawCombinedInfo = new StringBuilder(); vidId = rs.getString(1); title = rs.getString(2); description = rs.getString(3); // Append Raw Title rawCombinedInfo.append(title); rawCombinedInfo.append(" "); // Get Clean Title Info String cleanTitleInfo = GetCleanTitle(title); // Get Clean and Raw Tag Info for Video String[] tagInfo = new String[2]; GetTagInfo(vidId, tagInfo); String rawTagInfo = tagInfo[0]; String cleanTagInfo = tagInfo[1]; // Append Raw Tag Info rawCombinedInfo.append(rawTagInfo); rawCombinedInfo.append(" "); // Append Raw Description rawCombinedInfo.append(description); //Get Clean Description Info String cleanDescriptionInfo = GetDescriptionInfo(description); // Append Clean Title cleanCombinedInfo.append(cleanTitleInfo); cleanCombinedInfo.append(" "); // Append Clean Tags cleanCombinedInfo.append(cleanTagInfo); cleanCombinedInfo.append(" "); // Append Clean Description cleanCombinedInfo.append(cleanDescriptionInfo); // Clean extra spaces String combinedInfoStr = cleanCombinedInfo.toString().replaceAll("\\s+", " ").trim(); String rawCombinedInfoStr = rawCombinedInfo.toString(); // if (combinedInfoStr.contains("missionari")) { // String hello = "hello"; // hello += " my good friend"; // } // System.out.println("\n" + rawCombinedInfoStr); // This score is relative to the ranking of the topics. 1 is best, 0 is last Map<String, Double> map = nb.getScoresForEveryTopicNormWithin(combinedInfoStr); double score = map.get(topic); // This score is relative to all the scores given by the classifier //double score = nb.getRawScoresForEveryTopic(combinedInfoStr).get(topic); // for (Map.Entry<String, Double> entry : map.entrySet()) // { // System.out.println(entry.getKey() + ": " + entry.getValue()); // } videoScores.put(vidId, score); } //videoScores = nb.NormalizeScoresThoughout(videoScores); return videoScores; } catch (Exception e) { System.out.println("Error: Could not select title and description 1"); //System.out.println(e.); e.printStackTrace(); } finally { safeClose(rs); safeClose(stmt); } } catch (ClassNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (SQLException e) { // TODO Auto-generated catch block e.printStackTrace(); } return null; } private void ClassifyVideosFromDB(NaiveBayes nb, Map<String, List<String>> assignedTopics, Map<String, Integer> assignedTopicCount) { try { Class.forName("com.mysql.jdbc.Driver"); conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/cs653", "root", ""); PreparedStatement stmt = null; ResultSet rs = null; String title = null; String description = null; String vidId = null; try { String query = "select vid, title, description from videos"; stmt = conn.prepareStatement(query); rs = stmt.executeQuery(); while (rs.next()) { StringBuilder cleanCombinedInfo = new StringBuilder(); StringBuilder rawCombinedInfo = new StringBuilder(); vidId = rs.getString(1); title = rs.getString(2); description = rs.getString(3); // Append Raw Title rawCombinedInfo.append(title); rawCombinedInfo.append(" "); // Get Clean Title Info String cleanTitleInfo = GetCleanTitle(title); // Get Clean and Raw Tag Info for Video String[] tagInfo = new String[2]; GetTagInfo(vidId, tagInfo); String rawTagInfo = tagInfo[0]; String cleanTagInfo = tagInfo[1]; // Append Raw Tag Info rawCombinedInfo.append(rawTagInfo); rawCombinedInfo.append(" "); // Append Raw Description rawCombinedInfo.append(description); //Get Clean Description Info String cleanDescriptionInfo = GetDescriptionInfo(description); // Append Clean Title cleanCombinedInfo.append(cleanTitleInfo); cleanCombinedInfo.append(" "); // Append Clean Tags cleanCombinedInfo.append(cleanTagInfo); cleanCombinedInfo.append(" "); // Append Clean Description cleanCombinedInfo.append(cleanDescriptionInfo); // Clean extra spaces String combinedInfoStr = cleanCombinedInfo.toString().replaceAll("\\s+", " ").trim(); String rawCombinedInfoStr = rawCombinedInfo.toString(); boolean raw = true; // Does not assign topics to the Church Auditing Department/Committee Report Videos if (!rawCombinedInfoStr.contains("Church Audit") && !rawCombinedInfoStr.contains("Statistical Report") && !rawCombinedInfoStr.contains("Committee Report")) { // Predict Topics String output = nb.predict(combinedInfoStr); if (vidId.equals("FYVvE4tr2BI")) { System.out.println(output + ": " + rawCombinedInfoStr); } if (assignedTopics.containsKey(output)) { List<String> allInfo = assignedTopics.get(output); if (!raw) { allInfo.add(combinedInfoStr); } else { allInfo.add(rawCombinedInfoStr); } assignedTopics.put(output, allInfo); } else { List<String> allInfo = new ArrayList<String>(); if (!raw) { allInfo.add(combinedInfoStr); } else { allInfo.add(rawCombinedInfoStr); } assignedTopics.put(output, allInfo); } //System.out.println(output + ": " + combinedInfoStr); if (!assignedTopicCount.containsKey(output)) { assignedTopicCount.put(output, 1); } else { int count = assignedTopicCount.get(output); assignedTopicCount.put(output, ++count); } } } } catch (Exception e) { System.out.println("Error: Could not select title and description 1"); //System.out.println(e.); e.printStackTrace(); } finally { safeClose(rs); safeClose(stmt); } } catch (ClassNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (SQLException e) { // TODO Auto-generated catch block e.printStackTrace(); } } /* GetTagInfo() * * Saves Tag Info into rawTagInfo * Saves Clean Info into cleanTagInfo */ private void GetTagInfo(String vidId, String[] tagInfo) { try { StringBuilder rawCombinedInfo = new StringBuilder(); StringBuilder cleanCombinedInfo = new StringBuilder(); // Get Tags from Video String tagQuery = "select t.vid, t.tag from videos v, tags t where t.vid = v.vid and t.vid = ?;"; PreparedStatement tagStmt = conn.prepareStatement(tagQuery); tagStmt.setString(1, vidId); ResultSet tagRs = tagStmt.executeQuery(); // Append Tags while (tagRs.next()) { String tag = tagRs.getString(2); //System.out.print(tag); if (!isNumeric(tag) && !"notag".equals(tag)) { // raw info rawCombinedInfo.append(tag); rawCombinedInfo.append(" "); //tag = nbc.RemoveHighFrequencyWordsFromSentence(tag); //System.out.print( " " + tag); if (!tag.equals("") && tag != null) { String stemmedTag = pWord.stem(tag); cleanCombinedInfo.append(stemmedTag); //System.out.println(" " + stemmedTag); cleanCombinedInfo.append(" "); } } } tagInfo[0] = rawCombinedInfo.toString(); tagInfo[1] = cleanCombinedInfo.toString(); } catch (SQLException e) { // TODO Auto-generated catch block e.printStackTrace(); } } /* GetCleanTitle() * * returns a cleaned up string of the title */ private String GetCleanTitle(String title) { //Take away all punctuation and numbers title = title.replaceAll("[^0-9a-zA-Z ]", " "); //Get rid of Extra spaces title = title.replaceAll("\\s+", " ").trim(); title = RemoveStopWordsFromSentenceAndStem(title); return title; } /* GetDescriptionInfo() * * returns a cleaned up string of the description */ private String GetDescriptionInfo(String description) { if (description != null) { //Take away all punctuation and numbers description = description.replaceAll("[^0-9a-zA-Z ]", " "); //Get rid of Extra spaces description = description.replaceAll("\\s+", " ").trim(); description = RemoveStopWordsFromSentenceAndStem(description); // Append Description if (description != null && !description.equals("null") && !description.equals("")) { return description; } } return ""; } /* Reads from an excel sheet * * Topic, Words * */ public void ReadTrainingExamples(Map<String, String[]> trainingExamples) { String excelFilePath = "resources/topics/AssignedTopics.xlsx"; FileInputStream inputStream; try { inputStream = new FileInputStream(new File(excelFilePath)); Workbook workbook = new XSSFWorkbook(inputStream); Sheet firstSheet = workbook.getSheetAt(0); String[] emtpyArray = new String[1]; emtpyArray[0] = " "; // Using for Topics that don't fit any of the topics //trainingExamples.put("Another Topic", emtpyArray ); for (int row = 0; row < 60; row++) { String topicName = firstSheet.getRow(row).getCell(1).toString(); String topicWords = firstSheet.getRow(row).getCell(2).toString(); String[] topicWordsArray = topicWords.split("\\s"); trainingExamples.put(topicName, topicWordsArray); } // Saving this topics for scoring allTrainingTopics = trainingExamples; workbook.close(); inputStream.close(); } catch (FileNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } private void CreateFileWithRawDataAndTopicAssignment(Map<String, List<String>> assignedTopics, Map<String, Integer> assignedTopicCount) { PrintWriter writer; try { writer = new PrintWriter("Raw-Data-With-Topic-Assignment.txt", "UTF-8"); // Print out Results about output Iterator<Map.Entry<String, List<String>>> iter = assignedTopics.entrySet().iterator(); //loop through all the categories and training examples while (iter.hasNext()) { Map.Entry<String, List<String>> entry = iter.next(); String assignedTopic = entry.getKey(); List<String> allInfo = entry.getValue(); for (int i = 0; i < allInfo.size(); i++) { writer.println("\n" + assignedTopic + " : " + allInfo.get(i)); } } // Print out Results about output Iterator<Map.Entry<String, Integer>> it = assignedTopicCount.entrySet().iterator(); //loop through all the categories and training examples while (it.hasNext()) { Map.Entry<String, Integer> entry = it.next(); String assignedTopic = entry.getKey(); Integer count = entry.getValue(); writer.println("\n" + assignedTopic + " " + count); } it = assignedTopicCount.entrySet().iterator(); while (it.hasNext()) { Map.Entry<String, Integer> entry = it.next(); String assignedTopic = entry.getKey(); Integer count = entry.getValue(); writer.println(assignedTopic); } it = assignedTopicCount.entrySet().iterator(); while (it.hasNext()) { Map.Entry<String, Integer> entry = it.next(); String assignedTopic = entry.getKey(); Integer count = entry.getValue(); writer.println(count); } writer.close(); } catch (FileNotFoundException | UnsupportedEncodingException e) { // TODO Auto-generated catch block e.printStackTrace(); } } private String RemoveStopWordsFromSentenceAndStem(String sent) { List<String> nonStopWordsStemmed = new ArrayList<String>(); String allWords[] = sent.split(" "); for (int i = 0; i < allWords.length; i++) { String cleanString = rStemStop(allWords[i]); if (cleanString != null) { if (cleanString.length() >= 4 && !cleanString.substring(0, 4).equals("http")) { nonStopWordsStemmed.add(cleanString); } else if (cleanString.length() < 4 && !cleanString.equals("www")) { nonStopWordsStemmed.add(cleanString); } } } StringBuilder stringToBuild = new StringBuilder(); for (int i = 0; i < nonStopWordsStemmed.size(); i++) { stringToBuild.append(nonStopWordsStemmed.get(i)); if (i != nonStopWordsStemmed.size() - 1) { stringToBuild.append(" "); } } return stringToBuild.toString(); } private String rStemStop(String getString) { String pStr = getString.toLowerCase(); if (!sWord.contains(pStr)) { pStr = pWord.stem(pStr); if (!pStr.contentEquals("Invalid term") && !pStr.contentEquals("No term entered")) { //System.out.println(pStr); return pStr; } } return null; } private void safeClose(Connection con) { if (con != null) { try { con.close(); } catch (SQLException e) { // ... } } } private void safeClose(Statement st) { if (st != null) { try { st.close(); } catch (SQLException e) { // ... } } } private void safeClose(PreparedStatement cps) { if (cps != null) { try { cps.close(); } catch (SQLException e) { // ... } } } private void safeClose(ResultSet crs) { if (crs != null) { try { crs.close(); } catch (SQLException e) { // ... } } } private boolean isNumeric(String str) { try { double d = Double.parseDouble(str); } catch (NumberFormatException nfe) { return false; } return true; } private Map<String, Double> sortByComparator(Map<String, Double> unsortMap, final boolean order) { List<Entry<String, Double>> list = new LinkedList<Entry<String, Double>>(unsortMap.entrySet()); // Sorting the list based on values Collections.sort(list, new Comparator<Entry<String, Double>>() { public int compare(Entry<String, Double> o1, Entry<String, Double> o2) { if (order) { return o1.getValue().compareTo(o2.getValue()); } else { return o2.getValue().compareTo(o1.getValue()); } } }); // Maintaining insertion order with the help of LinkedList Map<String, Double> sortedMap = new LinkedHashMap<String, Double>(); for (Entry<String, Double> entry : list) { sortedMap.put(entry.getKey(), entry.getValue()); } return sortedMap; } }