Java tutorial
/******************************************************************************* * Copyright (C) 2014 ?lvaro Carrera Barroso * Grupo de Sistemas Inteligentes - Universidad Politecnica de Madrid * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. *******************************************************************************/ /** * es.upm.dit.gsi.barmas.launcher.WekaClassifiersValidator.java */ package es.upm.dit.gsi.barmas.launcher; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.logging.Level; import java.util.logging.Logger; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.functions.SMO; import weka.classifiers.rules.PART; import weka.classifiers.trees.J48; import weka.classifiers.trees.LADTree; import weka.classifiers.trees.NBTree; import weka.core.Instances; import weka.core.converters.ConverterUtils.DataSource; import com.csvreader.CsvWriter; import es.upm.dit.gsi.barmas.launcher.logging.LogConfigurator; /** * Project: barmas File: * es.upm.dit.gsi.barmas.launcher.WekaClassifiersValidator.java * * Grupo de Sistemas Inteligentes Departamento de Ingeniera de Sistemas * Telemticos Universidad Politcnica de Madrid (UPM) * * @author alvarocarrera * @email a.carrera@gsi.dit.upm.es * @twitter @alvarocarrera * @date 17/02/2014 * @version 0.1 * */ public class WekaClassifiersValidator { private Logger logger; private String dataset; private String inputFolder; private String outputFolder; private String resultsFilePath; private CsvWriter writer; private int folds; private int columns; private Integer[] lebas; /** * @param args */ public static void main(String[] args) { HashMap<String, Integer[]> datasets = new HashMap<String, Integer[]>(); // Integer[] zoolebas = { 0, 4, 8 }; // datasets.put("zoo", zoolebas); // Integer[] solarflarelebas = { 0, 3, 6 }; // datasets.put("solarflare", solarflarelebas); // Integer[] marketinglebas = { 0, 3, 7 }; // datasets.put("marketing", marketinglebas); // Integer[] mushroomlebas = { 0, 6, 11 }; // datasets.put("mushroom", mushroomlebas); // Integer[] kowlanlebas = { 0, 7, 14 }; // datasets.put("kowlancz02", kowlanlebas); Integer[] pokerlebas = { 0, 3, 5 }; datasets.put("poker", pokerlebas); // Integer[] chesslebas = { 0, 2, 3 }; // datasets.put("chess", chesslebas); // Integer[] nurserylebas = { 0, 2, 5 }; // datasets.put("nursery", nurserylebas); List<Classifier> classifiers = null; for (String dataset : datasets.keySet()) { String simName = dataset + "-simulation"; String inputFolder = "../experiments/" + simName + "/input"; String outputFolder = "../experiments/" + simName + "/weka"; int folds = 10; int maxAgents = 4; int minAgents = 2; WekaClassifiersValidator validator = new WekaClassifiersValidator(dataset, inputFolder, outputFolder, folds, minAgents, maxAgents, datasets.get(dataset)); classifiers = validator.validateWekaClassifiers(classifiers); } } /** * Constructor * * @param simulationID * @param inputPath * @param outputPath * @param folds * @param minAgents * @param maxAgents * @param lebas */ public WekaClassifiersValidator(String simulationID, String inputPath, String outputPath, int folds, int minAgents, int maxAgents, Integer[] lebas) { this.dataset = simulationID; this.inputFolder = inputPath; this.outputFolder = outputPath; this.logger = Logger.getLogger("WekaClassifierValidator-" + this.dataset); LogConfigurator.log2File(logger, "WekaClassifierValidator-" + this.dataset, Level.ALL, Level.INFO, this.outputFolder); logger.info("--> Configuring WekaClassifierValidator..."); this.resultsFilePath = outputPath + "/weka-results.csv"; this.columns = 9; this.folds = folds; this.lebas = lebas; // this.minAgents = minAgents; // this.maxAgents = maxAgents; // this.minLEBA = minLEBA; // this.maxLEBA = maxLEBA; File dir = new File(this.outputFolder); if (!dir.exists() || !dir.isDirectory()) { dir.mkdirs(); } try { File file = new File(resultsFilePath); if (!file.exists()) { this.writer = new CsvWriter(new FileWriter(file), ','); String[] headers = new String[this.columns]; headers[0] = "dataset"; headers[1] = "kfold"; headers[2] = "classifier"; headers[3] = "iteration"; headers[4] = "ratioOk"; headers[5] = "ratioWrong"; headers[6] = "agentID"; headers[7] = "agents"; headers[8] = "leba"; writer.writeRecord(headers); } else { this.writer = new CsvWriter(new FileWriter(file, true), ','); } } catch (IOException e) { logger.severe("Problems creating weka-results.csv file"); logger.severe(e.getMessage()); } logger.info("<-- WekaClassifierValidator configured"); } /** * */ public List<Classifier> validateWekaClassifiers(List<Classifier> classifiers) { if (classifiers == null) { classifiers = this.getNewClassifiers(); logger.info( "All classifiers are going to be validated, because no list of classifiers has been provided."); } else if (classifiers.isEmpty()) { logger.warning("No algorithms available!! All of them were eliminated :( Jop"); } else { logger.info("The following algorithms are going to be tested: "); for (Classifier classifier : classifiers) { logger.info(">> " + classifier.getClass().getSimpleName()); } } logger.info(">> Validating all classifiers for dataset: " + this.dataset); List<Classifier> eliminateds = new ArrayList<Classifier>(); for (Classifier classifier : classifiers) { try { this.validateClassifier(classifier); } catch (Exception e) { logger.info(">> Eliminating classifier: " + classifier.getClass().getSimpleName()); eliminateds.add(classifier); } } classifiers.removeAll(eliminateds); logger.info(">> Dataset: " + this.dataset + " -> These are the survivals:"); for (Classifier classifier : classifiers) { logger.info("--> " + classifier.getClass().getSimpleName()); } logger.info("<----------------------------------------------->"); logger.info(">> Dataset: " + this.dataset + " -> These are the eliminated algorithms:"); for (Classifier classifier : eliminateds) { logger.info("--> " + classifier.getClass().getSimpleName()); } logger.info("<-- All classifiers validated for dataset: " + this.dataset); this.writer.close(); return classifiers; } /** * @throws Exception * * */ public void validateClassifier(Classifier classifier) throws Exception { String classifierName = classifier.getClass().getSimpleName(); logger.info("--> Starting validation for classfier " + classifierName); int ratioint = (int) ((1 / (double) folds) * 100); double roundedratio = ((double) ratioint) / 100; String[] row; Classifier copiedClassifier; try { // Central Agent HashMap<Integer, double[][]> resultsMap = new HashMap<Integer, double[][]>(); HashMap<Integer, double[][]> resultsNoEssMap = new HashMap<Integer, double[][]>(); for (int leba : lebas) { resultsMap.put(leba, new double[this.folds][2]); resultsNoEssMap.put(leba, new double[this.folds][2]); } logger.info("Starting validation for BayesCentralAgent dataset with " + classifierName + " done for dataset: " + this.dataset); for (int iteration = 0; iteration < this.folds; iteration++) { String inputPath = this.inputFolder + "/" + roundedratio + "testRatio/iteration-" + iteration; Instances testData = WekaClassifiersValidator.getDataFromCSV(inputPath + "/test-dataset.arff"); Instances trainData = WekaClassifiersValidator .getDataFromCSV(inputPath + "/bayes-central-dataset.arff"); try { logger.info("Learning model..."); copiedClassifier = Classifier.makeCopy(classifier); copiedClassifier.buildClassifier(trainData); logger.info("Finishing learning process. Model built for classifier " + classifierName + " in iteration " + iteration); } catch (Exception e) { logger.severe("Problems training model for " + classifier.getClass().getSimpleName()); logger.severe(e.getMessage()); throw e; } for (int leba : lebas) { double[][] results = resultsMap.get(leba); double[] pcts = this.getValidation(copiedClassifier, trainData, testData, leba); results[iteration][0] = results[iteration][0] + pcts[0]; results[iteration][1] = results[iteration][1] + pcts[1]; resultsMap.put(leba, results); row = new String[this.columns]; row[0] = this.dataset; row[1] = Integer.toString(this.folds); row[2] = classifierName; row[3] = Integer.toString(iteration); row[4] = Double.toString(pcts[0]); row[5] = Double.toString(pcts[1]); row[6] = "BayesCentralAgent"; row[7] = "1"; row[8] = Integer.toString(leba); writer.writeRecord(row); } /* * Instances testDataNoEssentials = * WekaClassifiersValidator.getDataFromCSV(inputPath + * "/test-dataset.arff"); Instances trainDataNoEssentials = * WekaClassifiersValidator.getDataFromCSV(inputPath + * "/bayes-central-dataset-noEssentials.arff"); try { * logger.info("Learning model..."); copiedClassifier = * Classifier.makeCopy(classifier); * copiedClassifier.buildClassifier(trainDataNoEssentials); * * logger.info( * "Finishing learning process. Model built for classifier " + * classifierName + " in iteration " + iteration + * " without essentials"); } catch (Exception e) { * logger.severe("Problems training model for " + * classifier.getClass().getSimpleName()); * logger.severe(e.getMessage()); throw e; } * * for (int leba = this.minLEBA; leba <= this.maxLEBA; leba++) { * double[][] resultsNoEss = resultsNoEssMap.get(leba); double[] * pcts = this.getValidation(copiedClassifier, * trainDataNoEssentials, testDataNoEssentials, leba); * resultsNoEss[iteration][0] = resultsNoEss[iteration][0] + * pcts[0]; resultsNoEss[iteration][1] = * resultsNoEss[iteration][1] + pcts[1]; * * resultsNoEssMap.put(leba, resultsNoEss); * * row = new String[this.columns]; row[0] = this.dataset; row[1] * = Integer.toString(this.folds); row[2] = classifierName; * row[3] = Integer.toString(iteration); row[4] = * Double.toString(pcts[0]); row[5] = Double.toString(pcts[1]); * row[6] = "BayesCentralAgent-NoEssentials"; row[7] = "1"; * row[8] = Integer.toString(leba); writer.writeRecord(row); } */ // ------------------------------------------------------------- // --------------------------- FOR AGENTS DATASETS ISOLATED - NO // SENSE BECAUSE WEKA CLASSIFIER WERE DESIGNED TO BE CENTRALISED // ------------------------------------------------------------- // // Agents combinations // for (int i = this.minAgents; i <= this.maxAgents; i++) { // logger.info("Validation for agents datasets with " + // classifierName // + " done for dataset: " + this.dataset + " with LEBA=" + // leba); // HashMap<Integer, Double> successRatio = new HashMap<Integer, // Double>(); // HashMap<Integer, Double> wrongRatio = new HashMap<Integer, // Double>(); // for (int j = 0; j < i; j++) { // successRatio.put(j, 0.0); // wrongRatio.put(j, 0.0); // } // for (int iteration = 0; iteration < this.folds; iteration++) // { // String inputPath = this.inputFolder + "/" + roundedratio // + "testRatio/iteration-" + iteration; // Instances testData = this.getDataFromCSV(inputPath + // "/test-dataset.csv"); // for (int j = 0; j < i; j++) { // Instances trainData = this.getDataFromCSV(inputPath + "/" + i // + "agents/agent-" + j + "-dataset.csv"); // double[] pcts = this.getValidation(classifier, trainData, // testData, // leba); // successRatio.put(j, successRatio.get(j) + pcts[0]); // wrongRatio.put(j, wrongRatio.get(j) + pcts[1]); // // row = new String[this.columns]; // row[0] = this.dataset; // row[1] = Integer.toString(this.folds); // row[2] = classifierName; // row[3] = Integer.toString(iteration); // row[4] = Double.toString(pcts[0]); // row[5] = Double.toString(pcts[1]); // row[6] = "Agent" + j; // row[7] = Integer.toString(i); // row[8] = Integer.toString(leba); // writer.writeRecord(row); // } // // writer.flush(); // } // // for (int j = 0; j < i; j++) { // row = new String[this.columns]; // row[0] = this.dataset; // row[1] = Integer.toString(this.folds); // row[2] = classifierName; // row[3] = "AVERAGE"; // row[4] = Double.toString(successRatio.get(j) / this.folds); // row[5] = Double.toString(wrongRatio.get(j) / this.folds); // row[6] = "Agent" + j; // row[7] = Integer.toString(i); // row[8] = Integer.toString(leba); // writer.writeRecord(row); // // logger.info("Validation for Agent" + j + " dataset (for " + i // + " agents configuration) with " + classifierName // + " done for dataset: " + this.dataset + " with LEBA=" + // leba); // } // // writer.flush(); // } // ------------------------------------------------------------- // ---------- END FOR AGENTS DATASETS ISOLATED ----------------- // ------------------------------------------------------------- logger.info("<-- Validation for classfier " + classifierName + " done for dataset: " + this.dataset + " for iteration " + iteration); } for (int leba : lebas) { double[] sum = new double[2]; double[][] results = resultsMap.get(leba); for (int iteration = 0; iteration < this.folds; iteration++) { sum[0] = sum[0] + results[iteration][0]; sum[1] = sum[1] + results[iteration][1]; } row = new String[this.columns]; row[0] = this.dataset; row[1] = Integer.toString(this.folds); row[2] = classifierName; row[3] = "AVERAGE"; row[4] = Double.toString(sum[0] / this.folds); row[5] = Double.toString(sum[1] / this.folds); row[6] = "BayesCentralAgent"; row[7] = "1"; row[8] = Integer.toString(leba); writer.writeRecord(row); logger.info("Validation for BayesCentralAgent dataset with " + classifierName + " done for dataset: " + this.dataset + " with LEBA=" + leba); writer.flush(); } // for (int leba : lebas) { // double[] sum = new double[2]; // double[][] results = resultsNoEssMap.get(leba); // for (int iteration = 0; iteration < this.folds; iteration++) { // sum[0] = sum[0] + results[iteration][0]; // sum[1] = sum[1] + results[iteration][1]; // } // // row = new String[this.columns]; // row[0] = this.dataset; // row[1] = Integer.toString(this.folds); // row[2] = classifierName; // row[3] = "AVERAGE"; // row[4] = Double.toString(sum[0] / this.folds); // row[5] = Double.toString(sum[1] / this.folds); // row[6] = "BayesCentralAgent-NoEssentials"; // row[7] = "1"; // row[8] = Integer.toString(leba); // writer.writeRecord(row); // // logger.info("Validation for BayesCentralAgent dataset with " + // classifierName // + " done for dataset: " + this.dataset + " with LEBA=" + leba); // writer.flush(); // } logger.info("Validation for BayesCentralAgent dataset with " + classifierName + " done for dataset: " + this.dataset); writer.flush(); } catch (Exception e) { logger.severe("Problem validating classifier " + classifierName); logger.severe(e.getMessage()); e.printStackTrace(); throw e; } } /** * @param cls * @param trainingData * @param testData * @param leba * @return [0] = pctCorrect, [1] = pctIncorrect * @throws Exception */ public double[] getValidation(Classifier cls, Instances trainingData, Instances testData, int leba) throws Exception { Instances testDataWithLEBA = new Instances(testData); for (int j = 0; j < leba; j++) { if (j < testDataWithLEBA.numAttributes() - 1) { for (int i = 0; i < testDataWithLEBA.numInstances(); i++) { testDataWithLEBA.instance(i).setMissing(j); } } } Evaluation eval; try { eval = new Evaluation(trainingData); logger.fine("Evaluating model with leba: " + leba); eval.evaluateModel(cls, testDataWithLEBA); double[] results = new double[2]; results[0] = eval.pctCorrect() / 100; results[1] = eval.pctIncorrect() / 100; return results; } catch (Exception e) { logger.severe("Problems evaluating model for " + cls.getClass().getSimpleName()); logger.severe(e.getMessage()); e.printStackTrace(); throw e; } } /** * @return a list of all WEKA classifiers */ public List<Classifier> getNewClassifiers() { Classifier classifier; List<Classifier> classifiers = new ArrayList<Classifier>(); // NBTree classifier = new NBTree(); classifiers.add(classifier); // PART classifier = new PART(); classifiers.add(classifier); // J48 classifier = new J48(); ((J48) classifier).setUnpruned(true); classifiers.add(classifier); // // J48Graft // classifier = new J48graft(); // ((J48graft) classifier).setUnpruned(true); // classifiers.add(classifier); // // OneR // classifier = new OneR(); // classifiers.add(classifier); // LADTree classifier = new LADTree(); classifiers.add(classifier); // // REPTree // classifier = new REPTree(); // classifiers.add(classifier); // // SimpleLogistic // classifier = new SimpleLogistic(); // classifiers.add(classifier); // // Logistic // classifier = new Logistic(); // classifiers.add(classifier); // // MultiLayerPerceptron // classifier = new MultilayerPerceptron(); // classifiers.add(classifier); // // DecisionStump // classifier = new DecisionStump(); // classifiers.add(classifier); // // LMT // classifier = new LMT(); // classifiers.add(classifier); // // SimpleCart // classifier = new SimpleCart(); // classifiers.add(classifier); // // BFTree // classifier = new BFTree(); // classifiers.add(classifier); // // RBFNetwork // classifier = new RBFNetwork(); // classifiers.add(classifier); // // DTNB // classifier = new DTNB(); // classifiers.add(classifier); // // Jrip // classifier = new JRip(); // classifiers.add(classifier); // // Conjunction Rule // classifier = new ConjunctiveRule(); // classifiers.add(classifier); // // ZeroR // classifier = new ZeroR(); // classifiers.add(classifier); // SMO classifier = new SMO(); classifiers.add(classifier); // // OneR // classifier = new OneR(); // classifiers.add(classifier); // // RandomForest // classifier = new RandomForest(); // classifiers.add(classifier); return classifiers; } /** * @param csvFilePath * @return * @throws Exception */ public static Instances getDataFromCSV(String csvFilePath) throws Exception { DataSource source = new DataSource(csvFilePath); Instances data = source.getDataSet(); data.setClassIndex(data.numAttributes() - 1); return data; } }