sentinets.Prediction.java Source code

Java tutorial

Introduction

Here is the source code for sentinets.Prediction.java

Source

/*******************************************************************************
 * Copyright (c) 2015 University of Illinois Board of Trustees, All rights reserved.
 * Developed at GSLIS/ the iSchool, by Dr. Jana Diesner, Shubhanshu Mishra, Liang Tao, and Chieh-Li Chin.    
 * This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or any later version.
 * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 * You should have received a copy of the GNU General Public License along with this program; if not, see <http://www.gnu.org/licenses>.
 *******************************************************************************/
/**
 * 
 */
package sentinets;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;

import org.apache.commons.lang3.StringUtils;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.SGD;
import weka.classifiers.meta.FilteredClassifier;
import weka.classifiers.trees.J48;
import weka.core.Debug.Random;
import weka.core.Instances;
import weka.core.converters.CSVLoader;
import weka.core.converters.CSVSaver;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

/**
 * This class is for handling all routines related to prediction and updating of the model.
 * @author Shubhanshu
 *
 */
public class Prediction {

    Instances original, labled, unlabled, ep_instances, sns_instances;
    String inputFile, outputDir;
    MODELTYPE classifierType;
    int clsIndex;
    String filePrefix;
    Classifier cls;

    private static int CLASSINDEX = 44;
    private static String stringAttr = "3,4,6,8,10-12,14,42,43,45";
    private static String removeAttr = "3,4,6,10-12,14,42,43,45-last";

    /*private static int CLASSINDEX = 19;
    private static String stringAttr = "3,4,6,8,10-12,14,17,18,20";
    private static String removeAttr = "3,4,6,10-12,14,17,18,20-last";*/
    //FilteredClassifier cls;
    public static enum outDirIndex {
        ORIGINAL, LABELED, UPDATES, MODELS
    };

    public static final String[] outDirs = { "original", "labled", "updates", "models" };
    private static final String SentimentModelFile = "./data/models/All_filtered_pos_j48.model"; // This file no longer exists.
    private static String SentimentModelFile_word = "./data/models/All_filtered_pos_SGD_10000_3.model";
    private String customModelFile = "";
    private static String FILEPREFIX = "Sentiment";
    private boolean showProbability = true;

    public static enum MODELTYPE {
        SENTIMENT("Meta Model"), SENTIMENT_WORD("Word Model"), CUSTOM("Custom...");
        private String text;

        private MODELTYPE(String text) {
            this.text = text;
        }

        @Override
        public String toString() {
            return this.text;
        }

    };

    private ArrayList<Double[][]> classDist;
    private String[] classNames = { "positive", "negative" };
    private int classProbIndex = -1;

    public static void setModelParams(int class_index, String string_list, String remove_list, String model) {
        CLASSINDEX = class_index;
        stringAttr = string_list;
        removeAttr = remove_list;
        SentimentModelFile_word = model;
    }

    public Prediction(String inputFile, String outputDir, MODELTYPE classifierType) {
        this.inputFile = inputFile;
        this.outputDir = outputDir;
        this.classifierType = classifierType;
        this.setInstances(this.inputFile);
    }

    public Prediction(String inputFile, String outputDir, MODELTYPE classifierType, String prefix) {
        this(inputFile, outputDir, classifierType);
        this.filePrefix = prefix;
    }

    /**
     * @return the customModelFile
     */
    public String getCustomModelFile() {
        return customModelFile;
    }

    /**
     * @param customModelFile the customModelFile to set
     */
    public void setCustomModelFile(String customModelFile) {
        this.customModelFile = customModelFile;
        System.out.println("Using custom model file for predictions: " + customModelFile);
    }

    /**
     * @param args
     */
    public static void main(String[] args) {
        String inputFile = "./output/Supplementary_POS.tsv";
        String outputDir = "./output";
        Prediction obj = new Prediction(inputFile, outputDir, MODELTYPE.SENTIMENT);
        obj.doPredictions(false);
    }

    private void prepareOutputFolder() {
        for (String dirName : outDirs) {
            dirName = outputDir + "/" + dirName;
            Utils.createFolder(dirName);
        }
    }

    /**
     * 
     * 
     * @param predicted - if false then do the predictions else do not do the predictions
     */
    public int doPredictions(boolean predicted) {
        //classDist = new ArrayList<Double[][]>();
        int status = 0;
        this.prepareOutputFolder();
        for (int i = 0; i < unlabled.numInstances(); i++) {
            /*
             * Distribution stored as:
             * {{E, P}, {S, NS}}
             */
            //classDist.add(TweetCorpusStatistics.getNullDist()); 
        }
        status = setClassifier(classifierType);
        if (status > 0) {
            System.err.println("Error in setting classifier.");
            return 1;
        }
        labled = new Instances(unlabled);
        labled.setClassIndex(labled.numAttributes() - 1);
        if (!predicted) {
            status = performClassification();
            if (status > 0) {
                System.err.println("Error in performing classification");
                return 2;
            }
        }

        status = writePredictions(labled, "/" + Utils.getOutDir(Utils.OutDirIndex.LABELED) + "/" + filePrefix);
        if (status < 0) {
            System.err.println("Writing output file failed: " + Utils.getOutDir(Utils.OutDirIndex.LABELED));
            return 3;
        }
        try {
            Remove r = new Remove();
            status = writePredictions(original,
                    "/" + Utils.getOutDir(Utils.OutDirIndex.ORIGINAL) + "/" + filePrefix);
            if (status < 0) {
                System.err.println("Writing output file failed: " + Utils.getOutDir(Utils.OutDirIndex.ORIGINAL));
                return 4;
            }
            //writeStats(original);
        } catch (Exception e) {
            e.printStackTrace();
            return 5;
        }
        return 0;
    }

    public String printDist(Double[][] dist) {
        String distStr = "";
        if (dist != null)
            distStr = dist[0][0] + "\t" + dist[0][1] + "\t" + dist[1][0] + "\t" + dist[1][1];
        return distStr;
    }

    public void writeStats(Instances tweetInstances) {
        //TweetCorpusStatistics stats = new TweetCorpusStatistics();
        System.out.println("Stats Instances: \n" + tweetInstances.toSummaryString());
        for (int i = 0; i < tweetInstances.size(); i++) {
            String user = tweetInstances.get(i).stringValue(11 - 1);
            String mentions = tweetInstances.get(i).stringValue(3 - 1);
            String hashtags = tweetInstances.get(i).stringValue(14 - 1);
            String epClass = tweetInstances.get(i).stringValue(15 - 1);
            String snsClass = tweetInstances.get(i).stringValue(16 - 1);
            System.out.println("Tweet Details:\t" + user + "\t" + mentions + "\t" + hashtags + "\t"
                    + printDist(classDist.get(i)));
            //stats.updateStatistics(user, mentions, hashtags, epClass+","+snsClass, classDist.get(i));
        }
    }

    public int performClassification() {
        for (int i = 0; i < unlabled.numInstances(); i++) {
            double clsLabel = 0;
            //Double[][] instanceDist = classDist.get(i);
            double[] dist = { -1., -1. };
            //int distIndex = 0;
            if (classifierType == MODELTYPE.SENTIMENT || classifierType == MODELTYPE.SENTIMENT_WORD
                    || classifierType == MODELTYPE.CUSTOM) {
                //distIndex = 0;
            }
            //System.out.println(unlabled.instance(i).toString(4));
            try {
                dist = cls.distributionForInstance(unlabled.instance(i));
                //System.out.println("Distributions for "+filePrefix+" :\t"+dist[0]+","+dist[1]+","+dist[2]);
                //instanceDist[distIndex][0] = dist[0];
                //instanceDist[distIndex][1] = dist[1];
                clsLabel = cls.classifyInstance(unlabled.instance(i));
                //System.out.println("CLSLABEL: " + clsLabel);
            } catch (Exception e) {
                e.printStackTrace();
                System.err.println("Classification task failed.");
                return 1;
            }
            labled.instance(i).setClassValue(clsLabel);
            original.instance(i).setValue(clsIndex, clsLabel);
            if (showProbability) {
                original.instance(i).setValue(classProbIndex, Double.max(dist[0], dist[1]));
            }
            //classDist.set(i, instanceDist);
        }
        return 0;
    }

    public String updateModel(String inputFile, ArrayList<Double[]> metrics) {
        String output = "";
        this.setInstances(inputFile);
        FilteredClassifier fcls = (FilteredClassifier) this.cls;
        SGD cls = (SGD) fcls.getClassifier();
        Filter filter = fcls.getFilter();
        Instances insAll;
        try {
            insAll = Filter.useFilter(this.unlabled, filter);
            if (insAll.size() > 0) {
                Random rand = new Random(10);
                int folds = 10 > insAll.size() ? 2 : 10;
                Instances randData = new Instances(insAll);
                randData.randomize(rand);
                if (randData.classAttribute().isNominal()) {
                    randData.stratify(folds);
                }
                Evaluation eval = new Evaluation(randData);
                eval.evaluateModel(cls, insAll);
                System.out.println("Initial Evaluation");
                System.out.println(eval.toSummaryString());
                System.out.println(eval.toClassDetailsString());
                metrics.add(new Double[] { eval.fMeasure(0), eval.fMeasure(1), eval.weightedFMeasure() });
                output += "\n====" + "Initial Evaluation" + "====\n";
                output += "\n" + eval.toSummaryString();
                output += "\n" + eval.toClassDetailsString();
                System.out.println("Cross Validated Evaluation");
                output += "\n====" + "Cross Validated Evaluation" + "====\n";
                for (int n = 0; n < folds; n++) {
                    Instances train = randData.trainCV(folds, n);
                    Instances test = randData.testCV(folds, n);

                    for (int i = 0; i < train.numInstances(); i++) {
                        cls.updateClassifier(train.instance(i));
                    }

                    eval.evaluateModel(cls, test);
                    System.out.println("Cross Validated Evaluation fold: " + n);
                    output += "\n====" + "Cross Validated Evaluation fold (" + n + ")====\n";
                    System.out.println(eval.toSummaryString());
                    System.out.println(eval.toClassDetailsString());
                    output += "\n" + eval.toSummaryString();
                    output += "\n" + eval.toClassDetailsString();
                    metrics.add(new Double[] { eval.fMeasure(0), eval.fMeasure(1), eval.weightedFMeasure() });
                }
                for (int i = 0; i < insAll.numInstances(); i++) {
                    cls.updateClassifier(insAll.instance(i));
                }
                eval.evaluateModel(cls, insAll);
                System.out.println("Final Evaluation");
                System.out.println(eval.toSummaryString());
                System.out.println(eval.toClassDetailsString());
                output += "\n====" + "Final Evaluation" + "====\n";
                output += "\n" + eval.toSummaryString();
                output += "\n" + eval.toClassDetailsString();
                metrics.add(new Double[] { eval.fMeasure(0), eval.fMeasure(1), eval.weightedFMeasure() });
                fcls.setClassifier(cls);
                String modelFilePath = outputDir + "/" + Utils.getOutDir(Utils.OutDirIndex.MODELS)
                        + "/updatedClassifier.model";
                weka.core.SerializationHelper.write(modelFilePath, fcls);
                output += "\n" + "Updated Model saved at: " + modelFilePath;
            } else {
                output += "No new instances for training the model.";
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return output;
    }

    public int setClassifier(MODELTYPE classifierType) {
        String modelFile = "";
        cls = null;
        try {
            if (classifierType == MODELTYPE.SENTIMENT) {
                modelFile = SentimentModelFile;
                clsIndex = CLASSINDEX - 1;

                cls = (J48) weka.core.SerializationHelper.read(modelFile);

                //filePrefix = FILEPREFIX;
            } else if (classifierType == MODELTYPE.SENTIMENT_WORD) {
                modelFile = SentimentModelFile_word;
                clsIndex = CLASSINDEX - 1;
                cls = (FilteredClassifier) weka.core.SerializationHelper.read(modelFile);
                //filePrefix = FILEPREFIX;
            } else if (classifierType == MODELTYPE.CUSTOM) {
                modelFile = getCustomModelFile();
                //this.classifierType = MODELTYPE.SENTIMENT_WORD; 
                clsIndex = CLASSINDEX - 1;
                cls = (FilteredClassifier) weka.core.SerializationHelper.read(modelFile);
            } else {
                System.out.println("Wrong Classifier type");
                return 1;
            }
        } catch (Exception e) {
            e.printStackTrace();
            return 2;
        }
        return 0;
    }

    public int setInstances(String inputFile) {
        //String[] nominalVals =  {"42:positive,neutral,negative"};
        String[] nominalVals = { CLASSINDEX + ":" + StringUtils.join(classNames, ",") };
        original = null;
        try {
            System.out.println("[In Prediction] Loading instances. ");
            CSVLoader csvSource = new CSVLoader();
            csvSource.setSource(new File(inputFile));
            csvSource.setFieldSeparator("\t");
            csvSource.setNominalAttributes(CLASSINDEX + "");
            csvSource.setStringAttributes(stringAttr);
            csvSource.setNominalLabelSpecs(nominalVals);
            original = csvSource.getDataSet();
            unlabled = original;
            classProbIndex = original.numAttributes() - 1;
            //System.out.println(unlabled.toSummaryString());
            Remove r = new Remove();
            //r.setAttributeIndices("3-4,6,10-12,14");
            if (classifierType == MODELTYPE.SENTIMENT || classifierType == MODELTYPE.SENTIMENT_WORD
                    || classifierType == MODELTYPE.CUSTOM) {
                if (showProbability) {
                    /*
                    Add afilter;
                    afilter = new Add();
                    afilter.setAttributeName("last");
                    afilter.setAttributeName("prediction_prob");
                    afilter.setInputFormat(original);
                    original = Filter.useFilter(original, afilter);
                    classProbIndex = original.numAttributes()-1;*/
                }
                if (classifierType == MODELTYPE.SENTIMENT) {
                    r.setAttributeIndices("3,4,6,8,10-12,14,42,43,45-last");
                    System.out.println("Filtering instances for SENTIMENT");
                } else if (classifierType == MODELTYPE.SENTIMENT_WORD || classifierType == MODELTYPE.CUSTOM) {
                    r.setAttributeIndices(removeAttr);
                    System.out.println("Filtering instances for SENTIMENT WORD");
                }
            }
            //r.setAttributeIndices("3-4,6,10-12,14,40-41,43-last");
            r.setInputFormat(unlabled);
            unlabled = Remove.useFilter(unlabled, r);
            //System.out.println(unlabled.toSummaryString());
            r = new Remove();
            //System.out.println(unlabled.toSummaryString());

        } catch (FileNotFoundException e) {
            e.printStackTrace();
            return 1;
        } catch (IOException e) {
            e.printStackTrace();
            return 2;
        } catch (Exception e) {
            e.printStackTrace();
            return 3;
        }
        int cIdx = unlabled.numAttributes() - 1;
        unlabled.setClassIndex(cIdx);
        System.out.println(
                "Class Attribute is: " + unlabled.classAttribute() + " at index: " + unlabled.classIndex());
        return 0;
    }

    public int writePredictions(Instances ins, String filePrefix) {
        try {
            System.out.println("Trying to create the following files:");
            System.out.println(outputDir + "/" + filePrefix + ".arff");
            System.out.println(outputDir + "/" + filePrefix + ".tsv");
            BufferedWriter writer = new BufferedWriter(new FileWriter(outputDir + "/" + filePrefix + ".arff"));
            writer.write(ins.toString());
            writer.newLine();
            writer.flush();
            writer.close();
            CSVSaver s = new CSVSaver();

            s.setFile(new File(outputDir + "/" + filePrefix + ".tsv"));
            s.setInstances(ins);
            s.setFieldSeparator("\t");
            s.writeBatch();

        } catch (IOException e) {
            e.printStackTrace();
            return 1;
        }
        return 0;
    }

}