io.agi.core.ml.supervised.Svm.java Source code

Java tutorial

Introduction

Here is the source code for io.agi.core.ml.supervised.Svm.java

Source

/*
 * Copyright (c) 2016.
 *
 * This file is part of Project AGI. <http://agi.io>
 *
 * Project AGI is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Project AGI is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Project AGI.  If not, see <http://www.gnu.org/licenses/>.
 */

package io.agi.core.ml.supervised;

import io.agi.core.data.Data;
import io.agi.core.orm.Callback;
import io.agi.core.orm.NamedObject;
import io.agi.core.orm.ObjectMap;
import libsvm.*;
import org.apache.commons.io.FileUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.File;
import java.io.Reader;
import java.io.IOException;
import java.io.StringReader;
import java.io.BufferedReader;

/**
 * Created by gideon on 14/12/16.
 */
public class Svm extends NamedObject implements Callback, SupervisedBatchTraining<SvmConfig> {

    public static final String MODEL_FILENAME = "temp_svm.model";

    protected static final Logger _logger = LogManager.getLogger();

    private SvmConfig _config;
    private svm_model _model = null;

    public Svm(String name, ObjectMap om) {
        super(name, om);
    }

    @Override
    public void call() {
        update();
    }

    public void update() {

    }

    @Override
    public void setup(SvmConfig config) {
        this._config = config;
        loadModel(); // load model if it exists in config object
    }

    @Override
    public void reset() {
        _model = null;
        saveModel();
    }

    @Override
    public void loadModel() {
        String modelString = _config.getModelString();
        if (modelString != null && modelString.length() != 0) {
            loadModel(modelString);
        }
    }

    @Override
    public void loadModel(String modelString) {
        Reader stringReader = new StringReader(modelString);
        BufferedReader bufferedReader = new BufferedReader(stringReader);

        try {
            _model = svm.svm_load_model(bufferedReader);
            saveModel();
        } catch (IOException e) {
            _logger.error("Unable to load svm model.");
            _logger.error(e.toString(), e);
        }
    }

    @Override
    public String getModelString() {
        return _config.getModelString();
    }

    @Override
    public String saveModel() {
        String modelString = null;
        try {
            modelString = modelString();
        } catch (Exception e) {
            _logger.error("Could not save model to config.");
            _logger.error(e.toString(), e);
        }

        _config.setModelString(modelString);
        return modelString;
    }

    /**
     * Serialise the model into a string and return.
     * @return The model as a string.
     * @throws Exception
     */
    private String modelString() throws Exception {

        String modelString = null;

        if (_model != null) {
            try {
                File modelFile = new File(MODEL_FILENAME);
                svm.svm_save_model(MODEL_FILENAME, _model);
                modelString = FileUtils.readFileToString(modelFile);
            } catch (IOException e) {
                _logger.error("Unable to save svm model.");
                _logger.error(e.toString(), e);
            }
        } else {
            String errorMessage = "Cannot to save svm model before it is defined";
            _logger.error(errorMessage);
            throw new Exception(errorMessage);
        }

        return modelString;
    }

    public void train(Data featuresMatrix, Data classTruthVector) {

        int n = SupervisedUtil.calcNFromFeatureMatrix(featuresMatrix);

        svm_parameter parameters = setupParameters();

        svm_problem problem = setupProblem(featuresMatrix, classTruthVector);

        _model = svm.svm_train(problem, parameters);

        saveModel(); // save the model to config object
    }

    public void predict(Data featuresMatrixTrain, Data predictionsVector) {

        int m = SupervisedUtil.calcMFromFeatureMatrix(featuresMatrixTrain); // m = number of data points
        int n = SupervisedUtil.calcNFromFeatureMatrix(featuresMatrixTrain); // n = feature vector size

        svm_node[][] x = new svm_node[m][n];

        // iterate data points (vectors in the VectorSeries - each vector is a data point)
        for (int j = 0; j < m; ++j) {

            // iterate dimensions of x (elements of the vector)
            for (int i = 0; i < n; ++i) {

                double xi = SupervisedUtil.getFeatureValue(featuresMatrixTrain, n, j, i);

                x[j][i] = new svm_node();
                x[j][i].index = i + 1;
                x[j][i].value = xi;
            }

            predictionsVector._values[j] = (float) svm.svm_predict(_model, x[j]);
        }
    }

    private svm_problem setupProblem(Data featuresMatrix, Data classTruthVector) {

        int m = SupervisedUtil.calcMFromFeatureMatrix(featuresMatrix); // m = number of data points
        int n = SupervisedUtil.calcNFromFeatureMatrix(featuresMatrix); // n = feature vector size

        svm_problem prob = new svm_problem();
        prob.l = m;
        prob.y = new double[prob.l];
        prob.x = new svm_node[prob.l][n];

        // iterate data points (vectors in the VectorSeries - each vector is a data point)
        for (int j = 0; j < m; ++j) {

            // iterate dimensions of x (elements of the vector)
            for (int i = 0; i < n; ++i) {

                float classTruth = SupervisedUtil.getClassTruth(classTruthVector, j);
                double xi = SupervisedUtil.getFeatureValue(featuresMatrix, n, j, i);

                prob.x[j][i] = new svm_node();
                prob.x[j][i].index = i + 1;
                prob.x[j][i].value = xi;
                prob.y[j] = classTruth;
            }
        }

        return prob;
    }

    private svm_parameter setupParameters() {
        svm_parameter param = new svm_parameter();

        // default values
        param.svm_type = svm_parameter.C_SVC;
        param.kernel_type = svm_parameter.RBF;
        param.degree = 3;
        param.coef0 = 0;
        param.nu = 0.5;
        param.cache_size = 40;
        param.eps = 1e-3;
        param.p = 0.1;
        param.shrinking = 1;
        param.probability = 0;
        param.nr_weight = 0;
        param.weight_label = new int[0];
        param.weight = new double[0];

        // values from config
        param.C = _config.getConstraintsViolation();
        param.gamma = _config.getGamma();

        return param;
    }

}