pl.polzone.classifier.Classifier.java Source code

Java tutorial

Introduction

Here is the source code for pl.polzone.classifier.Classifier.java

Source

/*
 * Copyright (C) 2013 Grzegorz Taczyk
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy of this 
 * software and associated documentation files (the "Software"), to deal in the Software 
 * without restriction, including without limitation the rights to use, copy, modify, merge, 
 * publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons 
 * to whom the Software is furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in all copies or 
 * substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING 
 * BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
 * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR 
 * THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 * 
 */

package pl.polzone.classifier;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;

import org.apache.commons.collections.Bag;
import org.apache.commons.collections.bag.HashBag;
import org.apache.commons.lang.StringUtils;

import com.google.common.base.Function;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultiset;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import com.google.common.collect.Multiset.Entry;
import com.google.common.collect.Multisets;
import com.google.common.collect.Sets;

public class Classifier {
    private HashMap<String, Bag> occurences = new HashMap<String, Bag>();
    private HashMap<String, String> stems = new HashMap<String, String>();
    private Bag wordCount = new HashBag();
    private int feedCount = 0;

    public final byte STEM_THRESHOLD = 2;

    private String stem(String word) {
        if (stems.containsKey(word))
            return stems.get(word);

        for (String stem : stems.keySet())
            if (StringUtils.getLevenshteinDistance(stem, word) < STEM_THRESHOLD
                    || StringUtils.getLevenshteinDistance(stems.get(stem), word) < STEM_THRESHOLD)
                return stem;

        stems.put(word, word);

        return word;
    }

    public void feed(String category, java.util.List<String> words) {
        feedCount++;

        for (Object word : new HashSet(Lists.transform(words, new Function<String, String>() {
            @Override
            public String apply(String word) {
                return stem(word);
            }
        }))) {
            if (word == null)
                continue;
            wordCount.add(word);
            if (!occurences.containsKey(word))
                occurences.put((String) word, new HashBag());
            occurences.get(word).add(category);
        }
    }

    public String predict(java.util.List<String> words) {
        final Multiset<String> scores = HashMultiset.create();
        for (String word : words) {
            word = stem(word);
            if (wordCount.getCount(word) > feedCount / 2)
                continue;
            if (occurences.containsKey(word))
                for (Object category : occurences.get(word).uniqueSet())
                    scores.add((String) category,
                            occurences.get(word).getCount(category) + (feedCount - wordCount.getCount(word)));
        }

        if (scores.isEmpty())
            return null;

        Iterator<Entry<String>> sorted = Multisets.copyHighestCountFirst(scores).entrySet().iterator();
        String highest = sorted.next().getElement();
        if (sorted.hasNext()) {
            String runnerUp = sorted.next().getElement();
            if (scores.count(highest) > scores.count(runnerUp) * 2)
                feed(highest, words);
        }
        return highest;
    }
}