edu.cmu.lti.oaqa.baseqa.providers.ml.classifiers.WekaProvider.java Source code

Java tutorial

Introduction

Here is the source code for edu.cmu.lti.oaqa.baseqa.providers.ml.classifiers.WekaProvider.java

Source

/*
 * Open Advancement Question Answering (OAQA) Project Copyright 2016 Carnegie Mellon University
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations
 * under the License.
 */

package edu.cmu.lti.oaqa.baseqa.providers.ml.classifiers;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multiset;
import com.google.common.io.Files;
import edu.cmu.lti.oaqa.ecd.config.ConfigurableProvider;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.resource.ResourceSpecifier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.*;
import weka.core.converters.ArffSaver;

import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.stream.IntStream;

import static java.util.stream.Collectors.toMap;

/**
 * <p>
 *   A {@link ClassifierProvider} that wraps <a href="http://www.cs.waikato.ac.nz/ml/weka/">Weka</a>
 *   classifiers.
 *   A descriptor of this {@link ConfigurableProvider} should specify the actual classifier name
 *   (full class path) via <tt>classifier-name</tt> parameter.
 * </p>
 * <p>
 *   Other parameters include <tt>model-file</tt>, <tt>dataset-schema-file</tt>,
 *   <tt>dataset-export</tt>, etc.
 * </p>
 * <p>
 *   Note that Weka is licensed under PDL!
 * </p>
 *
 * @author <a href="mailto:ziy@cs.cmu.edu">Zi Yang</a> created on 4/8/15
 */
public class WekaProvider extends ConfigurableProvider implements ClassifierProvider {

    private File modelFile;

    private File datasetSchemaFile;

    private Classifier classifier;

    private Instances datasetSchema;

    private File datasetExportFile;

    private String classifierName;

    private String[] options;

    private boolean balanceWeight;

    private static final Logger LOG = LoggerFactory.getLogger(WekaProvider.class);

    @Override
    public boolean initialize(ResourceSpecifier aSpecifier, Map<String, Object> aAdditionalParams)
            throws ResourceInitializationException {
        boolean ret = super.initialize(aSpecifier, aAdditionalParams);
        // model
        if ((modelFile = new File((String) getParameterValue("model-file"))).exists()) {
            try {
                classifier = (Classifier) SerializationHelper.read(modelFile.getAbsolutePath());
            } catch (Exception e) {
                throw new ResourceInitializationException(e);
            }
        }
        // dataset schema
        if ((datasetSchemaFile = new File((String) getParameterValue("dataset-schema-file"))).exists()) {
            try {
                datasetSchema = (Instances) SerializationHelper.read(datasetSchemaFile.getAbsolutePath());
            } catch (Exception e) {
                throw new ResourceInitializationException(e);
            }
        }
        // training instances backup as arff
        Object datasetExport;
        if ((datasetExport = getParameterValue("dataset-export")) != null) {
            datasetExportFile = new File(String.class.cast(datasetExport));
        }
        // classifier
        classifierName = String.class.cast(getParameterValue("classifier-name"));
        options = Iterables.toArray((Iterable<String>) getParameterValue("options"), String.class);
        balanceWeight = (boolean) getParameterValue("balance-weight");
        return ret;
    }

    @Override
    public Map<String, Double> infer(Map<String, Double> features) throws AnalysisEngineProcessException {
        Instances testInstances = new Instances(datasetSchema, 1);
        Instance instance = newInstance(features, null, 1.0, testInstances);
        double[] probs;
        try {
            probs = classifier.distributionForInstance(instance);
        } catch (Exception e) {
            throw new AnalysisEngineProcessException(e);
        }
        return IntStream.range(0, probs.length).boxed()
                .collect(toMap(i -> datasetSchema.classAttribute().value(i), i -> probs[i]));
    }

    @Override
    public void train(List<Map<String, Double>> X, List<String> Y, boolean crossValidation)
            throws AnalysisEngineProcessException {
        // create attribute (including label) info
        ArrayList<Attribute> attributes = new ArrayList<>();
        ClassifierProvider.featureNames(X).stream().map(Attribute::new).forEachOrdered(attributes::add);
        Attribute label = new Attribute("__label__", ClassifierProvider.labelNames(Y));
        attributes.add(label);
        String name = Files.getNameWithoutExtension(modelFile.getName());
        datasetSchema = new Instances(name, attributes, X.size());
        datasetSchema.setClass(label);
        // add instances
        Instances trainingInstances = new Instances(datasetSchema, X.size());
        if (balanceWeight) {
            Multiset<String> labelCounts = HashMultiset.create(Y);
            double maxCount = labelCounts.entrySet().stream().mapToInt(Multiset.Entry::getCount).max()
                    .orElseThrow(AnalysisEngineProcessException::new);
            for (int i = 0; i < X.size(); i++) {
                String y = Y.get(i);
                double weight = maxCount / labelCounts.count(y);
                trainingInstances.add(newInstance(X.get(i), y, weight, trainingInstances));
            }
        } else {
            for (int i = 0; i < X.size(); i++) {
                trainingInstances.add(newInstance(X.get(i), Y.get(i), 1.0, trainingInstances));
            }
        }
        // training
        try {
            classifier = AbstractClassifier.forName(classifierName, options);
            classifier.buildClassifier(trainingInstances);
        } catch (Exception e) {
            throw new AnalysisEngineProcessException(e);
        }
        // write model and dataset schema
        try {
            SerializationHelper.write(modelFile.getAbsolutePath(), classifier);
            SerializationHelper.write(datasetSchemaFile.getAbsolutePath(), datasetSchema);
        } catch (Exception e) {
            throw new AnalysisEngineProcessException(e);
        }
        // backup training dataset as arff file
        if (datasetExportFile != null) {
            try {
                ArffSaver saver = new ArffSaver();
                saver.setInstances(trainingInstances);
                saver.setFile(datasetExportFile);
                saver.writeBatch();
            } catch (IOException e) {
                throw new AnalysisEngineProcessException(e);
            }
        }
        if (crossValidation) {
            try {
                Evaluation eval = new Evaluation(trainingInstances);
                Random rand = new Random();
                eval.crossValidateModel(classifier, trainingInstances, 10, rand);
                LOG.debug(eval.toSummaryString());
            } catch (Exception e) {
                throw new AnalysisEngineProcessException(e);
            }
        }
    }

    private static Instance newInstance(Map<String, Double> features, String label, double weight,
            Instances dataset) {
        double[] values = new double[dataset.numAttributes()];
        for (Map.Entry<String, Double> entry : features.entrySet()) {
            Attribute attribute = dataset.attribute(entry.getKey());
            if (attribute == null)
                continue;
            values[attribute.index()] = entry.getValue();
        }
        SparseInstance instance = new SparseInstance(weight, values);
        instance.setDataset(dataset);
        if (label != null)
            instance.setClassValue(label);
        return instance;
    }

}