Java tutorial
/* * Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ------------------- * To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit * http://www.manning.com/ingersoll */ package com.tamingtext.classifier.mlt; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.FileReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import org.apache.commons.cli2.CommandLine; import org.apache.commons.cli2.Group; import org.apache.commons.cli2.Option; import org.apache.commons.cli2.OptionException; import org.apache.commons.cli2.builder.ArgumentBuilder; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.commons.cli2.builder.GroupBuilder; import org.apache.commons.cli2.commandline.Parser; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.en.EnglishAnalyzer; import org.apache.lucene.analysis.shingle.ShingleAnalyzerWrapper; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexWriterConfig.OpenMode; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.util.Version; import org.apache.mahout.common.CommandLineUtil; import org.apache.mahout.common.commandline.DefaultOptionCreator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.tamingtext.util.FileUtil; public class TrainMoreLikeThis { private static final Logger log = LoggerFactory.getLogger(TrainMoreLikeThis.class); public static final String CATEGORY_KEY = "categories"; public static enum MatchMode { KNN, TFIDF } private IndexWriter writer; private int nGramSize = 1; public TrainMoreLikeThis() { } public void setNGramSize(int nGramSize) { this.nGramSize = nGramSize; } public void train(String source, String destination, MatchMode mode) throws Exception { File[] inputFiles = FileUtil.buildFileList(new File(source)); if (inputFiles.length < 2) { throw new IllegalStateException("There must be more than one training file in " + source); } openIndexWriter(destination); switch (mode) { case TFIDF: this.buildTfidfIndex(inputFiles); break; case KNN: this.buildKnnIndex(inputFiles); break; default: throw new IllegalStateException("Unknown match mode: " + mode.toString()); } closeIndexWriter(); } /** builda a lucene index suidable for knn based classification. Each category's content is indexed into * separate documents in the index, and the category that has the haghest count in the tip N hits is * is the category that is assigned. * @param inputFiles * @param writer * @throws Exception */ protected void buildKnnIndex(File[] inputFiles) throws Exception { int lineCount = 0; int fileCount = 0; String line = null; String category = null; Set<String> categories = new HashSet<String>(); long start = System.currentTimeMillis(); // reuse these fields //<start id="lucene.examples.fields"/> Field id = new Field("id", "", Field.Store.YES, Field.Index.NOT_ANALYZED, Field.TermVector.NO); Field categoryField = new Field("category", "", Field.Store.YES, Field.Index.NOT_ANALYZED, Field.TermVector.NO); Field contentField = new Field("content", "", Field.Store.NO, Field.Index.ANALYZED, Field.TermVector.WITH_POSITIONS_OFFSETS); //<end id="lucene.examples.fields"/> for (File ff : inputFiles) { fileCount++; lineCount = 0; category = null; BufferedReader in = new BufferedReader(new FileReader(ff)); //<start id="lucene.examples.knn.train"/> while ((line = in.readLine()) != null) { String[] parts = line.split("\t"); //<co id="luc.knn.content"/> if (parts.length != 2) continue; category = parts[0]; categories.add(category); Document d = new Document(); //<co id="luc.knn.document"/> id.setValue(category + "-" + lineCount++); categoryField.setValue(category); contentField.setValue(parts[1]); d.add(id); d.add(categoryField); d.add(contentField); writer.addDocument(d); //<co id="luc.knn.index"/> } /*<calloutlist> <callout arearefs="luc.knn.content">Collect Content</callout> <callout arearefs="luc.knn.document">Build Document</callout> <callout arearefs="luc.knn.index">Index Document</callout> </calloutlist>*/ //<end id="lucene.examples.knn.train"/> in.close(); log.info("Knn: Added document for category " + category + " with " + lineCount + " lines"); } writer.commit(generateUserData(categories)); log.info("Knn: Added " + fileCount + " categories in " + (System.currentTimeMillis() - start) + " msec."); } /** builds a lucene index suitable for tfidf based classification. Each categories content is indexed into * a single document in the index, and the best match for a MoreLikeThis query is the category that * is assigned. * @param inputFiles * @param writer * @throws Exception */ protected void buildTfidfIndex(File[] inputFiles) throws Exception { int lineCount = 0; int fileCount = 0; String line = null; Set<String> categories = new HashSet<String>(); long start = System.currentTimeMillis(); // reuse these fields Field id = new Field("id", "", Field.Store.YES, Field.Index.NOT_ANALYZED, Field.TermVector.NO); Field categoryField = new Field("category", "", Field.Store.YES, Field.Index.NOT_ANALYZED, Field.TermVector.NO); Field contentField = new Field("content", "", Field.Store.NO, Field.Index.ANALYZED, Field.TermVector.WITH_POSITIONS_OFFSETS); // read data from input files. for (File ff : inputFiles) { fileCount++; lineCount = 0; // read all training documents into a string BufferedReader in = new BufferedReader(new InputStreamReader(new FileInputStream(ff), "UTF-8")); //<start id="lucene.examples.tfidf.train"/> StringBuilder content = new StringBuilder(); String category = null; while ((line = in.readLine()) != null) { String[] parts = line.split("\t"); //<co id="luc.tf.content"/> if (parts.length != 2) continue; category = parts[0]; categories.add(category); content.append(parts[1]).append(" "); lineCount++; } in.close(); Document d = new Document(); //<co id="luc.tf.document"/> id.setValue(category + "-" + lineCount); categoryField.setValue(category); contentField.setValue(content.toString()); d.add(id); d.add(categoryField); d.add(contentField); writer.addDocument(d); //<co id="luc.tf.index"/> /*<calloutlist> <callout arearefs="luc.tf.content">Collect Content</callout> <callout arearefs="luc.tf.document">Build Document</callout> <callout arearefs="luc.tf.index">Index Document</callout> </calloutlist>*/ //<end id="lucene.examples.tfidf.train"/> log.info("TfIdf: Added document for category " + category + " with " + lineCount + " lines"); } writer.commit(generateUserData(categories)); log.info("TfIdf: Added " + fileCount + " categories in " + (System.currentTimeMillis() - start) + " msec."); } protected void openIndexWriter(String pathname) throws IOException { //<start id="lucene.examples.index.setup"/> Directory directory //<co id="luc.index.dir"/> = FSDirectory.open(new File(pathname)); Analyzer analyzer //<co id="luc.index.analyzer"/> = new EnglishAnalyzer(Version.LUCENE_36); if (nGramSize > 1) { //<co id="luc.index.shingle"/> ShingleAnalyzerWrapper sw = new ShingleAnalyzerWrapper(analyzer, nGramSize, // min shingle size nGramSize, // max shingle size "-", // token separator true, // output unigrams true); // output unigrams if no shingles analyzer = sw; } IndexWriterConfig config //<co id="luc.index.create"/> = new IndexWriterConfig(Version.LUCENE_36, analyzer); config.setOpenMode(OpenMode.CREATE); IndexWriter writer = new IndexWriter(directory, config); /* <calloutlist> <callout arearefs="luc.index.dir">Create Index Directory</callout> <callout arearefs="luc.index.analyzer">Setup Analyzer</callout> <callout arearefs="luc.index.shingle">Setup Shingle Filter</callout> <callout arearefs="luc.index.create">Create <classname>IndexWriter</classname></callout> </calloutlist> */ //<end id="lucene.examples.index.setup"/> this.writer = writer; } protected void closeIndexWriter() throws IOException { log.info("Starting optimize"); // optimize and close the index. writer.optimize(); writer.close(); writer = null; log.info("Optimize complete, index closed"); } protected static Map<String, String> generateUserData(Collection<String> categories) { StringBuilder b = new StringBuilder(); for (String cat : categories) { b.append(cat).append('|'); } b.setLength(b.length() - 1); Map<String, String> userData = new HashMap<String, String>(); userData.put(CATEGORY_KEY, b.toString()); return userData; } public static void main(String[] args) throws Exception { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option helpOpt = DefaultOptionCreator.helpOption(); Option inputDirOpt = obuilder.withLongName("input").withRequired(true) .withArgument(abuilder.withName("input").withMinimum(1).withMaximum(1).create()) .withDescription("The input directory, containing properly formatted files: " + "One doc per line, first entry on the line is the label, rest is the evidence") .withShortName("i").create(); Option outputOpt = obuilder.withLongName("output").withRequired(true) .withArgument(abuilder.withName("output").withMinimum(1).withMaximum(1).create()) .withDescription("The output directory").withShortName("o").create(); Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(false) .withArgument(abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()) .withDescription("Size of the n-gram. Default Value: 1 ").withShortName("ng").create(); Option typeOpt = obuilder.withLongName("classifierType").withRequired(false) .withArgument(abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create()) .withDescription("Type of classifier: knn|tfidf.").withShortName("type").create(); Group group = gbuilder.withName("Options").withOption(gramSizeOpt).withOption(helpOpt) .withOption(inputDirOpt).withOption(outputOpt).withOption(typeOpt).create(); try { Parser parser = new Parser(); parser.setGroup(group); parser.setHelpOption(helpOpt); CommandLine cmdLine = parser.parse(args); if (cmdLine.hasOption(helpOpt)) { CommandLineUtil.printHelp(group); return; } String classifierType = (String) cmdLine.getValue(typeOpt); int gramSize = 1; if (cmdLine.hasOption(gramSizeOpt)) { gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt)); } String inputPath = (String) cmdLine.getValue(inputDirOpt); String outputPath = (String) cmdLine.getValue(outputOpt); TrainMoreLikeThis trainer = new TrainMoreLikeThis(); MatchMode mode; if ("knn".equalsIgnoreCase(classifierType)) { mode = MatchMode.KNN; } else if ("tfidf".equalsIgnoreCase(classifierType)) { mode = MatchMode.TFIDF; } else { throw new IllegalArgumentException("Unkown classifierType: " + classifierType); } if (gramSize > 1) trainer.setNGramSize(gramSize); trainer.train(inputPath, outputPath, mode); } catch (OptionException e) { log.error("Error while parsing options", e); } } }