org.opentox.qsar.processors.trainers.regression.SVMTrainer.java Source code

Java tutorial

Introduction

Here is the source code for org.opentox.qsar.processors.trainers.regression.SVMTrainer.java

Source

/*
 *
 * YAQP - Yet Another QSAR Project:
 * Machine Learning algorithms designed for the prediction of toxicological
 * features of chemical compounds become available on the Web. Yaqp is developed
 * under OpenTox (http://opentox.org) which is an FP7-funded EU research project.
 * This project was developed at the Automatic Control Lab in the Chemical Engineering
 * School of the National Technical University of Athens. Please read README for more
 * information.
 *
 * Copyright (C) 2009-2010 Pantelis Sopasakis & Charalampos Chomenides
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *
 * Contact:
 * Pantelis Sopasakis
 * chvng@mail.ntua.gr
 * Address: Iroon Politechniou St. 9, Zografou, Athens Greece
 * tel. +30 210 7723236
 */
package org.opentox.qsar.processors.trainers.regression;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import org.opentox.config.ServerFolders;
import org.opentox.core.exceptions.Cause;
import org.opentox.ontology.components.Feature;
import org.opentox.ontology.components.QSARModel;
import org.opentox.ontology.components.QSARModel.ModelStatus;
import org.opentox.ontology.util.AlgorithmParameter;
import org.opentox.ontology.util.YaqpAlgorithms;
import org.opentox.ontology.util.vocabulary.ConstantParameters;
import org.opentox.qsar.exceptions.QSARException;
import org.opentox.qsar.processors.filters.AttributeCleanup;
import org.opentox.qsar.processors.filters.AttributeCleanup.ATTRIBUTE_TYPE;
import org.opentox.qsar.processors.filters.SimpleMVHFilter;
import org.opentox.qsar.processors.trainers.WekaTrainer;
import org.opentox.www.rest.components.YaqpForm;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.SVMreg;
import weka.classifiers.functions.supportVector.Kernel;
import weka.classifiers.functions.supportVector.PolyKernel;
import weka.classifiers.functions.supportVector.RBFKernel;
import weka.core.Instances;
import weka.core.converters.ArffSaver;

/**
 *
 * @author Pantelis Sopasakis
 * @author Charalampos Chomenides
 */
final public class SVMTrainer extends WekaRegressor {

    /**
     * The parameter gamma
     */
    private double gamma = Double
            .parseDouble(ConstantParameters.SVMParams().get(ConstantParameters.gamma).paramValue.toString());
    /**
     * The cost used in the trainer's cost function
     */
    private double cost = Double
            .parseDouble(ConstantParameters.SVMParams().get(ConstantParameters.cost).paramValue.toString());
    /**
     * Parameter epsilon or e-SVM training algorithm
     */
    private double epsilon = Double
            .parseDouble(ConstantParameters.SVMParams().get(ConstantParameters.epsilon).paramValue.toString());
    /**
     * The bias of the kernel function of the SVM model.
     */
    private double coeff0 = Double
            .parseDouble(ConstantParameters.SVMParams().get(ConstantParameters.coeff0).paramValue.toString());
    /**
     * Maximum cache size.
     */
    private int cacheSize = Integer
            .parseInt(ConstantParameters.SVMParams().get(ConstantParameters.cacheSize).paramValue.toString());
    /**
     * Degree of a polynomial kernel
     */
    private int degree = Integer
            .parseInt(ConstantParameters.SVMParams().get(ConstantParameters.degree).paramValue.toString());
    /**
     * Convergence criterion.
     */
    private double tolerance = Double
            .parseDouble(ConstantParameters.SVMParams().get(ConstantParameters.tolerance).paramValue.toString());
    /**
     * The kernel of the SVM model.
     */
    private String kernel = ConstantParameters.SVMParams().get(ConstantParameters.kernel).paramValue.toString();

    @SuppressWarnings({ "unchecked" })
    public SVMTrainer(final YaqpForm form) throws QSARException {
        super(form);

        // CHECK GAMMA
        try {
            if (form.getFirstValue(ConstantParameters.gamma) != null) {
                this.gamma = Double.parseDouble(form.getFirstValue(ConstantParameters.gamma));
            }
            if (gamma <= 0) {
                throw new QSARException(Cause.XQReg3002, "The parameter gamma must be strictly positive. "
                        + "You provided the illegal value: {" + gamma + "}");
            }
        } catch (final NumberFormatException ex) {
            throw new QSARException(Cause.XQReg3001, "Parameter gamma should be numeric. "
                    + "You provided the illegal value : {" + form.getFirstValue(ConstantParameters.gamma) + "}",
                    ex);
        }
        putParameter(ConstantParameters.gamma, new AlgorithmParameter((double) gamma));

        // CHECK COST
        try {
            if (form.getFirstValue(ConstantParameters.cost) != null) {
                this.cost = Double.parseDouble(form.getFirstValue(ConstantParameters.cost));
            }
            if (cost <= 0) {
                throw new QSARException(Cause.XQReg3004, "The parameter " + ConstantParameters.cost
                        + " must be strictly positive. " + "You provided the illegal value: {" + cost + "}");
            }
        } catch (final NumberFormatException ex) {
            throw new QSARException(Cause.XQReg3003, "Parameter " + ConstantParameters.cost + " should be numeric. "
                    + "You provided the illegal " + "value : {" + form.getFirstValue(ConstantParameters.cost) + "}",
                    ex);
        }
        putParameter(ConstantParameters.cost, new AlgorithmParameter(cost));

        // CHECK EPSILON
        try {
            if (form.getFirstValue(ConstantParameters.epsilon) != null) {
                this.epsilon = Double.parseDouble(form.getFirstValue(ConstantParameters.epsilon));
            }
            if (epsilon <= 0) {
                throw new QSARException(Cause.XQReg3006, "The parameter " + ConstantParameters.epsilon
                        + " must be strictly positive. " + "You provided the illegal value: {" + epsilon + "}");
            }
        } catch (final NumberFormatException ex) {
            throw new QSARException(Cause.XQReg3005,
                    "Parameter " + ConstantParameters.epsilon + " should be numeric. " + "You provided the illegal "
                            + "value : {" + form.getFirstValue(ConstantParameters.epsilon) + "}",
                    ex);
        }
        putParameter(ConstantParameters.epsilon, new AlgorithmParameter(epsilon));

        // CHECK COEFF_0
        try {
            if (form.getFirstValue(ConstantParameters.coeff0) != null) {
                this.coeff0 = Double.parseDouble(form.getFirstValue(ConstantParameters.coeff0));
            }
        } catch (final NumberFormatException ex) {
            throw new QSARException(Cause.XQReg3007,
                    "Parameter " + ConstantParameters.coeff0 + " should be numeric. " + "You provided the illegal "
                            + "value : {" + form.getFirstValue(ConstantParameters.coeff0) + "}",
                    ex);
        }
        putParameter(ConstantParameters.coeff0, new AlgorithmParameter(coeff0));

        // CHECK CACHE SIZE
        try {
            if (form.getFirstValue(ConstantParameters.cacheSize) != null) {
                this.cacheSize = Integer.parseInt(form.getFirstValue(ConstantParameters.cacheSize));
            }
        } catch (final NumberFormatException ex) {
            throw new QSARException(Cause.XQReg3008,
                    "Parameter " + ConstantParameters.cacheSize + " should be integer. "
                            + "You provided the illegal " + "value : {"
                            + form.getFirstValue(ConstantParameters.cacheSize) + "}",
                    ex);
        }
        putParameter(ConstantParameters.cacheSize, new AlgorithmParameter(cacheSize));

        // CHECK DEGREE
        try {
            if (form.getFirstValue(ConstantParameters.degree) != null) {
                this.degree = Integer.parseInt(form.getFirstValue(ConstantParameters.degree));
            }
        } catch (final NumberFormatException ex) {
            throw new QSARException(Cause.XQReg3009,
                    "Parameter " + ConstantParameters.degree + " should be integer. " + "You provided the illegal "
                            + "value : {" + form.getFirstValue(ConstantParameters.degree) + "}",
                    ex);
        }
        putParameter(ConstantParameters.degree, new AlgorithmParameter(degree));

        // CHECK TOLERANCE
        try {
            if (form.getFirstValue(ConstantParameters.tolerance) != null) {
                this.tolerance = Double.parseDouble(form.getFirstValue(ConstantParameters.tolerance));
            }
            if (tolerance < 1E-6) {
                throw new QSARException(Cause.XQReg3011, "The parameter " + ConstantParameters.tolerance
                        + " must be greater that 1E-6. " + "You provided the illegal value: {" + tolerance + "}");
            }
        } catch (final NumberFormatException ex) {
            throw new QSARException(Cause.XQReg3010,
                    "Parameter " + ConstantParameters.tolerance + " should be numeric. "
                            + "You provided the illegal value : {"
                            + form.getFirstValue(ConstantParameters.tolerance) + "}",
                    ex);
        }
        putParameter(ConstantParameters.tolerance, new AlgorithmParameter(tolerance));

        // CHECK KERNEL
        if (form.getFirstValue(ConstantParameters.kernel) != null) {
            this.kernel = form.getFirstValue(ConstantParameters.kernel).toUpperCase();
            if (!kernel.equals("RBF") && !kernel.equals("LINEAR") && !kernel.equals("POLYNOMIAL")) {
                throw new QSARException(Cause.XQReg3012,
                        "The available kernels are [RBF; LINEAR; POLYNOMIAL]. Note that "
                                + "this parameter is not case-sensitive, i.e. rbf is the same as RbF. However you provided "
                                + "the illegal value : {" + kernel + "}");
            }
        }
        putParameter(ConstantParameters.kernel, new AlgorithmParameter(kernel));

    }

    public SVMTrainer() {
        super();
    }

    /**
     *
     * @param data
     * @return
     * @throws QSARException
     */
    public QSARModel train(Instances data) throws QSARException {

        // NOTE: The checks (check if data is null and if the prediction feature is
        //       acceptable are found in WekaRegressor. The method preprocessData(Instances)
        //       does this job.        

        // GET A UUID AND DEFINE THE TEMPORARY FILE WHERE THE TRAINING DATA
        // ARE STORED IN ARFF FORMAT PRIOR TO TRAINING.
        final String rand = java.util.UUID.randomUUID().toString();
        final String temporaryFilePath = ServerFolders.temp + "/" + rand + ".arff";
        final File tempFile = new File(temporaryFilePath);

        // SAVE THE DATA IN THE TEMPORARY FILE
        try {
            ArffSaver dataSaver = new ArffSaver();
            dataSaver.setInstances(data);
            dataSaver.setDestination(new FileOutputStream(tempFile));
            dataSaver.writeBatch();
        } catch (final IOException ex) {
            tempFile.delete();
            throw new RuntimeException(
                    "Unexpected condition while trying to save the " + "dataset in a temporary ARFF file", ex);
        }

        // INITIALIZE THE REGRESSOR
        SVMreg regressor = new SVMreg();
        final String[] regressorOptions = { "-P", Double.toString(epsilon), "-T", Double.toString(tolerance) };

        Kernel svm_kernel = null;
        if (kernel.equalsIgnoreCase("rbf")) {
            RBFKernel rbf_kernel = new RBFKernel();
            rbf_kernel.setGamma(Double.parseDouble(Double.toString(gamma)));
            rbf_kernel.setCacheSize(Integer.parseInt(Integer.toString(cacheSize)));
            svm_kernel = rbf_kernel;
        } else if (kernel.equalsIgnoreCase("polynomial")) {
            PolyKernel poly_kernel = new PolyKernel();
            poly_kernel.setExponent(Double.parseDouble(Integer.toString(degree)));
            poly_kernel.setCacheSize(Integer.parseInt(Integer.toString(cacheSize)));
            poly_kernel.setUseLowerOrder(true);
            svm_kernel = poly_kernel;
        } else if (kernel.equalsIgnoreCase("linear")) {
            PolyKernel poly_kernel = new PolyKernel();
            poly_kernel.setExponent((double) 1.0);
            poly_kernel.setCacheSize(Integer.parseInt(Integer.toString(cacheSize)));
            poly_kernel.setUseLowerOrder(true);
            svm_kernel = poly_kernel;
        }
        regressor.setKernel(svm_kernel);
        try {
            regressor.setOptions(regressorOptions);
        } catch (final Exception ex) {
            tempFile.delete();
            throw new IllegalArgumentException("Bad options in SVM trainer for epsilon = {" + epsilon + "} or "
                    + "tolerance = {" + tolerance + "}.", ex);
        }

        // PERFORM THE TRAINING
        String[] generalOptions = { "-c", Integer.toString(data.classIndex() + 1), "-t", temporaryFilePath,
                /// Save the model in the following directory
                "-d", ServerFolders.models_weka + "/" + uuid };
        try {
            Evaluation.evaluateModel(regressor, generalOptions);
        } catch (final Exception ex) {
            tempFile.delete();
            throw new QSARException(Cause.XQReg350, "Unexpected condition while trying to train "
                    + "an SVM model. Possible explanation : {" + ex.getMessage() + "}", ex);
        }

        QSARModel model = new QSARModel();

        model.setParams(getParameters());
        model.setCode(uuid.toString());
        model.setAlgorithm(YaqpAlgorithms.SVM);
        model.setDataset(datasetUri);
        model.setModelStatus(ModelStatus.UNDER_DEVELOPMENT);

        ArrayList<Feature> independentFeatures = new ArrayList<Feature>();
        for (int i = 0; i < data.numAttributes(); i++) {
            Feature f = new Feature(data.attribute(i).name());
            if (data.classIndex() != i) {
                independentFeatures.add(f);
            }
        }

        Feature dependentFeature = new Feature(data.classAttribute().name());
        Feature predictedFeature = dependentFeature;
        model.setDependentFeature(dependentFeature);
        model.setIndependentFeatures(independentFeatures);
        model.setPredictionFeature(predictedFeature);
        tempFile.delete();
        return model;
    }
}