com.tamingtext.classifier.bayes.ExtractTrainingData.java Source code

Java tutorial

Introduction

Here is the source code for com.tamingtext.classifier.bayes.ExtractTrainingData.java

Source

/*
 * Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 * -------------------
 * To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit
 * http://www.manning.com/ingersoll
 */

package com.tamingtext.classifier.bayes;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.TermFreqVector;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** A utility to extract training data from a Lucene index using document term vectors to recreate the list of terms
 *  found in each document. Writes output in Mahout Bayes classifier input format */
public class ExtractTrainingData {

    private static final Logger log = LoggerFactory.getLogger(ExtractTrainingData.class);

    static final Map<String, PrintWriter> trainingWriters = new HashMap<String, PrintWriter>();

    public static void main(String[] args) {

        log.info("Command-line arguments: " + Arrays.toString(args));

        DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
        ArgumentBuilder abuilder = new ArgumentBuilder();
        GroupBuilder gbuilder = new GroupBuilder();

        Option inputOpt = obuilder.withLongName("dir").withRequired(true)
                .withArgument(abuilder.withName("dir").withMinimum(1).withMaximum(1).create())
                .withDescription("Lucene index directory containing input data").withShortName("d").create();

        Option categoryOpt = obuilder.withLongName("categories").withRequired(true)
                .withArgument(abuilder.withName("file").withMinimum(1).withMaximum(1).create())
                .withDescription("File containing a list of categories").withShortName("c").create();

        Option outputOpt = obuilder.withLongName("output").withRequired(false)
                .withArgument(abuilder.withName("output").withMinimum(1).withMaximum(1).create())
                .withDescription("Output directory").withShortName("o").create();

        Option categoryFieldsOpt = obuilder.withLongName("category-fields").withRequired(true)
                .withArgument(abuilder.withName("fields").withMinimum(1).withMaximum(1).create())
                .withDescription("Fields to match categories against (comma-delimited)").withShortName("cf")
                .create();

        Option textFieldsOpt = obuilder.withLongName("text-fields").withRequired(true)
                .withArgument(abuilder.withName("fields").withMinimum(1).withMaximum(1).create())
                .withDescription("Fields from which to extract training text (comma-delimited)").withShortName("tf")
                .create();

        Option useTermVectorsOpt = obuilder.withLongName("use-term-vectors").withDescription(
                "Extract term vectors containing preprocessed data " + "instead of unprocessed, stored text values")
                .withShortName("tv").create();

        Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
                .create();

        Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(categoryOpt)
                .withOption(outputOpt).withOption(categoryFieldsOpt).withOption(textFieldsOpt)
                .withOption(useTermVectorsOpt).create();

        try {
            Parser parser = new Parser();
            parser.setGroup(group);
            CommandLine cmdLine = parser.parse(args);

            if (cmdLine.hasOption(helpOpt)) {
                CommandLineUtil.printHelp(group);
                return;
            }

            File inputDir = new File(cmdLine.getValue(inputOpt).toString());

            if (!inputDir.isDirectory()) {
                throw new IllegalArgumentException(inputDir + " does not exist or is not a directory");
            }

            File categoryFile = new File(cmdLine.getValue(categoryOpt).toString());

            if (!categoryFile.isFile()) {
                throw new IllegalArgumentException(categoryFile + " does not exist or is not a directory");
            }

            File outputDir = new File(cmdLine.getValue(outputOpt).toString());

            outputDir.mkdirs();

            if (!outputDir.isDirectory()) {
                throw new IllegalArgumentException(outputDir + " is not a directory or could not be created");
            }

            Collection<String> categoryFields = stringToList(cmdLine.getValue(categoryFieldsOpt).toString());

            if (categoryFields.size() < 1) {
                throw new IllegalArgumentException("At least one category field must be spcified.");
            }

            Collection<String> textFields = stringToList(cmdLine.getValue(textFieldsOpt).toString());

            if (categoryFields.size() < 1) {
                throw new IllegalArgumentException("At least one text field must be spcified.");
            }

            boolean useTermVectors = cmdLine.hasOption(useTermVectorsOpt);

            extractTraininingData(inputDir, categoryFile, categoryFields, textFields, outputDir, useTermVectors);

        } catch (OptionException e) {
            log.error("Exception", e);
            CommandLineUtil.printHelp(group);
        } catch (IOException e) {
            log.error("IOException", e);
        } finally {
            closeWriters();
        }
    }

    /**
     * Extract training data from a lucene index. 
     * <p>
     * Iterates over documents in the lucene index, the values in the categoryFields are inspected and if found to 
     * contain any of the strings found in the category file, a training data item will be emitted, assigned to the
     * matching category and containing the terms found in the fields listed in textFields. Output is written to
     * the output directory with one file per category.
     * <p>
     * The category file contains one line per category, each line contains a number of whitespace delimited strings. 
     * The first string on each line is the category name, while subsequent strings will be used to identify documents
     * that belong in that category.
     * <p>
     * 'Technology Computers Macintosh' will cause documents that contain either 'Technology', 'Computers' or 'Machintosh'
     * in one of their categoryFields to be assigned to the 'Technology' category.
     * 
     * 
     * @param indexDir 
     *   directory of lucene index to extract from
     *   
     * @param maxDocs
     *   the maximum number of documents to process.
     *   
     * @param categoryFile
     *   file containing category strings to extract
     *   
     * @param categoryFields
     *   list of fields to match against category data
     *   
     * @param textFields
     *   list of fields containing terms to extract
     *   
     * @param outputDir
     *   directory to write output to
     *   
     * @throws IOException
     */
    public static void extractTraininingData(File indexDir, File categoryFile, Collection<String> categoryFields,
            Collection<String> textFields, File outputDir, boolean useTermVectors) throws IOException {

        log.info("Index dir: " + indexDir);
        log.info("Category file: " + categoryFile);
        log.info("Output dir: " + outputDir);
        log.info("Category fields: " + categoryFields.toString());
        log.info("Text fields: " + textFields.toString());
        log.info("Use Term Vectors?: " + useTermVectors);
        OpenObjectIntHashMap<String> categoryCounts = new OpenObjectIntHashMap<String>();
        Map<String, List<String>> categories = readCategoryFile(categoryFile);

        Directory dir = FSDirectory.open(indexDir);
        IndexReader reader = IndexReader.open(dir, true);
        int max = reader.maxDoc();

        StringBuilder buf = new StringBuilder();

        for (int i = 0; i < max; i++) {
            if (!reader.isDeleted(i)) {
                Document d = reader.document(i);
                String category = null;

                // determine whether any of the fields in this document contain a 
                // category in the category list
                fields: for (String field : categoryFields) {
                    for (Field f : d.getFields(field)) {
                        if (f.isStored() && !f.isBinary()) {
                            String fieldValue = f.stringValue().toLowerCase();
                            for (String cat : categories.keySet()) {
                                List<String> cats = categories.get(cat);
                                for (String c : cats) {
                                    if (fieldValue.contains(c)) {
                                        category = cat;
                                        break fields;
                                    }
                                }
                            }
                        }
                    }
                }

                if (category == null)
                    continue;

                // append the terms from each of the textFields to the training data for this document.
                buf.setLength(0);
                for (String field : textFields) {
                    if (useTermVectors) {
                        appendVectorTerms(buf, reader.getTermFreqVector(i, field));
                    } else {
                        appendFieldText(buf, d.getField(field));
                    }
                }
                getWriterForCategory(outputDir, category).printf("%s\t%s\n", category, buf.toString());
                categoryCounts.adjustOrPutValue(category, 1, 1);
            }
        }

        if (log.isInfoEnabled()) {
            StringBuilder b = new StringBuilder();
            b.append("\nCatagory document counts:\n");
            LinkedList<String> keyList = new LinkedList<String>();
            categoryCounts.keysSortedByValue(keyList);
            String key;
            while (!keyList.isEmpty()) {
                key = keyList.removeLast();
                b.append(categoryCounts.get(key)).append('\t').append(key).append('\n');
            }
            log.info(b.toString());
        }
    }

    /** Read the category file from disk, see {@link #extractTraininingData(File, File, Collection, Collection, File)}
     *  for a description of the format.
     * 
     * @param categoryFile
     * @return
     * @throws IOException
     */
    public static Map<String, List<String>> readCategoryFile(File categoryFile) throws IOException {
        Map<String, List<String>> categoryMap = new HashMap<String, List<String>>();
        BufferedReader rin = new BufferedReader(new InputStreamReader(new FileInputStream(categoryFile), "UTF-8"));
        String line;
        while ((line = rin.readLine()) != null) {
            String[] parts = line.trim().toLowerCase().split("\\s+");
            if (parts.length > 0) {
                String key = parts[0];
                for (String e : parts) {
                    List<String> entries = categoryMap.get(key);
                    if (entries == null) {
                        entries = new LinkedList<String>();
                        categoryMap.put(key, entries);
                    }
                    entries.add(e);
                }
            }
        }
        rin.close();
        return categoryMap;
    }

    /** Obtain a writer for the training data assigned to the the specified category.
     * <p>
     * Maintains an internal hash of writers used for a category which must be closed by {@link #closeWriters()}.
     * <p>
     * 
     * @param outputDir
     * @param category
     * @return
     * @throws IOException
     */
    protected static PrintWriter getWriterForCategory(File outputDir, String category) throws IOException {
        PrintWriter out = trainingWriters.get(category);
        if (out == null) {
            out = new PrintWriter(new OutputStreamWriter(new FileOutputStream(new File(outputDir, category))));
            trainingWriters.put(category, out);
        }
        return out;
    }

    /** Close writers opened by {@link #getWriterForCategory(File, String)} */
    protected static void closeWriters() {
        for (PrintWriter p : trainingWriters.values()) {
            p.close();
        }
    }

    /** Append the contents of the specified termVector to a buffer containing a list of terms
     * 
     * @param buf
     * @param tv
     */
    protected static void appendVectorTerms(StringBuilder buf, TermFreqVector tv) {
        if (tv == null)
            return;

        String[] terms = tv.getTerms();
        int[] frequencies = tv.getTermFrequencies();

        for (int j = 0; j < terms.length; j++) {
            int freq = frequencies[j];
            String term = terms[j];
            for (int k = 0; k < freq; k++) {
                buf.append(term).append(' ');
            }
        }
    }

    /** Append the contents of the specified field to buffer containing text,
     *  normalizing whitespace in the process.
     *  
     * @param buf
     * @param f
     */
    protected static void appendFieldText(StringBuilder buf, Field f) {
        if (f == null)
            return;
        if (f.isBinary())
            return;
        if (!f.isStored())
            return;
        if (buf.length() > 0)
            buf.append(' ');

        String s = f.stringValue();
        s = s.replaceAll("\\s+", " "); // normalize whitespace.
        buf.append(s);
    }

    /** Split a comma-delimited set of strings into a list
     * 
     * @param input
     * @return
     */
    private static Collection<String> stringToList(String input) {
        if (input == null || input.equals(""))
            return Collections.emptyList();
        String[] parts = input.split("\\s*,\\s*");
        return Arrays.asList(parts);
    }

}