Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package biz.itcons.wsdm.hw2_2; import biz.itcons.wsdm.hw2_2.util.StringTokenization; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.io.Reader; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVRecord; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * A NaiveBayes classfier. Needs training data to work. * @author Maciej Czyowicz <maciej.czyzowicz@itcons.biz> */ public class NaiveBayes { private static final Logger LOGGER = LoggerFactory.getLogger(NaiveBayes.class.getName()); private Map<String, ClassificationItem> parsedEntries = new HashMap<>(); private Set<String> vocabulary = new HashSet<>(); private Set<String> stemmedVocabulary = new HashSet<>(); private int globalTrainingCount = 0; /** * Fills NaiveBayes with data for classification. * * @param file a CSV file with documents that are classified. * @param classifierPos a position in CSV file where classification is * stored (zero based indexing) * @param documentPos a position in CSV file where document body is stored * (zero based indexing) * @throws FileNotFoundException * @throws IOException */ public void trainingDataSet(String file, int classifierPos, int documentPos) throws FileNotFoundException, IOException { try (Reader in = new FileReader(file)) { LOGGER.debug("Opening training data set: " + file); globalTrainingCount = 0; for (CSVRecord record : CSVFormat.EXCEL.parse(in)) { String classifier = record.get(classifierPos); globalTrainingCount++; if (parsedEntries.containsKey(classifier)) { parsedEntries.get(classifier).addDocument(record.get(documentPos)); } else { ClassificationItem ci = new ClassificationItem(classifier); ci.addDocument(record.get(documentPos)); parsedEntries.put(classifier, ci); } } LOGGER.trace("Read " + globalTrainingCount + " from training set"); } } /** * Creates global vocabulary based on training set. Required for calculating * probabilities for final classification. */ public void buildGlobalVocabulary() { for (ClassificationItem ci : parsedEntries.values()) { vocabulary.addAll(ci.getTokens().keySet()); stemmedVocabulary.addAll(ci.getStemedTokens().keySet()); } LOGGER.debug( "Vocabulary size: " + vocabulary.size() + " in " + parsedEntries.keySet().size() + " classes."); LOGGER.debug("Stem vocabulary size: " + vocabulary.size() + " in " + parsedEntries.keySet().size() + " classes."); } /** * Checks feasibility of classification based on test data. For given set * it runs NaiveBayes classifier and checks the result with one that is * provided for test data. * @param file A CSV file with documents that are classified. * @param classifierPos a position in CSV file where classification is * stored (zero based indexing) * @param documentPos a position in CSV file where document body is stored * (zero based indexing) * @return ration of correctly classified documents vs. all test documents. * @throws FileNotFoundException * @throws IOException */ public double testDataSet(String file, int classifierPos, int documentPos) throws FileNotFoundException, IOException { int testSetCount = 0; int correctClass = 0; try (Reader in = new FileReader(file)) { LOGGER.debug("Opening training data set: " + file); for (CSVRecord record : CSVFormat.EXCEL.parse(in)) { testSetCount++; String classification = classifyDocument(record.get(documentPos)); if (record.get(classifierPos).equals(classification)) { correctClass++; } } LOGGER.trace("Read " + testSetCount + " from test set"); LOGGER.info("success ratio: " + correctClass + " of " + testSetCount); return (double) correctClass / testSetCount; } } /** * A classification for document. Based on the train data, an attempt to * classify a document. * @param document A document to be classfied. * @return Most probable classification for given document. */ public String classifyDocument(String document) { Map<String, Integer> docTokens = StringTokenization.getTokensWithMultiplicity(document); double totalProbability = Double.NEGATIVE_INFINITY; String totalClassification = ""; for (Map.Entry<String, ClassificationItem> e : parsedEntries.entrySet()) { double priorProb = ((double) e.getValue().getDocumentCount()) / globalTrainingCount; double currentClassProb = Math.log10(priorProb) + e.getValue().calcCondProbDoc(docTokens, vocabulary.size()); if (currentClassProb > totalProbability) { totalProbability = currentClassProb; totalClassification = e.getKey(); } } return totalClassification; } }