Java tutorial
/* * Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ------------------- * To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit * http://www.manning.com/ingersoll */ package com.tamingtext.tagrecommender; import java.io.BufferedReader; import java.io.File; import java.io.FileOutputStream; import java.io.FileReader; import java.io.IOException; import java.io.PrintStream; import java.net.MalformedURLException; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.HashSet; import org.apache.commons.cli2.CommandLine; import org.apache.commons.cli2.Group; import org.apache.commons.cli2.Option; import org.apache.commons.cli2.OptionException; import org.apache.commons.cli2.builder.ArgumentBuilder; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.commons.cli2.builder.GroupBuilder; import org.apache.commons.cli2.commandline.Parser; import org.apache.mahout.common.CommandLineUtil; import org.apache.mahout.common.commandline.DefaultOptionCreator; import org.apache.mahout.math.function.ObjectIntProcedure; import org.apache.mahout.math.map.OpenObjectIntHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.tamingtext.tagrecommender.TagRecommenderClient.ScoreTag; public class TestStackOverflowTagger { private static final Logger log = LoggerFactory.getLogger(TestStackOverflowTagger.class); private final NumberFormat nf = new DecimalFormat("##.##"); private TagRecommenderClient client; private File inputFile; private File countFile; private File outputFile; private String solrUrl; private int maxTags = 5; public static void main(String[] args) { TestStackOverflowTagger t = new TestStackOverflowTagger(); if (t.parseArgs(args)) { t.execute(); } } public boolean parseArgs(String[] args) { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option helpOpt = DefaultOptionCreator.helpOption(); Option inputFileOpt = obuilder.withLongName("inputFile").withRequired(true) .withArgument(abuilder.withName("inputFile").withMinimum(1).withMaximum(1).create()) .withDescription("The input file").withShortName("i").create(); Option countFileOpt = obuilder.withLongName("countFile").withRequired(true) .withArgument(abuilder.withName("countFile").withMinimum(1).withMaximum(1).create()) .withDescription("The tag count file").withShortName("c").create(); Option outputFileOpt = obuilder.withLongName("outputFile").withRequired(true) .withArgument(abuilder.withName("outputFile").withMinimum(1).withMaximum(1).create()) .withDescription("The output file").withShortName("c").create(); Option solrUrlOpt = obuilder.withLongName("solrUrl").withRequired(true) .withArgument(abuilder.withName("solrUrl").withMinimum(1).withMaximum(1).create()) .withDescription("URL of the solr server").withShortName("s").create(); Group group = gbuilder.withName("Options").withOption(inputFileOpt).withOption(countFileOpt) .withOption(outputFileOpt).withOption(solrUrlOpt).create(); try { Parser parser = new Parser(); parser.setGroup(group); CommandLine cmdLine = parser.parse(args); if (cmdLine.hasOption(helpOpt)) { CommandLineUtil.printHelp(group); return false; } inputFile = new File((String) cmdLine.getValue(inputFileOpt)); countFile = new File((String) cmdLine.getValue(countFileOpt)); outputFile = new File((String) cmdLine.getValue(outputFileOpt)); solrUrl = (String) cmdLine.getValue(solrUrlOpt); client = new TagRecommenderClient(solrUrl); } catch (OptionException e) { log.error("Command-line option Exception", e); CommandLineUtil.printHelp(group); return false; } catch (MalformedURLException e) { log.error("MalformedURLException", e); return false; } validate(); return true; } public void validate() { Util.validateFileWritable(outputFile); } public void loadTags(OpenObjectIntHashMap<String> tags) throws IOException { BufferedReader reader = new BufferedReader(new FileReader(countFile)); String line; while ((line = reader.readLine()) != null) { int pos = line.lastIndexOf('\t'); String tag = new String(line.substring(pos + 1)); tags.adjustOrPutValue(tag, 0, 0); } } public void execute() { PrintStream out = null; try { OpenObjectIntHashMap<String> tagCounts = new OpenObjectIntHashMap<String>(); OpenObjectIntHashMap<String> tagCorrect = new OpenObjectIntHashMap<String>(); loadTags(tagCounts); StackOverflowStream stream = new StackOverflowStream(); stream.open(inputFile.getAbsolutePath()); out = new PrintStream(new FileOutputStream(outputFile)); int correctTagCount = 0; int postCount = 0; HashSet<String> postTags = new HashSet<String>(); float postPctCorrect; int totalSingleCorrect = 0; int totalHalfCorrect = 0; for (StackOverflowPost post : stream) { correctTagCount = 0; postCount++; postTags.clear(); postTags.addAll(post.getTags()); for (String tag : post.getTags()) { if (tagCounts.containsKey(tag)) { tagCounts.adjustOrPutValue(tag, 1, 1); } } ScoreTag[] tags = client.getTags(post.getTitle() + "\n" + post.getBody(), maxTags); for (ScoreTag tag : tags) { if (postTags.contains(tag.getTag())) { correctTagCount += 1; tagCorrect.adjustOrPutValue(tag.getTag(), 1, 1); } } if (correctTagCount > 0) { totalSingleCorrect += 1; } postPctCorrect = correctTagCount / (float) postTags.size(); if (postPctCorrect >= 0.50f) { totalHalfCorrect += 1; } if ((postCount % 100) == 0) { dumpStats(System.err, postCount, totalSingleCorrect, totalHalfCorrect); } } dumpStats(System.err, postCount, totalSingleCorrect, totalHalfCorrect); dumpStats(out, postCount, totalSingleCorrect, totalHalfCorrect); dumpTags(out, tagCounts, tagCorrect); } catch (Exception ex) { throw (RuntimeException) new RuntimeException().initCause(ex); } finally { if (out != null) { out.close(); } } } /** Dump the tag metrics */ public void dumpTags(final PrintStream out, final OpenObjectIntHashMap<String> tagCounts, final OpenObjectIntHashMap<String> tagCorrect) { out.println("-- tag\ttotal\tcorrect\tpct-correct --"); tagCounts.forEachPair(new ObjectIntProcedure<String>() { @Override public boolean apply(String tag, int total) { int correct = tagCorrect.get(tag); out.println( tag + "\t" + total + "\t" + correct + "\t" + nf.format(((correct * 100) / (float) total))); return true; } }); out.println(); out.flush(); } /** Dump the overall metrics */ public void dumpStats(PrintStream out, int postCount, int totalSingleCorrect, int totalHalfCorrect) { out.println("evaluated " + postCount + " posts; " + totalSingleCorrect + " with one correct tag, " + totalHalfCorrect + " with half correct"); out.print("\t %single correct: " + nf.format((totalSingleCorrect * 100) / (float) postCount)); out.println(", %half correct: " + nf.format((totalHalfCorrect * 100) / (float) postCount)); out.println(); out.flush(); } }