de.tudarmstadt.ukp.dkpro.tc.weka.task.serialization.LoadModelConnectorWeka.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dkpro.tc.weka.task.serialization.LoadModelConnectorWeka.java

Source

/**
 * Copyright 2015
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program. If not, see http://www.gnu.org/licenses/.
 */
package de.tudarmstadt.ukp.dkpro.tc.weka.task.serialization;

import static de.tudarmstadt.ukp.dkpro.tc.core.Constants.FM_UNIT;
import static de.tudarmstadt.ukp.dkpro.tc.core.Constants.MODEL_CLASSIFIER;
import static de.tudarmstadt.ukp.dkpro.tc.core.Constants.MODEL_CLASS_LABELS;
import static de.tudarmstadt.ukp.dkpro.tc.core.Constants.MODEL_FEATURE_NAMES;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.apache.commons.io.FileUtils;
import org.apache.log4j.Logger;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.ExternalResource;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;

import de.tudarmstadt.ukp.dkpro.tc.api.features.FeatureExtractorResource_ImplBase;
import de.tudarmstadt.ukp.dkpro.tc.api.features.Instance;
import de.tudarmstadt.ukp.dkpro.tc.api.type.TextClassificationFocus;
import de.tudarmstadt.ukp.dkpro.tc.api.type.TextClassificationOutcome;
import de.tudarmstadt.ukp.dkpro.tc.core.ml.ModelSerialization_ImplBase;
import de.tudarmstadt.ukp.dkpro.tc.ml.uima.TcAnnotatorDocument;
import de.tudarmstadt.ukp.dkpro.tc.weka.util.WekaUtils;
import weka.classifiers.Classifier;
import weka.core.Attribute;

public class LoadModelConnectorWeka extends ModelSerialization_ImplBase {

    @ConfigurationParameter(name = TcAnnotatorDocument.PARAM_TC_MODEL_LOCATION, mandatory = true)
    private File tcModelLocation;

    public static final String PARAM_BIPARTITION_THRESHOLD = "bipartitionThreshold";
    @ConfigurationParameter(name = PARAM_BIPARTITION_THRESHOLD, mandatory = true, defaultValue = "0.5")
    private String bipartitionThreshold;

    @ExternalResource(key = PARAM_FEATURE_EXTRACTORS, mandatory = true)
    protected FeatureExtractorResource_ImplBase[] featureExtractors;

    @ConfigurationParameter(name = PARAM_LEARNING_MODE, mandatory = true)
    private String learningMode;

    @ConfigurationParameter(name = PARAM_FEATURE_MODE, mandatory = true)
    private String featureMode;

    private Classifier cls;
    private List<Attribute> attributes;
    private List<String> classLabels;

    @Override
    public void initialize(UimaContext context) throws ResourceInitializationException {
        super.initialize(context);

        try {
            cls = (Classifier) weka.core.SerializationHelper
                    .read(new File(tcModelLocation, MODEL_CLASSIFIER).getAbsolutePath());
        } catch (Exception e) {
            throw new ResourceInitializationException(e);
        }

        attributes = new ArrayList<>();
        try {
            for (String attributeName : FileUtils.readLines(new File(tcModelLocation, MODEL_FEATURE_NAMES))) {
                attributes.add(new Attribute(attributeName));
            }
        } catch (IOException e) {
            throw new ResourceInitializationException(e);
        }

        classLabels = new ArrayList<>();
        try {
            for (String classLabel : FileUtils.readLines(new File(tcModelLocation, MODEL_CLASS_LABELS))) {
                classLabels.add(classLabel);
            }
        } catch (IOException e) {
            throw new ResourceInitializationException(e);
        }

    }

    @Override
    public void process(JCas jcas) throws AnalysisEngineProcessException {
        try {
            Logger.getLogger(getClass()).debug("START: process(JCAS) - applying Weka Model");

            Instance instance = de.tudarmstadt.ukp.dkpro.tc.core.util.TaskUtils.getSingleInstance(featureMode,
                    featureExtractors, jcas, false, false);

            weka.core.Instance wekaInstance = WekaUtils.tcInstanceToWekaInstance(instance, attributes, classLabels,
                    false);

            String val = classLabels.get((int) cls.classifyInstance(wekaInstance));

            TextClassificationOutcome outcome = null;
            if (!FM_UNIT.equals(featureMode))
                outcome = JCasUtil.selectSingle(jcas, TextClassificationOutcome.class);
            else
                outcome = getOutcomeForFocus(jcas);

            outcome.setOutcome(val);

            Logger.getLogger(getClass()).debug(
                    "Found classification result \"" + val + "\" for text: \"" + outcome.getCoveredText() + "\"");

            Logger.getLogger(getClass()).debug("END: process(JCAS) - applying Weka Model");
        } catch (Exception e) {
            throw new AnalysisEngineProcessException(new IllegalStateException(e.getMessage()));
        }
    }

    private TextClassificationOutcome getOutcomeForFocus(JCas jcas) {
        TextClassificationOutcome outcome = null;
        TextClassificationFocus focus = null;

        try {
            focus = JCasUtil.selectSingle(jcas, TextClassificationFocus.class);
        } catch (Exception ex) {
            String msg = "Error while trying to retrieve TC focus from CAS. Details: " + ex.getMessage();
            Logger.getLogger(getClass()).error(msg, ex);
            throw new RuntimeException(msg);
        }

        Iterator<TextClassificationOutcome> outcomeIterator = JCasUtil.iterator(focus,
                TextClassificationOutcome.class, false /* ambiguous*/, false /* strict */);

        if (!outcomeIterator.hasNext())
            throw new IllegalStateException("There should be exactly one TC outcome covered by the TC focus from "
                    + focus.getBegin() + " to " + focus.getEnd());

        outcome = outcomeIterator.next();

        if (outcomeIterator.hasNext())
            throw new IllegalStateException("There should be exactly one TC outcome covered by the TC focus from "
                    + focus.getBegin() + " to " + focus.getEnd());

        return outcome;
    }
}