de.tudarmstadt.ukp.experiments.argumentation.clustering.entropy.ClusterTopicMatrixGenerator.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.experiments.argumentation.clustering.entropy.ClusterTopicMatrixGenerator.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.experiments.argumentation.clustering.entropy;

import de.tudarmstadt.ukp.experiments.argumentation.clustering.ClusteringUtils;
import de.tudarmstadt.ukp.experiments.argumentation.clustering.VectorUtils;
import de.tudarmstadt.ukp.dkpro.argumentation.type.DebateArgumentMetaData;
import de.tudarmstadt.ukp.dkpro.argumentation.type.Embeddings;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.Vector;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasConsumer_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;

import java.io.*;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.TreeMap;

/**
 * @author Ivan Habernal
 */
public class ClusterTopicMatrixGenerator extends JCasConsumer_ImplBase {
    /**
     * Output from {@link ClusterCentroidsMain}
     */
    public static final String PARAM_CENTROIDS_FILE = "centroidsFile";
    @ConfigurationParameter(name = PARAM_CENTROIDS_FILE, mandatory = true)
    File centroidsFile;

    public static final String PARAM_OUTPUT_FILE = "outputFile";
    @ConfigurationParameter(name = PARAM_OUTPUT_FILE, mandatory = false)
    File outputFile;

    /**
     * Output mapping (debate -> topic distribution) from {@link DebateTopicExtractorMain}
     */
    public static final String PARAM_DEBATE_TOPIC_MAP_FILE = "debateTopicMapFile";
    @ConfigurationParameter(name = PARAM_DEBATE_TOPIC_MAP_FILE, mandatory = true)
    File debateTopicMapFile;

    Map<String, List<Double>> debateTopicMap;

    TreeMap<Integer, Vector> centroids;

    /**
     * A real-number matrix (clusters x topics) counts or probabilities
     */
    protected Matrix clusterTopicMatrix;

    @Override
    @SuppressWarnings("unchecked")
    public void initialize(UimaContext context) throws ResourceInitializationException {
        super.initialize(context);

        try {
            // load mapping debateURL -> topic distribution
            debateTopicMap = (Map<String, List<Double>>) new ObjectInputStream(
                    new FileInputStream(debateTopicMapFile)).readObject();

            // load centroids
            centroids = (TreeMap<Integer, Vector>) new ObjectInputStream(new FileInputStream(centroidsFile))
                    .readObject();

            // initialize matrix
            int numTopics = debateTopicMap.entrySet().iterator().next().getValue().size();
            clusterTopicMatrix = new DenseMatrix(centroids.size(), numTopics);
        } catch (IOException | ClassNotFoundException e) {
            throw new ResourceInitializationException(e);
        }
    }

    @Override
    public void process(JCas aJCas) throws AnalysisEngineProcessException {
        // get debate topic distribution
        DebateArgumentMetaData debate = JCasUtil.selectSingle(aJCas, DebateArgumentMetaData.class);
        String debateUrl = debate.getDebateUrl();

        // get it from the cache
        List<Double> topicDistribution = debateTopicMap.get(debateUrl);

        if (topicDistribution == null) {
            throw new IllegalStateException("Cannot find topic distribution for debate " + debateUrl
                    + " in cache file " + debateTopicMapFile);
        }

        Vector topicDistributionVector = VectorUtils.listToVector(topicDistribution);

        // iterate over sentences
        for (Sentence sentence : JCasUtil.select(aJCas, Sentence.class)) {
            // and load the appropriate distance to centroids
            List<Embeddings> embeddingsList = JCasUtil.selectCovered(Embeddings.class, sentence);

            if (embeddingsList.size() != 1) {
                throw new AnalysisEngineProcessException(
                        new IllegalStateException("Expected 1 embedding annotations for sentence, but "
                                + embeddingsList.size() + " found." + "Sentence: " + sentence.getBegin()
                                + sentence.getEnd() + ", " + StringUtils.join(embeddingsList.iterator(), "\n")));
            }

            Embeddings embeddings = embeddingsList.iterator().next();
            DenseVector embeddingsVector = new DenseVector(embeddings.getVector().toArray());

            Vector distanceToClusterCentroidsVector = ClusteringUtils
                    .transformEmbeddingVectorToDistanceToClusterCentroidsVector(embeddingsVector, centroids);

            updateClusterTopicMatrix(distanceToClusterCentroidsVector, topicDistributionVector);
        }
    }

    /**
     * Updates the co-occurrence matrix of clusters and topics
     *
     * @param distanceToClusterCentroidsVector distance to cluster centroids of the particular
     *                                         sentence
     * @param topicDistributionVector          topic distribution for the document
     */
    protected void updateClusterTopicMatrix(Vector distanceToClusterCentroidsVector,
            Vector topicDistributionVector) {
        int cluster = VectorUtils.largestValues(distanceToClusterCentroidsVector, 1).entrySet().iterator().next()
                .getValue();
        int topic = VectorUtils.largestValues(topicDistributionVector, 1).entrySet().iterator().next().getValue();

        // just increase co-occurrence
        clusterTopicMatrix.add(cluster, topic, 1.0);
    }

    @Override
    public void collectionProcessComplete() throws AnalysisEngineProcessException {
        super.collectionProcessComplete();

        try {
            if (outputFile != null) {
                PrintWriter pw = new PrintWriter(new FileOutputStream(outputFile));
                for (int i = 0; i < clusterTopicMatrix.numRows(); i++) {
                    for (int j = 0; j < clusterTopicMatrix.numColumns(); j++) {
                        pw.printf(Locale.ENGLISH, "%.5f\t", clusterTopicMatrix.get(i, j));
                    }
                    pw.println();
                }
                IOUtils.closeQuietly(pw);
            } else {
                // print the matrix
                System.out.println(clusterTopicMatrix);
            }
        } catch (IOException e) {
            throw new AnalysisEngineProcessException(e);
        }
    }
}