edu.cmu.lti.oaqa.baseqa.providers.ml.classifiers.LibSvmProvider.java Source code

Java tutorial

Introduction

Here is the source code for edu.cmu.lti.oaqa.baseqa.providers.ml.classifiers.LibSvmProvider.java

Source

/*
 * Open Advancement Question Answering (OAQA) Project Copyright 2016 Carnegie Mellon University
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations
 * under the License.
 */

package edu.cmu.lti.oaqa.baseqa.providers.ml.classifiers;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.io.Files;
import edu.cmu.lti.oaqa.ecd.config.ConfigurableProvider;
import libsvm.*;
import org.apache.commons.codec.Charsets;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.resource.ResourceSpecifier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;

import static java.util.stream.Collectors.toMap;

/**
 * <p>
 *   A {@link ClassifierProvider} that wraps
 *   <a href="https://www.csie.ntu.edu.tw/~cjlin/libsvm/">LibSVM</a> classifiers.
 *   The solver type is hardwired in the code (C_SVC / RBF).
 * </p>
 * <p>
 *   Parameters include <tt>model-file</tt>, <tt>feat-index-file</tt>, <tt>label-index-file</tt>,
 *   etc.
 *   The latter two map between the string-based feature and label names to integers (indexes).
 * </p>
 * <p>
 *   Note that it has a special
 *   <a href="http://www.csie.ntu.edu.tw/~cjlin/libsvm/COPYRIGHT">copyright</a>.
 * </p>
 *
 * @author <a href="mailto:ziy@cs.cmu.edu">Zi Yang</a> created on 4/5/15
 */
public class LibSvmProvider extends ConfigurableProvider implements ClassifierProvider {

    private File featIndexFile;

    private File labelIndexFile;

    private File modelFile;

    private Map<Integer, String> fid2feat;

    private BiMap<Integer, String> lid2label;

    private BiMap<String, Integer> label2lid;

    private svm_model model;

    private svm_parameter param;

    private static final Logger LOG = LoggerFactory.getLogger(LibSvmProvider.class);

    @Override
    public boolean initialize(ResourceSpecifier aSpecifier, Map<String, Object> aAdditionalParams)
            throws ResourceInitializationException {
        boolean ret = super.initialize(aSpecifier, aAdditionalParams);
        // feature id map
        if ((featIndexFile = new File((String) getParameterValue("feat-index-file"))).exists()) {
            try {
                fid2feat = ClassifierProvider.loadIdKeyMap(featIndexFile);
            } catch (IOException e) {
                throw new ResourceInitializationException(e);
            }
        }
        // label id map
        if ((labelIndexFile = new File((String) getParameterValue("label-index-file"))).exists()) {
            try {
                lid2label = HashBiMap.create(ClassifierProvider.loadIdKeyMap(labelIndexFile));
                label2lid = lid2label.inverse();
            } catch (IOException e) {
                throw new ResourceInitializationException(e);
            }
        }
        // model
        if ((modelFile = new File((String) getParameterValue("model-file"))).exists()) {
            try {
                model = svm.svm_load_model(Files.newReader(modelFile, Charsets.UTF_8));
            } catch (IOException e) {
                throw new ResourceInitializationException(e);
            }
        }
        // parameter
        param = new svm_parameter();
        param.svm_type = svm_parameter.C_SVC;
        param.kernel_type = svm_parameter.RBF;
        // param.probability = 1;
        // param.gamma = 0.5;
        // param.nu = 0.5;
        // param.C = 1;
        // param.cache_size = 20000;
        // param.eps = 0.001;
        return ret;
    }

    @Override
    public Map<String, Double> infer(Map<String, Double> features) {
        svm_node[] x = IntStream.range(1, fid2feat.size() + 1).mapToObj(j -> {
            svm_node node = new svm_node();
            node.index = j;
            node.value = features.getOrDefault(fid2feat.get(j), 0.0);
            return node;
        }).toArray(svm_node[]::new);
        double[] values = new double[lid2label.size()];
        svm.svm_predict_values(model, x, values);
        int[] lids = new int[lid2label.size()];
        svm.svm_get_labels(model, lids);
        return IntStream.range(0, values.length).boxed()
                .collect(toMap(i -> lid2label.get(lids[i]), i -> values[i]));
    }

    @Override
    public String predict(Map<String, Double> features) {
        svm_node[] x = IntStream.range(1, fid2feat.size() + 1).mapToObj(j -> {
            svm_node node = new svm_node();
            node.index = j;
            node.value = features.getOrDefault(fid2feat.get(j), 0.0);
            return node;
        }).toArray(svm_node[]::new);
        double result = svm.svm_predict(model, x);
        return lid2label.get((int) result);
    }

    @Override
    public void train(List<Map<String, Double>> X, List<String> Y, boolean crossValidation)
            throws AnalysisEngineProcessException {
        // create feature to id map
        fid2feat = ClassifierProvider.createFeatureIdKeyMap(X);
        // create label to id map
        lid2label = ClassifierProvider.createLabelIdKeyMap(Y);
        label2lid = lid2label.inverse();
        try {
            ClassifierProvider.saveIdKeyMap(fid2feat, featIndexFile);
            ClassifierProvider.saveIdKeyMap(lid2label, labelIndexFile);
        } catch (IOException e) {
            throw new AnalysisEngineProcessException(e);
        }
        // create libsvm data structure and train
        svm_problem prob = new svm_problem();
        assert X.size() == Y.size();
        int dataCount = X.size();
        int featCount = fid2feat.size();
        LOG.info("Training for {} instances, {} features, {} labels.", dataCount, featCount, lid2label.size());
        prob.l = dataCount;
        prob.x = X.stream().map(x -> IntStream.range(1, featCount + 1).mapToObj(j -> {
            svm_node node = new svm_node();
            node.index = j;
            node.value = x.getOrDefault(fid2feat.get(j), 0.0);
            return node;
        }).toArray(svm_node[]::new)).toArray(svm_node[][]::new);
        prob.y = Y.stream().mapToDouble(label2lid::get).toArray();
        model = svm.svm_train(prob, param);
        try {
            svm.svm_save_model(modelFile.getAbsolutePath(), model);
        } catch (IOException e) {
            throw new AnalysisEngineProcessException(e);
        }
        double[] target = new double[prob.l];
        if (crossValidation) {
            svm.svm_cross_validation(prob, param, 10, target);
        }
    }

}