biz.itcons.wsdm.hw2_2.NaiveBayes.java Source code

Java tutorial

Introduction

Here is the source code for biz.itcons.wsdm.hw2_2.NaiveBayes.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package biz.itcons.wsdm.hw2_2;

import biz.itcons.wsdm.hw2_2.util.StringTokenization;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * A NaiveBayes classfier. Needs training data to work.
 * @author Maciej Czyowicz <maciej.czyzowicz@itcons.biz>
 */
public class NaiveBayes {

    private static final Logger LOGGER = LoggerFactory.getLogger(NaiveBayes.class.getName());
    private Map<String, ClassificationItem> parsedEntries = new HashMap<>();
    private Set<String> vocabulary = new HashSet<>();
    private Set<String> stemmedVocabulary = new HashSet<>();
    private int globalTrainingCount = 0;

    /**
     * Fills NaiveBayes with data for classification.
     *
     * @param file a CSV file with documents that are classified.
     * @param classifierPos a position in CSV file where classification is
     * stored (zero based indexing)
     * @param documentPos a position in CSV file where document body is stored
     * (zero based indexing)
     * @throws FileNotFoundException
     * @throws IOException
     */
    public void trainingDataSet(String file, int classifierPos, int documentPos)
            throws FileNotFoundException, IOException {
        try (Reader in = new FileReader(file)) {
            LOGGER.debug("Opening training data set: " + file);
            globalTrainingCount = 0;
            for (CSVRecord record : CSVFormat.EXCEL.parse(in)) {
                String classifier = record.get(classifierPos);
                globalTrainingCount++;
                if (parsedEntries.containsKey(classifier)) {
                    parsedEntries.get(classifier).addDocument(record.get(documentPos));
                } else {
                    ClassificationItem ci = new ClassificationItem(classifier);
                    ci.addDocument(record.get(documentPos));
                    parsedEntries.put(classifier, ci);
                }
            }
            LOGGER.trace("Read " + globalTrainingCount + " from training set");
        }
    }

    /**
     * Creates global vocabulary based on training set. Required for calculating
     * probabilities for final classification.
     */
    public void buildGlobalVocabulary() {
        for (ClassificationItem ci : parsedEntries.values()) {
            vocabulary.addAll(ci.getTokens().keySet());
            stemmedVocabulary.addAll(ci.getStemedTokens().keySet());
        }
        LOGGER.debug(
                "Vocabulary size: " + vocabulary.size() + " in " + parsedEntries.keySet().size() + " classes.");
        LOGGER.debug("Stem vocabulary size: " + vocabulary.size() + " in " + parsedEntries.keySet().size()
                + " classes.");
    }

    /**
     * Checks feasibility of classification based on test data. For given set
     * it runs NaiveBayes classifier and checks the result with one that is
     * provided for test data.
     * @param file A CSV file with documents that are classified.
     * @param classifierPos a position in CSV file where classification is
     * stored (zero based indexing)
     * @param documentPos a position in CSV file where document body is stored
     * (zero based indexing)
     * @return ration of correctly classified documents vs. all test documents.
     * @throws FileNotFoundException
     * @throws IOException 
     */
    public double testDataSet(String file, int classifierPos, int documentPos)
            throws FileNotFoundException, IOException {
        int testSetCount = 0;
        int correctClass = 0;
        try (Reader in = new FileReader(file)) {
            LOGGER.debug("Opening training data set: " + file);
            for (CSVRecord record : CSVFormat.EXCEL.parse(in)) {
                testSetCount++;
                String classification = classifyDocument(record.get(documentPos));

                if (record.get(classifierPos).equals(classification)) {
                    correctClass++;
                }
            }
            LOGGER.trace("Read " + testSetCount + " from test set");
            LOGGER.info("success ratio: " + correctClass + " of " + testSetCount);
            return (double) correctClass / testSetCount;
        }
    }

    /**
     * A classification for document. Based on the train data, an attempt to 
     * classify a document.
     * @param document A document to be classfied.
     * @return Most probable classification for given document.
     */
    public String classifyDocument(String document) {
        Map<String, Integer> docTokens = StringTokenization.getTokensWithMultiplicity(document);

        double totalProbability = Double.NEGATIVE_INFINITY;
        String totalClassification = "";
        for (Map.Entry<String, ClassificationItem> e : parsedEntries.entrySet()) {
            double priorProb = ((double) e.getValue().getDocumentCount()) / globalTrainingCount;
            double currentClassProb = Math.log10(priorProb)
                    + e.getValue().calcCondProbDoc(docTokens, vocabulary.size());
            if (currentClassProb > totalProbability) {
                totalProbability = currentClassProb;
                totalClassification = e.getKey();
            }
        }
        return totalClassification;
    }
}