de.tudarmstadt.ukp.experiments.argumentation.sequence.evaluation.ArgumentSequenceLabelingEvaluation.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.experiments.argumentation.sequence.evaluation.ArgumentSequenceLabelingEvaluation.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.experiments.argumentation.sequence.evaluation;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import de.tudarmstadt.ukp.experiments.argumentation.sequence.DocumentDomain;
import de.tudarmstadt.ukp.experiments.argumentation.sequence.DocumentRegister;
import de.tudarmstadt.ukp.experiments.argumentation.sequence.adapter.SVMAdapterBatchTokenReport;
import de.tudarmstadt.ukp.experiments.argumentation.sequence.evaluation.helpers.FeatureSetHelper;
import de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.clustering.ArgumentSpaceFeatureExtractor;
import de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.lda.LDATopicsFeature;
import de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.lexical.LemmaLuceneNGramUFE;
import de.tudarmstadt.ukp.experiments.argumentation.sequence.io.ArgumentSequenceSentenceLevelReader;
import de.tudarmstadt.ukp.experiments.argumentation.sequence.report.TokenLevelEvaluationReport;
import de.tudarmstadt.ukp.experiments.argumentation.sequence.report.TokenLevelMacroFMReport;
import de.tudarmstadt.ukp.dkpro.lab.Lab;
import de.tudarmstadt.ukp.dkpro.lab.task.Dimension;
import de.tudarmstadt.ukp.dkpro.lab.task.ParameterSpace;
import de.tudarmstadt.ukp.dkpro.lab.task.impl.BatchTask;
import de.tudarmstadt.ukp.dkpro.tc.core.Constants;
import de.tudarmstadt.ukp.dkpro.tc.fstore.simple.SparseFeatureStore;
import de.tudarmstadt.ukp.dkpro.tc.ml.ExperimentCrossValidation;
import de.tudarmstadt.ukp.dkpro.tc.ml.ExperimentTrainTest;
import de.tudarmstadt.ukp.dkpro.tc.svmhmm.task.SVMHMMTestTask;
import org.apache.commons.lang.StringUtils;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.fit.component.NoOpAnnotator;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.resource.ResourceInitializationException;

import java.io.File;
import java.text.SimpleDateFormat;
import java.util.*;

import static java.util.Arrays.asList;

/**
 * @author Ivan Habernal
 */
public class ArgumentSequenceLabelingEvaluation {
    private static final int NUM_FOLDS = 10;

    @Parameter(names = { "--featureSet",
            "--fs" }, description = "Feature set name (e.g., fs0, fs0fs1fs2, fs3fs4, ...)", required = true)
    String featureSet;

    @Parameter(names = { "--corpusPath", "--c" }, description = "Corpus path with XMI files", required = true)
    String corpusPath;

    @Parameter(names = { "--outputPath", "--o" }, description = "Main output path (folder)", required = true)
    String outputPath;

    @Parameter(names = { "--paramE", "--e" }, description = "Parameter e for SVMHMM")
    private Integer paramE = 0;

    @Parameter(names = { "--paramT", "--t" }, description = "Parameter T for SVMHMM")
    private Integer paramT = 1;

    @Parameter(names = { "--scenario",
            "--s" }, description = "Evaluation scenario (cv = cross-validation, cd = cross domain,"
                    + " id = in domain)", required = true)
    String scenario;

    @Parameter(names = { "--cl", "--clusters" }, description = "Which clusters? Comma-delimited, e.g., s100,a500")
    String clusters;

    public static void main(String[] args) throws Exception {
        ArgumentSequenceLabelingEvaluation evaluation = new ArgumentSequenceLabelingEvaluation();
        JCommander jCommander = new JCommander(evaluation, args);
        try {
            evaluation.run();
        } catch (ParameterException e) {
            e.printStackTrace();
            jCommander.usage();
        }

    }

    public void run() throws Exception {
        System.setProperty("org.apache.uima.logger.class", "org.apache.uima.util.impl.Log4jLogger_impl");
        SVMHMMTestTask.PRINT_STD_OUT = true;

        File mainOutputFolder = new File(outputPath);
        // date
        String date = new SimpleDateFormat("yyyy-MM-dd-HH-mm-ss").format(new Date(System.currentTimeMillis()));
        File outputFolder = new File(mainOutputFolder,
                scenario + "_" + featureSet + "_e" + paramE + "_t" + paramT + "_" + clusters + "_" + date);

        outputFolder.mkdirs();

        System.setProperty("DKPRO_HOME", outputFolder.getAbsolutePath());

        if ("cv".equals(scenario)) {
            // cross validation
            runCrossValidation(getParameterSpace(null, false, clusters, null, null));

        } else if ("id".equals(scenario)) {
            // in-domain for all domains
            for (DocumentDomain documentDomain : DocumentDomain.values()) {
                runInDomainCrossValidation(documentDomain,
                        getParameterSpace(documentDomain, true, clusters, null, null));
            }
        } else if ("cd".equals(scenario)) {

            //                cross-domain for all domains
            for (DocumentDomain documentDomain : DocumentDomain.values()) {
                String domainClusters = determineDomainClusters(documentDomain);
                runCrossDomain(documentDomain,
                        getParameterSpace(documentDomain, false, domainClusters, null, null));
            }
        } else if ("cr".equals(scenario)) {
            // cross register
            Set<DocumentRegister> commentsForums = new TreeSet<>();
            commentsForums.add(DocumentRegister.ARTCOMMENT);
            commentsForums.add(DocumentRegister.FORUMPOST);

            Set<DocumentRegister> blogsArticles = new TreeSet<>();
            blogsArticles.add(DocumentRegister.BLOGPOST);
            blogsArticles.add(DocumentRegister.ARTICLE);

            // one way
            runCrossRegister(getParameterSpace(null, false, clusters, commentsForums, blogsArticles),
                    "Train-commentsForums_Test-blogsArticles");

            // and the other
            runCrossRegister(getParameterSpace(null, false, clusters, blogsArticles, commentsForums),
                    "Train-blogsArticles_Test-commentsForums");
        } else {
            throw new IllegalArgumentException("Unknown 'scenario' argument: " + scenario);
        }
    }

    private String determineDomainClusters(DocumentDomain documentDomain) {
        final Map<DocumentDomain, String> clusterMapping = new HashMap<>();
        clusterMapping.put(DocumentDomain.HOMESCHOOLING, "arg-all-minus-hs,sent-all-minus-hs");
        clusterMapping.put(DocumentDomain.MAINSTREAMING, "arg-all-minus-ms,sent-all-minus-ms");
        clusterMapping.put(DocumentDomain.PRAYER_IN_SCHOOLS, "arg-all-minus-pis,sent-all-minus-pis");
        clusterMapping.put(DocumentDomain.PUBLIC_PRIVATE_SCHOOLS, "arg-all-minus-pps,sent-all-minus-pps");
        clusterMapping.put(DocumentDomain.REDSHIRTING, "arg-all-minus-rs,sent-all-minus-rs");
        clusterMapping.put(DocumentDomain.SINGLE_SEX_EDUCATION, "arg-all-minus-sse,sent-all-minus-sse");

        return clusterMapping.get(documentDomain);
    }

    public Map<String, Object> createDimReaders(DocumentDomain documentDomain, boolean inDomain,
            String corpusFilePathTrain, Set<DocumentRegister> trainingRegister,
            Set<DocumentRegister> testRegister) {
        Map<String, Object> result = new HashMap<>();

        // we take all documents regardless of domain
        if (documentDomain == null) {
            // normal CV
            if (trainingRegister == null && testRegister == null) {
                result.put(Constants.DIM_READER_TRAIN, ArgumentSequenceSentenceLevelReader.class);
                result.put(Constants.DIM_READER_TRAIN_PARAMS,
                        Arrays.asList(ArgumentSequenceSentenceLevelReader.PARAM_SOURCE_LOCATION,
                                corpusFilePathTrain, ArgumentSequenceSentenceLevelReader.PARAM_PATTERNS,
                                ArgumentSequenceSentenceLevelReader.INCLUDE_PREFIX + "*.xmi",
                                ArgumentSequenceSentenceLevelReader.PARAM_LENIENT, false));
            } else {
                // we have cross-register train-test
                result.put(Constants.DIM_READER_TRAIN, ArgumentSequenceSentenceLevelReader.class);
                result.put(Constants.DIM_READER_TRAIN_PARAMS,
                        Arrays.asList(ArgumentSequenceSentenceLevelReader.PARAM_SOURCE_LOCATION,
                                corpusFilePathTrain, ArgumentSequenceSentenceLevelReader.PARAM_PATTERNS,
                                ArgumentSequenceSentenceLevelReader.INCLUDE_PREFIX + "*.xmi",
                                ArgumentSequenceSentenceLevelReader.PARAM_DOCUMENT_REGISTER,
                                StringUtils.join(trainingRegister, " ")));
                result.put(Constants.DIM_READER_TEST, ArgumentSequenceSentenceLevelReader.class);
                result.put(Constants.DIM_READER_TEST_PARAMS,
                        Arrays.asList(ArgumentSequenceSentenceLevelReader.PARAM_SOURCE_LOCATION,
                                corpusFilePathTrain, ArgumentSequenceSentenceLevelReader.PARAM_PATTERNS,
                                ArgumentSequenceSentenceLevelReader.INCLUDE_PREFIX + "*.xmi",
                                ArgumentSequenceSentenceLevelReader.PARAM_LENIENT, false,
                                ArgumentSequenceSentenceLevelReader.PARAM_DOCUMENT_REGISTER,
                                StringUtils.join(testRegister, " ")));
            }
        } else {
            if (inDomain) {
                // in domain cross validation
                result.put(Constants.DIM_READER_TRAIN, ArgumentSequenceSentenceLevelReader.class);
                result.put(Constants.DIM_READER_TRAIN_PARAMS,
                        Arrays.asList(ArgumentSequenceSentenceLevelReader.PARAM_SOURCE_LOCATION,
                                corpusFilePathTrain, ArgumentSequenceSentenceLevelReader.PARAM_PATTERNS,
                                ArgumentSequenceSentenceLevelReader.INCLUDE_PREFIX + "*.xmi",
                                ArgumentSequenceSentenceLevelReader.PARAM_LENIENT, false,
                                ArgumentSequenceSentenceLevelReader.PARAM_DOCUMENT_DOMAIN,
                                documentDomain.toString()));
            } else {
                // get all domains minus documentDomain
                Set<DocumentDomain> trainingDomains = new HashSet<>();
                trainingDomains.addAll(Arrays.asList(DocumentDomain.values()));
                trainingDomains.remove(documentDomain);
                String trainingDomainsAsParam = StringUtils.join(trainingDomains, " ");

                // we have cross-domain train-test (param documentDomain is the test domain)
                result.put(Constants.DIM_READER_TRAIN, ArgumentSequenceSentenceLevelReader.class);
                result.put(Constants.DIM_READER_TRAIN_PARAMS,
                        Arrays.asList(ArgumentSequenceSentenceLevelReader.PARAM_SOURCE_LOCATION,
                                corpusFilePathTrain, ArgumentSequenceSentenceLevelReader.PARAM_PATTERNS,
                                ArgumentSequenceSentenceLevelReader.INCLUDE_PREFIX + "*.xmi",
                                ArgumentSequenceSentenceLevelReader.PARAM_LENIENT, false,
                                ArgumentSequenceSentenceLevelReader.PARAM_DOCUMENT_DOMAIN, trainingDomainsAsParam));

                // we have cross-domain train-test (param documentDomain is the test domain)
                result.put(Constants.DIM_READER_TEST, ArgumentSequenceSentenceLevelReader.class);
                result.put(Constants.DIM_READER_TEST_PARAMS,
                        Arrays.asList(ArgumentSequenceSentenceLevelReader.PARAM_SOURCE_LOCATION,
                                corpusFilePathTrain, ArgumentSequenceSentenceLevelReader.PARAM_PATTERNS,
                                ArgumentSequenceSentenceLevelReader.INCLUDE_PREFIX + "*.xmi",
                                ArgumentSequenceSentenceLevelReader.PARAM_LENIENT, false,
                                ArgumentSequenceSentenceLevelReader.PARAM_DOCUMENT_DOMAIN,
                                documentDomain.toString()));
            }
        }

        return result;
    }

    @SuppressWarnings("unchecked")
    public ParameterSpace getParameterSpace(DocumentDomain documentDomain, boolean inDomain,
            String requiredClusters, Set<DocumentRegister> trainingRegister, Set<DocumentRegister> testRegister) {
        // configure training and test data reader dimension
        Map<String, Object> dimReaders = createDimReaders(documentDomain, inDomain, corpusPath, trainingRegister,
                testRegister);

        Dimension<List<String>> dimFeatureSets = Dimension.create(Constants.DIM_FEATURE_SET,
                FeatureSetHelper.getFeatureSet(featureSet));

        // parameters to configure feature extractors
        Dimension<List<Object>> dimPipelineParameters = Dimension.create(Constants.DIM_PIPELINE_PARAMS,
                asList(new Object[] {
                        // top 50k ngrams
                        LemmaLuceneNGramUFE.PARAM_NGRAM_USE_TOP_K, "10000", LemmaLuceneNGramUFE.PARAM_NGRAM_MIN_N,
                        1, LemmaLuceneNGramUFE.PARAM_NGRAM_MAX_N, 3, LDATopicsFeature.PARAM_LDA_MODEL_FILE,
                        "classpath:/lda/mallet-lda-model-30-topics.bin",
                        ArgumentSpaceFeatureExtractor.PARAM_CENTROIDS,
                        requiredClusters != null ? requiredClusters : "" }));

        // various orders of dependencies of transitions in HMM (max 3)
        Dimension<Integer> dimClassificationArgsT = Dimension.create(SVMHMMTestTask.PARAM_ORDER_T, paramT);

        // various orders of dependencies of emissions in HMM (max 1)
        Dimension<Integer> dimClassificationArgsE = Dimension.create(SVMHMMTestTask.PARAM_ORDER_E, paramE);

        // try different parametrization of C
        Dimension<Double> dimClassificationArgsC = Dimension.create(SVMHMMTestTask.PARAM_C, 5.0);

        return new ParameterSpace(Dimension.createBundle("readers", dimReaders),
                Dimension.create(Constants.DIM_LEARNING_MODE, Constants.LM_SINGLE_LABEL),
                Dimension.create(Constants.DIM_FEATURE_MODE, Constants.FM_SEQUENCE),
                Dimension.create(Constants.DIM_FEATURE_STORE, SparseFeatureStore.class.getName()),
                dimPipelineParameters, dimFeatureSets, dimClassificationArgsE, dimClassificationArgsT,
                dimClassificationArgsC);
    }

    public void runCrossValidation(ParameterSpace pSpace) throws Exception {
        ExperimentCrossValidation batch = new ExperimentCrossValidation("ArgumentSequenceLabelingCV",
                SVMAdapterBatchTokenReport.class, getPreprocessing(), NUM_FOLDS);
        batch.setParameterSpace(pSpace);
        batch.addInnerReport(TokenLevelEvaluationReport.class);
        batch.setExecutionPolicy(BatchTask.ExecutionPolicy.RUN_AGAIN);

        // Run
        Lab.getInstance().run(batch);
    }

    protected AnalysisEngineDescription getPreprocessing() throws ResourceInitializationException {
        return AnalysisEngineFactory.createEngineDescription(NoOpAnnotator.class);
    }

    public void runInDomainCrossValidation(DocumentDomain documentDomain, ParameterSpace pSpace) throws Exception {
        ExperimentCrossValidation batch = new ExperimentCrossValidation(
                "ArgumentSequenceLabeling_InDomain_" + documentDomain.toString(), SVMAdapterBatchTokenReport.class,
                getPreprocessing(), NUM_FOLDS);
        batch.setParameterSpace(pSpace);
        batch.addInnerReport(TokenLevelEvaluationReport.class);
        batch.addInnerReport(TokenLevelMacroFMReport.class);
        batch.setExecutionPolicy(BatchTask.ExecutionPolicy.RUN_AGAIN);

        // Run
        Lab.getInstance().run(batch);
    }

    public void runCrossDomain(DocumentDomain documentDomain, ParameterSpace pSpace) throws Exception {
        ExperimentTrainTest batch = new ExperimentTrainTest(
                "ArgumentSequenceLabeling_CrossDomain_" + documentDomain.toString(),
                SVMAdapterBatchTokenReport.class, getPreprocessing());
        batch.setParameterSpace(pSpace);
        batch.addInnerReport(TokenLevelEvaluationReport.class);
        batch.addInnerReport(TokenLevelMacroFMReport.class);
        batch.setExecutionPolicy(BatchTask.ExecutionPolicy.RUN_AGAIN);

        // Run
        Lab.getInstance().run(batch);
    }

    public void runCrossRegister(ParameterSpace pSpace, String name) throws Exception {
        ExperimentTrainTest batch = new ExperimentTrainTest("ArgumentSequenceLabeling_CrossRegister_" + name,
                SVMAdapterBatchTokenReport.class, getPreprocessing());
        batch.setParameterSpace(pSpace);
        batch.addInnerReport(TokenLevelEvaluationReport.class);
        batch.addInnerReport(TokenLevelMacroFMReport.class);
        batch.setExecutionPolicy(BatchTask.ExecutionPolicy.RUN_AGAIN);

        // Run
        Lab.getInstance().run(batch);
    }

}