com.rokittech.ml.server.utils.MLUtils.java Source code

Java tutorial

Introduction

Here is the source code for com.rokittech.ml.server.utils.MLUtils.java

Source

/*
 * Copyright 2017 ROKITT Inc.
 * (https://www.rokittastra.com)
    
 * This program is a free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    
 * This program also uses Spring software that is licensed under
 * the Apache License, Version 2.0 (the "License");
 * you may not use Spring files except in compliance with the License.
 *
 * We are using Spring software with Apache 2 license according to the
 * recommendations of ASF:
 *
 *      https://www.apache.org/licenses/GPL-compatibility.html
 *
 * You may obtain a copy of the Apache 2 License at
    
 *      http://www.apache.org/licenses/LICENSE-2.0
    
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License and Apache 2 License for more details.
 *
 */
package com.rokittech.ml.server.utils;

import com.rokittech.ml.server.dto.MLInstance;
import weka.classifiers.Classifier;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.lazy.IBk;
import weka.classifiers.meta.Bagging;
import weka.classifiers.trees.DecisionStump;
import weka.classifiers.trees.J48;
import weka.classifiers.trees.RandomForest;
import weka.classifiers.trees.RandomTree;
import weka.core.Instance;
import weka.core.SparseInstance;
import weka.core.Utils;

import java.util.List;

import static com.rokittech.ml.server.utils.ValidatorUtils.notEmpty;
import static com.rokittech.ml.server.utils.ValidatorUtils.notNull;

/**
 * Created by andrii.zavalnyi on 3/11/17.
 */
public class MLUtils {

    private static final String[] ML_ATTRIBUTE_NAMES = new String[] { "F1", "F2", "F3", "F4", "F5", "F6", "F7",
            "F8", "F9", "F10", "F11", "F12", "F13", "F14" };
    private static final int ML_ATTR_SIZE = ML_ATTRIBUTE_NAMES.length + 1;

    public static Instance toTesInstance(MLInstance mlInstance, List<String> features) {
        return new SparseInstance(1, toTestAttribute(mlInstance, features));
    }

    private static double[] toTestAttribute(MLInstance mlInstance, List<String> features) {
        return createAttributeFromMLInstance(mlInstance, features, true);
    }

    public static Instance toInstance(MLInstance mlInstance, List<String> features) {
        return new SparseInstance(1, toAttribute(mlInstance, features));
    }

    public static double[] toAttribute(MLInstance mlInstance, List<String> features) {
        return createAttributeFromMLInstance(mlInstance, features, false);
    }

    private static double[] createAttributeFromMLInstance(MLInstance mlInstance, List<String> features,
            boolean isTest) {
        int index = 0;
        double[] vals = new double[features.size() + 1];
        if (features.contains(ML_ATTRIBUTE_NAMES[0])) {
            vals[index] = fillVal(mlInstance.getF1());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[1])) {
            vals[index] = fillVal(mlInstance.getF2());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[2])) {
            vals[index] = fillVal(mlInstance.getF3());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[3])) {
            vals[index] = fillVal(mlInstance.getF4());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[4])) {
            vals[index] = fillVal(mlInstance.getF5());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[5])) {
            vals[index] = fillVal(mlInstance.getF6());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[6])) {
            vals[index] = fillVal(mlInstance.getF7());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[7])) {
            vals[index] = fillVal(mlInstance.getF8());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[8])) {
            vals[index] = fillVal(mlInstance.getF9());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[9])) {
            vals[index] = fillVal(mlInstance.getF10());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[10])) {
            vals[index] = fillVal(mlInstance.getF11());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[11])) {
            vals[index] = fillVal(mlInstance.getF12());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[12])) {
            vals[index] = fillVal(mlInstance.getF13());
            index++;
        }
        if (features.contains(ML_ATTRIBUTE_NAMES[13])) {
            vals[index] = fillVal(mlInstance.getF14());
            index++;
        }
        if (isTest) {
            vals[index] = mlInstance.getIsReal() ? 1.0d : 0.0d;
        } else {
            vals[index] = Utils.missingValue();
        }
        return vals;
    }

    private static Double fillVal(Double o) {
        return o == null ? Utils.missingValue() : o;
    }

    public static Classifier getClassifier(String mlAlgorithm) {
        notEmpty(mlAlgorithm);
        Classifier classifier;
        switch (mlAlgorithm.toUpperCase()) {
        case "J48": {
            classifier = new J48();
            break;
        }
        case "IBK": {
            classifier = new IBk();
            break;
        }
        case "NAIVE_BAYES": {
            classifier = new NaiveBayes();
            break;
        }
        case "RANDOM_TREE": {
            classifier = new RandomTree();
            break;
        }
        case "RANDOM_FOREST": {
            classifier = new RandomForest();
            break;
        }
        case "BOOSTING": {
            classifier = new DecisionStump();
            break;
        }
        case "BAGGING": {
            classifier = new Bagging();
            break;
        }
        default:
            throw new UnsupportedOperationException("Classifier " + mlAlgorithm + " is not supported.");
        }
        return classifier;
    }

    public static Classifier castClassifier(Object o, String mlAlgorithm) {
        notEmpty(mlAlgorithm);
        notNull(o);
        Classifier classifier;
        switch (mlAlgorithm.toUpperCase()) {
        case "J48": {
            classifier = (J48) o;
            break;
        }
        case "IBK": {
            classifier = (IBk) o;
            break;
        }
        case "NAIVE_BAYES": {
            classifier = (NaiveBayes) o;
            break;
        }
        case "RANDOM_TREE": {
            classifier = (RandomTree) o;
            break;
        }
        case "RANDOM_FOREST": {
            classifier = (RandomForest) o;
            break;
        }
        case "BOOSTING": {
            classifier = (DecisionStump) o;
            break;
        }
        case "BAGGING": {
            classifier = (Bagging) o;
            break;
        }
        default:
            throw new UnsupportedOperationException("Classifier " + mlAlgorithm + " is not supported.");
        }
        return classifier;
    }
}