de.tudarmstadt.ukp.experiments.argumentation.clustering.EmbeddingsClutoDataWriter.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.experiments.argumentation.clustering.EmbeddingsClutoDataWriter.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package de.tudarmstadt.ukp.experiments.argumentation.clustering;

import de.tudarmstadt.ukp.experiments.argumentation.clustering.embeddings.EmbeddingsAnnotator;
import de.tudarmstadt.ukp.dkpro.argumentation.type.Embeddings;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasConsumer_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.DoubleArray;
import org.apache.uima.resource.ResourceInitializationException;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Collection;
import java.util.Locale;

/**
 * For each {@link de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence} in the input JCas,
 * it prints the embedding vector to the output file; the output file is in the Cluto format
 * (heading, vectors).
 *
 * @author Ivan Habernal
 */
public class EmbeddingsClutoDataWriter extends JCasConsumer_ImplBase {
    public static final String PARAM_OUTPUT_FOLDER = "outputFile";
    @ConfigurationParameter(name = PARAM_OUTPUT_FOLDER, mandatory = true)
    private File outputFile;

    // fixme make dynamical
    public static final int VECTOR_SIZE = 300;
    private PrintWriter pwTemp;
    private File tmpFile;
    private int collectionSize = 0;

    @Override
    public void initialize(UimaContext context) throws ResourceInitializationException {

        super.initialize(context);

        try {
            // temporary file and writer
            tmpFile = File.createTempFile("tmp_cluto", ".dat");
            pwTemp = new PrintWriter(tmpFile);
        } catch (IOException e) {
            throw new ResourceInitializationException(e);
        }
    }

    @Override
    public void process(JCas aJCas) throws AnalysisEngineProcessException {
        // process each sentence
        Collection<Embeddings> embeddingsCollection = JCasUtil.select(aJCas, Embeddings.class);

        if (embeddingsCollection.isEmpty()) {
            throw new AnalysisEngineProcessException(new IllegalStateException(
                    "No embeddings found in the document. You should annotate it first using "
                            + EmbeddingsAnnotator.class.getName()));
        }

        for (Embeddings embeddings : embeddingsCollection) {
            DoubleArray vector = embeddings.getVector();
            double[] doubles = vector.toArray();

            // print final vector values
            for (double value : doubles) {
                pwTemp.printf(Locale.ENGLISH, "%f ", value);
            }
            pwTemp.println();

            collectionSize++;
        }
    }

    @Override
    public void collectionProcessComplete() throws AnalysisEngineProcessException {
        super.collectionProcessComplete();

        // now produce output file
        // close temp file
        pwTemp.flush();
        pwTemp.close();

        try {
            PrintWriter pw = new PrintWriter(outputFile);

            pw.printf(Locale.ENGLISH, "%d %d%n", collectionSize, VECTOR_SIZE);

            // copy the rest
            IOUtils.copy(new FileInputStream(tmpFile), pw);

            pw.flush();
            pw.close();

            // delete tmp file
            FileUtils.deleteQuietly(tmpFile);
        } catch (IOException e) {
            throw new AnalysisEngineProcessException(e);
        }
    }
}