com.ivanrf.smsspam.SpamClassifier.java Source code

Java tutorial

Introduction

Here is the source code for com.ivanrf.smsspam.SpamClassifier.java

Source

/*
 * Copyright (C) 2013 Ivan Ridao Freitas
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.ivanrf.smsspam;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.List;
import java.util.Random;

import javax.swing.JTextArea;

import com.ivanrf.utils.Utils;

import weka.attributeSelection.InfoGainAttributeEval;
import weka.attributeSelection.Ranker;
import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.functions.SMO;
import weka.classifiers.lazy.IBk;
import weka.classifiers.meta.AdaBoostM1;
import weka.classifiers.meta.FilteredClassifier;
import weka.classifiers.rules.PART;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instances;
import weka.core.converters.ArffLoader.ArffReader;
import weka.core.tokenizers.WordTokenizer;
import weka.filters.Filter;
import weka.filters.MultiFilter;
import weka.filters.supervised.attribute.AttributeSelection;
import weka.filters.unsupervised.attribute.StringToWordVector;

public class SpamClassifier {

    public static final String TOKENIZER_DEFAULT = "Default";
    public static final String TOKENIZER_COMPLETE = "Complete";
    public static final String TOKENIZER_COMPLETE_NUMBERS = "Complete with numbers";

    public static final String CLASSIFIER_SMO = "SMO";
    public static final String CLASSIFIER_NB = "NaiveBayes";
    public static final String CLASSIFIER_IB1 = "IB1";
    public static final String CLASSIFIER_IB3 = "IB3";
    public static final String CLASSIFIER_IB5 = "IB5";
    public static final String CLASSIFIER_PART = "PART";

    private static FilteredClassifier initFilterClassifier(int wordsToKeep, String tokenizerOp,
            boolean useAttributeSelection, String classifierOp, boolean boosting) throws Exception {
        StringToWordVector filter = new StringToWordVector();
        filter.setDoNotOperateOnPerClassBasis(true);
        filter.setLowerCaseTokens(true);
        filter.setWordsToKeep(wordsToKeep);

        if (!tokenizerOp.equals(TOKENIZER_DEFAULT)) {
            //Make a tokenizer
            WordTokenizer wt = new WordTokenizer();
            if (tokenizerOp.equals(TOKENIZER_COMPLETE))
                wt.setDelimiters(" \r\n\t.,;:\'\"()?!-+*&#$%/=<>[]_`@\\^{}");
            else //TOKENIZER_COMPLETE_NUMBERS)
                wt.setDelimiters(" \r\n\t.,;:\'\"()?!-+*&#$%/=<>[]_`@\\^{}|~0123456789");
            filter.setTokenizer(wt);
        }

        FilteredClassifier classifier = new FilteredClassifier();
        classifier.setFilter(filter);

        if (useAttributeSelection) {
            AttributeSelection as = new AttributeSelection();
            as.setEvaluator(new InfoGainAttributeEval());
            Ranker r = new Ranker();
            r.setThreshold(0);
            as.setSearch(r);

            MultiFilter mf = new MultiFilter();
            mf.setFilters(new Filter[] { filter, as });

            classifier.setFilter(mf);
        }

        if (classifierOp.equals(CLASSIFIER_SMO))
            classifier.setClassifier(new SMO());
        else if (classifierOp.equals(CLASSIFIER_NB))
            classifier.setClassifier(new NaiveBayes());
        else if (classifierOp.equals(CLASSIFIER_IB1))
            classifier.setClassifier(new IBk(1));
        else if (classifierOp.equals(CLASSIFIER_IB3))
            classifier.setClassifier(new IBk(3));
        else if (classifierOp.equals(CLASSIFIER_IB5))
            classifier.setClassifier(new IBk(5));
        else if (classifierOp.equals(CLASSIFIER_PART))
            classifier.setClassifier(new PART()); //Tarda mucho

        if (boosting) {
            AdaBoostM1 boost = new AdaBoostM1();
            boost.setClassifier(classifier.getClassifier());
            classifier.setClassifier(boost); //Con NB tarda mucho
        }

        return classifier;
    }

    public static void train(int wordsToKeep, String tokenizerOp, boolean useAttributeSelection,
            String classifierOp, boolean boosting, JTextArea log) {
        try {
            long start = System.currentTimeMillis();

            String modelName = getModelName(wordsToKeep, tokenizerOp, useAttributeSelection, classifierOp,
                    boosting);
            showEstimatedTime(true, modelName, log);

            Instances trainData = loadDataset("SMSSpamCollection.arff", log);
            trainData.setClassIndex(0);

            FilteredClassifier classifier = initFilterClassifier(wordsToKeep, tokenizerOp, useAttributeSelection,
                    classifierOp, boosting);

            publishEstado("=== Building the classifier on the filtered data ===", log);
            classifier.buildClassifier(trainData);

            publishEstado(classifier.toString(), log);
            publishEstado("=== Training done ===", log);

            saveModel(classifier, modelName, log);

            publishEstado("Elapsed time: " + Utils.getDateHsMinSegString(System.currentTimeMillis() - start), log);
        } catch (Exception e) {
            e.printStackTrace();
            publishEstado("Error found when training", log);
        }
    }

    public static void evaluate(int wordsToKeep, String tokenizerOp, boolean useAttributeSelection,
            String classifierOp, boolean boosting, JTextArea log) {
        try {
            long start = System.currentTimeMillis();

            String modelName = getModelName(wordsToKeep, tokenizerOp, useAttributeSelection, classifierOp,
                    boosting);
            showEstimatedTime(false, modelName, log);

            Instances trainData = loadDataset("SMSSpamCollection.arff", log);
            trainData.setClassIndex(0);
            FilteredClassifier classifier = initFilterClassifier(wordsToKeep, tokenizerOp, useAttributeSelection,
                    classifierOp, boosting);

            publishEstado("=== Performing cross-validation ===", log);
            Evaluation eval = new Evaluation(trainData);
            //         eval.evaluateModel(classifier, trainData);
            eval.crossValidateModel(classifier, trainData, 10, new Random(1));

            publishEstado(eval.toSummaryString(), log);
            publishEstado(eval.toClassDetailsString(), log);
            publishEstado(eval.toMatrixString(), log);
            publishEstado("=== Evaluation finished ===", log);

            publishEstado("Elapsed time: " + Utils.getDateHsMinSegString(System.currentTimeMillis() - start), log);
        } catch (Exception e) {
            e.printStackTrace();
            publishEstado("Error found when evaluating", log);
        }
    }

    public static String classify(String model, String text, JTextArea log) {
        FilteredClassifier classifier = loadModel(model, log);

        //Create the instance
        ArrayList<String> fvNominalVal = new ArrayList<String>();
        fvNominalVal.add("ham");
        fvNominalVal.add("spam");

        Attribute attribute1 = new Attribute("spam_class", fvNominalVal);
        Attribute attribute2 = new Attribute("text", (List<String>) null);
        ArrayList<Attribute> fvWekaAttributes = new ArrayList<Attribute>();
        fvWekaAttributes.add(attribute1);
        fvWekaAttributes.add(attribute2);

        Instances instances = new Instances("Test relation", fvWekaAttributes, 1);
        instances.setClassIndex(0);

        DenseInstance instance = new DenseInstance(2);
        instance.setValue(attribute2, text);
        instances.add(instance);

        publishEstado("=== Instance created ===", log);
        publishEstado(instances.toString(), log);

        //Classify the instance
        try {
            publishEstado("=== Classifying instance ===", log);

            double pred = classifier.classifyInstance(instances.instance(0));

            publishEstado("=== Instance classified  ===", log);

            String classPredicted = instances.classAttribute().value((int) pred);
            publishEstado("Class predicted: " + classPredicted, log);

            return classPredicted;
        } catch (Exception e) {
            publishEstado("Error found when classifying the text", log);
            return null;
        }
    }

    private static Instances loadDataset(String fileName, JTextArea log) {
        try {
            publishEstado("=== Loading dataset: " + fileName + " ===", log);
            BufferedReader reader = new BufferedReader(new FileReader(fileName));
            ArffReader arff = new ArffReader(reader);
            Instances trainData = arff.getData();
            reader.close();
            publishEstado("=== Dataset loaded ===", log);
            return trainData;
        } catch (IOException e) {
            publishEstado("Error found when reading: " + fileName, log);
            e.printStackTrace();
            return null;
        }
    }

    private static String getModelName(int wordsToKeep, String tokenizerOp, boolean useAttributeSelection,
            String classifierOp, boolean boosting) {
        String tk = "TK-";
        if (tokenizerOp.equals(TOKENIZER_DEFAULT))
            tk += "Default";
        else if (tokenizerOp.equals(TOKENIZER_COMPLETE))
            tk += "Complete";
        else //TOKENIZER_COMPLETE_NUMBERS
            tk += "Numbers";

        String modelName = classifierOp + "_W" + wordsToKeep + "_" + tk;

        if (useAttributeSelection)
            modelName += "_AttSelection";
        if (boosting)
            modelName += "_Boosting";

        return modelName + ".dat";
    }

    private static void saveModel(FilteredClassifier classifier, String fileName, JTextArea log) {
        try {
            publishEstado("=== Saving model: " + fileName + " ===", log);
            ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(fileName));
            out.writeObject(classifier);
            out.close();
            publishEstado("=== Model saved ===", log);
        } catch (IOException e) {
            publishEstado("Error found when writing: " + fileName, log);
        }
    }

    private static FilteredClassifier loadModel(String fileName, JTextArea log) {
        try {
            publishEstado("=== Loading model: " + fileName + " ===", log);
            ObjectInputStream in = new ObjectInputStream(new FileInputStream(fileName));
            FilteredClassifier classifier = (FilteredClassifier) in.readObject();
            in.close();
            publishEstado("=== Model loaded ===", log);
            return classifier;
        } catch (Exception e) {
            publishEstado("Error found when reading: " + fileName, log);
            return null;
        }
    }

    private static void publishEstado(String estado, JTextArea log) {
        log.append(estado + "\n");
    }

    //************************************* TIEMPOS ESTIMADOS ****************************************//

    private static void showEstimatedTime(boolean train, String modelName, JTextArea log) {
        modelName = modelName.replace(".dat", "");
        String estimatedTime = (train) ? trainingTime.get(modelName) : evaluationTime.get(modelName);
        if (estimatedTime != null)
            publishEstado("Estimated time: " + estimatedTime, log);
    }

    private static Hashtable<String, String> trainingTime = initTrainingTime();
    private static Hashtable<String, String> evaluationTime = initEvaluationTime();

    private static Hashtable<String, String> initTrainingTime() {
        Hashtable<String, String> tt = new Hashtable<String, String>();
        tt.put("SMO_W1000000_TK-Numbers", "23s");
        tt.put("SMO_W1000000_TK-Numbers_Boosting", "3m25s");
        tt.put("SMO_W1000000_TK-Complete_Boosting", "29s");
        tt.put("SMO_W1000000_TK-Complete", "33s");
        tt.put("SMO_W1000000_TK-Default", "27s");
        tt.put("SMO_W1000000_TK-Complete_AttSelection", "1m9s");
        tt.put("NaiveBayes_W1000000_TK-Numbers_Boosting", "39m47s (YOU CAN CANCEL)");
        tt.put("NaiveBayes_W1000000_TK-Complete_Boosting", "44m42s (YOU CAN CANCEL)");
        tt.put("SMO_W1000_TK-Complete", "11s");
        tt.put("SMO_W1000_TK-Default", "12s");
        tt.put("NaiveBayes_W1000000_TK-Numbers", "37s");
        tt.put("NaiveBayes_W1000000_TK-Complete", "41s");
        tt.put("NaiveBayes_W1000000_TK-Complete_AttSelection", "1m18s");
        tt.put("NaiveBayes_W1000_TK-Complete", "5s");
        tt.put("NaiveBayes_W1000000_TK-Default", "42s");
        tt.put("NaiveBayes_W1000_TK-Default", "5s");
        tt.put("PART_W1000000_TK-Default", "31m42s (YOU CAN CANCEL)");
        tt.put("PART_W1000_TK-Default", "3m39s");
        tt.put("PART_W1000_TK-Complete", "3m47s");
        tt.put("PART_W1000000_TK-Complete", "26m36s (YOU CAN CANCEL)");
        tt.put("IB1_W1000000_TK-Numbers", "4s");
        tt.put("IB1_W1000000_TK-Numbers_Boosting", "45s");
        tt.put("IB1_W1000_TK-Complete", "1s");
        tt.put("IB1_W1000_TK-Default", "1s");
        tt.put("IB1_W1000000_TK-Complete", "4s");
        tt.put("IB1_W1000000_TK-Complete_Boosting", "47s");
        tt.put("IB1_W1000000_TK-Complete_AttSelection", "1m10s");
        tt.put("IB1_W1000000_TK-Default", "4s");
        tt.put("IB3_W1000_TK-Complete", "1s");
        tt.put("IB3_W1000_TK-Default", "1s");
        tt.put("IB3_W1000000_TK-Complete", "4s");
        tt.put("IB3_W1000000_TK-Default", "5s");
        tt.put("IB5_W1000_TK-Complete", "2s");
        tt.put("IB5_W1000_TK-Default", "1s");
        tt.put("IB5_W1000000_TK-Complete", "4s");
        tt.put("IB5_W1000000_TK-Default", "5s");
        return tt;
    }

    private static Hashtable<String, String> initEvaluationTime() {
        Hashtable<String, String> et = new Hashtable<String, String>();
        et.put("SMO_W1000000_TK-Numbers", "2m16s");
        et.put("SMO_W1000000_TK-Numbers_Boosting", "18m24s");
        et.put("SMO_W1000000_TK-Complete_Boosting", "5m55s");
        et.put("SMO_W1000000_TK-Complete", "2m27s");
        et.put("SMO_W1000000_TK-Default", "2m38s");
        et.put("SMO_W1000000_TK-Complete_AttSelection", "8m53s");
        et.put("NaiveBayes_W1000000_TK-Numbers_Boosting", "5h42m6s (YOU CAN CANCEL)");
        et.put("NaiveBayes_W1000000_TK-Complete_Boosting", "8h1m (YOU CAN CANCEL)");
        et.put("SMO_W1000_TK-Complete", "1m34s");
        et.put("SMO_W1000_TK-Default", "1m31s");
        et.put("NaiveBayes_W1000000_TK-Numbers", "6m22s");
        et.put("NaiveBayes_W1000000_TK-Complete", "7m8s");
        et.put("NaiveBayes_W1000000_TK-Complete_AttSelection", "10m24s");
        et.put("NaiveBayes_W1000_TK-Complete", "58s");
        et.put("NaiveBayes_W1000000_TK-Default", "7m31s");
        et.put("NaiveBayes_W1000_TK-Default", "59s");
        et.put("PART_W1000000_TK-Default", "3h52m48s (YOU CAN CANCEL)");
        et.put("PART_W1000_TK-Default", "28m1s (YOU CAN CANCEL)");
        et.put("PART_W1000_TK-Complete", "33m22s (YOU CAN CANCEL)");
        et.put("PART_W1000000_TK-Complete", "3h35m51s (YOU CAN CANCEL)");
        et.put("IB1_W1000000_TK-Numbers", "1m38s");
        et.put("IB1_W1000000_TK-Numbers_Boosting", "6m29s");
        et.put("IB1_W1000_TK-Complete", "55s");
        et.put("IB1_W1000_TK-Default", "44s");
        et.put("IB1_W1000000_TK-Complete", "1m38s");
        et.put("IB1_W1000000_TK-Complete_Boosting", "6m52s");
        et.put("IB1_W1000000_TK-Complete_AttSelection", "9m8s");
        et.put("IB1_W1000000_TK-Default", "1m36s");
        et.put("IB3_W1000_TK-Complete", "57s");
        et.put("IB3_W1000_TK-Default", "50s");
        et.put("IB3_W1000000_TK-Complete", "1m44s");
        et.put("IB3_W1000000_TK-Default", "1m46s");
        et.put("IB5_W1000_TK-Complete", "59s");
        et.put("IB5_W1000_TK-Default", "51s");
        et.put("IB5_W1000000_TK-Complete", "1m46s");
        et.put("IB5_W1000000_TK-Default", "1m55s");
        return et;
    }
}