org.dllearner.experiments.Examples.java Source code

Java tutorial

Introduction

Here is the source code for org.dllearner.experiments.Examples.java

Source

/**
 * Copyright (C) 2007-2011, Jens Lehmann
 *
 * This file is part of DL-Learner.
 *
 * DL-Learner is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * DL-Learner is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package org.dllearner.experiments;

import com.google.common.collect.Sets;
import org.apache.log4j.Logger;
import org.dllearner.utilities.URLencodeUTF8;

import java.io.FileWriter;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Collection;
import java.util.SortedSet;
import java.util.TreeSet;

/**
 * a container for examples used for operations like randomization
 * 
 * @author Sebastian Hellmann <hellmann@informatik.uni-leipzig.de>
 * 
 */
public class Examples {
    private static final Logger logger = Logger.getLogger(Examples.class);
    public static DecimalFormat df1 = new DecimalFormat("00.#%");
    public static DecimalFormat df2 = new DecimalFormat("00.##%");
    public static DecimalFormat df3 = new DecimalFormat("00.###%");
    private DecimalFormat myDf = df2;

    private final SortedSet<String> posTrain = new TreeSet<>();
    private final SortedSet<String> negTrain = new TreeSet<>();
    private final SortedSet<String> posTest = new TreeSet<>();
    private final SortedSet<String> negTest = new TreeSet<>();

    /**
     * default constructor
     */
    public Examples() {
    }

    /**
     * constructor to add training examples
     * 
     * @param posTrain
     * @param negTrain
     */
    public Examples(SortedSet<String> posTrain, SortedSet<String> negTrain) {
        this.addPosTrain(posTrain);
        this.addNegTrain(negTrain);
    }

    /**
     * adds all examples, doublettes are removed automatically
     * 
     * @param posTrain
     * @param negTrain
     * @param posTest
     * @param negTest
     */
    public Examples(SortedSet<String> posTrain, SortedSet<String> negTrain, SortedSet<String> posTest,
            SortedSet<String> negTest) {
        this.addPosTrain(posTrain);
        this.addPosTest(posTest);
        this.addNegTrain(negTrain);
        this.addNegTest(negTest);
    }

    /**
     * calculates precision based on the test set removes all training data from
     * retrieved first
     * 
     * @param retrieved
     * @return
     */
    public double precision(SortedSet<String> retrieved) {
        if (retrieved.size() == 0) {
            return 0.0d;
        }
        SortedSet<String> retrievedClean = new TreeSet<>(retrieved);
        retrievedClean.removeAll(posTrain);
        retrievedClean.removeAll(negTrain);

        int posAsPos = Sets.intersection(retrievedClean, getPosTest()).size();
        return ((double) posAsPos) / ((double) retrievedClean.size());
    }

    /**
     * calculates recall based on the test set
     * 
     * 
     * @param retrieved
     * @return
     */
    public double recall(SortedSet<String> retrieved) {
        if (sizeTotalOfPositives() == 0) {
            return 0.0d;
        }
        int posAsPos = Sets.intersection(getPosTest(), retrieved).size();
        return ((double) posAsPos) / ((double) posTest.size());
    }

    private void _remove(String toBeRemoved) {
        _removeAll(Arrays.asList(toBeRemoved));
    }

    private void _removeAll(Collection<String> toBeRemoved) {
        if (posTrain.removeAll(toBeRemoved) || negTrain.removeAll(toBeRemoved) || posTest.removeAll(toBeRemoved)
                || negTest.removeAll(toBeRemoved)) {
            logger.warn("There has been some overlap in the examples, but it was removed automatically");
        }
    }

    public void addPosTrain(Collection<String> pos) {
        _removeAll(pos);
        posTrain.addAll(pos);
    }

    public void addPosTest(Collection<String> pos) {
        _removeAll(pos);
        posTest.addAll(pos);
    }

    public void addNegTrain(Collection<String> neg) {
        _removeAll(neg);
        negTrain.addAll(neg);
    }

    public void addNegTest(Collection<String> neg) {
        _removeAll(neg);
        negTest.addAll(neg);
    }

    public void addPosTrain(String pos) {
        _remove(pos);
        posTrain.add(pos);
    }

    public void addPosTest(String pos) {
        _remove(pos);
        posTest.add(pos);
    }

    public void addNegTrain(String neg) {
        _remove(neg);
        negTrain.add(neg);
    }

    public void addNegTest(String neg) {
        _remove(neg);
        negTest.add(neg);
    }

    public boolean checkConsistency() {
        for (String one : posTrain) {
            if (negTrain.contains(one)) {
                logger.error("positve and negative example overlap " + one);
                return false;
            }
        }
        return true;
    }

    @Override
    public String toString() {
        String ret = "Total: " + size();
        double posPercent = posTrain.size() / (double) sizeTotalOfPositives();
        double negPercent = negTrain.size() / (double) sizeTotalOfNegatives();
        ret += "\nPositive: " + posTrain.size() + " | " + posTest.size() + " (" + myDf.format(posPercent) + ")";
        ret += "\nNegative: " + negTrain.size() + " | " + negTest.size() + " (" + myDf.format(negPercent) + ")";

        return ret;
    }

    public String toFullString() {

        String ret = "Training:\n";
        for (String one : posTrain) {
            ret += "+\"" + one + "\"\n";
        }
        for (String one : negTrain) {
            ret += "-\"" + one + "\"\n";
        }
        ret += "Testing:\n";
        for (String one : posTest) {
            ret += "+\"" + one + "\"\n";
        }
        for (String one : negTest) {
            ret += "-\"" + one + "\"\n";
        }

        return ret + this.toString();

    }

    public void writeExamples(String filename) {
        try {
            FileWriter a = new FileWriter(filename, false);

            StringBuffer buffer = new StringBuffer();
            buffer.append("\n\n\n\n\n");
            for (String s : posTrain) {
                a.write("import(\"" + URLencodeUTF8.encode(s) + "\");\n");
                buffer.append("+\"").append(s).append("\"\n");
            }
            for (String s : negTrain) {
                a.write("import(\"" + URLencodeUTF8.encode(s) + "\");\n");
                buffer.append("-\"").append(s).append("\"\n");
            }

            a.write(buffer.toString());
            a.flush();
            a.close();
            logger.info("wrote examples to " + filename);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * sum of training and test data
     * @return
     */
    public int size() {
        return posTrain.size() + negTrain.size() + posTest.size() + negTest.size();
    }

    public int sizeTotalOfPositives() {
        return posTrain.size() + posTest.size();
    }

    public int sizeTotalOfNegatives() {
        return negTrain.size() + negTest.size();
    }

    public int sizeOfTrainingSets() {
        return posTrain.size() + negTrain.size();
    }

    public int sizeOfTestSets() {
        return posTest.size() + negTest.size();
    }

    public SortedSet<String> getAllExamples() {
        SortedSet<String> total = new TreeSet<>();
        total.addAll(getPositiveExamples());
        total.addAll(getNegativeExamples());
        return total;
    }

    public SortedSet<String> getPositiveExamples() {
        SortedSet<String> total = new TreeSet<>();
        total.addAll(posTrain);
        total.addAll(posTest);
        return total;
    }

    public SortedSet<String> getNegativeExamples() {
        SortedSet<String> total = new TreeSet<>();
        total.addAll(negTrain);
        total.addAll(negTest);
        return total;
    }

    public SortedSet<String> getTestExamples() {
        SortedSet<String> total = new TreeSet<>();
        total.addAll(posTest);
        total.addAll(negTest);
        return total;
    }

    public SortedSet<String> getTrainExamples() {
        SortedSet<String> total = new TreeSet<>();
        total.addAll(posTrain);
        total.addAll(negTrain);
        return total;
    }

    public SortedSet<String> getPosTrain() {
        return posTrain;
    }

    public SortedSet<String> getNegTrain() {
        return negTrain;
    }

    public SortedSet<String> getPosTest() {
        return posTest;
    }

    public SortedSet<String> getNegTest() {
        return negTest;
    }

}