de.tudarmstadt.ukp.experiments.argumentation.sequence.evaluation.GenerateCrossDomainCVReport.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.experiments.argumentation.sequence.evaluation.GenerateCrossDomainCVReport.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.experiments.argumentation.sequence.evaluation;

import de.tudarmstadt.ukp.dkpro.tc.svmhmm.util.ConfusionMatrix;
import de.tudarmstadt.ukp.dkpro.tc.svmhmm.util.SVMHMMUtils;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.csv.CSVRecord;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;

import java.io.*;

/**
 * @author Ivan Habernal
 */
public class GenerateCrossDomainCVReport {
    private static final String TOKEN_LEVEL_PREDICTIONS_CSV = "tokenLevelPredictions.csv";

    public static void generateCrossDomainReport(File folder) throws IOException {
        aggregateDomainResults(folder, "", "SVMHMMTestTask", "Evaluation_CrossDomain_Full_CV");
    }

    public static void generateInDomainReport(File folder) throws IOException {
        aggregateDomainResults(folder, "", "BatchTaskRandomCrossValidation$1ArgumentSequenceLabeling_InDomain_",
                "Evaluation_InDomain_Full_CV");
    }

    /**
     * Merges id2outcome files from sub-folders with cross-domain and creates a new folder
     * with overall results
     *
     * @param folder folder
     * @throws java.io.IOException
     */
    public static void aggregateDomainResults(File folder, String subDirPrefix, final String taskFolderSubText,
            String outputFolderName) throws IOException {
        // list all sub-folders
        File[] folders = folder.listFiles(new FileFilter() {
            @Override
            public boolean accept(File pathname) {
                return pathname.isDirectory() && pathname.getName().contains(taskFolderSubText);
            }
        });

        if (folders.length == 0) {
            throw new IllegalArgumentException("No sub-folders 'SVMHMMTestTask*' found in " + folder);
        }

        // write to a file
        File outFolder = new File(folder, outputFolderName);
        File output = new File(outFolder, subDirPrefix);
        output.mkdirs();

        File outCsv = new File(output, TOKEN_LEVEL_PREDICTIONS_CSV);

        CSVPrinter csvPrinter = new CSVPrinter(new FileWriter(outCsv), SVMHMMUtils.CSV_FORMAT);
        csvPrinter.printComment(SVMHMMUtils.CSV_COMMENT);

        ConfusionMatrix cm = new ConfusionMatrix();

        for (File domain : folders) {
            File tokenLevelPredictionsCsv = new File(domain, subDirPrefix + "/" + TOKEN_LEVEL_PREDICTIONS_CSV);

            if (!tokenLevelPredictionsCsv.exists()) {
                throw new IllegalArgumentException(
                        "Cannot locate tokenLevelPredictions.csv: " + tokenLevelPredictionsCsv);
            }

            CSVParser csvParser = new CSVParser(new FileReader(tokenLevelPredictionsCsv),
                    CSVFormat.DEFAULT.withCommentMarker('#'));

            for (CSVRecord csvRecord : csvParser) {
                // copy record
                csvPrinter.printRecord(csvRecord);

                // update confusion matrix
                cm.increaseValue(csvRecord.get(0), csvRecord.get(1));
            }
        }

        // write to file
        FileUtils.writeStringToFile(new File(outFolder, "confusionMatrix.txt"), cm.toString() + "\n"
                + cm.printNiceResults() + "\n" + cm.printLabelPrecRecFm() + "\n" + cm.printClassDistributionGold());

        // write csv
        IOUtils.closeQuietly(csvPrinter);
    }

    public static void main(String[] args) throws Exception {

        File mainFolder = new File(args[0]);
        for (File dir : EvalHelper.listSubFolders(mainFolder)) {
            generateCrossDomainReport(dir);
            generateInDomainReport(dir);
        }

    }

}