org.apache.mahout.classifier.chi_rwcs.mapreduce.TestModel.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.classifier.chi_rwcs.mapreduce.TestModel.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.mahout.classifier.chi_rwcs.mapreduce;

import java.io.IOException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Arrays;

import com.google.common.io.Closeables;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.classifier.chi_rwcs.Chi_RWCSUtils;
import org.apache.mahout.classifier.ResultAnalyzer;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.chi_rwcs.data.Dataset;
import org.apache.mahout.classifier.chi_rwcs.mapreduce.Chi_RWCSClassifier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Tool to classify a Dataset using a previously built model
 */
public class TestModel extends Configured implements Tool {

    private static final Logger log = LoggerFactory.getLogger(TestModel.class);

    private FileSystem dataFS;
    private Path dataPath; // test data path
    private Path datasetPath;
    private Path modelPath; // path where the model is stored
    private FileSystem outFS;
    private Path outputPath; // path to predictions file, if null do not output the predictions
    private String dataName;
    private long time;

    @Override
    public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
        // TODO Auto-generated method stub
        DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
        ArgumentBuilder abuilder = new ArgumentBuilder();
        GroupBuilder gbuilder = new GroupBuilder();

        Option inputOpt = DefaultOptionCreator.inputOption().create();

        Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true)
                .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create())
                .withDescription("Dataset path").create();

        Option modelOpt = obuilder.withLongName("model").withShortName("m").withRequired(true)
                .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create())
                .withDescription("Path to the Model").create();

        Option outputOpt = DefaultOptionCreator.outputOption().create();

        Option helpOpt = DefaultOptionCreator.helpOption();

        Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(datasetOpt).withOption(modelOpt)
                .withOption(outputOpt).withOption(helpOpt).create();

        try {
            Parser parser = new Parser();
            parser.setGroup(group);
            CommandLine cmdLine = parser.parse(args);

            if (cmdLine.hasOption("help")) {
                CommandLineUtil.printHelp(group);
                return -1;
            }

            dataName = cmdLine.getValue(inputOpt).toString();
            String datasetName = cmdLine.getValue(datasetOpt).toString();
            String modelName = cmdLine.getValue(modelOpt).toString();
            String outputName = cmdLine.hasOption(outputOpt) ? cmdLine.getValue(outputOpt).toString() : null;

            if (log.isDebugEnabled()) {
                log.debug("inout     : {}", dataName);
                log.debug("dataset   : {}", datasetName);
                log.debug("model     : {}", modelName);
                log.debug("output    : {}", outputName);
            }

            dataPath = new Path(dataName);
            datasetPath = new Path(datasetName);
            modelPath = new Path(modelName);
            if (outputName != null) {
                outputPath = new Path(outputName);
            }

        } catch (OptionException e) {

            log.warn(e.toString(), e);
            CommandLineUtil.printHelp(group);
            return -1;

        }

        time = System.currentTimeMillis();

        testModel();

        time = System.currentTimeMillis() - time;

        writeToFileClassifyTime(Chi_RWCSUtils.elapsedTime(time));

        return 0;
    }

    private void testModel() throws IOException, ClassNotFoundException, InterruptedException {

        // make sure the output file does not exist
        if (outputPath != null) {
            outFS = outputPath.getFileSystem(getConf());
            if (outFS.exists(outputPath)) {
                throw new IllegalArgumentException("Output path already exists");
            }
        }

        // make sure the model exists
        FileSystem mfs = modelPath.getFileSystem(getConf());
        if (!mfs.exists(modelPath)) {
            throw new IllegalArgumentException("The model path does not exist");
        }

        // make sure the test data exists
        dataFS = dataPath.getFileSystem(getConf());
        if (!dataFS.exists(dataPath)) {
            throw new IllegalArgumentException("The Test data path does not exist");
        }

        if (outputPath == null) {
            throw new IllegalArgumentException(
                    "You must specify the ouputPath when using the mapreduce implementation");
        }

        Chi_RWCSClassifier classifier = new Chi_RWCSClassifier(modelPath, dataPath, datasetPath, outputPath,
                getConf());
        classifier.run();

        double[][] results = classifier.getResults();
        if (results != null) {
            writePredictions(results);
            Dataset dataset = Dataset.load(getConf(), datasetPath);
            ResultAnalyzer analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown");
            for (double[] res : results) {
                analyzer.addInstance(dataset.getLabelString(res[0]),
                        new ClassifierResult(dataset.getLabelString(res[1]), 1.0));
            }
            parseOutput(analyzer);
        }
    }

    private void parseOutput(ResultAnalyzer analyzer) throws IOException {
        NumberFormat decimalFormatter = new DecimalFormat("0.########");
        outFS = outputPath.getFileSystem(getConf());
        FSDataOutputStream ofile = null;
        int pos = dataName.indexOf('t');
        String subStr = dataName.substring(0, pos);
        Path filenamePath = new Path(outputPath, subStr + "_confusion_matrix").suffix(".txt");
        try {
            if (ofile == null) {
                // this is the first value, it contains the name of the input file
                ofile = outFS.create(filenamePath);
                // write the Confusion Matrix                                    
                StringBuilder returnString = new StringBuilder(200);
                returnString.append("=======================================================").append('\n');
                returnString.append("Confusion Matrix\n");
                returnString.append("-------------------------------------------------------").append('\n');
                int[][] matrix = analyzer.getConfusionMatrix().getConfusionMatrix();
                for (int i = 0; i < matrix.length - 1; i++) {
                    for (int j = 0; j < matrix[i].length - 1; j++) {
                        returnString.append(StringUtils.rightPad(Integer.toString(matrix[i][j]), 5)).append('\t');
                    }
                    returnString.append('\n');
                }
                returnString.append("-------------------------------------------------------").append('\n');
                returnString.append("AUC - Area Under the Curve ROC\n");
                returnString.append(StringUtils.rightPad(decimalFormatter.format(computeAuc(matrix)), 5))
                        .append('\n');
                returnString.append("-------------------------------------------------------").append('\n');
                returnString.append("GM - Geometric Mean\n");
                returnString.append(StringUtils.rightPad(decimalFormatter.format(computeGM(matrix)), 5))
                        .append('\n');
                returnString.append("-------------------------------------------------------").append('\n');
                String output = returnString.toString();
                ofile.writeUTF(output);
                ofile.close();
            }
        } finally {
            Closeables.closeQuietly(ofile);
        }
    }

    private double computeAuc(int[][] matrix) {
        int[] classesDistribution = new int[matrix.length - 1];
        for (int i = 0; i < matrix.length - 1; i++) {
            for (int j = 0; j < matrix[i].length - 1; j++) {
                classesDistribution[i] += matrix[i][j];
            }
        }
        int posClassId = 0;
        int posNumInstances = classesDistribution[0];
        for (int k = 1; k < matrix.length - 1; k++) {
            if (classesDistribution[k] < posNumInstances) {
                posClassId = k;
                posNumInstances = classesDistribution[k];
            }
        }
        double tp_rate = 0.0;
        double fp_rate = 0.0;
        if (posClassId == 0) {
            tp_rate = ((double) matrix[0][0] / (matrix[0][0] + matrix[0][1]));
            fp_rate = ((double) matrix[1][0] / (matrix[1][0] + matrix[1][1]));
        } else {
            fp_rate = ((double) matrix[0][1] / (matrix[0][1] + matrix[0][0]));
            tp_rate = ((double) matrix[1][1] / (matrix[1][1] + matrix[1][0]));
        }
        return ((1 + tp_rate - fp_rate) / 2);
    }

    private double computeGM(int[][] matrix) {
        int[] classesDistribution = new int[matrix.length - 1];
        for (int i = 0; i < matrix.length - 1; i++) {
            for (int j = 0; j < matrix[i].length - 1; j++) {
                classesDistribution[i] += matrix[i][j];
            }
        }
        int posClassId = 0;
        int posNumInstances = classesDistribution[0];
        for (int k = 1; k < matrix.length - 1; k++) {
            if (classesDistribution[k] < posNumInstances) {
                posClassId = k;
                posNumInstances = classesDistribution[k];
            }
        }
        double sensisivity = 0.0;
        double specificity = 0.0;
        if (posClassId == 0) {
            sensisivity = ((double) matrix[0][0] / (matrix[0][0] + matrix[0][1]));
            specificity = ((double) matrix[1][1] / (matrix[1][1] + matrix[1][0]));
        } else {
            specificity = ((double) matrix[0][0] / (matrix[0][0] + matrix[0][1]));
            sensisivity = ((double) matrix[1][1] / (matrix[1][1] + matrix[1][0]));
        }
        return (Math.sqrt(sensisivity * specificity));
    }

    private void writePredictions(double results[][]) throws IOException {
        NumberFormat decimalFormatter = new DecimalFormat("0.########");
        outFS = outputPath.getFileSystem(getConf());
        FSDataOutputStream ofile = null;
        Path filenamePath = new Path(outputPath, "Predictions").suffix(".txt");
        try {
            if (ofile == null) {
                // this is the first value, it contains the name of the input file
                ofile = outFS.create(filenamePath);
                // write the Confusion Matrix                                       
                StringBuilder returnString = new StringBuilder();

                for (double[] res : results) {
                    // returnString.append(res[1]+"\n");
                    String dato = Double.toString(res[1]) + "\n";
                    ofile.writeBytes(dato);

                }

                ofile.close();
            }
        } finally {
            Closeables.closeQuietly(ofile);
        }
    }

    private void writeToFileClassifyTime(String time) throws IOException {
        FileSystem outFS = outputPath.getFileSystem(getConf());
        FSDataOutputStream ofile = null;
        Path filenamePath = new Path(outputPath, dataName + "_classify_time").suffix(".txt");
        try {
            if (ofile == null) {
                // this is the first value, it contains the name of the input file
                ofile = outFS.create(filenamePath);
                // write the Classify Time                                    
                StringBuilder returnString = new StringBuilder(200);
                returnString.append("=======================================================").append('\n');
                returnString.append("Classify Time\n");
                returnString.append("-------------------------------------------------------").append('\n');
                returnString.append(StringUtils.rightPad(time, 5)).append('\n');
                returnString.append("-------------------------------------------------------").append('\n');
                String output = returnString.toString();
                ofile.writeUTF(output);
                ofile.close();
            }
        } finally {
            Closeables.closeQuietly(ofile);
        }
    }

    public static void main(String[] args) throws Exception {
        ToolRunner.run(new Configuration(), new TestModel(), args);
    }
}