de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.discourse.PDTBDiscourseFeatures.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.discourse.PDTBDiscourseFeatures.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.discourse;

import de.tudarmstadt.ukp.dkpro.argumentation.misc.uima.JCasUtil2;
import de.tudarmstadt.ukp.experiments.argumentation.sequence.feature.AbstractUnitSentenceFeatureGenerator;
import de.tudarmstadt.ukp.dkpro.core.api.discourse.DiscourseDumpWriter;
import de.tudarmstadt.ukp.dkpro.core.api.frequency.util.FrequencyDistribution;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;
import de.tudarmstadt.ukp.dkpro.core.discourse.pdtb.DiscourseArgument;
import de.tudarmstadt.ukp.dkpro.core.discourse.pdtb.DiscourseAttribution;
import de.tudarmstadt.ukp.dkpro.core.discourse.pdtb.DiscourseConnective;
import de.tudarmstadt.ukp.dkpro.core.discourse.pdtb.DiscourseRelation;
import de.tudarmstadt.ukp.dkpro.core.io.xmi.XmiReader;
import de.tudarmstadt.ukp.dkpro.tc.api.exception.TextClassificationException;
import de.tudarmstadt.ukp.dkpro.tc.api.features.Feature;
import org.apache.commons.lang.StringUtils;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.fit.pipeline.SimplePipeline;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

/**
 * @author Ivan Habernal
 */
public class PDTBDiscourseFeatures extends AbstractUnitSentenceFeatureGenerator {
    private static final String FN_ARG_NUMBER = "discourseArgNumber_";
    private static final String FN_ARG_TYPE = "discourseArgType_";
    private static final String FN_ARG_TYPE_NUMBER = "discourseArgTypeNumber_";
    private static final String FN_HAS_DISCOURSE_CONN = "discourseHasConnectives_";
    private static final String FN_ATTRIBUTION = "discourseAttribution_";

    @Override
    protected List<Feature> extract(JCas jCas, Sentence sentence, String sentencePrefix)
            throws TextClassificationException {
        List<Feature> result = new ArrayList<>();

        FrequencyDistribution<String> discourseFeaturesFreq = new FrequencyDistribution<>();
        FrequencyDistribution<String> discourseFeaturesBinary = new FrequencyDistribution<>();

        // discourse relations sorted by id
        Map<Integer, DiscourseRelation> map = new TreeMap<>();
        for (DiscourseRelation discourseRelation : JCasUtil2.selectOverlapping(DiscourseRelation.class, sentence,
                jCas)) {
            map.put(discourseRelation.getRelationId(), discourseRelation);
        }

        for (DiscourseArgument argument : JCasUtil2.selectOverlapping(DiscourseArgument.class, sentence, jCas)) {

            // arg number, type, and both
            int argumentNumber = argument.getArgumentNumber();
            String argumentType = argument.getArgumentType();

            // relation type
            DiscourseRelation discourseRelation = map.get(argument.getParentRelationId());
            String relationType = discourseRelation.getClass().getSimpleName();

            discourseFeaturesFreq.addSample("discourse_relType_argNo_" + relationType + "_" + argumentNumber, 1);

            if (argumentType != null) {
                discourseFeaturesFreq.addSample("discourse_relType_argType_" + relationType + "_" + argumentType,
                        1);
                discourseFeaturesFreq.addSample("discourse_relType_argType_argNo_" + relationType + "_"
                        + argumentType + "_" + argumentNumber, 1);
            }

            // discourse connectives
            List<DiscourseConnective> discourseConnectives = JCasUtil2.selectOverlapping(DiscourseConnective.class,
                    sentence, jCas);

            if (!discourseConnectives.isEmpty()) {
                discourseFeaturesBinary.addSample(FN_HAS_DISCOURSE_CONN, 1);
            }

            // connective types
            for (DiscourseConnective connective : discourseConnectives) {
                discourseFeaturesBinary.addSample("discourse_connectiveTypes_" + connective.getConnectiveType(), 1);
            }

            // attribution
            for (DiscourseAttribution discourseAttribution : JCasUtil.selectCovered(DiscourseAttribution.class,
                    sentence)) {
                discourseFeaturesBinary
                        .addSample("discourse_attributionTokens_" + glueAttributionTokens(discourseAttribution), 1);
            }
        }

        //        System.out.println("Binary:\n" + discourseFeaturesBinary);
        //        System.out.println("Freq:\n" + discourseFeaturesFreq);

        //        for (DiscourseRelation relation : JCasUtil2.selectOverlapping(DiscourseRelation.class,
        //                sentence, jCas)) {
        //            debugRelation(relation);
        //        }

        // create binary features
        for (String key : discourseFeaturesBinary.getKeys()) {
            result.add(new Feature(sentencePrefix + key, 1));
        }

        // create binary features
        for (String key : discourseFeaturesFreq.getKeys()) {
            result.add(new Feature(sentencePrefix + key, discourseFeaturesFreq.getCount(key)));
        }

        return result;
    }

    private static String glueAttributionTokens(DiscourseAttribution attribution) {
        final int maxTokens = 3;

        List<String> tokens = new ArrayList<>();
        for (Token token : JCasUtil.selectCovered(Token.class, attribution)) {
            tokens.add(token.getLemma().getCoveredText().toLowerCase());
        }

        tokens = tokens.subList(0, tokens.size() > maxTokens ? maxTokens : tokens.size() - 1);

        return StringUtils.join(tokens, "_");
    }

    public static void main(String[] args) throws Exception {
        final String corpusFilePathTrain = args[0];
        SimplePipeline.runPipeline(
                CollectionReaderFactory.createReaderDescription(XmiReader.class, XmiReader.PARAM_LENIENT, false,
                        XmiReader.PARAM_SOURCE_LOCATION, corpusFilePathTrain, XmiReader.PARAM_PATTERNS,
                        XmiReader.INCLUDE_PREFIX + "*.xmi"),
                AnalysisEngineFactory.createEngineDescription(DiscourseDumpWriter.class));
    }

}