de.tudarmstadt.ukp.dariah.IO.DARIAHWriter.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dariah.IO.DARIAHWriter.java

Source

/*******************************************************************************
 * Copyright 2015
 * Ubiquitous Knowledge Processing (UKP) Lab 
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ******************************************************************************/

package de.tudarmstadt.ukp.dariah.IO;

import static org.apache.commons.io.IOUtils.closeQuietly;
import static org.apache.uima.fit.util.JCasUtil.select;
import static org.apache.uima.fit.util.JCasUtil.selectCovered;

import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import java.util.logging.Logger;

import org.apache.commons.lang.StringUtils;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.TypeCapability;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.FSArray;

import de.tudarmstadt.ukp.dkpro.core.api.coref.type.CoreferenceChain;
import de.tudarmstadt.ukp.dkpro.core.api.coref.type.CoreferenceLink;
import de.tudarmstadt.ukp.dkpro.core.api.io.JCasFileWriter_ImplBase;
import de.tudarmstadt.ukp.dkpro.core.api.lexmorph.type.morph.Morpheme;
import de.tudarmstadt.ukp.dkpro.core.api.lexmorph.type.pos.POS;
import de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData;
import de.tudarmstadt.ukp.dkpro.core.api.ner.type.NamedEntity;
import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Paragraph;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;
import de.tudarmstadt.ukp.dkpro.core.api.semantics.type.SemanticArgument;
import de.tudarmstadt.ukp.dkpro.core.api.semantics.type.SemanticPredicate;
import de.tudarmstadt.ukp.dkpro.core.api.syntax.type.chunk.Chunk;
import de.tudarmstadt.ukp.dkpro.core.api.syntax.type.constituent.ROOT;
import de.tudarmstadt.ukp.dkpro.core.api.syntax.type.dependency.Dependency;
import de.tudarmstadt.ukp.dkpro.core.io.penntree.PennTreeNode;
import de.tudarmstadt.ukp.dkpro.core.io.penntree.PennTreeUtils;
import de.tudarmstadt.ukp.dariah.type.DirectSpeech;
import de.tudarmstadt.ukp.dariah.type.Hyphenation;
import de.tudarmstadt.ukp.dariah.type.Section;

/**
 * @author Nils Reimers
 */
@TypeCapability(inputs = { "de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData",
        "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence",
        "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token",
        "de.tudarmstadt.ukp.dkpro.core.api.lexmorph.type.morph.Morpheme",
        "de.tudarmstadt.ukp.dkpro.core.api.lexmorph.type.pos.POS",
        "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Lemma",
        "de.tudarmstadt.ukp.dkpro.core.api.syntax.type.dependency.Dependency" })
public class DARIAHWriter extends JCasFileWriter_ImplBase {
    private static final String UNUSED = "_";
    private static final int UNUSED_INT = -2;

    /**
     * Name of configuration parameter that contains the character encoding used by the input files.
     */
    public static final String PARAM_ENCODING = ComponentParameters.PARAM_SOURCE_ENCODING;
    @ConfigurationParameter(name = PARAM_ENCODING, mandatory = true, defaultValue = "UTF-8")
    private String encoding;

    public static final String PARAM_FILENAME_SUFFIX = "filenameSuffix";
    @ConfigurationParameter(name = PARAM_FILENAME_SUFFIX, mandatory = true, defaultValue = ".csv")
    private String filenameSuffix;

    public static final String PARAM_WRITE_POS = ComponentParameters.PARAM_WRITE_POS;
    @ConfigurationParameter(name = PARAM_WRITE_POS, mandatory = true, defaultValue = "true")
    private boolean writePos;

    public static final String PARAM_WRITE_MORPH = "writeMorph";
    @ConfigurationParameter(name = PARAM_WRITE_MORPH, mandatory = true, defaultValue = "true")
    private boolean writeMorph;

    public static final String PARAM_WRITE_HYPHENATION = "writeHyphenation";
    @ConfigurationParameter(name = PARAM_WRITE_MORPH, mandatory = true, defaultValue = "true")
    private boolean writeHyphenation;

    public static final String PARAM_WRITE_LEMMA = ComponentParameters.PARAM_WRITE_LEMMA;
    @ConfigurationParameter(name = PARAM_WRITE_LEMMA, mandatory = true, defaultValue = "true")
    private boolean writeLemma;

    public static final String PARAM_WRITE_DEPENDENCY = ComponentParameters.PARAM_WRITE_DEPENDENCY;
    @ConfigurationParameter(name = PARAM_WRITE_DEPENDENCY, mandatory = true, defaultValue = "true")
    private boolean writeDependency;

    @Override
    public void process(JCas aJCas) throws AnalysisEngineProcessException {

        PrintWriter out = null;
        try {

            out = new PrintWriter(new OutputStreamWriter(getOutputStream(aJCas, filenameSuffix), encoding));
            convert(aJCas, out);
        } catch (Exception e) {
            throw new AnalysisEngineProcessException(e);
        } finally {
            closeQuietly(out);
        }

    }

    private void convert(JCas aJCas, PrintWriter aOut) {
        int paragraphId = 0, sentenceId = 0, tokenId = 0;

        Map<Token, Collection<NamedEntity>> neCoveringMap = JCasUtil.indexCovering(aJCas, Token.class,
                NamedEntity.class);
        Map<Token, Collection<Chunk>> chunksCoveringMap = JCasUtil.indexCovering(aJCas, Token.class, Chunk.class);

        Map<Token, Collection<Section>> sectionCoveringMap = JCasUtil.indexCovering(aJCas, Token.class,
                Section.class);
        Map<Token, Collection<DirectSpeech>> directSpeechCoveringMap = JCasUtil.indexCovering(aJCas, Token.class,
                DirectSpeech.class);

        Map<Token, Collection<SemanticPredicate>> predIdx = JCasUtil.indexCovered(aJCas, Token.class,
                SemanticPredicate.class);

        Map<SemanticPredicate, Collection<Token>> pred2TokenIdx = JCasUtil.indexCovering(aJCas,
                SemanticPredicate.class, Token.class);

        Map<SemanticArgument, Collection<Token>> argIdx = JCasUtil.indexCovered(aJCas, SemanticArgument.class,
                Token.class);

        //Coreference
        Map<Token, Collection<CoreferenceLink>> corefLinksCoveringMap = JCasUtil.indexCovering(aJCas, Token.class,
                CoreferenceLink.class);
        HashMap<CoreferenceLink, CoreferenceChain> linkToChainMap = new HashMap<>();
        HashMap<CoreferenceChain, Integer> corefChainToIntMap = new HashMap<>();

        int corefChainId = 0;
        for (CoreferenceChain chain : JCasUtil.select(aJCas, CoreferenceChain.class)) {

            CoreferenceLink link = chain.getFirst();
            int count = 0;
            while (link != null) {
                linkToChainMap.put(link, chain);
                link = link.getNext();
                count++;
            }
            if (count > 0) {
                corefChainToIntMap.put(chain, corefChainId);
                corefChainId++;
            }
        }

        HashMap<Token, Row> ctokens = new LinkedHashMap<Token, Row>();

        Collection<Paragraph> paragraphs = select(aJCas, Paragraph.class);
        Collection<Sentence> sentences = select(aJCas, Sentence.class);
        TreeSet<Integer> sentenceEnds = new TreeSet<>();

        for (Sentence sentence : sentences) {
            sentenceEnds.add(sentence.getEnd());
        }

        for (Paragraph paragraph : paragraphs) {
            sentenceEnds.add(paragraph.getEnd());
        }

        for (Paragraph para : select(aJCas, Paragraph.class)) {

            for (Sentence sentence : selectCovered(Sentence.class, para)) {

                // Tokens
                List<Token> tokens = selectCovered(Token.class, sentence);

                // Check if we should try to include the morphology in output
                List<Morpheme> morphologies = selectCovered(Morpheme.class, sentence);
                boolean useMorphology = tokens.size() == morphologies.size();

                // Check if we should try to include the morphology in output
                List<Hyphenation> hyphenations = selectCovered(Hyphenation.class, sentence);
                boolean useHyphenation = tokens.size() == hyphenations.size();

                //Parsing information
                String[] parseFragments = null;
                List<ROOT> root = selectCovered(ROOT.class, sentence);
                if (root.size() == 1) {
                    PennTreeNode rootNode = PennTreeUtils.convertPennTree(root.get(0));
                    if ("ROOT".equals(rootNode.getLabel())) {
                        rootNode.setLabel("TOP");
                    }
                    parseFragments = toPrettyPennTree(rootNode);
                }
                boolean useParseFragements = (parseFragments != null && parseFragments.length == tokens.size());

                List<SemanticPredicate> preds = selectCovered(SemanticPredicate.class, sentence);

                for (int i = 0; i < tokens.size(); i++) {
                    Row row = new Row();

                    row.paragraphId = paragraphId;
                    row.sentenceId = sentenceId;
                    row.tokenId = tokenId;
                    row.token = tokens.get(i);
                    row.args = new SemanticArgument[preds.size()];

                    if (useParseFragements) {
                        row.parseFragment = parseFragments[i];
                    }

                    if (useMorphology) {
                        row.morphology = morphologies.get(i);
                    }

                    if (useHyphenation) {
                        row.hyphenation = hyphenations.get(i);
                    }

                    // Section ID
                    Collection<Section> section = sectionCoveringMap.get(row.token);
                    if (section.size() > 0)
                        row.sectionId = section.toArray(new Section[0])[0].getValue();

                    // Named entities
                    Collection<NamedEntity> ne = neCoveringMap.get(row.token);
                    if (ne.size() > 0)
                        row.ne = ne.toArray(new NamedEntity[0])[0];

                    // Chunk
                    Collection<Chunk> chunks = chunksCoveringMap.get(row.token);
                    if (chunks.size() > 0)
                        row.chunk = chunks.toArray(new Chunk[0])[0];

                    //Quote annotation
                    Collection<DirectSpeech> ds = directSpeechCoveringMap.get(row.token);
                    if (ds.size() > 0)
                        row.directSpeech = ds.toArray(new DirectSpeech[0])[0];

                    //Coref
                    Collection<CoreferenceLink> corefLinks = corefLinksCoveringMap.get(row.token);
                    row.corefChains = UNUSED;
                    if (corefLinks.size() > 0) {

                        String[] chainIds = new String[corefLinks.size()];
                        //                  StringBuilder chainIdsStr = new StringBuilder();

                        int k = 0;
                        for (CoreferenceLink link : corefLinks) {
                            CoreferenceChain chain = linkToChainMap.get(link);
                            int chainId = corefChainToIntMap.get(chain);

                            //chainIds[k++] = chainId;

                            String BIOMarker = "I";
                            if (link.getCoveredText().substring(0, row.token.getCoveredText().length())
                                    .equals(row.token.getCoveredText())) {
                                BIOMarker = "B";
                            }
                            chainIds[k++] = BIOMarker + "-" + chainId;
                        }

                        //Sort without the BIO marker
                        Arrays.sort(chainIds, new Comparator<String>() {
                            public int compare(String idx1, String idx2) {
                                Integer id1 = new Integer(idx1.substring(2));
                                Integer id2 = new Integer(idx2.substring(2));

                                return Integer.compare(id1, id2);
                            }
                        });

                        StringBuilder chainIdsStr = new StringBuilder();
                        for (String chainId : chainIds) {
                            chainIdsStr.append(chainId + ",");
                        }

                        row.corefChains = chainIdsStr.substring(0, chainIdsStr.length() - 1);
                    }

                    //Predicate
                    Collection<SemanticPredicate> predsForToken = predIdx.get(row.token);
                    if (predsForToken != null && !predsForToken.isEmpty()) {
                        row.pred = predsForToken.iterator().next();
                    }

                    ctokens.put(row.token, row);
                    tokenId++;
                }

                // Dependencies
                for (Dependency rel : selectCovered(Dependency.class, sentence)) {
                    ctokens.get(rel.getDependent()).deprel = rel;
                }

                // Semantic arguments
                for (int p = 0; p < preds.size(); p++) {
                    FSArray args = preds.get(p).getArguments();

                    //Set the column position info
                    Collection<Token> tokensOfPredicate = pred2TokenIdx.get(preds.get(p));
                    for (Token t : tokensOfPredicate) {
                        Row row = ctokens.get(t);
                        row.semanticArgIndex = p;
                    }

                    //Set the arguments information
                    for (SemanticArgument arg : select(args, SemanticArgument.class)) {
                        for (Token t : argIdx.get(arg)) {
                            Row row = ctokens.get(t);
                            row.args[p] = arg;
                        }
                    }
                }

                sentenceId++;
            }
            paragraphId++;
        }

        // Write to output file
        int maxPredArguments = 0;
        for (Row row : ctokens.values()) {
            maxPredArguments = Math.max(maxPredArguments, row.args.length);
        }

        aOut.printf("%s\n", StringUtils.join(getHeader(maxPredArguments), "\t").trim());

        for (Row row : ctokens.values()) {
            String[] output = getData(ctokens, maxPredArguments, row);
            aOut.printf("%s\n", StringUtils.join(output, "\t").trim());
        }

    }

    private String[] getData(HashMap<Token, Row> ctokens, int numPredArguments, Row row) {
        String lemma = UNUSED;
        if (writeLemma && (row.token.getLemma() != null)) {
            lemma = row.token.getLemma().getValue();
        }

        String pos = UNUSED;
        String cpos = UNUSED;
        if (writePos && (row.token.getPos() != null)) {
            POS posAnno = row.token.getPos();
            pos = posAnno.getPosValue();
            if (!posAnno.getClass().equals(POS.class)) {
                cpos = posAnno.getClass().getSimpleName();
            } else {
                cpos = pos;
            }
        }

        int headId = UNUSED_INT;
        String deprel = UNUSED;
        if (writeDependency && (row.deprel != null)) {
            deprel = row.deprel.getDependencyType();
            headId = ctokens.get(row.deprel.getGovernor()).tokenId;
            if (headId == row.tokenId) {
                // ROOT dependencies may be modeled as a loop, ignore these.
                headId = -1;
            }
        }

        String head = UNUSED;
        if (headId != UNUSED_INT) {
            head = Integer.toString(headId);
        }

        String morphology = UNUSED;
        if (writeMorph && (row.morphology != null)) {
            morphology = row.morphology.getMorphTag();
        }

        String hyphenation = UNUSED;
        if (writeHyphenation && (row.hyphenation != null)) {
            hyphenation = row.hyphenation.getValue();
        }

        String chunk = UNUSED;
        if (row.chunk != null) {
            chunk = row.chunk.getChunkValue(); //Remove IOB tagging from Stanford Tagger
            //BIO-Tagging, B for beginning tag, I for all intermediate tags
            if (row.chunk.getBegin() == row.token.getBegin()) {
                chunk = "B-" + chunk;
            } else {
                chunk = "I-" + chunk;
            }
        }

        String ne = UNUSED;

        if (row.ne != null) {
            if (row.ne.getValue().length() > 1 && row.ne.getValue().substring(1, 2).equals("-"))
                ne = row.ne.getValue().substring(2); //Remove IOB tagging from Stanford Tagger
            else
                ne = row.ne.getValue();

            //BIO-Tagging, B for beginning tag, I for all intermediate tags
            if (row.ne.getBegin() == row.token.getBegin()) {
                ne = "B-" + ne;
            } else {
                ne = "I-" + ne;
            }
        }

        String quoteMarker = "0";
        if (row.directSpeech != null) {
            quoteMarker = "1";
        }

        String parseFragment = UNUSED;
        if (row.parseFragment != null)
            parseFragment = row.parseFragment;

        String fillpred = UNUSED;
        String pred = UNUSED;
        String semanticArgumentIndex = UNUSED;

        if (row.pred != null) {
            fillpred = "Y";
            pred = row.pred.getCategory();
            semanticArgumentIndex = String.valueOf(row.semanticArgIndex);
        }

        String[] apreds = new String[numPredArguments];
        for (int i = 0; i < apreds.length; i++) {
            apreds[i] = UNUSED;

            if (row.args.length > i && row.args[i] != null) {
                apreds[i] = row.args[i].getRole();
            }
        }

        String[] output = new String[] { row.sectionId, Integer.toString(row.paragraphId),
                Integer.toString(row.sentenceId), Integer.toString(row.tokenId),
                Integer.toString(row.token.getBegin()), Integer.toString(row.token.getEnd()),
                row.token.getCoveredText(), lemma, cpos, pos, chunk, morphology, hyphenation, head, deprel, ne,
                quoteMarker, row.corefChains, parseFragment, pred, semanticArgumentIndex,
                StringUtils.join(apreds, '\t')

        };
        return output;
    }

    private String[] getHeader(int numPredArguments) {
        List<String> header = new LinkedList<String>(Arrays.asList(new String[] { "SectionId", "ParagraphId",
                "SentenceId", "TokenId", "Begin", "End", "Token", "Lemma", "CPOS", "POS", "Chunk", "Morphology",
                "Hyphenation", "DependencyHead", "DependencyRelation", "NamedEntity", "QuoteMarker",
                "CoreferenceChainIds", "SyntaxTree", "Predicate", "SemanticArgumentIndex" }));

        for (int i = 0; i < numPredArguments; i++) {
            header.add("SemanticArgument" + i);
        }

        return header.toArray(new String[0]);
    }

    public static String[] toPrettyPennTree(PennTreeNode aNode) {
        StringBuilder sb = new StringBuilder();
        toPennTree(sb, aNode);
        return sb.toString().trim().split("\n+");
    }

    private static void toPennTree(StringBuilder aSb, PennTreeNode aNode) {
        // This is a "(Label Token)"
        if (aNode.isPreTerminal()) {
            aSb.append("*");
        } else {
            aSb.append('(');
            aSb.append(aNode.getLabel());

            Iterator<PennTreeNode> i = aNode.getChildren().iterator();
            while (i.hasNext()) {
                PennTreeNode child = i.next();
                toPennTree(aSb, child);
                if (i.hasNext()) {
                    aSb.append("\n");
                }
            }

            aSb.append(')');
        }
    }

    private static final class Row {
        String sectionId = "_";
        int paragraphId;
        int sentenceId;
        int tokenId;
        Token token;
        Chunk chunk;
        Morpheme morphology;
        Hyphenation hyphenation;
        Dependency deprel;
        NamedEntity ne;
        DirectSpeech directSpeech;
        String parseFragment;
        String corefChains;
        SemanticPredicate pred;
        int semanticArgIndex;
        SemanticArgument[] args;
    }
}