de.tudarmstadt.lt.nlkg.EvaluateArgs.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.lt.nlkg.EvaluateArgs.java

Source

/*
 *   Copyright 2012
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 */
package de.tudarmstadt.lt.nlkg;

import java.io.FileNotFoundException;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.collections.Transformer;
import org.apache.commons.io.LineIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import de.tudarmstadt.lt.nlkg.DT.Word;

/**
 *
 * @author Steffen Remus
 **/
public class EvaluateArgs {

    public static void main(String[] args) throws IllegalArgumentException, FileNotFoundException {

        //      read("data/updates/3/args.tsv");
        evaluate("data/updates/6/args_unique_v4.tsv");
        //      System.out.println(_CONTEXT.get(0));
        //      System.out.println(_CONTEXT.get(_CONTEXT.size() - 1));
        //      System.out.println(_CONTEXT.size());

    }

    static List<String> _CONTEXT = new ArrayList<String>();
    static List<String> _ARG_L = new ArrayList<String>();
    static List<String> _ARG_R = new ArrayList<String>();
    static List<Boolean> _ENTAILING = new ArrayList<Boolean>();

    //   // CONTEXT ARG_L ARG_R ENTAILING
    //   static void read(String file) throws IllegalArgumentException, FileNotFoundException {
    //      LineIterator iter = new LineIterator(new FileReader(file));
    //      iter.nextLine(); // skip first line
    //      int lineno = 1;
    //      while (iter.hasNext()) {
    //         lineno++;
    //         String line = iter.nextLine();
    //         if (line.trim().isEmpty())
    //            continue;
    //
    //         String[] splits = line.split("\t");
    //         String context = splits[0].trim();
    //         String arg_l = splits[1].trim();
    //         String arg_r = splits[2].trim();
    //         boolean entailing = Boolean.valueOf(splits[3].trim());
    //
    //         _CONTEXT.add(context);
    //         _ARG_L.add(arg_l);
    //         _ARG_R.add(arg_r);
    //         _ENTAILING.add(entailing);
    //      }
    //      System.out.println(lineno);
    //   }

    // CONTEXT ARG_L ARG_R ENTAILING
    static void evaluate(String file) throws IllegalArgumentException, FileNotFoundException {
        DT dt = new DT() {
            {
                _mysql_host = "localhost";
                _mysql_dbname = "nlkg_1";
            }
        };
        LineIterator iter = new LineIterator(new FileReader(file));
        iter.nextLine(); // skip first line
        int lineno = 1;
        double tp = 0d, tn = 0d, fp = 0d, fn = 0d;
        while (iter.hasNext() && (lineno < 100 || true)) {
            lineno++;
            String line = iter.nextLine();
            if (line.trim().isEmpty())
                continue;

            String[] splits = line.split("\t");
            //         String context = splits[0].trim();
            String arg_l = splits[0].trim();
            String arg_r = splits[1].trim();
            boolean entailing_trueclass = Boolean.valueOf(splits[2].trim());

            //         _CONTEXT.add(context);
            _ARG_L.add(arg_l);
            _ARG_R.add(arg_r);
            _ENTAILING.add(entailing_trueclass);

            boolean entailing_predicted = predictEntailing(dt, arg_l, arg_r);
            if (lineno % 100 == 0)
                Evaluate.log_progress();

            if (entailing_predicted && entailing_trueclass) {
                Evaluate.log_true(String.format("%d %-10s %-30s %-30s %b %n", lineno, "tp", arg_l, arg_r,
                        entailing_trueclass));
                tp++;
            }
            if (!entailing_predicted && !entailing_trueclass) {
                Evaluate.log_true(String.format("%d %-10s %-30s %-30s %b %n", lineno, "tn", arg_l, arg_r,
                        entailing_trueclass));
                tn++;
            }
            if (entailing_predicted && !entailing_trueclass) {
                Evaluate.log_false(String.format("%d %-10s %-30s %-30s %b %n", lineno, "fp", arg_l, arg_r,
                        entailing_trueclass));
                fp++;
            }
            if (!entailing_predicted && entailing_trueclass) {
                Evaluate.log_false(String.format("%d %-10s %-30s %-30s %b %n", lineno, "fn", arg_l, arg_r,
                        entailing_trueclass));
                fn++;
            }

        }

        dt.disconnect();

        System.out.format("%ntp: %d; fp: %d; fn: %d; tn: %d; %n", (int) tp, (int) fp, (int) fn, (int) tn);
        System.out.println("Precision = " + (tp / (tp + fp)));
        System.out.println("Recall    = " + (tp / (tp + fn)));
        System.out.println("F1        = " + ((2 * tp) / ((2 * tp) + fn + fp)));
    }

    static boolean predictEntailing(DT dt, String arg_l, String arg_r) {
        return predictEntailingContainedInNTopEntries(dt, arg_l, arg_r);
    }

    static int ntop = 7;

    static boolean predictEntailingContainedInNTopEntries(DT dt, String arg_l, String arg_r) {
        DT.Entry e = dt.get(arg_l, ntop);
        @SuppressWarnings("unchecked")
        Iterator<String> string_iter = IteratorUtils.transformedIterator(e.dtwords, new Transformer() {
            @Override
            public Object transform(Object input) {
                return ((Word) input).word;
            }
        });

        Set<String> dtwords = new HashSet<String>(
                Arrays.asList((String[]) IteratorUtils.toArray(string_iter, String.class)));
        return dtwords.contains(arg_r);
    }

}