com.tamingtext.tagrecommender.TestStackOverflowTagger.java Source code

Java tutorial

Introduction

Here is the source code for com.tamingtext.tagrecommender.TestStackOverflowTagger.java

Source

/*
 * Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 * -------------------
 * To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit
 * http://www.manning.com/ingersoll
 */

package com.tamingtext.tagrecommender;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.net.MalformedURLException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.HashSet;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.math.function.ObjectIntProcedure;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.tamingtext.tagrecommender.TagRecommenderClient.ScoreTag;

public class TestStackOverflowTagger {

    private static final Logger log = LoggerFactory.getLogger(TestStackOverflowTagger.class);

    private final NumberFormat nf = new DecimalFormat("##.##");
    private TagRecommenderClient client;
    private File inputFile;
    private File countFile;
    private File outputFile;

    private String solrUrl;
    private int maxTags = 5;

    public static void main(String[] args) {
        TestStackOverflowTagger t = new TestStackOverflowTagger();
        if (t.parseArgs(args)) {
            t.execute();
        }
    }

    public boolean parseArgs(String[] args) {
        DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
        ArgumentBuilder abuilder = new ArgumentBuilder();
        GroupBuilder gbuilder = new GroupBuilder();
        Option helpOpt = DefaultOptionCreator.helpOption();

        Option inputFileOpt = obuilder.withLongName("inputFile").withRequired(true)
                .withArgument(abuilder.withName("inputFile").withMinimum(1).withMaximum(1).create())
                .withDescription("The input file").withShortName("i").create();

        Option countFileOpt = obuilder.withLongName("countFile").withRequired(true)
                .withArgument(abuilder.withName("countFile").withMinimum(1).withMaximum(1).create())
                .withDescription("The tag count file").withShortName("c").create();

        Option outputFileOpt = obuilder.withLongName("outputFile").withRequired(true)
                .withArgument(abuilder.withName("outputFile").withMinimum(1).withMaximum(1).create())
                .withDescription("The output file").withShortName("c").create();

        Option solrUrlOpt = obuilder.withLongName("solrUrl").withRequired(true)
                .withArgument(abuilder.withName("solrUrl").withMinimum(1).withMaximum(1).create())
                .withDescription("URL of the solr server").withShortName("s").create();

        Group group = gbuilder.withName("Options").withOption(inputFileOpt).withOption(countFileOpt)
                .withOption(outputFileOpt).withOption(solrUrlOpt).create();

        try {
            Parser parser = new Parser();
            parser.setGroup(group);
            CommandLine cmdLine = parser.parse(args);

            if (cmdLine.hasOption(helpOpt)) {
                CommandLineUtil.printHelp(group);
                return false;
            }

            inputFile = new File((String) cmdLine.getValue(inputFileOpt));
            countFile = new File((String) cmdLine.getValue(countFileOpt));
            outputFile = new File((String) cmdLine.getValue(outputFileOpt));
            solrUrl = (String) cmdLine.getValue(solrUrlOpt);
            client = new TagRecommenderClient(solrUrl);
        } catch (OptionException e) {
            log.error("Command-line option Exception", e);
            CommandLineUtil.printHelp(group);
            return false;
        } catch (MalformedURLException e) {
            log.error("MalformedURLException", e);
            return false;
        }

        validate();
        return true;
    }

    public void validate() {
        Util.validateFileWritable(outputFile);
    }

    public void loadTags(OpenObjectIntHashMap<String> tags) throws IOException {
        BufferedReader reader = new BufferedReader(new FileReader(countFile));
        String line;
        while ((line = reader.readLine()) != null) {
            int pos = line.lastIndexOf('\t');
            String tag = new String(line.substring(pos + 1));
            tags.adjustOrPutValue(tag, 0, 0);
        }
    }

    public void execute() {
        PrintStream out = null;

        try {
            OpenObjectIntHashMap<String> tagCounts = new OpenObjectIntHashMap<String>();
            OpenObjectIntHashMap<String> tagCorrect = new OpenObjectIntHashMap<String>();
            loadTags(tagCounts);

            StackOverflowStream stream = new StackOverflowStream();
            stream.open(inputFile.getAbsolutePath());

            out = new PrintStream(new FileOutputStream(outputFile));

            int correctTagCount = 0;
            int postCount = 0;

            HashSet<String> postTags = new HashSet<String>();
            float postPctCorrect;

            int totalSingleCorrect = 0;
            int totalHalfCorrect = 0;

            for (StackOverflowPost post : stream) {
                correctTagCount = 0;
                postCount++;

                postTags.clear();
                postTags.addAll(post.getTags());
                for (String tag : post.getTags()) {
                    if (tagCounts.containsKey(tag)) {
                        tagCounts.adjustOrPutValue(tag, 1, 1);
                    }
                }

                ScoreTag[] tags = client.getTags(post.getTitle() + "\n" + post.getBody(), maxTags);

                for (ScoreTag tag : tags) {
                    if (postTags.contains(tag.getTag())) {
                        correctTagCount += 1;
                        tagCorrect.adjustOrPutValue(tag.getTag(), 1, 1);
                    }
                }

                if (correctTagCount > 0) {
                    totalSingleCorrect += 1;
                }

                postPctCorrect = correctTagCount / (float) postTags.size();
                if (postPctCorrect >= 0.50f) {
                    totalHalfCorrect += 1;
                }

                if ((postCount % 100) == 0) {
                    dumpStats(System.err, postCount, totalSingleCorrect, totalHalfCorrect);
                }

            }

            dumpStats(System.err, postCount, totalSingleCorrect, totalHalfCorrect);
            dumpStats(out, postCount, totalSingleCorrect, totalHalfCorrect);
            dumpTags(out, tagCounts, tagCorrect);
        } catch (Exception ex) {
            throw (RuntimeException) new RuntimeException().initCause(ex);
        } finally {
            if (out != null) {
                out.close();
            }
        }
    }

    /** Dump the tag metrics */
    public void dumpTags(final PrintStream out, final OpenObjectIntHashMap<String> tagCounts,
            final OpenObjectIntHashMap<String> tagCorrect) {

        out.println("-- tag\ttotal\tcorrect\tpct-correct --");

        tagCounts.forEachPair(new ObjectIntProcedure<String>() {
            @Override
            public boolean apply(String tag, int total) {
                int correct = tagCorrect.get(tag);

                out.println(
                        tag + "\t" + total + "\t" + correct + "\t" + nf.format(((correct * 100) / (float) total)));
                return true;
            }
        });

        out.println();
        out.flush();
    }

    /** Dump the overall metrics */
    public void dumpStats(PrintStream out, int postCount, int totalSingleCorrect, int totalHalfCorrect) {
        out.println("evaluated " + postCount + " posts; " + totalSingleCorrect + " with one correct tag, "
                + totalHalfCorrect + " with half correct");

        out.print("\t %single correct: " + nf.format((totalSingleCorrect * 100) / (float) postCount));
        out.println(", %half correct: " + nf.format((totalHalfCorrect * 100) / (float) postCount));
        out.println();
        out.flush();
    }
}