CV.java Source code

Java tutorial

Introduction

Here is the source code for CV.java

Source

/**
 * Copyright (C) 2007-2008, Jens Lehmann
 *
 * This file is part of DL-Learner.
 * 
 * DL-Learner is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * DL-Learner is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *
 */

import java.io.File;
import java.lang.reflect.InvocationTargetException;
import java.text.DecimalFormat;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.apache.http.impl.cookie.BestMatchSpec;
import org.dllearner.core.AbstractCELA;
import org.dllearner.core.AbstractLearningProblem;
import org.dllearner.core.AbstractReasonerComponent;
import org.dllearner.core.ComponentInitException;
import org.dllearner.core.owl.Description;
import org.dllearner.core.owl.Individual;
import org.dllearner.learningproblems.Heuristics;
import org.dllearner.learningproblems.PosNegLP;
import org.dllearner.learningproblems.PosOnlyLP;
import org.dllearner.utilities.Files;
import org.dllearner.utilities.Helper;
import org.dllearner.utilities.datastructures.Datastructures;
import org.dllearner.utilities.owl.OWLAPIConverter;
import org.dllearner.utilities.statistics.Stat;
import org.semanticweb.owlapi.io.ToStringRenderer;
import org.semanticweb.owlapi.model.OWLClassExpression;
import org.semanticweb.owlapi.model.OWLIndividual;
import org.semanticweb.owlapi.util.SimpleShortFormProvider;

import uk.ac.manchester.cs.owl.owlapi.mansyntaxrenderer.ManchesterOWLSyntaxOWLObjectRendererImpl;

/**
 * Performs cross validation for the given problem. Supports
 * k-fold cross-validation and leave-one-out cross-validation.
 * 
 * @author Jens Lehmann
 *
 */
public class CV {

    // statistical values
    protected Stat runtime = new Stat();
    protected Stat accuracy = new Stat();
    protected Stat length = new Stat();
    protected Stat accuracyTraining = new Stat();
    protected Stat fMeasure = new Stat();
    protected Stat fMeasureTraining = new Stat();
    public static boolean writeToFile = false;
    public static File outputFile;
    public static boolean multiThreaded = false;

    protected Stat trainingCompletenessStat = new Stat();
    protected Stat trainingCorrectnessStat = new Stat();

    protected Stat testingCompletenessStat = new Stat();
    protected Stat testingCorrectnessStat = new Stat();

    DecimalFormat df = new DecimalFormat();

    public CV() {

    }

    public CV(AbstractCELA la, AbstractLearningProblem lp, final AbstractReasonerComponent rs, int folds,
            boolean leaveOneOut) {
        //console rendering of class expressions
        ManchesterOWLSyntaxOWLObjectRendererImpl renderer = new ManchesterOWLSyntaxOWLObjectRendererImpl();
        ToStringRenderer.getInstance().setRenderer(renderer);
        ToStringRenderer.getInstance().setShortFormProvider(new SimpleShortFormProvider());

        // the training and test sets used later on
        List<Set<OWLIndividual>> trainingSetsPos = new LinkedList<Set<OWLIndividual>>();
        List<Set<OWLIndividual>> trainingSetsNeg = new LinkedList<Set<OWLIndividual>>();
        List<Set<OWLIndividual>> testSetsPos = new LinkedList<Set<OWLIndividual>>();
        List<Set<OWLIndividual>> testSetsNeg = new LinkedList<Set<OWLIndividual>>();

        // get examples and shuffle them too
        Set<OWLIndividual> posExamples;
        Set<OWLIndividual> negExamples;
        if (lp instanceof PosNegLP) {
            posExamples = OWLAPIConverter.getOWLAPIIndividuals(((PosNegLP) lp).getPositiveExamples());
            negExamples = OWLAPIConverter.getOWLAPIIndividuals(((PosNegLP) lp).getNegativeExamples());
        } else if (lp instanceof PosOnlyLP) {
            posExamples = OWLAPIConverter.getOWLAPIIndividuals(((PosNegLP) lp).getPositiveExamples());
            negExamples = new HashSet<OWLIndividual>();
        } else {
            throw new IllegalArgumentException("Only PosNeg and PosOnly learning problems are supported");
        }
        List<OWLIndividual> posExamplesList = new LinkedList<OWLIndividual>(posExamples);
        List<OWLIndividual> negExamplesList = new LinkedList<OWLIndividual>(negExamples);
        Collections.shuffle(posExamplesList, new Random(1));
        Collections.shuffle(negExamplesList, new Random(2));

        // sanity check whether nr. of folds makes sense for this benchmark
        if (!leaveOneOut && (posExamples.size() < folds && negExamples.size() < folds)) {
            System.out.println("The number of folds is higher than the number of "
                    + "positive/negative examples. This can result in empty test sets. Exiting.");
            System.exit(0);
        }
        //         
        if (leaveOneOut) {
            // note that leave-one-out is not identical to k-fold with
            // k = nr. of examples in the current implementation, because
            // with n folds and n examples there is no guarantee that a fold
            // is never empty (this is an implementation issue)
            int nrOfExamples = posExamples.size() + negExamples.size();
            for (int i = 0; i < nrOfExamples; i++) {
                // ...
            }
            System.out.println("Leave-one-out not supported yet.");
            System.exit(1);
        } else {
            // calculating where to split the sets, ; note that we split
            // positive and negative examples separately such that the 
            // distribution of positive and negative examples remains similar
            // (note that there are better but more complex ways to implement this,
            // which guarantee that the sum of the elements of a fold for pos
            // and neg differs by at most 1 - it can differ by 2 in our implementation,
            // e.g. with 3 folds, 4 pos. examples, 4 neg. examples)
            int[] splitsPos = calculateSplits(posExamples.size(), folds);
            int[] splitsNeg = calculateSplits(negExamples.size(), folds);

            //            System.out.println(splitsPos[0]);
            //            System.out.println(splitsNeg[0]);

            // calculating training and test sets
            for (int i = 0; i < folds; i++) {
                Set<OWLIndividual> testPos = getTestingSet(posExamplesList, splitsPos, i);
                Set<OWLIndividual> testNeg = getTestingSet(negExamplesList, splitsNeg, i);
                testSetsPos.add(i, testPos);
                testSetsNeg.add(i, testNeg);
                trainingSetsPos.add(i, getTrainingSet(posExamples, testPos));
                trainingSetsNeg.add(i, getTrainingSet(negExamples, testNeg));
            }

        }

        // run the algorithm
        if (multiThreaded && lp instanceof Cloneable && la instanceof Cloneable) {
            ExecutorService es = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() - 1);
            for (int currFold = 0; currFold < folds; currFold++) {
                try {
                    final AbstractLearningProblem lpClone = (AbstractLearningProblem) lp.getClass()
                            .getMethod("clone").invoke(lp);
                    final Set<OWLIndividual> trainPos = trainingSetsPos.get(currFold);
                    final Set<OWLIndividual> trainNeg = trainingSetsNeg.get(currFold);
                    final Set<OWLIndividual> testPos = testSetsPos.get(currFold);
                    final Set<OWLIndividual> testNeg = testSetsNeg.get(currFold);
                    if (lp instanceof PosNegLP) {
                        ((PosNegLP) lpClone).setPositiveExamples(OWLAPIConverter.convertIndividuals(trainPos));
                        ((PosNegLP) lpClone).setNegativeExamples(OWLAPIConverter.convertIndividuals(trainNeg));
                    } else if (lp instanceof PosOnlyLP) {
                        ((PosOnlyLP) lpClone).setPositiveExamples(
                                new TreeSet<Individual>(OWLAPIConverter.convertIndividuals(trainPos)));
                    }
                    final AbstractCELA laClone = (AbstractCELA) la.getClass().getMethod("clone").invoke(la);
                    final int i = currFold;
                    es.submit(new Runnable() {

                        @Override
                        public void run() {
                            try {
                                validate(laClone, lpClone, rs, i, trainPos, trainNeg, testPos, testNeg);
                            } catch (Exception e) {
                                e.printStackTrace();
                            }
                        }
                    });
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                } catch (IllegalArgumentException e) {
                    e.printStackTrace();
                } catch (InvocationTargetException e) {
                    e.printStackTrace();
                } catch (NoSuchMethodException e) {
                    e.printStackTrace();
                } catch (SecurityException e) {
                    e.printStackTrace();
                }
            }
            es.shutdown();
            try {
                es.awaitTermination(1, TimeUnit.DAYS);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } else {
            for (int currFold = 0; currFold < folds; currFold++) {
                final Set<OWLIndividual> trainPos = trainingSetsPos.get(currFold);
                final Set<OWLIndividual> trainNeg = trainingSetsNeg.get(currFold);
                final Set<OWLIndividual> testPos = testSetsPos.get(currFold);
                final Set<OWLIndividual> testNeg = testSetsNeg.get(currFold);

                if (lp instanceof PosNegLP) {
                    ((PosNegLP) lp).setPositiveExamples(OWLAPIConverter.convertIndividuals(trainPos));
                    ((PosNegLP) lp).setNegativeExamples(OWLAPIConverter.convertIndividuals(trainNeg));
                } else if (lp instanceof PosOnlyLP) {
                    Set<Individual> convertIndividuals = OWLAPIConverter.convertIndividuals(trainPos);
                    ((PosOnlyLP) lp).setPositiveExamples(new TreeSet<Individual>(convertIndividuals));
                }

                validate(la, lp, rs, currFold, trainPos, trainNeg, testPos, testNeg);
            }
        }

        outputWriter("");
        outputWriter("Finished " + folds + "-folds cross-validation.");
        outputWriter("runtime: " + statOutput(df, runtime, "s"));
        outputWriter("length: " + statOutput(df, length, ""));
        outputWriter("F-Measure on training set: " + statOutput(df, fMeasureTraining, "%"));
        outputWriter("F-Measure: " + statOutput(df, fMeasure, "%"));
        outputWriter("predictive accuracy on training set: " + statOutput(df, accuracyTraining, "%"));
        outputWriter("predictive accuracy: " + statOutput(df, accuracy, "%"));

    }

    private void validate(AbstractCELA la, AbstractLearningProblem lp, AbstractReasonerComponent rs, int currFold,
            Set<OWLIndividual> trainPos, Set<OWLIndividual> trainNeg, Set<OWLIndividual> testPos,
            Set<OWLIndividual> testNeg) {
        Set<String> pos = Datastructures.individualSetToStringSet(OWLAPIConverter.convertIndividuals(trainPos));
        Set<String> neg = Datastructures.individualSetToStringSet(OWLAPIConverter.convertIndividuals(trainNeg));
        String output = "";
        output += "+" + new TreeSet<String>(pos) + "\n";
        output += "-" + new TreeSet<String>(neg) + "\n";
        try {
            lp.init();
            la.setLearningProblem(lp);
            la.init();
        } catch (ComponentInitException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        long algorithmStartTime = System.nanoTime();
        la.start();
        long algorithmDuration = System.nanoTime() - algorithmStartTime;
        runtime.addNumber(algorithmDuration / (double) 1000000000);

        OWLClassExpression concept = null;
        Description currentlyBestDescription = la.getCurrentlyBestDescription();

        SortedSet<Individual> convertIndividuals = new TreeSet<Individual>(
                OWLAPIConverter.convertIndividuals(testPos));
        SortedSet<Individual> tmp = rs.hasType(currentlyBestDescription, convertIndividuals);
        Set<Individual> tmp2 = Helper.difference(convertIndividuals, tmp);
        Set<Individual> tmp3 = rs.hasType(currentlyBestDescription, OWLAPIConverter.convertIndividuals(testNeg));

        // calculate training accuracies 
        int trainingCorrectPosClassified = getCorrectPosClassified(rs, currentlyBestDescription, trainPos);
        int trainingCorrectNegClassified = getCorrectNegClassified(rs, currentlyBestDescription, trainNeg);
        int trainingCorrectExamples = trainingCorrectPosClassified + trainingCorrectNegClassified;
        double trainingAccuracy = 100 * ((double) trainingCorrectExamples / (trainPos.size() + trainNeg.size()));
        accuracyTraining.addNumber(trainingAccuracy);
        // calculate test accuracies
        int correctPosClassified = getCorrectPosClassified(rs, currentlyBestDescription, testPos);
        int correctNegClassified = getCorrectNegClassified(rs, currentlyBestDescription, testNeg);
        int correctExamples = correctPosClassified + correctNegClassified;
        double currAccuracy = 100 * ((double) correctExamples / (testPos.size() + testNeg.size()));
        accuracy.addNumber(currAccuracy);
        // calculate training F-Score
        int negAsPosTraining = rs.hasType(currentlyBestDescription, OWLAPIConverter.convertIndividuals(trainNeg))
                .size();
        double precisionTraining = trainingCorrectPosClassified + negAsPosTraining == 0 ? 0
                : trainingCorrectPosClassified / (double) (trainingCorrectPosClassified + negAsPosTraining);
        double recallTraining = trainingCorrectPosClassified / (double) trainPos.size();
        fMeasureTraining.addNumber(100 * Heuristics.getFScore(recallTraining, precisionTraining));
        // calculate test F-Score
        int negAsPos = rs.hasType(currentlyBestDescription, OWLAPIConverter.convertIndividuals(testNeg)).size();
        double precision = correctPosClassified + negAsPos == 0 ? 0
                : correctPosClassified / (double) (correctPosClassified + negAsPos);
        double recall = correctPosClassified / (double) testPos.size();
        //      System.out.println(precision);System.out.println(recall);
        fMeasure.addNumber(100 * Heuristics.getFScore(recall, precision));

        //length.addNumber(OWLClassExpressionUtils);

        output += "test set errors pos: " + tmp2 + "\n";
        output += "test set errors neg: " + tmp3 + "\n";
        output += "fold " + currFold + ":" + "\n";
        output += "  training: " + pos.size() + " positive and " + neg.size() + " negative examples";
        output += "  testing: " + correctPosClassified + "/" + testPos.size() + " correct positives, "
                + correctNegClassified + "/" + testNeg.size() + " correct negatives" + "\n";
        output += "  concept: " + currentlyBestDescription.toString().replace("\n", " ") + "\n";
        output += "  accuracy: " + df.format(currAccuracy) + "% (" + df.format(trainingAccuracy)
                + "% on training set)" + "\n";
        //output += "  length: " + df.format(OWLClassExpressionUtils.getLength(concept)) + "\n";
        output += "  runtime: " + df.format(algorithmDuration / (double) 1000000000) + "s" + "\n";

        outputWriter(output);
    }

    protected int getCorrectPosClassified(AbstractReasonerComponent rs, Description concept,
            Set<OWLIndividual> testSetPos) {
        return rs.hasType(concept, OWLAPIConverter.convertIndividuals(testSetPos)).size();
    }

    protected int getCorrectNegClassified(AbstractReasonerComponent rs, Description concept,
            Set<OWLIndividual> testSetNeg) {
        return testSetNeg.size() - rs.hasType(concept, OWLAPIConverter.convertIndividuals(testSetNeg)).size();
    }

    public static Set<OWLIndividual> getTestingSet(List<OWLIndividual> examples, int[] splits, int fold) {
        int fromIndex;
        // we either start from 0 or after the last fold ended
        if (fold == 0)
            fromIndex = 0;
        else
            fromIndex = splits[fold - 1];
        // the split corresponds to the ends of the folds
        int toIndex = splits[fold];

        //      System.out.println("from " + fromIndex + " to " + toIndex);

        Set<OWLIndividual> testingSet = new HashSet<OWLIndividual>();
        // +1 because 2nd element is exclusive in subList method
        testingSet.addAll(examples.subList(fromIndex, toIndex));
        return testingSet;
    }

    public static Set<OWLIndividual> getTrainingSet(Set<OWLIndividual> examples, Set<OWLIndividual> testingSet) {
        return Helper.difference(examples, testingSet);
    }

    // takes nr. of examples and the nr. of folds for this examples;
    // returns an array which says where each fold ends, i.e.
    // splits[i] is the index of the last element of fold i in the examples
    public static int[] calculateSplits(int nrOfExamples, int folds) {
        int[] splits = new int[folds];
        for (int i = 1; i <= folds; i++) {
            // we always round up to the next integer
            splits[i - 1] = (int) Math.ceil(i * nrOfExamples / (double) folds);
        }
        return splits;
    }

    public static String statOutput(DecimalFormat df, Stat stat, String unit) {
        String str = "av. " + df.format(stat.getMean()) + unit;
        str += " (deviation " + df.format(stat.getStandardDeviation()) + unit + "; ";
        str += "min " + df.format(stat.getMin()) + unit + "; ";
        str += "max " + df.format(stat.getMax()) + unit + ")";
        return str;
    }

    public Stat getAccuracy() {
        return accuracy;
    }

    public Stat getLength() {
        return length;
    }

    public Stat getRuntime() {
        return runtime;
    }

    protected void outputWriter(String output) {
        if (writeToFile) {
            Files.appendToFile(outputFile, output + "\n");
            System.out.println(output);
        } else {
            System.out.println(output);
        }

    }

    public Stat getfMeasure() {
        return fMeasure;
    }

    public Stat getfMeasureTraining() {
        return fMeasureTraining;
    }

}