de.tudarmstadt.ukp.experiments.argumentation.clustering.entropy.ClusterSentencesCollector.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.experiments.argumentation.clustering.entropy.ClusterSentencesCollector.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.experiments.argumentation.clustering.entropy;

import de.tudarmstadt.ukp.experiments.argumentation.clustering.ClusterCentroidsMain;
import de.tudarmstadt.ukp.experiments.argumentation.clustering.ClusteringUtils;
import de.tudarmstadt.ukp.experiments.argumentation.clustering.VectorUtils;
import de.tudarmstadt.ukp.dkpro.argumentation.type.Embeddings;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import org.apache.commons.io.IOUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasConsumer_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;

import java.io.*;
import java.util.Locale;
import java.util.Map;
import java.util.TreeMap;

/**
 * @author Ivan Habernal
 */
public class ClusterSentencesCollector extends JCasConsumer_ImplBase {
    /**
     * Output from {@link ClusterCentroidsMain}
     */
    public static final String PARAM_CENTROIDS_FILE = "centroidsFile";
    @ConfigurationParameter(name = PARAM_CENTROIDS_FILE, mandatory = true)
    File centroidsFile;

    public static final String PARAM_OUTPUT_DIR = "outputDir";
    @ConfigurationParameter(name = PARAM_OUTPUT_DIR, mandatory = false)
    File outputDir;

    TreeMap<Integer, Vector> centroids;

    @Override
    @SuppressWarnings("unchecked")
    public void initialize(UimaContext context) throws ResourceInitializationException {
        super.initialize(context);

        try {
            // load centroids
            centroids = (TreeMap<Integer, Vector>) new ObjectInputStream(new FileInputStream(centroidsFile))
                    .readObject();
        } catch (IOException | ClassNotFoundException e) {
            throw new ResourceInitializationException(e);
        }
    }

    @Override
    public void process(JCas aJCas) throws AnalysisEngineProcessException {
        // iterate over embeddings
        for (Embeddings embeddings : JCasUtil.select(aJCas, Embeddings.class)) {
            DenseVector embeddingsVector = new DenseVector(embeddings.getVector().toArray());

            Vector distanceToClusterCentroidsVector = ClusteringUtils
                    .transformEmbeddingVectorToDistanceToClusterCentroidsVector(embeddingsVector, centroids);

            Map.Entry<Double, Integer> entry = VectorUtils.largestValues(distanceToClusterCentroidsVector, 1)
                    .entrySet().iterator().next();
            int cluster = entry.getValue();
            double distance = entry.getKey();

            try {
                appendSentence(cluster, distance, embeddings.getCoveredText());
            } catch (IOException e) {
                throw new AnalysisEngineProcessException(e);
            }
        }
    }

    private void appendSentence(int cluster, double distance, String coveredText) throws IOException {
        String fileName = String.format(Locale.ENGLISH, "%3d.txt", cluster);
        File file = new File(outputDir, fileName);

        PrintWriter pw = new PrintWriter(new FileWriter(file, true));

        pw.printf(Locale.ENGLISH, "%.4f\t%s%n", distance, coveredText.replaceAll("\\n+", " "));

        IOUtils.closeQuietly(pw);
    }
}