de.tudarmstadt.ukp.dkpro.core.opennlp.OpenNlpNamedEntityRecognizerTrainer.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dkpro.core.opennlp.OpenNlpNamedEntityRecognizerTrainer.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package de.tudarmstadt.ukp.dkpro.core.opennlp;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collections;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.commons.io.IOUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasConsumer_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;

import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters;
import de.tudarmstadt.ukp.dkpro.core.opennlp.internal.CasNameSampleStream;
import opennlp.tools.ml.BeamSearch;
import opennlp.tools.ml.EventTrainer;
import opennlp.tools.ml.maxent.GIS;
import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
import opennlp.tools.ml.perceptron.PerceptronTrainer;
import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
import opennlp.tools.namefind.BilouCodec;
import opennlp.tools.namefind.BioCodec;
import opennlp.tools.namefind.NameFinderME;
import opennlp.tools.namefind.TokenNameFinderFactory;
import opennlp.tools.namefind.TokenNameFinderModel;
import opennlp.tools.util.SequenceCodec;
import opennlp.tools.util.TrainingParameters;

/**
 * Train a POS tagging model for OpenNLP.
 */
public class OpenNlpNamedEntityRecognizerTrainer extends JCasConsumer_ImplBase {
    public static enum SequenceEncoding {
        BIO(BioCodec.class), BILOU(BilouCodec.class);

        private Class<? extends SequenceCodec<String>> codec;

        SequenceEncoding(Class<? extends SequenceCodec<String>> aCodec) {
            codec = aCodec;
        }

        private SequenceCodec<String> getCodec() {
            try {
                return codec.newInstance();
            } catch (InstantiationException | IllegalAccessException e) {
                throw new IllegalStateException(e);
            }
        }
    }

    public static final String PARAM_LANGUAGE = ComponentParameters.PARAM_LANGUAGE;
    @ConfigurationParameter(name = PARAM_LANGUAGE, mandatory = true)
    private String language;

    public static final String PARAM_TARGET_LOCATION = ComponentParameters.PARAM_TARGET_LOCATION;
    @ConfigurationParameter(name = PARAM_TARGET_LOCATION, mandatory = true)
    private File targetLocation;

    /**
     * @see GIS#MAXENT_VALUE
     * @see QNTrainer#MAXENT_QN_VALUE
     * @see PerceptronTrainer#PERCEPTRON_VALUE
     * @see SimplePerceptronSequenceTrainer#PERCEPTRON_SEQUENCE_VALUE
     */
    public static final String PARAM_ALGORITHM = "algorithm";
    @ConfigurationParameter(name = PARAM_ALGORITHM, mandatory = true, defaultValue = GIS.MAXENT_VALUE)
    private String algorithm;

    public static final String PARAM_TRAINER_TYPE = "trainerType";
    @ConfigurationParameter(name = PARAM_TRAINER_TYPE, mandatory = true, defaultValue = EventTrainer.EVENT_VALUE)
    private String trainerType;

    public static final String PARAM_ITERATIONS = "iterations";
    @ConfigurationParameter(name = PARAM_ITERATIONS, mandatory = true, defaultValue = "100")
    private int iterations;

    public static final String PARAM_CUTOFF = "cutoff";
    @ConfigurationParameter(name = PARAM_CUTOFF, mandatory = true, defaultValue = "5")
    private int cutoff;

    /**
     * @see NameFinderME#DEFAULT_BEAM_SIZE
     */
    public static final String PARAM_BEAMSIZE = "beamSize";
    @ConfigurationParameter(name = PARAM_BEAMSIZE, mandatory = true, defaultValue = "3")
    private int beamSize;

    public static final String PARAM_FEATURE_GEN = "featureGen";
    @ConfigurationParameter(name = PARAM_FEATURE_GEN, mandatory = false)
    private File featureGen;

    public static final String PARAM_SEQUENCE_ENCODING = "sequenceEncoding";
    @ConfigurationParameter(name = PARAM_SEQUENCE_ENCODING, mandatory = true, defaultValue = "BILOU")
    private SequenceEncoding sequenceEncoding;

    private CasNameSampleStream stream;
    private ExecutorService executor = Executors.newSingleThreadExecutor();
    private Future<TokenNameFinderModel> future;

    @Override
    public void initialize(UimaContext aContext) throws ResourceInitializationException {
        super.initialize(aContext);

        stream = new CasNameSampleStream();

        TrainingParameters params = new TrainingParameters();
        params.put(TrainingParameters.ALGORITHM_PARAM, algorithm);
        //        params.put(TrainingParameters.TRAINER_TYPE_PARAM,
        //                TrainerFactory.getTrainerType(params.getSettings()).name());
        params.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(iterations));
        params.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(cutoff));
        params.put(BeamSearch.BEAM_SIZE_PARAMETER, Integer.toString(beamSize));

        byte featureGenCfg[] = loadFeatureGen(featureGen);

        Callable<TokenNameFinderModel> trainTask = () -> {
            try {
                return NameFinderME.train(language, null, stream, params, new TokenNameFinderFactory(featureGenCfg,
                        Collections.<String, Object>emptyMap(), sequenceEncoding.getCodec()));
            } catch (Throwable e) {
                stream.close();
                throw e;
            }
        };

        future = executor.submit(trainTask);
    }

    @Override
    public void process(JCas aJCas) throws AnalysisEngineProcessException {
        if (!future.isCancelled()) {
            stream.send(aJCas);
        }
    }

    @Override
    public void collectionProcessComplete() throws AnalysisEngineProcessException {
        try {
            stream.close();
        } catch (IOException e) {
            throw new AnalysisEngineProcessException(e);
        }

        TokenNameFinderModel model;
        try {
            model = future.get();
        } catch (InterruptedException | ExecutionException e) {
            throw new AnalysisEngineProcessException(e);
        }

        try (OutputStream out = new FileOutputStream(targetLocation)) {
            model.serialize(out);
        } catch (IOException e) {
            throw new AnalysisEngineProcessException(e);
        }
    }

    private byte[] loadFeatureGen(File aFile) throws ResourceInitializationException {
        byte featureGenCfg[] = null;
        if (aFile != null) {
            try (InputStream in = new FileInputStream(aFile)) {
                featureGenCfg = IOUtils.toByteArray(in);
            } catch (IOException e) {
                throw new ResourceInitializationException(e);
            }
        }
        return featureGenCfg;
    }
}