Java tutorial
/** * Copyright (C) 2007-2009, Jens Lehmann * * This file is part of DL-Learner. * * DL-Learner is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 3 of the License, or * (at your option) any later version. * * DL-Learner is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. * */ package org.dllearner.scripts; import static java.util.Arrays.asList; import java.io.File; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Random; import java.util.Set; import java.util.TreeSet; import joptsimple.OptionParser; import joptsimple.OptionSet; import org.apache.commons.beanutils.PropertyUtils; import org.apache.log4j.ConsoleAppender; import org.apache.log4j.FileAppender; import org.apache.log4j.Layout; import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.apache.log4j.PatternLayout; import org.apache.log4j.Priority; import org.apache.log4j.SimpleLayout; import org.dllearner.cli.CLI; import org.dllearner.core.AbstractCELA; import org.dllearner.core.AbstractLearningProblem; import org.dllearner.core.AbstractReasonerComponent; import org.dllearner.core.ComponentInitException; import org.dllearner.core.owl.Description; import org.dllearner.core.owl.Individual; import org.dllearner.learningproblems.PosNegLP; import org.dllearner.parser.ParseException; import org.dllearner.utilities.Helper; import org.dllearner.utilities.datastructures.TrainTestList; import org.dllearner.utilities.statistics.Stat; import com.google.common.base.Charsets; import com.google.common.collect.Lists; import com.google.common.io.Files; /** * Performs nested cross validation for the given problem. A k fold outer and l * fold inner cross validation is used. Parameters: * <ul> * <li>The conf file to use.</li> * <li>k (number of outer folds)</li> * <li>l (number of inner folds)</li> * <li>parameter name to vary</li> * <li>a set of parameter values to test</li> * </ul> * * Example arguments: bla.conf 10 5 noise 25-40 * * Currently, only the optimisation of a single parameter is supported. * * Later versions may include support for testing a variety of parameters, e.g. * --conf bla.conf --outerfolds 10 --innerfolds 5 --complexparameters= * "para1=val1;...paran=valn#...#para1=val1;...paran=valn" This tests all * parameter combinations separated by #. * * Alternatively: --conf bla.conf --outerfolds 10 --innerfolds 5 --parameters= * "para1#para2#para3" --values="25-40#(high,medium,low)#boolean" This tests all * combinations of parameters and given values, where the script recognises * special patterns, e.g. integer ranges or the keyword boolean for * "true/false". * * Currently, only learning from positive and negative examples is supported. * * Currently, the script can only optimise towards classification accuracy. * (Can be extended to handle optimising F measure or other combinations of * precision, recall, accuracy.) * * @author Jens Lehmann * */ public class NestedCrossValidation { private static final Logger logger = Logger.getLogger(NestedCrossValidation.class.getName()); private static File logFile = new File("log/nested-cv.log"); DecimalFormat df = new DecimalFormat(); // overall statistics Stat globalAcc = new Stat(); Stat globalF = new Stat(); Stat globalRecall = new Stat(); Stat globalPrecision = new Stat(); Map<Double, Stat> globalParaStats = new HashMap<Double, Stat>(); /** * Entry method, which uses JOptSimple to parse parameters. * * @param args * Command line arguments (see class documentation). * @throws IOException * @throws ParseException * @throws ComponentInitException * @throws org.dllearner.confparser.ParseException */ public static void main(String[] args) throws IOException, ComponentInitException, ParseException, org.dllearner.confparser.ParseException { OptionParser parser = new OptionParser(); parser.acceptsAll(asList("h", "?", "help"), "Show help."); parser.acceptsAll(asList("c", "conf"), "The comma separated list of config files to be used.") .withRequiredArg().describedAs("file1, file2, ..."); parser.acceptsAll(asList("v", "verbose"), "Be more verbose."); parser.acceptsAll(asList("o", "outerfolds"), "Number of outer folds.").withRequiredArg() .ofType(Integer.class).describedAs("#folds"); parser.acceptsAll(asList("i", "innerfolds"), "Number of inner folds.").withRequiredArg() .ofType(Integer.class).describedAs("#folds"); parser.acceptsAll(asList("p", "parameter"), "Parameter to vary.").withRequiredArg(); parser.acceptsAll(asList("r", "pvalues", "range"), "Values of parameter. $x-$y can be used for integer ranges.").withRequiredArg(); parser.acceptsAll(asList("s", "stepsize", "steps"), "Step size of range.").withOptionalArg() .ofType(Double.class).defaultsTo(1d); // parse options and display a message for the user in case of problems OptionSet options = null; try { options = parser.parse(args); } catch (Exception e) { System.out.println("Error: " + e.getMessage() + ". Use -? to get help."); System.exit(0); } // print help screen if (options.has("?")) { parser.printHelpOn(System.out); // all options present => start nested cross validation } else if (options.has("c") && options.has("o") && options.has("i") && options.has("p") && options.has("r")) { // read all options in variables and parse option values String confFilesString = (String) options.valueOf("c"); List<File> confFiles = new ArrayList<File>(); for (String fileString : confFilesString.split(",")) { confFiles.add(new File(fileString.trim())); } int outerFolds = (Integer) options.valueOf("o"); int innerFolds = (Integer) options.valueOf("i"); String parameter = (String) options.valueOf("p"); String range = (String) options.valueOf("r"); String[] rangeSplit = range.split("-"); double rangeStart = Double.valueOf(rangeSplit[0]); double rangeEnd = Double.valueOf(rangeSplit[1]); double stepsize = (Double) options.valueOf("s"); boolean verbose = options.has("v"); // create logger (a simple logger which outputs // its messages to the console) Layout layout = new PatternLayout("%m%n"); ConsoleAppender consoleAppender = new ConsoleAppender(layout); Logger logger = Logger.getRootLogger(); logger.removeAllAppenders(); logger.addAppender(consoleAppender); logger.setLevel(Level.ERROR); Logger.getLogger("org.dllearner.algorithms").setLevel(Level.INFO); Logger.getLogger("org.dllearner.scripts").setLevel(Level.INFO); FileAppender fileAppender = new FileAppender(layout, logFile.getPath(), false); logger.addAppender(fileAppender); fileAppender.setThreshold(Level.INFO); // logger.addAppender(new FileAppender(layout, "nested-cv.log", false)); // disable OWL API info output java.util.logging.Logger.getLogger("").setLevel(java.util.logging.Level.WARNING); System.out.println( "Warning: The script is not well tested yet. (No known bugs, but needs more testing.)"); new NestedCrossValidation(confFiles, outerFolds, innerFolds, parameter, rangeStart, rangeEnd, stepsize, verbose); // an option is missing => print help screen and message } else { parser.printHelpOn(System.out); System.out.println( "\nYou need to specify the options c, i, o, p, r. Please consult the help table above."); } } public NestedCrossValidation(File confFile, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws ComponentInitException, ParseException, org.dllearner.confparser.ParseException, IOException { this(Lists.newArrayList(confFile), outerFolds, innerFolds, parameter, startValue, endValue, stepsize, verbose); } public NestedCrossValidation(List<File> confFiles, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws ComponentInitException, ParseException, org.dllearner.confparser.ParseException, IOException { for (File confFile : confFiles) { logger.info("++++++++++++++++++++++++++++++++++++++++++++++"); logger.info(confFile.getPath()); logger.info("++++++++++++++++++++++++++++++++++++++++++++++"); validate(confFile, outerFolds, innerFolds, parameter, startValue, endValue, stepsize, verbose); } logger.info("############################################"); logger.info("############################################"); // decide for the best parameter logger.info(" Overall summary over parameter values:"); double bestPara = startValue; double bestValue = Double.NEGATIVE_INFINITY; for (Entry<Double, Stat> entry : globalParaStats.entrySet()) { double para = entry.getKey(); Stat stat = entry.getValue(); logger.info(" value " + para + ": " + stat.prettyPrint("%")); if (stat.getMean() > bestValue) { bestPara = para; bestValue = stat.getMean(); } } logger.info(" selected " + bestPara + " as best parameter value (criterion value " + df.format(bestValue) + "%)"); // overall statistics logger.info("*******************"); logger.info("* Overall Results *"); logger.info("*******************"); logger.info("accuracy: " + globalAcc.prettyPrint("%")); logger.info("F measure: " + globalF.prettyPrint("%")); logger.info("precision: " + globalPrecision.prettyPrint("%")); logger.info("recall: " + globalRecall.prettyPrint("%")); } private void validate(File confFile, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws IOException, ComponentInitException { CLI start = new CLI(confFile); start.init(); AbstractLearningProblem lp = start.getLearningProblem(); if (!(lp instanceof PosNegLP)) { System.out.println("Positive only learning not supported yet."); System.exit(0); } // get examples and shuffle them LinkedList<Individual> posExamples = new LinkedList<Individual>(((PosNegLP) lp).getPositiveExamples()); Collections.shuffle(posExamples, new Random(1)); LinkedList<Individual> negExamples = new LinkedList<Individual>(((PosNegLP) lp).getNegativeExamples()); Collections.shuffle(negExamples, new Random(2)); AbstractReasonerComponent rc = start.getReasonerComponent(); rc.init(); String baseURI = rc.getBaseURI(); List<TrainTestList> posLists = getFolds(posExamples, outerFolds); List<TrainTestList> negLists = getFolds(negExamples, outerFolds); // overall statistics Stat accOverall = new Stat(); Stat fOverall = new Stat(); Stat recallOverall = new Stat(); Stat precisionOverall = new Stat(); for (int currOuterFold = 0; currOuterFold < outerFolds; currOuterFold++) { logger.info("Outer fold " + currOuterFold); TrainTestList posList = posLists.get(currOuterFold); TrainTestList negList = negLists.get(currOuterFold); // measure relevant criterion (accuracy, F-measure) over different parameter values Map<Double, Stat> paraStats = new HashMap<Double, Stat>(); for (double currParaValue = startValue; currParaValue <= endValue; currParaValue += stepsize) { logger.info(" Parameter value " + currParaValue + ":"); // split train folds again (computation of inner folds for each parameter // value is redundant, but not a big problem) List<Individual> trainPosList = posList.getTrainList(); List<TrainTestList> innerPosLists = getFolds(trainPosList, innerFolds); List<Individual> trainNegList = negList.getTrainList(); List<TrainTestList> innerNegLists = getFolds(trainNegList, innerFolds); // measure relevant criterion for parameter (by default accuracy, // can also be F measure) Stat paraCriterionStat = new Stat(); for (int currInnerFold = 0; currInnerFold < innerFolds; currInnerFold++) { logger.info(" Inner fold " + currInnerFold + ":"); // get positive & negative examples for training run Set<Individual> posEx = new TreeSet<Individual>( innerPosLists.get(currInnerFold).getTrainList()); Set<Individual> negEx = new TreeSet<Individual>( innerNegLists.get(currInnerFold).getTrainList()); // read conf file and exchange options for pos/neg examples // and parameter to optimise start = new CLI(confFile); start.init(); AbstractLearningProblem lpIn = start.getLearningProblem(); ((PosNegLP) lpIn).setPositiveExamples(posEx); ((PosNegLP) lpIn).setNegativeExamples(negEx); AbstractCELA laIn = start.getLearningAlgorithm(); try { PropertyUtils.setSimpleProperty(laIn, parameter, currParaValue); } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { e.printStackTrace(); } lpIn.init(); laIn.init(); laIn.start(); // evaluate learned expression Description concept = laIn.getCurrentlyBestDescription(); TreeSet<Individual> posTest = new TreeSet<Individual>( innerPosLists.get(currInnerFold).getTestList()); TreeSet<Individual> negTest = new TreeSet<Individual>( innerNegLists.get(currInnerFold).getTestList()); // true positive Set<Individual> posCorrect = rc.hasType(concept, posTest); // false negative Set<Individual> posError = Helper.difference(posTest, posCorrect); // false positive Set<Individual> negError = rc.hasType(concept, negTest); // true negative Set<Individual> negCorrect = Helper.difference(negTest, negError); // double posErrorRate = 100*(posError.size()/posTest.size()); // double negErrorRate = 100*(negError.size()/posTest.size()); double accuracy = 100 * ((double) (posCorrect.size() + negCorrect.size()) / (posTest.size() + negTest.size())); double precision = 100 * (double) posCorrect.size() / (posCorrect.size() + negError.size()) == 0 ? 0 : (posCorrect.size() + negError.size()); double recall = 100 * (double) posCorrect.size() / (posCorrect.size() + posError.size()) == 0 ? 0 : (posCorrect.size() + posError.size()); double fmeasure = 2 * (precision * recall) / (precision + recall) == 0 ? 0 : (precision + recall); paraCriterionStat.addNumber(accuracy); logger.info(" hypothesis: " + concept.toManchesterSyntaxString(baseURI, null)); logger.info(" accuracy: " + df.format(accuracy) + "%"); logger.info(" precision: " + df.format(precision) + "%"); logger.info(" recall: " + df.format(recall) + "%"); logger.info(" F measure: " + df.format(fmeasure) + "%"); if (verbose) { logger.info(" false positives (neg. examples classified as pos.): " + formatIndividualSet(posError, baseURI)); logger.info(" false negatives (pos. examples classified as neg.): " + formatIndividualSet(negError, baseURI)); } } paraStats.put(currParaValue, paraCriterionStat); Stat globalParaStat = globalParaStats.get(currParaValue); if (globalParaStat == null) { globalParaStat = new Stat(); globalParaStats.put(currParaValue, globalParaStat); } globalParaStat.add(paraCriterionStat); } // decide for the best parameter logger.info(" Summary over parameter values:"); double bestPara = startValue; double bestValue = Double.NEGATIVE_INFINITY; for (Entry<Double, Stat> entry : paraStats.entrySet()) { double para = entry.getKey(); Stat stat = entry.getValue(); logger.info(" value " + para + ": " + stat.prettyPrint("%")); if (stat.getMean() > bestValue) { bestPara = para; bestValue = stat.getMean(); } } logger.info(" selected " + bestPara + " as best parameter value (criterion value " + df.format(bestValue) + "%)"); logger.info(" Learn on Outer fold:"); // start a learning process with this parameter and evaluate it on the outer fold start = new CLI(confFile); start.init(); AbstractLearningProblem lpOut = start.getLearningProblem(); ((PosNegLP) lpOut) .setPositiveExamples(new TreeSet<Individual>(posLists.get(currOuterFold).getTrainList())); ((PosNegLP) lpOut) .setNegativeExamples(new TreeSet<Individual>(negLists.get(currOuterFold).getTrainList())); AbstractCELA laOut = start.getLearningAlgorithm(); try { PropertyUtils.setSimpleProperty(laOut, parameter, bestPara); } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { e.printStackTrace(); } lpOut.init(); laOut.init(); laOut.start(); // evaluate learned expression Description concept = laOut.getCurrentlyBestDescription(); TreeSet<Individual> posTest = new TreeSet<Individual>(posLists.get(currOuterFold).getTestList()); TreeSet<Individual> negTest = new TreeSet<Individual>(negLists.get(currOuterFold).getTestList()); AbstractReasonerComponent rs = start.getReasonerComponent(); // true positive Set<Individual> posCorrect = rs.hasType(concept, posTest); // false negative Set<Individual> posError = Helper.difference(posTest, posCorrect); // false positive Set<Individual> negError = rs.hasType(concept, negTest); // true negative Set<Individual> negCorrect = Helper.difference(negTest, negError); double accuracy = 100 * ((double) (posCorrect.size() + negCorrect.size()) / (posTest.size() + negTest.size())); double precision = 100 * (double) posCorrect.size() / (posCorrect.size() + negError.size()); double recall = 100 * (double) posCorrect.size() / (posCorrect.size() + posError.size()); double fmeasure = 2 * (precision * recall) / (precision + recall); logger.info(" hypothesis: " + concept.toManchesterSyntaxString(baseURI, null)); logger.info(" accuracy: " + df.format(accuracy) + "%"); logger.info(" precision: " + df.format(precision) + "%"); logger.info(" recall: " + df.format(recall) + "%"); logger.info(" F measure: " + df.format(fmeasure) + "%"); if (verbose) { logger.info(" false positives (neg. examples classified as pos.): " + formatIndividualSet(posError, baseURI)); logger.info(" false negatives (pos. examples classified as neg.): " + formatIndividualSet(negError, baseURI)); } // update overall statistics accOverall.addNumber(accuracy); fOverall.addNumber(fmeasure); recallOverall.addNumber(recall); precisionOverall.addNumber(precision); // free memory rs.releaseKB(); } globalAcc.add(accOverall); globalF.add(fOverall); globalPrecision.add(precisionOverall); globalRecall.add(recallOverall); // overall statistics logger.info("*******************"); logger.info("* Overall Results *"); logger.info("*******************"); logger.info("accuracy: " + accOverall.prettyPrint("%")); logger.info("F measure: " + fOverall.prettyPrint("%")); logger.info("precision: " + precisionOverall.prettyPrint("%")); logger.info("recall: " + recallOverall.prettyPrint("%")); } // convenience methods, which takes a list of examples and divides them in // train-test-lists according to the number of folds specified public static List<TrainTestList> getFolds(List<Individual> list, int folds) { List<TrainTestList> ret = new LinkedList<TrainTestList>(); int[] splits = CrossValidation.calculateSplits(list.size(), folds); for (int i = 0; i < folds; i++) { int fromIndex = (i == 0) ? 0 : splits[i - 1]; int toIndex = splits[i]; List<Individual> test = list.subList(fromIndex, toIndex); List<Individual> train = new LinkedList<Individual>(list); train.removeAll(test); ret.add(new TrainTestList(train, test)); } return ret; } private static String formatIndividualSet(Set<Individual> inds, String baseURI) { String ret = ""; int i = 0; for (Individual ind : inds) { ret += ind.toManchesterSyntaxString(baseURI, null) + " "; i++; if (i == 20) { break; } } return ret; } }