mao.datamining.ModelProcess.java Source code

Java tutorial

Introduction

Here is the source code for mao.datamining.ModelProcess.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package mao.datamining;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import static mao.datamining.ModelProcess.folds;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;
import weka.core.Utils;

/**
 * java -Xmx10g -Djava.util.logging.config.file=c:/DS1/logging.properties -jar
 * datamining.jar c:/DS1/processed/train.arff c:/DS1/processed/test.arff
 * classifier opt1 opt2 opt3
 *
 * @author mao
 */
public class ModelProcess {

    //One logger
    public static final Logger ProcessLogger = Logger.getLogger(ModelProcess.class.getName());
    //How many folds for cross validatioin
    static int folds = 10;
    private String caseNum;
    private boolean useCostMatrix = false;

    public static void main(String[] args) {
        String trainDSPath = null;
        String testDSPath = null;
        String classifier = null;
        String optionsFilePath = null;
        String reportFilePath = null;
        String testCaseNum = null;
        boolean runDummy = false;

        //        String options[] = null;
        //weka.classifiers.meta.Bagging -P 100 -S 1 -I 10 -W weka.classifiers.lazy.IBk -- 
        //    -K 1 -W 0 -A "weka.core.neighboursearch.LinearNNSearch -A \"weka.core.EuclideanDistance -R first-last\"" 
        //weka.classifiers.meta.AdaBoostM1 -P 100 -S 1 -I 10 -W weka.classifiers.lazy.IBk -- -K 1 -W 0 -A "weka.core.neighboursearch.LinearNNSearch -A \"weka.core.EuclideanDistance -R first-last\""
        ArrayList<String> optionList = new ArrayList<>();

        for (int i = 0; i < args.length; i++) {
            if (args[i].endsWith(".arff")) {

                trainDSPath = args[i];
                testDSPath = args[++i];
                classifier = args[++i];
                optionsFilePath = args[++i];
                reportFilePath = args[++i];
                testCaseNum = args[++i];
                ++i;
                if (args[i] != null && args[i].equals("dummy")) {
                    runDummy = true;
                }

                break;
            }
        }

        logging("trainDSPath file path: " + trainDSPath);
        logging("testDSPath file path: " + testDSPath);
        logging("classifier: " + classifier);
        logging("optionsFilePath file path: " + optionsFilePath);
        logging("report file path: " + reportFilePath);
        logging("Testcase Num: " + testCaseNum);
        logging("Dummy Run: " + runDummy);

        ModelProcess m = new ModelProcess();
        m.setCaseNum(testCaseNum);
        if (trainDSPath.contains(DataSetPair.resampleMatrix)) {
            m.useCostMatrix = true;
        }
        logging("Use Cost Matrix: " + m.useCostMatrix);

        String testCaseDetailFile = optionsFilePath + ".summary.txt";

        try {
            FileInputStream is = new FileInputStream(optionsFilePath);
            InputStreamReader r = new InputStreamReader(is);
            BufferedReader reader = new BufferedReader(r);
            String line = null;
            while ((line = reader.readLine()) != null) {
                optionList.add(line);
            }

        } catch (FileNotFoundException ex) {
            Logger.getLogger(ModelProcess.class.getName()).log(Level.SEVERE, null, ex);
        } catch (IOException ex) {
            Logger.getLogger(ModelProcess.class.getName()).log(Level.SEVERE, null, ex);
        }

        ModelProcess.logging(trainDSPath + "," + testDSPath + "," + classifier);

        DataSetPair ds = new DataSetPair(trainDSPath, testDSPath);
        ds.loadBothDataSets();
        String optionArray[] = new String[optionList.size()];
        for (int i = 0; i < optionList.size(); i++) {
            optionArray[i] = optionList.get(i);
        }
        //run test case
        m.runTestCase(ds, classifier, reportFilePath, optionArray, testCaseNum, runDummy, testCaseDetailFile);
    }

    /**
     * To create and verify the classifier models.
     *
     * @param classifierClazz
     * @param options
     */
    private void runTestCase(DataSetPair ds, String classifierClazz, String reportFilePath, String[] options,
            String caseNum, boolean runDummy, String testCaseDetailFile) {

        String optionsCopy[] = new String[options.length];
        System.arraycopy(options, 0, optionsCopy, 0, options.length);
        //write the summary
        FileOutputStream testCaseSummaryOut = null;
        try {
            testCaseSummaryOut = new FileOutputStream(testCaseDetailFile);
            testCaseSummaryOut.write(("Train File: " + ds.getTrainFileName()).getBytes());
            testCaseSummaryOut.write("\n\n".getBytes());
        } catch (Exception ex) {
            Logger.getLogger(ModelProcess.class.getName()).log(Level.SEVERE, null, ex);
        }

        //Prepare the TestResult
        TestResult result = new TestResult();
        result.setCaseNum(caseNum);
        result.setMissingValueMode(ds.getMissingProcessMode());
        result.setResampleMode(ds.getResampleMethod());
        result.setFeatureSelectMode(ds.getFeatureSelectionMode());
        result.setClassifier(classifierClazz);
        StringBuilder optionsStr = new StringBuilder();
        for (String s : optionsCopy) {
            if (s == null) {
                break;
            }
            optionsStr.append(s).append(" ");
        }
        result.setClassifierOptions(optionsStr.toString());

        //get the training and test data sets
        Instances finalTrainDataSet = ds.getFinalTrainDataSet();
        Instances finalTestDataSet = ds.getFinalTestDataSet();
        //start building and testing
        ModelProcess.logging(
                "\n\n\n=========================================START=================================================\n"
                        + "======= " + classifierClazz + "," + optionsStr.toString() + "========\n"
                        + "=========================================START=================================================");
        if (runDummy) {
            ModelProcess.logging("Dummy Run the process");
        } else {

            Classifier classifier = null;
            String tmpClazz = null;
            String tmpOption[] = null;
            //if use matrix, 
            //weka.classifiers.meta.CostSensitiveClassifier -cost-matrix "[0.0 2.0; 5.0 0.0]" -S 1 -W weka.classifiers.trees.J48 -- -C 0.25 -M 2
            if (this.useCostMatrix) {
                List<String> costOptionsList = new ArrayList<>();
                costOptionsList.add("-cost-matrix");
                costOptionsList.add("\"[0.0 3.0; 30.0 0.0]\"");
                costOptionsList.add("-S");
                costOptionsList.add("1");
                costOptionsList.add("-W");
                costOptionsList.add(classifierClazz);
                costOptionsList.add("--");
                for (String s : optionsCopy) {
                    costOptionsList.add(s);
                }
                String newOptions[] = new String[costOptionsList.size()];
                for (int i = 0; i < newOptions.length; i++) {
                    newOptions[i] = costOptionsList.get(i);
                }
                tmpClazz = "weka.classifiers.meta.CostSensitiveClassifier";
                tmpOption = newOptions;
            } else {
                tmpClazz = classifierClazz;
                tmpOption = optionsCopy;
            } //end create the classifier

            try {

                try {
                    //finalTrainDataSet
                    testCaseSummaryOut.write(
                            ("Data Set Summary: " + finalTrainDataSet.toSummaryString() + "\n\n").getBytes());
                    testCaseSummaryOut.write(("classifier: " + tmpClazz + "\n").getBytes());
                    testCaseSummaryOut.write(("options: " + Arrays.toString(tmpOption) + "\n\n").getBytes());
                } catch (FileNotFoundException ex) {
                    Logger.getLogger(ModelProcess.class.getName()).log(Level.SEVERE, null, ex);
                }

                classifier = (Classifier) Utils.forName(Classifier.class, tmpClazz, tmpOption);
            } catch (Exception ex) {
                Logger.getLogger(ModelProcess.class.getName()).log(Level.SEVERE, null, ex);
            } //end create the classifier

            //            this.testCV(classifier, finalTrainDataSet, testCaseSummaryOut, result);
            this.testWithExtraDS(classifier, finalTrainDataSet, finalTestDataSet, testCaseSummaryOut, result);

            try {
                testCaseSummaryOut.close();
            } catch (IOException ex) {
                Logger.getLogger(ModelProcess.class.getName()).log(Level.SEVERE, null, ex);
            }

            ModelProcess.logging(
                    "========================================END==================================================\n"
                            + result.toString()
                            + "========================================END==================================================");

            this.lockFile(reportFilePath);
            try (FileOutputStream fout = new FileOutputStream(reportFilePath, true)) {
                fout.write(result.toString().getBytes());
                fout.write("\n".getBytes());
                fout.flush();
                fout.close();
            } catch (Exception ex) {
                Main.logging(null, ex);
            }
            this.unLockFile(reportFilePath);
        }
    }

    private boolean lockFile(String reportFilePath) {
        String lockFile = reportFilePath + ".lock";

        while (new File(lockFile).exists()) {
            System.out.println("[" + this.getCaseNum() + "] Waiting 100 milliseconds...");
            try {
                Thread.sleep(100);
            } catch (InterruptedException ex) {
                Logger.getLogger(ModelProcess.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        try (FileOutputStream fout = new FileOutputStream(lockFile)) {
            fout.write(this.getCaseNum().getBytes());
            fout.flush();
            fout.close();
        } catch (Exception ex) {
            Main.logging(null, ex);
        }

        return true;
    }

    private void unLockFile(String reportFilePath) {
        String lockFilePath = reportFilePath + ".lock";
        File lockFile = new File(lockFilePath);
        if (lockFile.exists()) {
            lockFile.delete();
        }
    }

    private static void logging(String msg) {
        ProcessLogger.info(msg);
    }

    private static void logging(String msg, Throwable t) {
        ProcessLogger.log(Level.SEVERE, null, t);
    }

    private String getCaseNum() {
        return this.caseNum;
    }

    private void setCaseNum(String n) {
        this.caseNum = n;
    }

    private void testWithExtraDS(Classifier classifier, Instances finalTrainDataSet, Instances finalTestDataSet,
            FileOutputStream testCaseSummaryOut, TestResult result) {
        //Use final training dataset and final test dataset
        double confusionMatrix[][] = null;

        long start, end, trainTime = 0, testTime = 0;
        if (finalTestDataSet != null) {
            try {
                //counting training time
                start = System.currentTimeMillis();
                classifier.buildClassifier(finalTrainDataSet);
                end = System.currentTimeMillis();
                trainTime += end - start;

                //counting test time
                start = System.currentTimeMillis();
                Evaluation testEvalOnly = new Evaluation(finalTrainDataSet);
                testEvalOnly.evaluateModel(classifier, finalTestDataSet);
                end = System.currentTimeMillis();
                testTime += end - start;

                testCaseSummaryOut.write("=====================================================\n".getBytes());
                testCaseSummaryOut.write((testEvalOnly.toSummaryString("=== Test Summary ===", true)).getBytes());
                testCaseSummaryOut.write("\n".getBytes());
                testCaseSummaryOut
                        .write((testEvalOnly.toClassDetailsString("=== Test Class Detail ===\n")).getBytes());
                testCaseSummaryOut.write("\n".getBytes());
                testCaseSummaryOut
                        .write((testEvalOnly.toMatrixString("=== Confusion matrix for Test ===\n")).getBytes());
                testCaseSummaryOut.flush();

                confusionMatrix = testEvalOnly.confusionMatrix();
                result.setConfusionMatrix4Test(confusionMatrix);

                result.setAUT(testEvalOnly.areaUnderROC(1));
                result.setPrecision(testEvalOnly.precision(1));
                result.setRecall(testEvalOnly.recall(1));
            } catch (Exception e) {
                ModelProcess.logging(null, e);
            }
            result.setTrainingTime(trainTime);
            result.setTestTime(testTime);
        } //using test data set , end

    }

    private void testCV(Classifier classifier, Instances finalTrainDataSet, FileOutputStream testCaseSummaryOut,
            TestResult result) {
        long start, end, trainTime = 0, testTime = 0;
        Evaluation evalAll = null;
        double confusionMatrix[][] = null;
        // randomize data, and then stratify it into 10 groups
        Random rand = new Random(1);
        Instances randData = new Instances(finalTrainDataSet);
        randData.randomize(rand);
        if (randData.classAttribute().isNominal()) {
            //always run with 10 cross validation
            randData.stratify(folds);
        }

        try {
            evalAll = new Evaluation(randData);
            for (int i = 0; i < folds; i++) {
                Evaluation eval = new Evaluation(randData);
                Instances train = randData.trainCV(folds, i);
                Instances test = randData.testCV(folds, i);
                //counting traininig time
                start = System.currentTimeMillis();
                Classifier j48ClassifierCopy = Classifier.makeCopy(classifier);
                j48ClassifierCopy.buildClassifier(train);
                end = System.currentTimeMillis();
                trainTime += end - start;

                //counting test time
                start = System.currentTimeMillis();
                eval.evaluateModel(j48ClassifierCopy, test);
                evalAll.evaluateModel(j48ClassifierCopy, test);
                end = System.currentTimeMillis();
                testTime += end - start;
            }

        } catch (Exception e) {
            ModelProcess.logging(null, e);
        } //end test by cross validation

        // output evaluation
        try {
            ModelProcess.logging("");
            //write into summary file
            testCaseSummaryOut
                    .write((evalAll.toSummaryString("=== Cross Validation Summary ===", true)).getBytes());
            testCaseSummaryOut.write("\n".getBytes());
            testCaseSummaryOut.write(
                    (evalAll.toClassDetailsString("=== " + folds + "-fold Cross-validation Class Detail ===\n"))
                            .getBytes());
            testCaseSummaryOut.write("\n".getBytes());
            testCaseSummaryOut
                    .write((evalAll.toMatrixString("=== Confusion matrix for all folds ===\n")).getBytes());
            testCaseSummaryOut.flush();

            confusionMatrix = evalAll.confusionMatrix();
            result.setConfusionMatrix10Folds(confusionMatrix);
        } catch (Exception e) {
            ModelProcess.logging(null, e);
        }
    }

}