com.deafgoat.ml.prognosticator.AppClassifier.java Source code

Java tutorial

Introduction

Here is the source code for com.deafgoat.ml.prognosticator.AppClassifier.java

Source

/**
 * Copyright 2012, Wisdom Omuya.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at:
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.deafgoat.ml.prognosticator;

//Java
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;

//Log4j
import org.apache.log4j.Logger;

//Weka
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Attribute;
import weka.core.Instances;
import weka.core.SerializationHelper;

/**
 * Builds and cross validates models, classifies test instances.
 */
public final class AppClassifier {
    /**
     * Perform cross-validation on data set/builds model
     * 
     * @throws Exception
     */
    public void crossValidate() throws Exception {
        // stratify nominal target class
        if (_trainInstances.classAttribute().isNominal()) {
            _trainInstances.stratify(_folds);
        }
        _eval = new Evaluation(_trainInstances);
        for (int n = 0; n < _folds; n++) {

            if (_logger.isDebugEnabled()) {
                _logger.debug("Cross validation fold: " + (n + 1));
            }
            _train = _trainInstances.trainCV(_folds, n);
            _test = _trainInstances.testCV(_folds, n);
            _clsCopy = AbstractClassifier.makeCopy(_cls);
            try {
                _clsCopy.buildClassifier(_train);
            } catch (Exception e) {
                _logger.debug(_config._classifier + " can not handle " + getAttributeType(_test.classAttribute())
                        + " class attributes");
            }

            try {
                _eval.evaluateModel(_clsCopy, _test);
            } catch (Exception e) {
                _logger.debug("Can not evaluate model");
            }
        }

        if (_config._writeToMongoDB) {
            _logger.info("Writing model to mongoDB");
            // save the trained model
            saveModel();
            // save CV performance of trained model
            writeToMongoDB(_eval);
        }

        if (_config._writeToFile) {
            _logger.info("Writing model to file");
            SerializationHelper.write(_config._modelFile, _clsCopy);
        }
    }

    /**
     * Gets details on classified instances according to supplied attribute
     * 
     * @param attribute
     *            The focal attribute for error analysis
     * @throws Exception
     *             If model can not be evaluated
     */
    public void errorAnalysis(String attribute) throws Exception {
        readModel();
        _logger.info("Performing error analysis");
        Evaluation eval = new Evaluation(_testInstances);
        eval.evaluateModel(_cls, _testInstances);
        _predictionList = new HashMap<String, List<Prediction>>();
        String predicted, actual = null;
        double[] distribution = null;
        _predictionList.put(_config._truePositives, new ArrayList<Prediction>());
        _predictionList.put(_config._trueNegatives, new ArrayList<Prediction>());
        _predictionList.put(_config._falsePositives, new ArrayList<Prediction>());
        _predictionList.put(_config._falseNegatives, new ArrayList<Prediction>());
        for (int i = 0; i < _testInstances.numInstances(); i++) {
            distribution = _cls.distributionForInstance(_testInstances.instance(i));
            actual = _testInstances.classAttribute().value((int) _testInstances.instance(i).classValue());
            predicted = _testInstances.classAttribute()
                    .value((int) _cls.classifyInstance(_testInstances.instance(i)));
            // 0 is negative, 1 is positive
            if (!predicted.equals(actual)) {
                if (actual.equals(_config._negativeClassValue)) {
                    _predictionList.get(_config._falsePositives)
                            .add(new Prediction(i + 1, predicted, distribution, _fullData.instance(i)));
                } else if (actual.equals(_config._positiveClassValue)) {
                    _predictionList.get(_config._falseNegatives)
                            .add(new Prediction(i + 1, predicted, distribution, _fullData.instance(i)));
                }
            } else if (predicted.equals(actual)) {
                if (actual.equals(_config._negativeClassValue)) {
                    _predictionList.get(_config._trueNegatives)
                            .add(new Prediction(i + 1, predicted, distribution, _fullData.instance(i)));
                } else if (actual.equals(_config._positiveClassValue)) {
                    _predictionList.get(_config._truePositives)
                            .add(new Prediction(i + 1, predicted, distribution, _fullData.instance(i)));
                }
            }
        }
        BufferedWriter writer = null;
        String name, prediction = null;
        for (Entry<String, List<Prediction>> entry : _predictionList.entrySet()) {
            name = entry.getKey();
            Collections.sort(_predictionList.get(name), Collections.reverseOrder());
            writer = new BufferedWriter(new FileWriter(name));
            List<Prediction> predictions = _predictionList.get(name);
            for (int count = 0; count < predictions.size(); count++) {
                if (count < _config._maxCount) {
                    prediction = predictions.get(count).attributeDistribution(attribute);
                    if (Double.parseDouble(prediction.split(_delimeter)[1]) >= _config._minProb) {
                        writer.write(prediction + "\n");
                    }
                } else {
                    break;
                }
            }
            writer.close();
        }
    }

    /**
     * Evaluates model performance on test instances
     * 
     * @throws Exception
     *             If model can not be evaluated.
     */
    public void evaluate() throws Exception {
        readModel();
        _logger.info("Classifying with " + _config._classifier);
        Evaluation eval = new Evaluation(_testInstances);
        eval.evaluateModel(_cls, _testInstances);
        _logger.info("\n" + eval.toSummaryString());
        try {
            _logger.info("\n" + eval.toClassDetailsString());
        } catch (Exception e) {
            _logger.info("Can not create class details" + _config._classifier);
        }
        try {
            _logger.info("\n" + _eval.toMatrixString());
        } catch (Exception e) {
            _logger.info(
                    "Can not create confusion matrix for " + _config._classifier + " using " + _config._classValue);
        }
    }

    /**
     * Returns the Weka type of the given attribute
     */
    public String getAttributeType(Attribute attribute) {
        if (attribute.isDate()) {
            return "date";
        } else if (attribute.isNominal()) {
            return "nominal";
        } else if (attribute.isNumeric()) {
            return "numeric";
        } else {
            return "string";
        }
    }

    /**
     * Initialize instances classifier.
     * 
     * @throws Exception
     *             If the classifier can not be initialized.
     */
    public void initializeClassifier() throws Exception {
        String base = "weka.classifiers.";
        String[] groups = new String[] { "bayes.", "functions.", "lazy.", "meta.", "misc.", "rules.", "trees." };
        for (int i = 0; i < groups.length; i++) {
            try {
                _cls = AbstractClassifier.forName(base + groups[i] + _config._classifier, null);
                break;
            } catch (Exception e) {
                if (i == groups.length - 1) {
                    _logger.error("Could not create classifier - msg: " + e.getMessage(), e);
                }
            }
        }
    }

    /**
     * Does prediction in production
     * 
     * @throws Exception
     *             If model can not be evaluated
     */
    public void predict() throws Exception {
        _logger.info("Predicting test instances");
        readModel();
        String predicted = null;
        Prediction prediction = null;
        double[] distribution = null;
        HashMap<String, String> result = null;
        MongoResult mongoResult = null;
        ArrayList<Prediction> predictionList = new ArrayList<Prediction>();
        ArrayList<HashMap<String, String>> predictions = new ArrayList<HashMap<String, String>>();

        for (int i = 0; i < _testInstances.numInstances(); i++) {
            try {
                distribution = _cls.distributionForInstance(_testInstances.instance(i));
                predicted = _testInstances.classAttribute()
                        .value((int) _cls.classifyInstance(_testInstances.instance(i)));
                prediction = new Prediction(i + 1, predicted, distribution, _fullData.instance(i));

                if (_testInstances.classAttribute().isNominal() && _config._onlyPosNominal) {
                    // write only 'positive' predictions to file
                    if (predicted.equals(_config._positiveClassValue)) {
                        predictionList.add(prediction);
                    }
                } else {
                    predictionList.add(prediction);
                }
            } catch (Exception e) {
                _logger.debug(_config._classifier + " does not provide instance prediction distribution");
            }

            // writing ALL predictions to database
            if (_config._writeToMongoDB) {
                result = new HashMap<String, String>();
                if (_testInstances.classAttribute().isNumeric()) {
                    result.put("confidence", "");
                } else {
                    result.put("confidence", prediction.getConfidence().toString());
                }
                result.put(_config._classValue, prediction.getPrediction());
                predictions.add(result);
            }
        }

        if (_config._writeToFile) {
            _logger.info("Writing predictions to file");
            // sort prediction list
            try {
                Collections.sort(predictionList);
            } catch (Exception e) {
                _logger.debug("Can not use prediction compareTo");
            }

            BufferedWriter writer = new BufferedWriter(new FileWriter(_config._predictionFile));
            String value = null;
            double confidence = 0.0;
            int count = 0;
            int index = 0;
            for (Prediction entry : predictionList) {
                confidence = entry.getConfidence();
                value = entry.getPrediction();
                index = entry.getIndex();
                if (count < _config._maxCount) {
                    if (confidence >= _config._minProb) {
                        if (_testInstances.classAttribute().isNumeric()) {
                            writer.write(index + _delimeter + value + "\n");
                        } else {
                            writer.write(index + _delimeter + confidence + _delimeter + value + "\n");
                        }
                        count += 1;
                    }
                } else {
                    break;
                }
            }
            writer.close();
        }

        if (_config._writeToMongoDB) {
            _logger.info("Writing predictions to mongoDB");
            mongoResult = new MongoResult(_config._host, _config._port, _config._db, _config._predictionCollection);
            mongoResult.writeResult(_config._relation, predictions);
        }
    }

    /**
     * Output cross-validation results
     * 
     * @throws Exception
     *             If the confusion matrix can not be shown
     */
    public void printSummary() throws Exception {
        _logger.info(_eval.toSummaryString("\n" + _folds + "-fold Cross-validation\n", false));
        try {
            _logger.info("\n" + _eval.toMatrixString());
        } catch (Exception e) {
            _logger.info(
                    "Can not create confusion matrix for " + _config._classifier + " using " + _config._classValue);
        }
    }

    /**
     * Reads the trained model
     * 
     * @throws Exception
     *             If the model can not be read.
     */
    public void readModel() throws Exception {
        if (_logger.isDebugEnabled()) {
            _logger.debug("Deserializing model");
        }

        if (_config._writeToMongoDB) {
            MongoResult mongoResult = new MongoResult(_config._host, _config._port, _config._db,
                    _config._modelCollection);
            _cls = mongoResult.readModel(_config._relation);
            mongoResult.close();
        }

        if (_config._writeToFile) {
            _cls = (Classifier) SerializationHelper.read(_config._modelFile);
        }
    }

    /**
     * Saves the trained model
     * 
     * @throws Exception
     *             If the model can not be saved
     */
    public void saveModel() throws Exception {
        if (_logger.isDebugEnabled()) {
            _logger.debug("Serializing model");
        }

        if (_config._writeToMongoDB) {
            MongoResult mongoResult = new MongoResult(_config._host, _config._port, _config._db,
                    _config._modelCollection);
            mongoResult.writeModel(_config._relation, _clsCopy);
            mongoResult.close();
        }

        if (_config._writeToFile) {
            SerializationHelper.write(_config._modelFile, _clsCopy);
        }
    }

    /**
     * Write results to mongoDB
     * 
     * @param eval
     *            The evaluation object holding data.
     * @throws Exception
     */
    public void writeToMongoDB(Evaluation eval) throws Exception {
        MongoResult mongoResult = new MongoResult(_config._host, _config._port, _config._db,
                _config._modelCollection);
        mongoResult.writeExperiment(_config._relation, "summary", eval.toSummaryString());
        try {
            mongoResult.writeExperiment(_config._relation, "class detail", eval.toClassDetailsString());
        } catch (Exception e) {
            _logger.error("Can not create class details" + _config._classifier);
        }
        try {
            mongoResult.writeExperiment(_config._relation, "confusion matrix", eval.toMatrixString());
        } catch (Exception e) {
            _logger.error("Can not create confusion matrix for " + _config._classifier);
        }
        mongoResult.close();
    }

    /**
     * handle to classifier object
     */
    private Classifier _cls;

    /**
     * handle to a copy of the classifier object
     */
    private Classifier _clsCopy;

    /**
     * configuration handle
     */
    private ConfigReader _config;

    /**
     * the delimeter to use in the prediction file
     */
    private static final String _delimeter = "\t";

    /**
     * handle AbstractClassifier the evaluation object
     */
    private Evaluation _eval;

    /**
     * number of folds to use in cross-validation
     */
    private int _folds;

    /**
     * the full set of unfiltered data
     */
    private Instances _fullData;

    /**
     * handle to the logger
     */
    private Logger _logger;

    /**
     * contains all predictions made on test data
     */
    private HashMap<String, List<Prediction>> _predictionList;

    /**
     * holds test data (used in CV)
     */
    private Instances _test;

    /**
     * holds initialized test data
     */
    private Instances _testInstances;

    /**
     * holds training data (used in CV)
     */
    private Instances _train;

    /**
     * holds initialized training data
     */
    private Instances _trainInstances;

    /**
     * Constructor for classifying a given set of unfiltered test instances
     * 
     * @param fullData
     *            The full data set.
     * @param config
     *            The config reader handle.
     * @throws Exception
     */
    public AppClassifier(Instances fullData, ConfigReader config) throws Exception {
        _config = config;
        _testInstances = fullData;
        _fullData = fullData;
        _logger = AppLogger.getLogger();
    }

    /**
     * Constructor for classifying a given set of test instances which have been
     * filtered
     * 
     * @param filteredData
     *            The filtered data set.
     * @param fullData
     *            The full data set.
     * @param config
     *            The config reader handle.
     * @throws Exception
     */
    public AppClassifier(Instances filteredData, Instances fullData, ConfigReader config) throws Exception {
        _config = config;
        _testInstances = filteredData;
        _fullData = fullData;
        _logger = AppLogger.getLogger();
    }

    /**
     * Constructor to build model based on CV
     * 
     * @param trainData
     *            The training data.
     * @param fold
     *            The number of folds for corss validation.
     * @param config
     *            The config reader handle.
     * @throws Exception
     */
    public AppClassifier(Instances trainData, int fold, ConfigReader config) throws Exception {
        _folds = fold;
        _config = config;
        _trainInstances = new Instances(trainData);
        _logger = AppLogger.getLogger();
    }
}