Java tutorial
/* * Copyright 2016 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universitt Darmstadt * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.experiments.argumentation.sequence.evaluation; import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; import de.tudarmstadt.ukp.experiments.argumentation.sequence.DocumentDomain; import de.tudarmstadt.ukp.experiments.argumentation.sequence.DocumentRegister; import de.tudarmstadt.ukp.experiments.argumentation.sequence.adapter.SVMAdapterBatchTokenReport; import de.tudarmstadt.ukp.experiments.argumentation.sequence.evaluation.helpers.FeatureSetHelper; import de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.clustering.ArgumentSpaceFeatureExtractor; import de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.lda.LDATopicsFeature; import de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.lexical.LemmaLuceneNGramUFE; import de.tudarmstadt.ukp.experiments.argumentation.sequence.io.ArgumentSequenceSentenceLevelReader; import de.tudarmstadt.ukp.experiments.argumentation.sequence.report.TokenLevelEvaluationReport; import de.tudarmstadt.ukp.experiments.argumentation.sequence.report.TokenLevelMacroFMReport; import de.tudarmstadt.ukp.dkpro.lab.Lab; import de.tudarmstadt.ukp.dkpro.lab.task.Dimension; import de.tudarmstadt.ukp.dkpro.lab.task.ParameterSpace; import de.tudarmstadt.ukp.dkpro.lab.task.impl.BatchTask; import de.tudarmstadt.ukp.dkpro.tc.core.Constants; import de.tudarmstadt.ukp.dkpro.tc.fstore.simple.SparseFeatureStore; import de.tudarmstadt.ukp.dkpro.tc.ml.ExperimentCrossValidation; import de.tudarmstadt.ukp.dkpro.tc.ml.ExperimentTrainTest; import de.tudarmstadt.ukp.dkpro.tc.svmhmm.task.SVMHMMTestTask; import org.apache.commons.lang.StringUtils; import org.apache.uima.analysis_engine.AnalysisEngineDescription; import org.apache.uima.fit.component.NoOpAnnotator; import org.apache.uima.fit.factory.AnalysisEngineFactory; import org.apache.uima.resource.ResourceInitializationException; import java.io.File; import java.text.SimpleDateFormat; import java.util.*; import static java.util.Arrays.asList; /** * @author Ivan Habernal */ public class ArgumentSequenceLabelingEvaluation { private static final int NUM_FOLDS = 10; @Parameter(names = { "--featureSet", "--fs" }, description = "Feature set name (e.g., fs0, fs0fs1fs2, fs3fs4, ...)", required = true) String featureSet; @Parameter(names = { "--corpusPath", "--c" }, description = "Corpus path with XMI files", required = true) String corpusPath; @Parameter(names = { "--outputPath", "--o" }, description = "Main output path (folder)", required = true) String outputPath; @Parameter(names = { "--paramE", "--e" }, description = "Parameter e for SVMHMM") private Integer paramE = 0; @Parameter(names = { "--paramT", "--t" }, description = "Parameter T for SVMHMM") private Integer paramT = 1; @Parameter(names = { "--scenario", "--s" }, description = "Evaluation scenario (cv = cross-validation, cd = cross domain," + " id = in domain)", required = true) String scenario; @Parameter(names = { "--cl", "--clusters" }, description = "Which clusters? Comma-delimited, e.g., s100,a500") String clusters; public static void main(String[] args) throws Exception { ArgumentSequenceLabelingEvaluation evaluation = new ArgumentSequenceLabelingEvaluation(); JCommander jCommander = new JCommander(evaluation, args); try { evaluation.run(); } catch (ParameterException e) { e.printStackTrace(); jCommander.usage(); } } public void run() throws Exception { System.setProperty("org.apache.uima.logger.class", "org.apache.uima.util.impl.Log4jLogger_impl"); SVMHMMTestTask.PRINT_STD_OUT = true; File mainOutputFolder = new File(outputPath); // date String date = new SimpleDateFormat("yyyy-MM-dd-HH-mm-ss").format(new Date(System.currentTimeMillis())); File outputFolder = new File(mainOutputFolder, scenario + "_" + featureSet + "_e" + paramE + "_t" + paramT + "_" + clusters + "_" + date); outputFolder.mkdirs(); System.setProperty("DKPRO_HOME", outputFolder.getAbsolutePath()); if ("cv".equals(scenario)) { // cross validation runCrossValidation(getParameterSpace(null, false, clusters, null, null)); } else if ("id".equals(scenario)) { // in-domain for all domains for (DocumentDomain documentDomain : DocumentDomain.values()) { runInDomainCrossValidation(documentDomain, getParameterSpace(documentDomain, true, clusters, null, null)); } } else if ("cd".equals(scenario)) { // cross-domain for all domains for (DocumentDomain documentDomain : DocumentDomain.values()) { String domainClusters = determineDomainClusters(documentDomain); runCrossDomain(documentDomain, getParameterSpace(documentDomain, false, domainClusters, null, null)); } } else if ("cr".equals(scenario)) { // cross register Set<DocumentRegister> commentsForums = new TreeSet<>(); commentsForums.add(DocumentRegister.ARTCOMMENT); commentsForums.add(DocumentRegister.FORUMPOST); Set<DocumentRegister> blogsArticles = new TreeSet<>(); blogsArticles.add(DocumentRegister.BLOGPOST); blogsArticles.add(DocumentRegister.ARTICLE); // one way runCrossRegister(getParameterSpace(null, false, clusters, commentsForums, blogsArticles), "Train-commentsForums_Test-blogsArticles"); // and the other runCrossRegister(getParameterSpace(null, false, clusters, blogsArticles, commentsForums), "Train-blogsArticles_Test-commentsForums"); } else { throw new IllegalArgumentException("Unknown 'scenario' argument: " + scenario); } } private String determineDomainClusters(DocumentDomain documentDomain) { final Map<DocumentDomain, String> clusterMapping = new HashMap<>(); clusterMapping.put(DocumentDomain.HOMESCHOOLING, "arg-all-minus-hs,sent-all-minus-hs"); clusterMapping.put(DocumentDomain.MAINSTREAMING, "arg-all-minus-ms,sent-all-minus-ms"); clusterMapping.put(DocumentDomain.PRAYER_IN_SCHOOLS, "arg-all-minus-pis,sent-all-minus-pis"); clusterMapping.put(DocumentDomain.PUBLIC_PRIVATE_SCHOOLS, "arg-all-minus-pps,sent-all-minus-pps"); clusterMapping.put(DocumentDomain.REDSHIRTING, "arg-all-minus-rs,sent-all-minus-rs"); clusterMapping.put(DocumentDomain.SINGLE_SEX_EDUCATION, "arg-all-minus-sse,sent-all-minus-sse"); return clusterMapping.get(documentDomain); } public Map<String, Object> createDimReaders(DocumentDomain documentDomain, boolean inDomain, String corpusFilePathTrain, Set<DocumentRegister> trainingRegister, Set<DocumentRegister> testRegister) { Map<String, Object> result = new HashMap<>(); // we take all documents regardless of domain if (documentDomain == null) { // normal CV if (trainingRegister == null && testRegister == null) { result.put(Constants.DIM_READER_TRAIN, ArgumentSequenceSentenceLevelReader.class); result.put(Constants.DIM_READER_TRAIN_PARAMS, Arrays.asList(ArgumentSequenceSentenceLevelReader.PARAM_SOURCE_LOCATION, corpusFilePathTrain, ArgumentSequenceSentenceLevelReader.PARAM_PATTERNS, ArgumentSequenceSentenceLevelReader.INCLUDE_PREFIX + "*.xmi", ArgumentSequenceSentenceLevelReader.PARAM_LENIENT, false)); } else { // we have cross-register train-test result.put(Constants.DIM_READER_TRAIN, ArgumentSequenceSentenceLevelReader.class); result.put(Constants.DIM_READER_TRAIN_PARAMS, Arrays.asList(ArgumentSequenceSentenceLevelReader.PARAM_SOURCE_LOCATION, corpusFilePathTrain, ArgumentSequenceSentenceLevelReader.PARAM_PATTERNS, ArgumentSequenceSentenceLevelReader.INCLUDE_PREFIX + "*.xmi", ArgumentSequenceSentenceLevelReader.PARAM_DOCUMENT_REGISTER, StringUtils.join(trainingRegister, " "))); result.put(Constants.DIM_READER_TEST, ArgumentSequenceSentenceLevelReader.class); result.put(Constants.DIM_READER_TEST_PARAMS, Arrays.asList(ArgumentSequenceSentenceLevelReader.PARAM_SOURCE_LOCATION, corpusFilePathTrain, ArgumentSequenceSentenceLevelReader.PARAM_PATTERNS, ArgumentSequenceSentenceLevelReader.INCLUDE_PREFIX + "*.xmi", ArgumentSequenceSentenceLevelReader.PARAM_LENIENT, false, ArgumentSequenceSentenceLevelReader.PARAM_DOCUMENT_REGISTER, StringUtils.join(testRegister, " "))); } } else { if (inDomain) { // in domain cross validation result.put(Constants.DIM_READER_TRAIN, ArgumentSequenceSentenceLevelReader.class); result.put(Constants.DIM_READER_TRAIN_PARAMS, Arrays.asList(ArgumentSequenceSentenceLevelReader.PARAM_SOURCE_LOCATION, corpusFilePathTrain, ArgumentSequenceSentenceLevelReader.PARAM_PATTERNS, ArgumentSequenceSentenceLevelReader.INCLUDE_PREFIX + "*.xmi", ArgumentSequenceSentenceLevelReader.PARAM_LENIENT, false, ArgumentSequenceSentenceLevelReader.PARAM_DOCUMENT_DOMAIN, documentDomain.toString())); } else { // get all domains minus documentDomain Set<DocumentDomain> trainingDomains = new HashSet<>(); trainingDomains.addAll(Arrays.asList(DocumentDomain.values())); trainingDomains.remove(documentDomain); String trainingDomainsAsParam = StringUtils.join(trainingDomains, " "); // we have cross-domain train-test (param documentDomain is the test domain) result.put(Constants.DIM_READER_TRAIN, ArgumentSequenceSentenceLevelReader.class); result.put(Constants.DIM_READER_TRAIN_PARAMS, Arrays.asList(ArgumentSequenceSentenceLevelReader.PARAM_SOURCE_LOCATION, corpusFilePathTrain, ArgumentSequenceSentenceLevelReader.PARAM_PATTERNS, ArgumentSequenceSentenceLevelReader.INCLUDE_PREFIX + "*.xmi", ArgumentSequenceSentenceLevelReader.PARAM_LENIENT, false, ArgumentSequenceSentenceLevelReader.PARAM_DOCUMENT_DOMAIN, trainingDomainsAsParam)); // we have cross-domain train-test (param documentDomain is the test domain) result.put(Constants.DIM_READER_TEST, ArgumentSequenceSentenceLevelReader.class); result.put(Constants.DIM_READER_TEST_PARAMS, Arrays.asList(ArgumentSequenceSentenceLevelReader.PARAM_SOURCE_LOCATION, corpusFilePathTrain, ArgumentSequenceSentenceLevelReader.PARAM_PATTERNS, ArgumentSequenceSentenceLevelReader.INCLUDE_PREFIX + "*.xmi", ArgumentSequenceSentenceLevelReader.PARAM_LENIENT, false, ArgumentSequenceSentenceLevelReader.PARAM_DOCUMENT_DOMAIN, documentDomain.toString())); } } return result; } @SuppressWarnings("unchecked") public ParameterSpace getParameterSpace(DocumentDomain documentDomain, boolean inDomain, String requiredClusters, Set<DocumentRegister> trainingRegister, Set<DocumentRegister> testRegister) { // configure training and test data reader dimension Map<String, Object> dimReaders = createDimReaders(documentDomain, inDomain, corpusPath, trainingRegister, testRegister); Dimension<List<String>> dimFeatureSets = Dimension.create(Constants.DIM_FEATURE_SET, FeatureSetHelper.getFeatureSet(featureSet)); // parameters to configure feature extractors Dimension<List<Object>> dimPipelineParameters = Dimension.create(Constants.DIM_PIPELINE_PARAMS, asList(new Object[] { // top 50k ngrams LemmaLuceneNGramUFE.PARAM_NGRAM_USE_TOP_K, "10000", LemmaLuceneNGramUFE.PARAM_NGRAM_MIN_N, 1, LemmaLuceneNGramUFE.PARAM_NGRAM_MAX_N, 3, LDATopicsFeature.PARAM_LDA_MODEL_FILE, "classpath:/lda/mallet-lda-model-30-topics.bin", ArgumentSpaceFeatureExtractor.PARAM_CENTROIDS, requiredClusters != null ? requiredClusters : "" })); // various orders of dependencies of transitions in HMM (max 3) Dimension<Integer> dimClassificationArgsT = Dimension.create(SVMHMMTestTask.PARAM_ORDER_T, paramT); // various orders of dependencies of emissions in HMM (max 1) Dimension<Integer> dimClassificationArgsE = Dimension.create(SVMHMMTestTask.PARAM_ORDER_E, paramE); // try different parametrization of C Dimension<Double> dimClassificationArgsC = Dimension.create(SVMHMMTestTask.PARAM_C, 5.0); return new ParameterSpace(Dimension.createBundle("readers", dimReaders), Dimension.create(Constants.DIM_LEARNING_MODE, Constants.LM_SINGLE_LABEL), Dimension.create(Constants.DIM_FEATURE_MODE, Constants.FM_SEQUENCE), Dimension.create(Constants.DIM_FEATURE_STORE, SparseFeatureStore.class.getName()), dimPipelineParameters, dimFeatureSets, dimClassificationArgsE, dimClassificationArgsT, dimClassificationArgsC); } public void runCrossValidation(ParameterSpace pSpace) throws Exception { ExperimentCrossValidation batch = new ExperimentCrossValidation("ArgumentSequenceLabelingCV", SVMAdapterBatchTokenReport.class, getPreprocessing(), NUM_FOLDS); batch.setParameterSpace(pSpace); batch.addInnerReport(TokenLevelEvaluationReport.class); batch.setExecutionPolicy(BatchTask.ExecutionPolicy.RUN_AGAIN); // Run Lab.getInstance().run(batch); } protected AnalysisEngineDescription getPreprocessing() throws ResourceInitializationException { return AnalysisEngineFactory.createEngineDescription(NoOpAnnotator.class); } public void runInDomainCrossValidation(DocumentDomain documentDomain, ParameterSpace pSpace) throws Exception { ExperimentCrossValidation batch = new ExperimentCrossValidation( "ArgumentSequenceLabeling_InDomain_" + documentDomain.toString(), SVMAdapterBatchTokenReport.class, getPreprocessing(), NUM_FOLDS); batch.setParameterSpace(pSpace); batch.addInnerReport(TokenLevelEvaluationReport.class); batch.addInnerReport(TokenLevelMacroFMReport.class); batch.setExecutionPolicy(BatchTask.ExecutionPolicy.RUN_AGAIN); // Run Lab.getInstance().run(batch); } public void runCrossDomain(DocumentDomain documentDomain, ParameterSpace pSpace) throws Exception { ExperimentTrainTest batch = new ExperimentTrainTest( "ArgumentSequenceLabeling_CrossDomain_" + documentDomain.toString(), SVMAdapterBatchTokenReport.class, getPreprocessing()); batch.setParameterSpace(pSpace); batch.addInnerReport(TokenLevelEvaluationReport.class); batch.addInnerReport(TokenLevelMacroFMReport.class); batch.setExecutionPolicy(BatchTask.ExecutionPolicy.RUN_AGAIN); // Run Lab.getInstance().run(batch); } public void runCrossRegister(ParameterSpace pSpace, String name) throws Exception { ExperimentTrainTest batch = new ExperimentTrainTest("ArgumentSequenceLabeling_CrossRegister_" + name, SVMAdapterBatchTokenReport.class, getPreprocessing()); batch.setParameterSpace(pSpace); batch.addInnerReport(TokenLevelEvaluationReport.class); batch.addInnerReport(TokenLevelMacroFMReport.class); batch.setExecutionPolicy(BatchTask.ExecutionPolicy.RUN_AGAIN); // Run Lab.getInstance().run(batch); } }