tpt.dbweb.cat.evaluation.ReferenceEvaluator.java Source code

Java tutorial

Introduction

Here is the source code for tpt.dbweb.cat.evaluation.ReferenceEvaluator.java

Source

/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

package tpt.dbweb.cat.evaluation;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.StringWriter;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.beust.jcommander.Parameter;

import tpt.dbweb.cat.datatypes.Fraction;
import tpt.dbweb.cat.datatypes.TaggedText;
import tpt.dbweb.cat.io.ConllWriter;
import tpt.dbweb.cat.io.TaggedTextXMLReader;

/**
 * Calls lib/reference-coreference-scorers and parses the output.
 * @author Thomas Rebele
 *
 */
public class ReferenceEvaluator {

    private static Logger log = LoggerFactory.getLogger(ReferenceEvaluator.class);

    public static class Options {

        @Parameter(names = "--single-file", description = "put all articles in one file for reference-coreference-scorers")
        public boolean singleFile = true;

        @Parameter(names = "--remove-tmp-files", description = "temporarily generated files will get removed after execution")
        public boolean removeTmpFiles = true;

    }

    private Options options = new Options();

    // create a regex to parse
    private Pattern resultPattern = null;

    public ReferenceEvaluator() {
        String floatingPointRegex = "\\d*\\.?\\d+";
        String nomRegex = "\\((?<RecallNom>" + floatingPointRegex + ")\\s*";
        String denomRegex = "\\s*(?<RecallDenom>" + floatingPointRegex + ")\\)";
        String percentRegex = "\\s*(?<RecallPct>" + floatingPointRegex + ")\\%\\s*";
        String recallPattern = "Recall:\\s*(?:" + nomRegex + "/" + denomRegex + ")?" + percentRegex;
        String precisionPattern = recallPattern.replaceAll("Recall", "Precision");
        String f1Pattern = recallPattern.replaceAll("Recall", "F1");
        resultPattern = Pattern.compile(recallPattern + precisionPattern + f1Pattern + ".*");
    }

    public ReferenceEvaluator(Options refEvalOptions) {
        this.options = refEvalOptions;
    }

    private void silentDelete(Path path) {
        try {
            Files.delete(path);
        } catch (Exception e) {
        }
    }

    /**
     * EValuates metrics for two tagged text XML files.
     * @param goldstandard path to tagged text XML
     * @param compare path to tagged text XML
     * @param tmpDirectory where to store generated files
     * @return
     * @throws IOException
     */
    public ComparisonResult compareXMLFiles(Path goldstandard, Path compare, Path tmpDirectory) throws IOException {
        TaggedTextXMLReader.Options options = new TaggedTextXMLReader.Options();
        TaggedTextXMLReader reader = new TaggedTextXMLReader(options);
        List<TaggedText> goldstd = reader.getTaggedText(goldstandard), cmp = reader.getTaggedText(compare);
        return compare(goldstd, goldstandard.getFileName().toString(), cmp, compare.getFileName().toString(),
                tmpDirectory);
    }

    public ComparisonResult compare(List<TaggedText> goldstandard, List<TaggedText> compare, Path tmpDirectory)
            throws IOException {
        return compare(goldstandard, "goldstd", compare, "compare", tmpDirectory);
    }

    /**
     * Evaluates metrics for list of tagged texts.
     * @param goldstandard list of tagged texts
     * @param goldstandardFilename temporary filename for goldstandard .conll file
     * @param compare list of tagged texts
     * @param compareFilename temporary filename for compare .conll file
     * @param tmpDirectory
     * @return
     * @throws IOException
     */
    public ComparisonResult compare(List<TaggedText> goldstandard, String goldstandardFilename,
            List<TaggedText> compare, String compareFilename, Path tmpDirectory) throws IOException {
        String scorerOutput = goldstandardFilename + "-" + compareFilename + "-scorer-output";
        ConllWriter conll = new ConllWriter();
        ReferenceEvaluator evaluator = new ReferenceEvaluator();
        if (this.options.singleFile) {
            log.info("using only one thread, try to use the split file option to speed things up");
            Path goldstdConllFile = tmpDirectory.resolve(goldstandardFilename + ".conll");
            conll.writeTTList(goldstandard, goldstdConllFile);

            Path compareConllFile = tmpDirectory.resolve(compareFilename + ".conll");
            conll.writeTTList(compare, compareConllFile);

            Path scorerOutputFile = tmpDirectory.resolve(scorerOutput + ".txt");
            ComparisonResult result = evaluator.compareConllFiles(goldstdConllFile, compareConllFile,
                    scorerOutputFile);
            if (this.options.removeTmpFiles) {
                silentDelete(goldstdConllFile);
                silentDelete(compareConllFile);
                silentDelete(scorerOutputFile);
            }
            return result;

        } else {
            int cnt = Math.min(goldstandard.size(), compare.size());
            List<ComparisonResult> lst = IntStream.range(0, cnt).parallel().mapToObj((i) -> {
                String id = goldstandard.get(i).id;
                Path goldstdConllFile = tmpDirectory.resolve(goldstandardFilename + "-" + id);
                conll.writeTT(goldstandard.get(i), goldstdConllFile);

                Path compareConllFile = tmpDirectory.resolve(compareFilename + "-" + id);
                conll.writeTT(compare.get(i), compareConllFile);

                Path scorerOutputFile = tmpDirectory
                        .resolve(scorerOutput + Thread.currentThread().getName() + ".txt");
                ComparisonResult singleResult = evaluator.compareConllFiles(goldstdConllFile, compareConllFile,
                        scorerOutputFile);
                if (this.options.removeTmpFiles) {
                    try {
                        silentDelete(goldstdConllFile);
                        silentDelete(compareConllFile);
                        silentDelete(scorerOutputFile);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }

                return singleResult;
            }).collect(Collectors.toList());

            ComparisonResult result = new ComparisonResult();
            lst.forEach(cr -> result.merge(cr));
            return result;
        }
    }

    /**
     * Compare two files in Conll format (!)
     * @param goldstandard
     * @param compare
     * @param scorerOutput save output of scorer to this file if set
     * @return
     */
    public ComparisonResult compareConllFiles(Path goldstandard, Path compare, Path scorerOutput) {
        log.debug("comparing {} {}", goldstandard, compare);
        String cmd = "lib/reference-coreference-scorers/scorer.pl all " + goldstandard + " " + compare;
        String str = execExternalCommand(cmd);
        log.trace("reference-coreference-scorers output: {}", str);
        if (scorerOutput != null) {
            try {
                FileUtils.writeStringToFile(scorerOutput.toFile(), str);
                if (!options.removeTmpFiles) {
                    log.debug("wrote scorer output to {}", scorerOutput);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        return parseScorerOutput(str);
    }

    /**
     * Compare two Conll directories by concatenating them into one file.
     * Please ensure that the two directories have the same structure.
     * Files must end with 'conll'
     * @param goldstandardDir
     * @param compareDir
     * @param scorerOutput
     * @return
     */
    public ComparisonResult compareConllDirectorysAsSingleFile(Path goldstandardDir, Path compareDir,
            Path tmpDirectory) {
        try {
            Map<String, Path> goldstd = new HashMap<>();

            List<Path> cmp = Files.walk(compareDir).filter(p -> p.toString().endsWith("conll"))
                    .collect(Collectors.toList());
            Files.walk(goldstandardDir).filter(p -> p.toString().endsWith("conll"))
                    .forEach(p -> goldstd.put(p.getFileName().toString(), p));

            Path cmpFileGoldstd = tmpDirectory.resolve(goldstandardDir.toString().replace("/", "-"));
            Path cmpFileCmp = tmpDirectory.resolve(compareDir.toString().replace("/", "-"));

            for (Path cmpPath : cmp) {
                // try to find cmp file
                Path goldstdPath = goldstd.get(cmpPath.getFileName().toString());
                if (goldstdPath == null) {
                    log.warn("warning, compare file not found for goldstandard file: {}", cmpPath);
                    continue;
                }

                FileUtils.writeStringToFile(cmpFileGoldstd.toFile(),
                        FileUtils.readFileToString(goldstdPath.toFile()), true);
                FileUtils.writeStringToFile(cmpFileCmp.toFile(), FileUtils.readFileToString(cmpPath.toFile()),
                        true);
            }

            String scorerOutput = cmpFileGoldstd.getFileName() + "-" + cmpFileCmp.getFileName() + "-scorer-output";
            return compareConllFiles(cmpFileGoldstd, cmpFileCmp, tmpDirectory.resolve(scorerOutput + ".txt"));

        } catch (IOException e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * Call external command
     * @param cmd
     * @return its std output as a string
     */
    private String execExternalCommand(String cmd) {
        log.debug("executing command: {}", cmd);
        Process p = null;
        try {
            p = Runtime.getRuntime().exec(cmd);
        } catch (IOException e1) {
            e1.printStackTrace();
        }
        InputStream is = p.getInputStream();
        BufferedReader br = new BufferedReader(new InputStreamReader(is));
        StringWriter sw = new StringWriter();
        try {
            while (p.isAlive()) {
                try {
                    char[] buffer = new char[1024 * 4];
                    int n = 0;
                    while (-1 != (n = br.read(buffer))) {
                        sw.write(buffer, 0, n);
                    }
                    if (p.isAlive()) {
                        Thread.sleep(10);
                    }
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        } catch (InterruptedException e1) {
            Thread.currentThread().interrupt();
        }
        String str = sw.toString();
        return str;
    }

    /**
     * Get documents, metrics and their respective recall and precision values
     * @param output
     * @return
     */
    public ComparisonResult parseScorerOutput(String output) {
        String[] lines = output.split("\n");

        if (lines.length < 1) {
            return null;
        }
        if (!lines[0].startsWith("version: 8.01")) {
            log.warn("expected version 8.01, but got {}", lines[0]);
        }
        ComparisonResult result = new ComparisonResult();

        // go through all lines
        String actMetric = null;
        String actDocid = null;
        for (int i = 0; i < lines.length; i++) {
            String line = lines[i];

            // track document and metric info
            String metricPrefix = "METRIC ";
            String docidPrefix = "====>";
            String totalsPrefix = "====== TOTALS =======";
            if (line.startsWith(metricPrefix)) {
                actMetric = line.substring(metricPrefix.length(), line.length() - 1);
            }
            if (line.startsWith(docidPrefix)) {
                actDocid = line.substring(docidPrefix.length() + 1, line.length() - 1);
            }
            if (line.startsWith(totalsPrefix)) {
                actDocid = null;
            }

            // parse line if possible
            Matcher m = resultPattern.matcher(line);
            if (m.matches()) {
                double recallnom = Float.parseFloat(m.group("RecallNom"));
                double recalldenom = Float.parseFloat(m.group("RecallDenom"));
                double precisionnom = Float.parseFloat(m.group("PrecisionNom"));
                double precisiondenom = Float.parseFloat(m.group("PrecisionDenom"));

                /*int recallpercent = Integer.parseInt(m.group("RecallPct"));
                int precisionpercent = Integer.parseInt(m.group("PrecisionPct"));
                int f1percent = Integer.parseInt(m.group("F1Pct"));*/
                if (actMetric == null) {
                    log.error("metric is null, line {}", i);
                } else if (actDocid == null) {
                    log.error("docid is null, line {}", i);
                } else {
                    Fraction recall = new Fraction(recallnom, recalldenom);
                    Fraction precision = new Fraction(precisionnom, precisiondenom);
                    ValueEvaluationStatistics value = new ValueEvaluationStatistics(recall, precision);
                    result.docidToMetricToResult.computeIfAbsent(actDocid, k -> new HashMap<>()).put(actMetric,
                            value);
                }
            } else if (line.contains("Recall")) {
                log.debug("pattern didn't match {}", line);
            }
        }
        return result;
    }

    public static void main(String[] args) throws IOException {
        ReferenceEvaluator eval = new ReferenceEvaluator();

        /*ComparisonResult r = eval.compareConllFiles(Paths.get("result/ace2004/roth-dev/dev/conll-format/ace2004-aida2014-with-nme.xml"),
            Paths.get("result/ace2004/roth-dev/dev/conll-format/tagged-by-aida.xml"));*/
        Path goldstd = Paths.get("test-data/ace2004/roth-dev/dev/ace2004-aida2014-with-nme.xml");
        Path compare = Paths.get("result/ace2004/roth-dev/dev/tagged-by-aida.xml");
        Path tmpDir = Paths.get("result/ace2004/roth-dev/dev/conll-format");
        ComparisonResult r = eval.compareXMLFiles(goldstd, compare, tmpDir);

        System.out.println(r.docidToMetricToResult);
    }
}