com.memonews.mahout.sentiment.SentimentModelTester.java Source code

Java tutorial

Introduction

Here is the source code for com.memonews.mahout.sentiment.SentimentModelTester.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.memonews.mahout.sentiment;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.List;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.ResultAnalyzer;
import org.apache.mahout.classifier.sgd.ModelSerializer;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.Dictionary;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Lists;
import com.google.common.collect.Multiset;

/**
 * Run the 20 news groups test data through SGD, as trained by
 * {@link org.apache.mahout.classifier.sgd.TrainNewsGroups}.
 */
public final class SentimentModelTester {

    private String inputFile;
    private String modelFile;

    private SentimentModelTester() {
    }

    public static void main(final String[] args) throws IOException {
        final SentimentModelTester runner = new SentimentModelTester();
        if (runner.parseArgs(args)) {
            runner.run(new PrintWriter(System.out, true));
        }
    }

    public void run(final PrintWriter output) throws IOException {

        final File base = new File(inputFile);
        // contains the best model
        final OnlineLogisticRegression classifier = ModelSerializer.readBinary(new FileInputStream(modelFile),
                OnlineLogisticRegression.class);

        final Dictionary newsGroups = new Dictionary();
        final Multiset<String> overallCounts = HashMultiset.create();

        final List<File> files = Lists.newArrayList();
        for (final File newsgroup : base.listFiles()) {
            if (newsgroup.isDirectory()) {
                newsGroups.intern(newsgroup.getName());
                files.addAll(Arrays.asList(newsgroup.listFiles()));
            }
        }
        System.out.printf("%d test files\n", files.size());
        final ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
        for (final File file : files) {
            final String ng = file.getParentFile().getName();

            final int actual = newsGroups.intern(ng);
            final SentimentModelHelper helper = new SentimentModelHelper();
            final Vector input = helper.encodeFeatureVector(file, overallCounts);// no
            // leak
            // type
            // ensures
            // this
            // is
            // a
            // normal
            // vector
            final Vector result = classifier.classifyFull(input);
            final int cat = result.maxValueIndex();
            final double score = result.maxValue();
            final double ll = classifier.logLikelihood(actual, input);
            final ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), score, ll);
            ra.addInstance(newsGroups.values().get(actual), cr);

        }
        output.printf("%s\n\n", ra.toString());
    }

    boolean parseArgs(final String[] args) {
        final DefaultOptionBuilder builder = new DefaultOptionBuilder();

        final Option help = builder.withLongName("help").withDescription("print this list").create();

        final ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        final Option inputFileOption = builder.withLongName("input").withRequired(true)
                .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
                .withDescription("where to get training data").create();

        final Option modelFileOption = builder.withLongName("model").withRequired(true)
                .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
                .withDescription("where to get a model").create();

        final Group normalArgs = new GroupBuilder().withOption(help).withOption(inputFileOption)
                .withOption(modelFileOption).create();

        final Parser parser = new Parser();
        parser.setHelpOption(help);
        parser.setHelpTrigger("--help");
        parser.setGroup(normalArgs);
        parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
        final CommandLine cmdLine = parser.parseAndHelp(args);

        if (cmdLine == null) {
            return false;
        }

        inputFile = (String) cmdLine.getValue(inputFileOption);
        modelFile = (String) cmdLine.getValue(modelFileOption);
        return true;
    }

}