com.tamingtext.classifier.mlt.MoreLikeThisCategorizer.java Source code

Java tutorial

Introduction

Here is the source code for com.tamingtext.classifier.mlt.MoreLikeThisCategorizer.java

Source

/*
 * Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 * -------------------
 * To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit
 * http://www.manning.com/ingersoll
 */

package com.tamingtext.classifier.mlt;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.shingle.ShingleAnalyzerWrapper;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Fieldable;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermEnum;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.similar.MoreLikeThis;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.tamingtext.classifier.mlt.TrainMoreLikeThis.MatchMode;

public class MoreLikeThisCategorizer {

    private static final Logger log = LoggerFactory.getLogger(MoreLikeThisCategorizer.class);

    MatchMode matchMode = MatchMode.TFIDF;
    IndexReader indexReader;
    IndexSearcher indexSearcher;
    MoreLikeThis moreLikeThis;
    String categoryFieldName;
    final Set<String> categories = new HashSet<String>();
    boolean captureCategories = false;
    int maxResults = 10;

    public MoreLikeThisCategorizer(IndexReader indexReader, String categoryFieldName) throws IOException {
        this.indexReader = indexReader;
        this.indexSearcher = new IndexSearcher(indexReader);
        this.moreLikeThis = new MoreLikeThis(indexReader);
        this.categoryFieldName = categoryFieldName;
        loadCategoriesFromIndex();
    }

    /** populate the list of categories by reading the values embedded in the index userData, falls back
     *  to scanCategories if the data is not present 
     * @throws IOException
     */
    protected void loadCategoriesFromIndex() throws IOException {
        Map<String, String> userData = indexReader.getCommitUserData();
        String categoryString = userData.get(TrainMoreLikeThis.CATEGORY_KEY);
        if (categoryString == null) {
            scanCategories();
            return;

        }

        String[] parts = categoryString.split("\\|");

        if (parts.length < 1) {
            scanCategories();
            return;
        }

        categories.addAll(Arrays.asList(parts));
        log.info("Loaded " + categories.size() + " categories from index");
    }

    /** populate the list of categories by reading the values from the categoryField in the index */
    protected void scanCategories() throws IOException {
        TermEnum te = indexReader.terms(new Term(categoryFieldName));
        final Set<String> c = categories;

        do {
            if (!te.term().field().equals(categoryFieldName))
                break;
            c.add(te.term().text());
        } while (te.next());

        log.info("Scanned " + c.size() + " categories from index");
    }

    public void setMaxResults(int maxResults) {
        this.maxResults = maxResults;
    }

    public Collection<String> getCategories() {
        return Collections.unmodifiableSet(categories);
    }

    public MatchMode getMatchMode() {
        return matchMode;
    }

    public void setMatchMode(MatchMode matchMode) {
        this.matchMode = matchMode;
    }

    public void setFieldNames(String[] fieldNames) {
        moreLikeThis.setFieldNames(fieldNames);
    }

    public void setAnalyzer(Analyzer analyzer) {
        moreLikeThis.setAnalyzer(analyzer);
    }

    public void setNgramSize(int size) {
        if (size <= 1)
            return;

        Analyzer a = moreLikeThis.getAnalyzer();
        ShingleAnalyzerWrapper sw;
        if (a instanceof ShingleAnalyzerWrapper) {
            sw = (ShingleAnalyzerWrapper) a;
        } else {
            sw = new ShingleAnalyzerWrapper(a);
            moreLikeThis.setAnalyzer(sw);
        }

        sw.setMaxShingleSize(size);
        sw.setMinShingleSize(size);
    }

    public CategoryHits[] categorize(Reader reader) throws IOException {
        Query query = moreLikeThis.like(reader);

        HashMap<String, CategoryHits> categoryHash = new HashMap<String, CategoryHits>(25);

        for (ScoreDoc sd : indexSearcher.search(query, maxResults).scoreDocs) {
            String cat = getDocClass(sd.doc);
            if (cat == null)
                continue;
            CategoryHits ch = categoryHash.get(cat);
            if (ch == null) {
                ch = new CategoryHits();
                ch.setLabel(cat);
                categoryHash.put(cat, ch);
            }

            ch.incrementScore(sd.score);
        }

        SortedSet<CategoryHits> sortedCats = new TreeSet<CategoryHits>(CategoryHits.byScoreComparator());
        sortedCats.addAll(categoryHash.values());
        return sortedCats.toArray(new CategoryHits[0]);
    }

    protected String getDocClass(int doc) throws IOException {
        Document d = indexReader.document(doc);
        Fieldable f = d.getFieldable(categoryFieldName);
        if (f == null)
            return null;
        if (!f.isStored())
            throw new IllegalArgumentException("Field " + f.name() + " is not stored.");
        return f.stringValue();
    }

    public static void main(String[] args) throws Exception {
        DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
        ArgumentBuilder abuilder = new ArgumentBuilder();
        GroupBuilder gbuilder = new GroupBuilder();

        Option helpOpt = DefaultOptionCreator.helpOption();

        Option inputDirOpt = obuilder.withLongName("input").withRequired(true)
                .withArgument(abuilder.withName("input").withMinimum(1).withMaximum(1).create())
                .withDescription("The input file to classify").withShortName("i").create();

        Option modelOpt = obuilder.withLongName("model").withRequired(true)
                .withArgument(abuilder.withName("index").withMinimum(1).withMaximum(1).create())
                .withDescription("The directory containing the index model").withShortName("m").create();

        Option categoryFieldOpt = obuilder.withLongName("categoryField").withRequired(true)
                .withArgument(abuilder.withName("index").withMinimum(1).withMaximum(1).create())
                .withDescription("Name of the field containing category information").withShortName("catf")
                .create();

        Option contentFieldOpt = obuilder.withLongName("contentField").withRequired(true)
                .withArgument(abuilder.withName("index").withMinimum(1).withMaximum(1).create())
                .withDescription("Name of the field containing content information").withShortName("contf")
                .create();

        Option maxResultsOpt = obuilder.withLongName("maxResults").withRequired(false)
                .withArgument(abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create())
                .withDescription("Number of results to retrive, default: 10 ").withShortName("r").create();

        Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(false)
                .withArgument(abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create())
                .withDescription("Size of the n-gram. Default Value: 1 ").withShortName("ng").create();

        Option typeOpt = obuilder.withLongName("classifierType").withRequired(false)
                .withArgument(abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create())
                .withDescription("Type of classifier: knn|tfidf. Default: bayes").withShortName("type").create();

        Group group = gbuilder.withName("Options").withOption(gramSizeOpt).withOption(helpOpt)
                .withOption(inputDirOpt).withOption(modelOpt).withOption(typeOpt).withOption(contentFieldOpt)
                .withOption(categoryFieldOpt).withOption(maxResultsOpt).create();

        try {
            Parser parser = new Parser();

            parser.setGroup(group);
            parser.setHelpOption(helpOpt);
            CommandLine cmdLine = parser.parse(args);
            if (cmdLine.hasOption(helpOpt)) {
                CommandLineUtil.printHelp(group);
                return;
            }

            String classifierType = (String) cmdLine.getValue(typeOpt);

            if (cmdLine.hasOption(gramSizeOpt)) {

            }

            int gramSize = 1;
            if (cmdLine.hasOption(gramSizeOpt)) {
                gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt));
            }

            int maxResults = 10;
            if (cmdLine.hasOption(maxResultsOpt)) {
                maxResults = Integer.parseInt((String) cmdLine.getValue(maxResultsOpt));
            }

            String inputPath = (String) cmdLine.getValue(inputDirOpt);
            String modelPath = (String) cmdLine.getValue(modelOpt);
            String categoryField = (String) cmdLine.getValue(categoryFieldOpt);
            String contentField = (String) cmdLine.getValue(contentFieldOpt);

            MatchMode mode;

            if ("knn".equalsIgnoreCase(classifierType)) {
                mode = MatchMode.KNN;
            } else if ("tfidf".equalsIgnoreCase(classifierType)) {
                mode = MatchMode.TFIDF;
            } else {
                throw new IllegalArgumentException("Unkown classifierType: " + classifierType);
            }

            Reader reader = new FileReader(inputPath);
            Directory directory = FSDirectory.open(new File(modelPath));
            IndexReader indexReader = IndexReader.open(directory);
            MoreLikeThisCategorizer categorizer = new MoreLikeThisCategorizer(indexReader, categoryField);
            categorizer.setMatchMode(mode);
            categorizer.setFieldNames(new String[] { contentField });
            categorizer.setMaxResults(maxResults);

            if (gramSize > 1)
                categorizer.setNgramSize(gramSize);

            CategoryHits[] categories = categorizer.categorize(reader);
            for (CategoryHits c : categories) {
                System.out.println(c.getLabel() + "\t" + c.getHits() + "\t" + c.getScore());
            }

        } catch (OptionException e) {
            log.error("Error while parsing options", e);
        }
    }
}