org.apache.mahout.classifier.ResultAnalyzer.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.classifier.ResultAnalyzer.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier;

import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Collection;

import org.apache.commons.lang3.StringUtils;
import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.math.stats.OnlineSummarizer;

/** ResultAnalyzer captures the classification statistics and displays in a tabular manner */
public class ResultAnalyzer {

    private final ConfusionMatrix confusionMatrix;
    private final OnlineSummarizer summarizer;
    private boolean hasLL;

    /*
     * === Summary ===
     * 
     * Correctly Classified Instances 635 92.9722 % Incorrectly Classified Instances 48 7.0278 % Kappa statistic
     * 0.923 Mean absolute error 0.0096 Root mean squared error 0.0817 Relative absolute error 9.9344 % Root
     * relative squared error 37.2742 % Total Number of Instances 683
     */
    private int correctlyClassified;
    private int incorrectlyClassified;

    public ResultAnalyzer(Collection<String> labelSet, String defaultLabel) {
        confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel);
        summarizer = new OnlineSummarizer();
    }

    public ConfusionMatrix getConfusionMatrix() {
        return this.confusionMatrix;
    }

    /**
     * 
     * @param correctLabel
     *          The correct label
     * @param classifiedResult
     *          The classified result
     * @return whether the instance was correct or not
     */
    public boolean addInstance(String correctLabel, ClassifierResult classifiedResult) {
        boolean result = correctLabel.equals(classifiedResult.getLabel());
        if (result) {
            correctlyClassified++;
        } else {
            incorrectlyClassified++;
        }
        confusionMatrix.addInstance(correctLabel, classifiedResult);
        if (classifiedResult.getLogLikelihood() != Double.MAX_VALUE) {
            summarizer.add(classifiedResult.getLogLikelihood());
            hasLL = true;
        }
        return result;
    }

    @Override
    public String toString() {
        StringBuilder returnString = new StringBuilder();

        returnString.append('\n');
        returnString.append("=======================================================\n");
        returnString.append("Summary\n");
        returnString.append("-------------------------------------------------------\n");
        int totalClassified = correctlyClassified + incorrectlyClassified;
        double percentageCorrect = (double) 100 * correctlyClassified / totalClassified;
        double percentageIncorrect = (double) 100 * incorrectlyClassified / totalClassified;
        NumberFormat decimalFormatter = new DecimalFormat("0.####");

        returnString.append(StringUtils.rightPad("Correctly Classified Instances", 40)).append(": ")
                .append(StringUtils.leftPad(Integer.toString(correctlyClassified), 10)).append('\t')
                .append(StringUtils.leftPad(decimalFormatter.format(percentageCorrect), 10)).append("%\n");
        returnString.append(StringUtils.rightPad("Incorrectly Classified Instances", 40)).append(": ")
                .append(StringUtils.leftPad(Integer.toString(incorrectlyClassified), 10)).append('\t')
                .append(StringUtils.leftPad(decimalFormatter.format(percentageIncorrect), 10)).append("%\n");
        returnString.append(StringUtils.rightPad("Total Classified Instances", 40)).append(": ")
                .append(StringUtils.leftPad(Integer.toString(totalClassified), 10)).append('\n');
        returnString.append('\n');

        returnString.append(confusionMatrix);
        returnString.append("=======================================================\n");
        returnString.append("Statistics\n");
        returnString.append("-------------------------------------------------------\n");

        RunningAverageAndStdDev normStats = confusionMatrix.getNormalizedStats();
        returnString.append(StringUtils.rightPad("Kappa", 40))
                .append(StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getKappa()), 10)).append('\n');
        returnString.append(StringUtils.rightPad("Accuracy", 40))
                .append(StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getAccuracy()), 10))
                .append("%\n");
        returnString.append(StringUtils.rightPad("Reliability", 40))
                .append(StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10))
                .append("%\n");
        returnString.append(StringUtils.rightPad("Reliability (standard deviation)", 40))
                .append(StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10))
                .append('\n');
        returnString.append(StringUtils.rightPad("Weighted precision", 40))
                .append(StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedPrecision()), 10))
                .append('\n');
        returnString.append(StringUtils.rightPad("Weighted recall", 40))
                .append(StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedRecall()), 10))
                .append('\n');
        returnString.append(StringUtils.rightPad("Weighted F1 score", 40))
                .append(StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedF1score()), 10))
                .append('\n');

        if (hasLL) {
            returnString.append(StringUtils.rightPad("Log-likelihood", 30)).append("mean      : ")
                    .append(StringUtils.leftPad(decimalFormatter.format(summarizer.getMean()), 10)).append('\n');
            returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("25%-ile   : ", 10))
                    .append(StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(1)), 10))
                    .append('\n');
            returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("75%-ile   : ", 10))
                    .append(StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(3)), 10))
                    .append('\n');
        }

        return returnString.toString();
    }
}