etc.aloe.filters.StringToDictionaryVector.java Source code

Java tutorial

Introduction

Here is the source code for etc.aloe.filters.StringToDictionaryVector.java

Source

/*
 * This file is part of ALOE.
 *
 * ALOE is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
    
 * ALOE is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
    
 * You should have received a copy of the GNU General Public License
 * along with ALOE.  If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
 */
package etc.aloe.filters;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.Serializable;
import java.util.*;
import weka.core.*;
import weka.filters.Filter;
import weka.filters.SimpleBatchFilter;

/**
 * Based on Weka's StringToWordVector, this derives its initial word list from a
 * provided lexicon instead of the string attribute. Filtering of terms is
 * performed similarly to StringToWordVector.
 *
 * @author Michael Brooks <mjbrooks@uw.edu>
 */
public class StringToDictionaryVector extends SimpleBatchFilter {

    private int stringAttributeIndex = -1;
    private String stringAttribute;
    List<String> termList;
    /**
     * Contains the number of documents (instances) in the input format from
     * which the dictionary is created. It is used in IDF transform.
     */
    private int m_NumInstances = -1;
    /**
     * whether to operate on a per-class basis.
     */
    private boolean m_doNotOperateOnPerClassBasis = false;
    /**
     * The default number of words (per class if there is a class attribute
     * assigned) to attempt to keep.
     */
    private int m_WordsToKeep = 1000;
    /**
     * the minimum (per-class) word frequency.
     */
    private int m_minTermFreq = 1;
    /**
     * Contains the number of documents (instances) a particular word appears
     * in. The counts are stored with the same indexing as m_selectedTerms.
     */
    private int[] m_DocsCounts;
    /**
     * A String prefix for the attribute names.
     */
    private String m_Prefix = "";
    /**
     * The set of terms that occurred frequently enough to be included as
     * attributes.
     */
    private ArrayList<String> m_selectedTerms;
    /**
     * The trie containing the selected terms for matching.
     */
    private Trie m_selectedTermsTrie;
    /**
     * Maps the terms to indices in m_selectedTerms
     */
    private HashMap<String, Integer> m_selectedTermIndices;
    /**
     * True if word frequencies should be transformed into log(1+fi) where fi is
     * the frequency of word i.
     */
    private boolean m_TFTransform;
    /**
     * True if word frequencies should be transformed into
     * fij*log(numOfDocs/numOfDocsWithWordi).
     */
    private boolean m_IDFTransform;
    /**
     * True if output instances should contain word frequency rather than
     * boolean 0 or 1.
     */
    private boolean m_OutputCounts = false;
    /**
     * The normalization to apply.
     */
    protected int m_filterType = FILTER_NONE;
    /**
     * normalization: No normalization.
     */
    public static final int FILTER_NONE = 0;
    /**
     * normalization: Normalize all data.
     */
    public static final int FILTER_NORMALIZE_ALL = 1;
    /**
     * normalization: Normalize test data only.
     */
    public static final int FILTER_NORMALIZE_TEST_ONLY = 2;
    /**
     * Specifies whether document's (instance's) word frequencies are to be
     * normalized. The are normalized to average length of documents specified
     * as input format.
     */
    public static final Tag[] TAGS_FILTER = { new Tag(FILTER_NONE, "No normalization"),
            new Tag(FILTER_NORMALIZE_ALL, "Normalize all data"),
            new Tag(FILTER_NORMALIZE_TEST_ONLY, "Normalize test data only") };
    /**
     * Contains the average length of documents (among the first batch of
     * instances aka training data). This is used in length normalization of
     * documents which will be normalized to average document length.
     */
    private double m_AvgDocLength = -1;

    /**
     * Gets whether if the word frequencies for a document (instance) should be
     * normalized or not.
     *
     * @return true if word frequencies are to be normalized.
     */
    public SelectedTag getNormalizeDocLength() {

        return new SelectedTag(m_filterType, TAGS_FILTER);
    }

    /**
     * Sets whether if the word frequencies for a document (instance) should be
     * normalized or not.
     *
     * @param newType the new type.
     */
    public void setNormalizeDocLength(SelectedTag newType) {

        if (newType.getTags() == TAGS_FILTER) {
            m_filterType = newType.getSelectedTag().getID();
        }
    }

    /**
     * Gets whether if the word frequencies should be transformed into
     * log(1+fij) where fij is the frequency of word i in document(instance) j.
     *
     * @return true if word frequencies are to be transformed.
     */
    public boolean getTFTransform() {
        return this.m_TFTransform;
    }

    /**
     * Sets whether if the word frequencies should be transformed into
     * log(1+fij) where fij is the frequency of word i in document(instance) j.
     *
     * @param TFTransform true if word frequencies are to be transformed.
     */
    public void setTFTransform(boolean TFTransform) {
        this.m_TFTransform = TFTransform;
    }

    /**
     * Sets whether if the word frequencies in a document should be transformed
     * into: <br> fij*log(num of Docs/num of Docs with word i) <br> where fij is
     * the frequency of word i in document(instance) j.
     *
     * @return true if the word frequencies are to be transformed.
     */
    public boolean getIDFTransform() {
        return this.m_IDFTransform;
    }

    /**
     * Sets whether if the word frequencies in a document should be transformed
     * into: <br> fij*log(num of Docs/num of Docs with word i) <br> where fij is
     * the frequency of word i in document(instance) j.
     *
     * @param IDFTransform true if the word frequecies are to be transformed
     */
    public void setIDFTransform(boolean IDFTransform) {
        this.m_IDFTransform = IDFTransform;
    }

    /**
     * Gets whether output instances contain 0 or 1 indicating word presence, or
     * word counts.
     *
     * @return true if word counts should be output.
     */
    public boolean getOutputWordCounts() {
        return m_OutputCounts;
    }

    /**
     * Sets whether output instances contain 0 or 1 indicating word presence, or
     * word counts.
     *
     * @param outputWordCounts true if word counts should be output.
     */
    public void setOutputWordCounts(boolean outputWordCounts) {
        m_OutputCounts = outputWordCounts;
    }

    /**
     * Get the attribute name prefix.
     *
     * @return The current attribute name prefix.
     */
    public String getAttributeNamePrefix() {
        return m_Prefix;
    }

    /**
     * Set the attribute name prefix.
     *
     * @param newPrefix String to use as the attribute name prefix.
     */
    public void setAttributeNamePrefix(String newPrefix) {
        m_Prefix = newPrefix;
    }

    /**
     * Get the MinTermFreq value.
     *
     * @return the MinTermFreq value.
     */
    public int getMinTermFreq() {
        return m_minTermFreq;
    }

    /**
     * Set the MinTermFreq value.
     *
     * @param newMinTermFreq The new MinTermFreq value.
     */
    public void setMinTermFreq(int newMinTermFreq) {
        this.m_minTermFreq = newMinTermFreq;
    }

    /**
     * Gets the number of words (per class if there is a class attribute
     * assigned) to attempt to keep.
     *
     * @return the target number of words in the output vector (per class if
     * assigned).
     */
    public int getWordsToKeep() {
        return m_WordsToKeep;
    }

    /**
     * Sets the number of words (per class if there is a class attribute
     * assigned) to attempt to keep.
     *
     * @param newWordsToKeep the target number of words in the output vector
     * (per class if assigned).
     */
    public void setWordsToKeep(int newWordsToKeep) {
        m_WordsToKeep = newWordsToKeep;
    }

    public String getStringAttribute() {
        return stringAttribute;
    }

    public void setStringAttribute(String name) {
        stringAttribute = name;
    }

    public List<String> getTermList() {
        return termList;
    }

    public void setTermList(List<String> termList) {
        this.termList = termList;
    }

    /**
     * Get the DoNotOperateOnPerClassBasis value.
     *
     * @return the DoNotOperateOnPerClassBasis value.
     */
    public boolean getDoNotOperateOnPerClassBasis() {
        return m_doNotOperateOnPerClassBasis;
    }

    /**
     * Set the DoNotOperateOnPerClassBasis value.
     *
     * @param newDoNotOperateOnPerClassBasis The new DoNotOperateOnPerClassBasis
     * value.
     */
    public void setDoNotOperateOnPerClassBasis(boolean newDoNotOperateOnPerClassBasis) {
        this.m_doNotOperateOnPerClassBasis = newDoNotOperateOnPerClassBasis;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.enableAllAttributes();
        result.enableAllClasses();
        result.enable(Capabilities.Capability.NO_CLASS); //// filter doesn't need class to be set//
        return result;
    }

    @Override
    public String globalInfo() {
        return "Creates a bag of words for a given string attribute. The values in the bag are selected from the provided dictionary.";
    }

    @Override
    protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
        if (getStringAttribute() == null) {
            throw new IllegalStateException("String attribute name not set");
        }

        stringAttributeIndex = inputFormat.attribute(getStringAttribute()).index();

        inputFormat = getInputFormat();
        //This generates m_selectedTerms and m_DocsCounts
        int[] docsCountsByTermIdx = determineDictionary(inputFormat);

        //Initialize the output format to be just like the input
        Instances outputFormat = new Instances(inputFormat, 0);

        //Set up the map from attr index to document frequency
        m_DocsCounts = new int[m_selectedTerms.size()];
        //And add the new attributes
        for (int i = 0; i < m_selectedTerms.size(); i++) {
            int attrIdx = outputFormat.numAttributes();
            int docsCount = docsCountsByTermIdx[i];
            m_DocsCounts[i] = docsCount;

            outputFormat.insertAttributeAt(new Attribute(m_Prefix + m_selectedTerms.get(i)), attrIdx);
        }

        return outputFormat;
    }

    private class Count {

        public int count = 0;
        public int docCount = 0;

        public Count(int count) {
            this.count = count;
        }
    }

    /**
     * sorts an array.
     *
     * @param array the array to sort
     */
    private static void sortArray(int[] array) {

        int i, j, h, N = array.length - 1;

        for (h = 1; h <= N / 9; h = 3 * h + 1)
            ;

        for (; h > 0; h /= 3) {
            for (i = h + 1; i <= N; i++) {
                int v = array[i];
                j = i;
                while (j > h && array[j - h] > v) {
                    array[j] = array[j - h];
                    j -= h;
                }
                array[j] = v;
            }
        }
    }

    private int[] determineDictionary(Instances instances) {
        if (stringAttributeIndex < 0) {
            throw new IllegalStateException("String attribute index not valid");
        }

        // Operate on a per-class basis if class attribute is set
        int classInd = instances.classIndex();
        int values = 1;
        if (!m_doNotOperateOnPerClassBasis && (classInd != -1)) {
            values = instances.attribute(classInd).numValues();
        }

        HashMap<String, Integer> termIndices = new HashMap<String, Integer>();
        for (int i = 0; i < termList.size(); i++) {
            termIndices.put(termList.get(i), i);
        }

        //Create the trie for matching terms
        Trie termTrie = new Trie(termList);

        //Initialize the dictionary/count map
        ArrayList<HashMap<Integer, Count>> termCounts = new ArrayList<HashMap<Integer, Count>>();
        for (int z = 0; z < values; z++) {
            termCounts.add(new HashMap<Integer, Count>());
        }

        //Go through all the instances and count the emoticons
        for (int i = 0; i < instances.numInstances(); i++) {
            Instance instance = instances.instance(i);
            int vInd = 0;
            if (!m_doNotOperateOnPerClassBasis && (classInd != -1)) {
                vInd = (int) instance.classValue();
            }

            //Get the string attribute to examine
            String stringValue = instance.stringValue(stringAttributeIndex);

            HashMap<Integer, Count> termCountsForClass = termCounts.get(vInd);

            HashMap<String, Integer> termMatches = termTrie.countNonoverlappingMatches(stringValue);
            for (Map.Entry<String, Integer> entry : termMatches.entrySet()) {
                String term = entry.getKey();
                int termIdx = termIndices.get(term);

                int matches = entry.getValue();

                Count count = termCountsForClass.get(termIdx);
                if (count == null) {
                    count = new Count(0);
                    termCountsForClass.put(termIdx, count);
                }

                if (matches > 0) {
                    count.docCount += 1;
                    count.count += matches;
                }
            }
        }

        // Figure out the minimum required word frequency
        int prune[] = new int[values];
        for (int z = 0; z < values; z++) {
            HashMap<Integer, Count> termCountsForClass = termCounts.get(z);

            int array[] = new int[termCountsForClass.size()];
            int pos = 0;
            for (Map.Entry<Integer, Count> entry : termCountsForClass.entrySet()) {
                array[pos] = entry.getValue().count;
                pos++;
            }

            // sort the array
            sortArray(array);

            if (array.length < m_WordsToKeep) {
                // if there aren't enough words, set the threshold to
                // minFreq
                prune[z] = m_minTermFreq;
            } else {
                // otherwise set it to be at least minFreq
                prune[z] = Math.max(m_minTermFreq, array[array.length - m_WordsToKeep]);
            }
        }

        // Add the word vector attributes (eliminating duplicates
        // that occur in multiple classes)
        HashSet<String> selectedTerms = new HashSet<String>();
        for (int z = 0; z < values; z++) {
            HashMap<Integer, Count> termCountsForClass = termCounts.get(z);

            for (Map.Entry<Integer, Count> entry : termCountsForClass.entrySet()) {
                int termIndex = entry.getKey();
                String term = termList.get(termIndex);
                Count count = entry.getValue();
                if (count.count >= prune[z]) {
                    selectedTerms.add(term);
                }
            }
        }

        //Save the selected terms as a list
        this.m_selectedTerms = new ArrayList<String>(selectedTerms);
        this.m_selectedTermsTrie = new Trie(this.m_selectedTerms);
        this.m_NumInstances = instances.size();

        //Construct the selected terms to index map
        this.m_selectedTermIndices = new HashMap<String, Integer>();
        for (int i = 0; i < m_selectedTerms.size(); i++) {
            m_selectedTermIndices.put(m_selectedTerms.get(i), i);
        }

        // Compute document frequencies, organized by selected term index (not original term index)
        int[] docsCounts = new int[m_selectedTerms.size()];
        for (int i = 0; i < m_selectedTerms.size(); i++) {
            String term = m_selectedTerms.get(i);
            int termIndex = termIndices.get(term);
            int docsCount = 0;
            for (int z = 0; z < values; z++) {
                HashMap<Integer, Count> termCountsForClass = termCounts.get(z);

                Count count = termCountsForClass.get(termIndex);
                if (count != null) {
                    docsCount += count.docCount;
                }
            }
            docsCounts[i] = docsCount;
        }
        return docsCounts;
    }

    /**
     * Converts the instance w/o normalization.
     *
     * @param instance the instance to convert
     *
     * @param ArrayList<Instance> the list of instances
     * @return the document length
     */
    private double convertInstancewoDocNorm(Instance instance, ArrayList<Instance> converted) {
        if (stringAttributeIndex < 0) {
            throw new IllegalStateException("String attribute index not valid");
        }

        int numOldValues = instance.numAttributes();
        double[] newValues = new double[numOldValues + m_selectedTerms.size()];

        // Copy all attributes from input to output
        for (int i = 0; i < getInputFormat().numAttributes(); i++) {
            if (getInputFormat().attribute(i).type() != Attribute.STRING) {
                // Add simple nominal and numeric attributes directly
                if (instance.value(i) != 0.0) {
                    newValues[i] = instance.value(i);
                }
            } else {
                if (instance.isMissing(i)) {
                    newValues[i] = Utils.missingValue();
                } else {

                    // If this is a string attribute, we have to first add
                    // this value to the range of possible values, then add
                    // its new internal index.
                    if (outputFormatPeek().attribute(i).numValues() == 0) {
                        // Note that the first string value in a
                        // SparseInstance doesn't get printed.
                        outputFormatPeek().attribute(i).addStringValue("Hack to defeat SparseInstance bug");
                    }
                    int newIndex = outputFormatPeek().attribute(i).addStringValue(instance.stringValue(i));
                    newValues[i] = newIndex;
                }
            }
        }

        String stringValue = instance.stringValue(stringAttributeIndex);
        double docLength = 0;

        HashMap<String, Integer> termMatches = m_selectedTermsTrie.countNonoverlappingMatches(stringValue);
        for (Map.Entry<String, Integer> entry : termMatches.entrySet()) {
            String term = entry.getKey();
            int termIdx = m_selectedTermIndices.get(term);
            double matches = entry.getValue();
            if (!m_OutputCounts && matches > 0) {
                matches = 1;
            }

            if (matches > 0) {
                if (m_TFTransform == true) {
                    matches = Math.log(matches + 1);
                }

                if (m_IDFTransform == true) {
                    matches = matches * Math.log(m_NumInstances / (double) m_DocsCounts[termIdx]);
                }

                newValues[numOldValues + termIdx] = matches;
                docLength += matches * matches;
            }
        }

        Instance result = new SparseInstance(instance.weight(), newValues);
        converted.add(result);

        return Math.sqrt(docLength);
    }

    /**
     * Normalizes given instance to average doc length (only the newly
     * constructed attributes).
     *
     * @param inst   the instance to normalize
     * @param double the document length
     * @throws Exception if avg. doc length not set
     */
    private void normalizeInstance(Instance inst, double docLength) throws Exception {

        if (docLength == 0) {
            return;
        }

        int numOldValues = getInputFormat().numAttributes();

        if (m_AvgDocLength < 0) {
            throw new Exception("Average document length not set.");
        }

        // Normalize document vector
        for (int j = numOldValues; j < inst.numAttributes(); j++) {
            double val = inst.value(j) * m_AvgDocLength / docLength;
            inst.setValue(j, val);
        }
    }

    @Override
    protected Instances process(Instances instances) throws Exception {
        Instances result = new Instances(getOutputFormat(), 0);

        // Convert all instances w/o normalization
        ArrayList<Instance> converted = new ArrayList<Instance>();
        ArrayList<Double> docLengths = new ArrayList<Double>();
        if (!isFirstBatchDone()) {
            m_AvgDocLength = 0;
        }
        for (int i = 0; i < instances.size(); i++) {
            double docLength = convertInstancewoDocNorm(instances.instance(i), converted);

            // Need to compute average document length if necessary
            if (m_filterType != FILTER_NONE) {
                if (!isFirstBatchDone()) {
                    m_AvgDocLength += docLength;
                }
                docLengths.add(docLength);
            }
        }
        if (m_filterType != FILTER_NONE) {

            if (!isFirstBatchDone()) {
                m_AvgDocLength /= instances.size();
            }

            // Perform normalization if necessary.
            if (isFirstBatchDone() || (!isFirstBatchDone() && m_filterType == FILTER_NORMALIZE_ALL)) {
                for (int i = 0; i < converted.size(); i++) {
                    normalizeInstance(converted.get(i), docLengths.get(i));
                }
            }
        }
        // Push all instances into the output queue
        for (int i = 0; i < converted.size(); i++) {
            result.add(converted.get(i));
        }

        return result;
    }

    private static class Trie implements Serializable {

        private static class TrieNode implements Serializable {

            boolean exists = false;
            HashMap<Character, TrieNode> branches = new HashMap<Character, Trie.TrieNode>();
        }

        TrieNode root = new TrieNode();

        public Trie() {
        }

        public Trie(Collection<String> tokens) {
            for (String token : tokens) {
                this.add(token);
            }
        }

        public void add(String term) {
            TrieNode currentNode = this.root;
            for (int i = 0; i < term.length(); i++) {
                char c = term.charAt(i);
                TrieNode next = currentNode.branches.get(c);
                if (next == null) {
                    next = new TrieNode();
                    currentNode.branches.put(c, next);
                }
                currentNode = next;
            }
            currentNode.exists = true;
        }

        public boolean contains(String term) {
            TrieNode currentNode = this.root;
            for (int i = 0; i < term.length(); i++) {
                char c = term.charAt(i);
                TrieNode next = currentNode.branches.get(c);
                if (next == null) {
                    return false;
                }
                currentNode = next;
            }
            return currentNode.exists;
        }

        /**
         * Finds the longest substring in the Trie that matches haystack
         * starting from the given index.
         *
         * @param haystack
         * @return
         */
        public String getLongestMatch(String haystack, int startingFrom) {
            TrieNode currentNode = this.root;
            String longestMatch = null;
            String workingString = "";
            for (int i = startingFrom; i < haystack.length(); i++) {
                char c = haystack.charAt(i);
                TrieNode next = currentNode.branches.get(c);
                if (next == null) {
                    return longestMatch;
                }
                currentNode = next;
                //Build up the next match
                workingString += c;

                if (currentNode.exists) {
                    //This is the best match so far
                    longestMatch = workingString;
                }
            }
            return longestMatch;
        }

        /**
         * Matches all Trie terms against the haystack, not overlapping. Returns
         * the matched words and the number of times they occurred.
         */
        public HashMap<String, Integer> countNonoverlappingMatches(String haystack) {
            HashMap<String, Integer> matchCounts = new HashMap<String, Integer>();

            //Go through the string greedily
            for (int i = 0; i < haystack.length();) {
                String longestMatch = getLongestMatch(haystack, i);
                if (longestMatch != null) {
                    if (!matchCounts.containsKey(longestMatch)) {
                        matchCounts.put(longestMatch, 1);
                    } else {
                        int count = matchCounts.get(longestMatch);
                        matchCounts.put(longestMatch, count + 1);
                    }

                    //Skip ahead by length
                    i += longestMatch.length();
                } else {
                    i++;
                }
            }

            return matchCounts;
        }
    }

    public static List<String> readDictionaryFile(File file) throws FileNotFoundException {
        //Read in the dictionary file
        HashSet<String> termSet = new HashSet<String>();

        Scanner dict = new Scanner(file);
        while (dict.hasNextLine()) {
            String line = dict.nextLine();
            if (!line.startsWith("### ")) {
                line = line.trim();
                if (!line.isEmpty()) {
                    termSet.add(line);
                }
            }
        }

        ArrayList<String> termList = new ArrayList<String>(termSet);
        return termList;
    }

    public static void main(String[] args) {

        //Create a test dataset
        ArrayList<Attribute> attributes = new ArrayList<Attribute>();
        attributes.add(new Attribute("message", (ArrayList<String>) null));
        attributes.add(new Attribute("id"));
        {
            ArrayList<String> classValues = new ArrayList<String>();
            classValues.add("0");
            classValues.add("1");
            attributes.add(new Attribute("class", classValues));
        }

        Instances instances = new Instances("test", attributes, 0);
        instances.setClassIndex(2);

        String[] messages = new String[] { "No emoticons here", "I have a smiley :)",
                "Two smileys and a frownie :) :) :(", "Several emoticons :( :-( :) :-) ;-) 8-) :-/ :-P" };

        for (int i = 0; i < messages.length; i++) {
            Instance instance = new DenseInstance(instances.numAttributes());
            instance.setValue(instances.attribute(0), messages[i]);
            instance.setValue(instances.attribute(1), i);
            instance.setValue(instances.attribute(2), Integer.toString(i % 2));
            instances.add(instance);
        }

        System.out.println("Before filter:");
        for (int i = 0; i < instances.size(); i++) {
            System.out.println(instances.instance(i).toString());
        }

        try {
            String dictionaryName = "emoticons.txt";
            StringToDictionaryVector filter = new StringToDictionaryVector();
            List<String> termList = StringToDictionaryVector.readDictionaryFile(new File(dictionaryName));
            filter.setTermList(termList);
            filter.setMinTermFreq(1);
            filter.setTFTransform(true);
            filter.setIDFTransform(true);
            filter.setNormalizeDocLength(new SelectedTag(FILTER_NORMALIZE_TEST_ONLY, TAGS_FILTER));
            filter.setOutputWordCounts(true);
            filter.setStringAttribute("message");

            filter.setInputFormat(instances);
            Instances trans1 = Filter.useFilter(instances, filter);
            Instances trans2 = Filter.useFilter(instances, filter);

            System.out.println("\nFirst application:");
            System.out.println(trans1.toString());

            System.out.println("\nSecond application:");
            System.out.println(trans2.toString());

        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}