de.tudarmstadt.lt.nlkg.EvaluatePreds.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.lt.nlkg.EvaluatePreds.java

Source

/*
 *   Copyright 2012
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 */
package de.tudarmstadt.lt.nlkg;

import java.io.FileNotFoundException;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.collections.Transformer;
import org.apache.commons.io.LineIterator;

import de.tudarmstadt.lt.nlkg.DT.Word;

/**
 *
 * @author Steffen Remus
 **/
public class EvaluatePreds {

    public static void main(String[] args) throws IllegalArgumentException, FileNotFoundException {
        evaluate("data/updates/3/preds.tsv");
        System.out.println(_X.get(0));
        System.out.println(_X.get(_X.size() - 1));
        System.out.println(_X.size());

    }

    static List<String> _X = new ArrayList<String>();
    static List<String> _Y = new ArrayList<String>();
    static List<String> _PRED_L = new ArrayList<String>();
    static List<String> _PRED_R = new ArrayList<String>();
    static List<Boolean> _ENTAILING = new ArrayList<Boolean>();

    // X Y PRED_L PRED_R ENTAILING
    static void evaluate(String file) throws IllegalArgumentException, FileNotFoundException {
        DT dt = new DT() {
            {
                _mysql_dbname = "nlkg_1";
            }
        };
        LineIterator iter = new LineIterator(new FileReader(file));
        iter.nextLine(); // skip first line
        int lineno = 1;
        double tp = 0d, tn = 0d, fp = 0d, fn = 0d;
        while (iter.hasNext() && (lineno < 100 || true)) {
            lineno++;
            String line = iter.nextLine();
            if (line.trim().isEmpty())
                continue;

            String[] splits = line.split("\t");
            String x = splits[0].trim();
            String y = splits[1].trim();
            String pred_l = splits[2].trim();
            String pred_r = splits[3].trim();
            boolean entailing_trueclass = Boolean.valueOf(splits[4].trim());

            _X.add(x);
            _Y.add(y);
            _PRED_L.add(pred_l);
            _PRED_R.add(pred_r);
            _ENTAILING.add(entailing_trueclass);

            boolean entailing_predicted = predictEntailing(dt, pred_l, pred_r);
            if (lineno % 100 == 0)
                Evaluate.log_progress();

            if (entailing_predicted && entailing_trueclass) {
                Evaluate.log_true(String.format("%d %-10s %-30s %-30s %b %n", lineno, "tp", pred_l, pred_r,
                        entailing_trueclass));
                tp++;
            }
            if (!entailing_predicted && !entailing_trueclass) {
                Evaluate.log_true(String.format("%d %-10s %-30s %-30s %b %n", lineno, "tn", pred_l, pred_r,
                        entailing_trueclass));
                tn++;
            }
            if (entailing_predicted && !entailing_trueclass) {
                Evaluate.log_false(String.format("%d %-10s %-30s %-30s %b %n", lineno, "fp", pred_l, pred_r,
                        entailing_trueclass));
                fp++;
            }
            if (!entailing_predicted && entailing_trueclass) {
                Evaluate.log_false(String.format("%d %-10s %-30s %-30s %b %n", lineno, "fn", pred_l, pred_r,
                        entailing_trueclass));
                fn++;
            }

        }

        System.out.format("tp: %d; fp: %d; fn: %d; tn: %d; %n", (int) tp, (int) fp, (int) fn, (int) tn);
        System.out.println("Precision = " + (tp / (tp + fp)));
        System.out.println("Recall    = " + (tp / (tp + fn)));
        System.out.println("F1        = " + ((2 * tp) / ((2 * tp) + fn + fp)));

        dt.disconnect();

    }

    static boolean predictEntailing(DT dt, String arg_l, String arg_r) {
        return predictEntailingContainedInTopEntries(dt, arg_l, arg_r);
    }

    static boolean predictEntailingContainedInTopEntries(DT dt, String pred_l, String pred_r) {
        pred_l = pred_l.replace('X', '\0').replace('Y', '\0').trim().replace(' ', '_');
        pred_r = pred_r.replace('X', '\0').replace('Y', '\0').trim().replace(' ', '_');

        DT.Entry e = dt.get(pred_l, 200);
        for (Word w : e)
            if (pred_r.equals(w.word))
                return true;
        return false;
    }

}