Java tutorial
/* * Copyright 2017 ROKITT Inc. * (https://www.rokittastra.com) * This program is a free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. * This program also uses Spring software that is licensed under * the Apache License, Version 2.0 (the "License"); * you may not use Spring files except in compliance with the License. * * We are using Spring software with Apache 2 license according to the * recommendations of ASF: * * https://www.apache.org/licenses/GPL-compatibility.html * * You may obtain a copy of the Apache 2 License at * http://www.apache.org/licenses/LICENSE-2.0 * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License and Apache 2 License for more details. * */ package com.rokittech.ml.server.utils; import com.rokittech.ml.server.dto.MLInstance; import weka.classifiers.Classifier; import weka.classifiers.bayes.NaiveBayes; import weka.classifiers.lazy.IBk; import weka.classifiers.meta.Bagging; import weka.classifiers.trees.DecisionStump; import weka.classifiers.trees.J48; import weka.classifiers.trees.RandomForest; import weka.classifiers.trees.RandomTree; import weka.core.Instance; import weka.core.SparseInstance; import weka.core.Utils; import java.util.List; import static com.rokittech.ml.server.utils.ValidatorUtils.notEmpty; import static com.rokittech.ml.server.utils.ValidatorUtils.notNull; /** * Created by andrii.zavalnyi on 3/11/17. */ public class MLUtils { private static final String[] ML_ATTRIBUTE_NAMES = new String[] { "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14" }; private static final int ML_ATTR_SIZE = ML_ATTRIBUTE_NAMES.length + 1; public static Instance toTesInstance(MLInstance mlInstance, List<String> features) { return new SparseInstance(1, toTestAttribute(mlInstance, features)); } private static double[] toTestAttribute(MLInstance mlInstance, List<String> features) { return createAttributeFromMLInstance(mlInstance, features, true); } public static Instance toInstance(MLInstance mlInstance, List<String> features) { return new SparseInstance(1, toAttribute(mlInstance, features)); } public static double[] toAttribute(MLInstance mlInstance, List<String> features) { return createAttributeFromMLInstance(mlInstance, features, false); } private static double[] createAttributeFromMLInstance(MLInstance mlInstance, List<String> features, boolean isTest) { int index = 0; double[] vals = new double[features.size() + 1]; if (features.contains(ML_ATTRIBUTE_NAMES[0])) { vals[index] = fillVal(mlInstance.getF1()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[1])) { vals[index] = fillVal(mlInstance.getF2()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[2])) { vals[index] = fillVal(mlInstance.getF3()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[3])) { vals[index] = fillVal(mlInstance.getF4()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[4])) { vals[index] = fillVal(mlInstance.getF5()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[5])) { vals[index] = fillVal(mlInstance.getF6()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[6])) { vals[index] = fillVal(mlInstance.getF7()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[7])) { vals[index] = fillVal(mlInstance.getF8()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[8])) { vals[index] = fillVal(mlInstance.getF9()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[9])) { vals[index] = fillVal(mlInstance.getF10()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[10])) { vals[index] = fillVal(mlInstance.getF11()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[11])) { vals[index] = fillVal(mlInstance.getF12()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[12])) { vals[index] = fillVal(mlInstance.getF13()); index++; } if (features.contains(ML_ATTRIBUTE_NAMES[13])) { vals[index] = fillVal(mlInstance.getF14()); index++; } if (isTest) { vals[index] = mlInstance.getIsReal() ? 1.0d : 0.0d; } else { vals[index] = Utils.missingValue(); } return vals; } private static Double fillVal(Double o) { return o == null ? Utils.missingValue() : o; } public static Classifier getClassifier(String mlAlgorithm) { notEmpty(mlAlgorithm); Classifier classifier; switch (mlAlgorithm.toUpperCase()) { case "J48": { classifier = new J48(); break; } case "IBK": { classifier = new IBk(); break; } case "NAIVE_BAYES": { classifier = new NaiveBayes(); break; } case "RANDOM_TREE": { classifier = new RandomTree(); break; } case "RANDOM_FOREST": { classifier = new RandomForest(); break; } case "BOOSTING": { classifier = new DecisionStump(); break; } case "BAGGING": { classifier = new Bagging(); break; } default: throw new UnsupportedOperationException("Classifier " + mlAlgorithm + " is not supported."); } return classifier; } public static Classifier castClassifier(Object o, String mlAlgorithm) { notEmpty(mlAlgorithm); notNull(o); Classifier classifier; switch (mlAlgorithm.toUpperCase()) { case "J48": { classifier = (J48) o; break; } case "IBK": { classifier = (IBk) o; break; } case "NAIVE_BAYES": { classifier = (NaiveBayes) o; break; } case "RANDOM_TREE": { classifier = (RandomTree) o; break; } case "RANDOM_FOREST": { classifier = (RandomForest) o; break; } case "BOOSTING": { classifier = (DecisionStump) o; break; } case "BAGGING": { classifier = (Bagging) o; break; } default: throw new UnsupportedOperationException("Classifier " + mlAlgorithm + " is not supported."); } return classifier; } }