Java tutorial
/* * This file is part of ALOE. * * ALOE is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * ALOE is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * You should have received a copy of the GNU General Public License * along with ALOE. If not, see <http://www.gnu.org/licenses/>. * * Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl) */ package etc.aloe.filters; import java.io.File; import java.io.FileNotFoundException; import java.io.Serializable; import java.util.*; import weka.core.*; import weka.filters.Filter; import weka.filters.SimpleBatchFilter; /** * Based on Weka's StringToWordVector, this derives its initial word list from a * provided lexicon instead of the string attribute. Filtering of terms is * performed similarly to StringToWordVector. * * @author Michael Brooks <mjbrooks@uw.edu> */ public class StringToDictionaryVector extends SimpleBatchFilter { private int stringAttributeIndex = -1; private String stringAttribute; List<String> termList; /** * Contains the number of documents (instances) in the input format from * which the dictionary is created. It is used in IDF transform. */ private int m_NumInstances = -1; /** * whether to operate on a per-class basis. */ private boolean m_doNotOperateOnPerClassBasis = false; /** * The default number of words (per class if there is a class attribute * assigned) to attempt to keep. */ private int m_WordsToKeep = 1000; /** * the minimum (per-class) word frequency. */ private int m_minTermFreq = 1; /** * Contains the number of documents (instances) a particular word appears * in. The counts are stored with the same indexing as m_selectedTerms. */ private int[] m_DocsCounts; /** * A String prefix for the attribute names. */ private String m_Prefix = ""; /** * The set of terms that occurred frequently enough to be included as * attributes. */ private ArrayList<String> m_selectedTerms; /** * The trie containing the selected terms for matching. */ private Trie m_selectedTermsTrie; /** * Maps the terms to indices in m_selectedTerms */ private HashMap<String, Integer> m_selectedTermIndices; /** * True if word frequencies should be transformed into log(1+fi) where fi is * the frequency of word i. */ private boolean m_TFTransform; /** * True if word frequencies should be transformed into * fij*log(numOfDocs/numOfDocsWithWordi). */ private boolean m_IDFTransform; /** * True if output instances should contain word frequency rather than * boolean 0 or 1. */ private boolean m_OutputCounts = false; /** * The normalization to apply. */ protected int m_filterType = FILTER_NONE; /** * normalization: No normalization. */ public static final int FILTER_NONE = 0; /** * normalization: Normalize all data. */ public static final int FILTER_NORMALIZE_ALL = 1; /** * normalization: Normalize test data only. */ public static final int FILTER_NORMALIZE_TEST_ONLY = 2; /** * Specifies whether document's (instance's) word frequencies are to be * normalized. The are normalized to average length of documents specified * as input format. */ public static final Tag[] TAGS_FILTER = { new Tag(FILTER_NONE, "No normalization"), new Tag(FILTER_NORMALIZE_ALL, "Normalize all data"), new Tag(FILTER_NORMALIZE_TEST_ONLY, "Normalize test data only") }; /** * Contains the average length of documents (among the first batch of * instances aka training data). This is used in length normalization of * documents which will be normalized to average document length. */ private double m_AvgDocLength = -1; /** * Gets whether if the word frequencies for a document (instance) should be * normalized or not. * * @return true if word frequencies are to be normalized. */ public SelectedTag getNormalizeDocLength() { return new SelectedTag(m_filterType, TAGS_FILTER); } /** * Sets whether if the word frequencies for a document (instance) should be * normalized or not. * * @param newType the new type. */ public void setNormalizeDocLength(SelectedTag newType) { if (newType.getTags() == TAGS_FILTER) { m_filterType = newType.getSelectedTag().getID(); } } /** * Gets whether if the word frequencies should be transformed into * log(1+fij) where fij is the frequency of word i in document(instance) j. * * @return true if word frequencies are to be transformed. */ public boolean getTFTransform() { return this.m_TFTransform; } /** * Sets whether if the word frequencies should be transformed into * log(1+fij) where fij is the frequency of word i in document(instance) j. * * @param TFTransform true if word frequencies are to be transformed. */ public void setTFTransform(boolean TFTransform) { this.m_TFTransform = TFTransform; } /** * Sets whether if the word frequencies in a document should be transformed * into: <br> fij*log(num of Docs/num of Docs with word i) <br> where fij is * the frequency of word i in document(instance) j. * * @return true if the word frequencies are to be transformed. */ public boolean getIDFTransform() { return this.m_IDFTransform; } /** * Sets whether if the word frequencies in a document should be transformed * into: <br> fij*log(num of Docs/num of Docs with word i) <br> where fij is * the frequency of word i in document(instance) j. * * @param IDFTransform true if the word frequecies are to be transformed */ public void setIDFTransform(boolean IDFTransform) { this.m_IDFTransform = IDFTransform; } /** * Gets whether output instances contain 0 or 1 indicating word presence, or * word counts. * * @return true if word counts should be output. */ public boolean getOutputWordCounts() { return m_OutputCounts; } /** * Sets whether output instances contain 0 or 1 indicating word presence, or * word counts. * * @param outputWordCounts true if word counts should be output. */ public void setOutputWordCounts(boolean outputWordCounts) { m_OutputCounts = outputWordCounts; } /** * Get the attribute name prefix. * * @return The current attribute name prefix. */ public String getAttributeNamePrefix() { return m_Prefix; } /** * Set the attribute name prefix. * * @param newPrefix String to use as the attribute name prefix. */ public void setAttributeNamePrefix(String newPrefix) { m_Prefix = newPrefix; } /** * Get the MinTermFreq value. * * @return the MinTermFreq value. */ public int getMinTermFreq() { return m_minTermFreq; } /** * Set the MinTermFreq value. * * @param newMinTermFreq The new MinTermFreq value. */ public void setMinTermFreq(int newMinTermFreq) { this.m_minTermFreq = newMinTermFreq; } /** * Gets the number of words (per class if there is a class attribute * assigned) to attempt to keep. * * @return the target number of words in the output vector (per class if * assigned). */ public int getWordsToKeep() { return m_WordsToKeep; } /** * Sets the number of words (per class if there is a class attribute * assigned) to attempt to keep. * * @param newWordsToKeep the target number of words in the output vector * (per class if assigned). */ public void setWordsToKeep(int newWordsToKeep) { m_WordsToKeep = newWordsToKeep; } public String getStringAttribute() { return stringAttribute; } public void setStringAttribute(String name) { stringAttribute = name; } public List<String> getTermList() { return termList; } public void setTermList(List<String> termList) { this.termList = termList; } /** * Get the DoNotOperateOnPerClassBasis value. * * @return the DoNotOperateOnPerClassBasis value. */ public boolean getDoNotOperateOnPerClassBasis() { return m_doNotOperateOnPerClassBasis; } /** * Set the DoNotOperateOnPerClassBasis value. * * @param newDoNotOperateOnPerClassBasis The new DoNotOperateOnPerClassBasis * value. */ public void setDoNotOperateOnPerClassBasis(boolean newDoNotOperateOnPerClassBasis) { this.m_doNotOperateOnPerClassBasis = newDoNotOperateOnPerClassBasis; } @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.enableAllAttributes(); result.enableAllClasses(); result.enable(Capabilities.Capability.NO_CLASS); //// filter doesn't need class to be set// return result; } @Override public String globalInfo() { return "Creates a bag of words for a given string attribute. The values in the bag are selected from the provided dictionary."; } @Override protected Instances determineOutputFormat(Instances inputFormat) throws Exception { if (getStringAttribute() == null) { throw new IllegalStateException("String attribute name not set"); } stringAttributeIndex = inputFormat.attribute(getStringAttribute()).index(); inputFormat = getInputFormat(); //This generates m_selectedTerms and m_DocsCounts int[] docsCountsByTermIdx = determineDictionary(inputFormat); //Initialize the output format to be just like the input Instances outputFormat = new Instances(inputFormat, 0); //Set up the map from attr index to document frequency m_DocsCounts = new int[m_selectedTerms.size()]; //And add the new attributes for (int i = 0; i < m_selectedTerms.size(); i++) { int attrIdx = outputFormat.numAttributes(); int docsCount = docsCountsByTermIdx[i]; m_DocsCounts[i] = docsCount; outputFormat.insertAttributeAt(new Attribute(m_Prefix + m_selectedTerms.get(i)), attrIdx); } return outputFormat; } private class Count { public int count = 0; public int docCount = 0; public Count(int count) { this.count = count; } } /** * sorts an array. * * @param array the array to sort */ private static void sortArray(int[] array) { int i, j, h, N = array.length - 1; for (h = 1; h <= N / 9; h = 3 * h + 1) ; for (; h > 0; h /= 3) { for (i = h + 1; i <= N; i++) { int v = array[i]; j = i; while (j > h && array[j - h] > v) { array[j] = array[j - h]; j -= h; } array[j] = v; } } } private int[] determineDictionary(Instances instances) { if (stringAttributeIndex < 0) { throw new IllegalStateException("String attribute index not valid"); } // Operate on a per-class basis if class attribute is set int classInd = instances.classIndex(); int values = 1; if (!m_doNotOperateOnPerClassBasis && (classInd != -1)) { values = instances.attribute(classInd).numValues(); } HashMap<String, Integer> termIndices = new HashMap<String, Integer>(); for (int i = 0; i < termList.size(); i++) { termIndices.put(termList.get(i), i); } //Create the trie for matching terms Trie termTrie = new Trie(termList); //Initialize the dictionary/count map ArrayList<HashMap<Integer, Count>> termCounts = new ArrayList<HashMap<Integer, Count>>(); for (int z = 0; z < values; z++) { termCounts.add(new HashMap<Integer, Count>()); } //Go through all the instances and count the emoticons for (int i = 0; i < instances.numInstances(); i++) { Instance instance = instances.instance(i); int vInd = 0; if (!m_doNotOperateOnPerClassBasis && (classInd != -1)) { vInd = (int) instance.classValue(); } //Get the string attribute to examine String stringValue = instance.stringValue(stringAttributeIndex); HashMap<Integer, Count> termCountsForClass = termCounts.get(vInd); HashMap<String, Integer> termMatches = termTrie.countNonoverlappingMatches(stringValue); for (Map.Entry<String, Integer> entry : termMatches.entrySet()) { String term = entry.getKey(); int termIdx = termIndices.get(term); int matches = entry.getValue(); Count count = termCountsForClass.get(termIdx); if (count == null) { count = new Count(0); termCountsForClass.put(termIdx, count); } if (matches > 0) { count.docCount += 1; count.count += matches; } } } // Figure out the minimum required word frequency int prune[] = new int[values]; for (int z = 0; z < values; z++) { HashMap<Integer, Count> termCountsForClass = termCounts.get(z); int array[] = new int[termCountsForClass.size()]; int pos = 0; for (Map.Entry<Integer, Count> entry : termCountsForClass.entrySet()) { array[pos] = entry.getValue().count; pos++; } // sort the array sortArray(array); if (array.length < m_WordsToKeep) { // if there aren't enough words, set the threshold to // minFreq prune[z] = m_minTermFreq; } else { // otherwise set it to be at least minFreq prune[z] = Math.max(m_minTermFreq, array[array.length - m_WordsToKeep]); } } // Add the word vector attributes (eliminating duplicates // that occur in multiple classes) HashSet<String> selectedTerms = new HashSet<String>(); for (int z = 0; z < values; z++) { HashMap<Integer, Count> termCountsForClass = termCounts.get(z); for (Map.Entry<Integer, Count> entry : termCountsForClass.entrySet()) { int termIndex = entry.getKey(); String term = termList.get(termIndex); Count count = entry.getValue(); if (count.count >= prune[z]) { selectedTerms.add(term); } } } //Save the selected terms as a list this.m_selectedTerms = new ArrayList<String>(selectedTerms); this.m_selectedTermsTrie = new Trie(this.m_selectedTerms); this.m_NumInstances = instances.size(); //Construct the selected terms to index map this.m_selectedTermIndices = new HashMap<String, Integer>(); for (int i = 0; i < m_selectedTerms.size(); i++) { m_selectedTermIndices.put(m_selectedTerms.get(i), i); } // Compute document frequencies, organized by selected term index (not original term index) int[] docsCounts = new int[m_selectedTerms.size()]; for (int i = 0; i < m_selectedTerms.size(); i++) { String term = m_selectedTerms.get(i); int termIndex = termIndices.get(term); int docsCount = 0; for (int z = 0; z < values; z++) { HashMap<Integer, Count> termCountsForClass = termCounts.get(z); Count count = termCountsForClass.get(termIndex); if (count != null) { docsCount += count.docCount; } } docsCounts[i] = docsCount; } return docsCounts; } /** * Converts the instance w/o normalization. * * @param instance the instance to convert * * @param ArrayList<Instance> the list of instances * @return the document length */ private double convertInstancewoDocNorm(Instance instance, ArrayList<Instance> converted) { if (stringAttributeIndex < 0) { throw new IllegalStateException("String attribute index not valid"); } int numOldValues = instance.numAttributes(); double[] newValues = new double[numOldValues + m_selectedTerms.size()]; // Copy all attributes from input to output for (int i = 0; i < getInputFormat().numAttributes(); i++) { if (getInputFormat().attribute(i).type() != Attribute.STRING) { // Add simple nominal and numeric attributes directly if (instance.value(i) != 0.0) { newValues[i] = instance.value(i); } } else { if (instance.isMissing(i)) { newValues[i] = Utils.missingValue(); } else { // If this is a string attribute, we have to first add // this value to the range of possible values, then add // its new internal index. if (outputFormatPeek().attribute(i).numValues() == 0) { // Note that the first string value in a // SparseInstance doesn't get printed. outputFormatPeek().attribute(i).addStringValue("Hack to defeat SparseInstance bug"); } int newIndex = outputFormatPeek().attribute(i).addStringValue(instance.stringValue(i)); newValues[i] = newIndex; } } } String stringValue = instance.stringValue(stringAttributeIndex); double docLength = 0; HashMap<String, Integer> termMatches = m_selectedTermsTrie.countNonoverlappingMatches(stringValue); for (Map.Entry<String, Integer> entry : termMatches.entrySet()) { String term = entry.getKey(); int termIdx = m_selectedTermIndices.get(term); double matches = entry.getValue(); if (!m_OutputCounts && matches > 0) { matches = 1; } if (matches > 0) { if (m_TFTransform == true) { matches = Math.log(matches + 1); } if (m_IDFTransform == true) { matches = matches * Math.log(m_NumInstances / (double) m_DocsCounts[termIdx]); } newValues[numOldValues + termIdx] = matches; docLength += matches * matches; } } Instance result = new SparseInstance(instance.weight(), newValues); converted.add(result); return Math.sqrt(docLength); } /** * Normalizes given instance to average doc length (only the newly * constructed attributes). * * @param inst the instance to normalize * @param double the document length * @throws Exception if avg. doc length not set */ private void normalizeInstance(Instance inst, double docLength) throws Exception { if (docLength == 0) { return; } int numOldValues = getInputFormat().numAttributes(); if (m_AvgDocLength < 0) { throw new Exception("Average document length not set."); } // Normalize document vector for (int j = numOldValues; j < inst.numAttributes(); j++) { double val = inst.value(j) * m_AvgDocLength / docLength; inst.setValue(j, val); } } @Override protected Instances process(Instances instances) throws Exception { Instances result = new Instances(getOutputFormat(), 0); // Convert all instances w/o normalization ArrayList<Instance> converted = new ArrayList<Instance>(); ArrayList<Double> docLengths = new ArrayList<Double>(); if (!isFirstBatchDone()) { m_AvgDocLength = 0; } for (int i = 0; i < instances.size(); i++) { double docLength = convertInstancewoDocNorm(instances.instance(i), converted); // Need to compute average document length if necessary if (m_filterType != FILTER_NONE) { if (!isFirstBatchDone()) { m_AvgDocLength += docLength; } docLengths.add(docLength); } } if (m_filterType != FILTER_NONE) { if (!isFirstBatchDone()) { m_AvgDocLength /= instances.size(); } // Perform normalization if necessary. if (isFirstBatchDone() || (!isFirstBatchDone() && m_filterType == FILTER_NORMALIZE_ALL)) { for (int i = 0; i < converted.size(); i++) { normalizeInstance(converted.get(i), docLengths.get(i)); } } } // Push all instances into the output queue for (int i = 0; i < converted.size(); i++) { result.add(converted.get(i)); } return result; } private static class Trie implements Serializable { private static class TrieNode implements Serializable { boolean exists = false; HashMap<Character, TrieNode> branches = new HashMap<Character, Trie.TrieNode>(); } TrieNode root = new TrieNode(); public Trie() { } public Trie(Collection<String> tokens) { for (String token : tokens) { this.add(token); } } public void add(String term) { TrieNode currentNode = this.root; for (int i = 0; i < term.length(); i++) { char c = term.charAt(i); TrieNode next = currentNode.branches.get(c); if (next == null) { next = new TrieNode(); currentNode.branches.put(c, next); } currentNode = next; } currentNode.exists = true; } public boolean contains(String term) { TrieNode currentNode = this.root; for (int i = 0; i < term.length(); i++) { char c = term.charAt(i); TrieNode next = currentNode.branches.get(c); if (next == null) { return false; } currentNode = next; } return currentNode.exists; } /** * Finds the longest substring in the Trie that matches haystack * starting from the given index. * * @param haystack * @return */ public String getLongestMatch(String haystack, int startingFrom) { TrieNode currentNode = this.root; String longestMatch = null; String workingString = ""; for (int i = startingFrom; i < haystack.length(); i++) { char c = haystack.charAt(i); TrieNode next = currentNode.branches.get(c); if (next == null) { return longestMatch; } currentNode = next; //Build up the next match workingString += c; if (currentNode.exists) { //This is the best match so far longestMatch = workingString; } } return longestMatch; } /** * Matches all Trie terms against the haystack, not overlapping. Returns * the matched words and the number of times they occurred. */ public HashMap<String, Integer> countNonoverlappingMatches(String haystack) { HashMap<String, Integer> matchCounts = new HashMap<String, Integer>(); //Go through the string greedily for (int i = 0; i < haystack.length();) { String longestMatch = getLongestMatch(haystack, i); if (longestMatch != null) { if (!matchCounts.containsKey(longestMatch)) { matchCounts.put(longestMatch, 1); } else { int count = matchCounts.get(longestMatch); matchCounts.put(longestMatch, count + 1); } //Skip ahead by length i += longestMatch.length(); } else { i++; } } return matchCounts; } } public static List<String> readDictionaryFile(File file) throws FileNotFoundException { //Read in the dictionary file HashSet<String> termSet = new HashSet<String>(); Scanner dict = new Scanner(file); while (dict.hasNextLine()) { String line = dict.nextLine(); if (!line.startsWith("### ")) { line = line.trim(); if (!line.isEmpty()) { termSet.add(line); } } } ArrayList<String> termList = new ArrayList<String>(termSet); return termList; } public static void main(String[] args) { //Create a test dataset ArrayList<Attribute> attributes = new ArrayList<Attribute>(); attributes.add(new Attribute("message", (ArrayList<String>) null)); attributes.add(new Attribute("id")); { ArrayList<String> classValues = new ArrayList<String>(); classValues.add("0"); classValues.add("1"); attributes.add(new Attribute("class", classValues)); } Instances instances = new Instances("test", attributes, 0); instances.setClassIndex(2); String[] messages = new String[] { "No emoticons here", "I have a smiley :)", "Two smileys and a frownie :) :) :(", "Several emoticons :( :-( :) :-) ;-) 8-) :-/ :-P" }; for (int i = 0; i < messages.length; i++) { Instance instance = new DenseInstance(instances.numAttributes()); instance.setValue(instances.attribute(0), messages[i]); instance.setValue(instances.attribute(1), i); instance.setValue(instances.attribute(2), Integer.toString(i % 2)); instances.add(instance); } System.out.println("Before filter:"); for (int i = 0; i < instances.size(); i++) { System.out.println(instances.instance(i).toString()); } try { String dictionaryName = "emoticons.txt"; StringToDictionaryVector filter = new StringToDictionaryVector(); List<String> termList = StringToDictionaryVector.readDictionaryFile(new File(dictionaryName)); filter.setTermList(termList); filter.setMinTermFreq(1); filter.setTFTransform(true); filter.setIDFTransform(true); filter.setNormalizeDocLength(new SelectedTag(FILTER_NORMALIZE_TEST_ONLY, TAGS_FILTER)); filter.setOutputWordCounts(true); filter.setStringAttribute("message"); filter.setInputFormat(instances); Instances trans1 = Filter.useFilter(instances, filter); Instances trans2 = Filter.useFilter(instances, filter); System.out.println("\nFirst application:"); System.out.println(trans1.toString()); System.out.println("\nSecond application:"); System.out.println(trans2.toString()); } catch (Exception e) { e.printStackTrace(); } } }