Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package org.esa.nest.gpf; import weka.classifiers.RandomizableClassifier; import weka.classifiers.UpdateableClassifier; import weka.core.Capabilities; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.OptionHandler; import weka.core.RevisionUtils; import weka.core.SelectedTag; import weka.core.Tag; import weka.core.Utils; import weka.core.Capabilities.Capability; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Normalize; import weka.filters.unsupervised.attribute.ReplaceMissingValues; import java.util.ArrayList; import java.util.Enumeration; import java.util.Random; import java.util.Vector; /* * SGD.java * Copyright (C) 2009-2011 University of Waikato, Hamilton, New Zealand * */ /** * <!-- globalinfo-start --> * Implements stochastic gradient descent for learning various linear models * (binary class SVM, binary class logistic regression and linear regression). * Globally replaces all missing values and transforms nominal attributes into * binary ones. It also normalizes all attributes, so the coefficients in the * output are based on the normalized data.<br/> * For numeric class attributes, the squared loss function (2) must be used. * <p/> * <!-- globalinfo-end --> * * <!-- options-start --> * Valid options are: * <p/> * * <pre> -F * Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), * 2 = squared loss (regression). * (default = 0)</pre> * * <pre> -L * The learning rate. If normalization is * turned off (as it is automatically for streaming data), then the * default learning rate will need to be reduced (try 0.0001). * (default = 0.01).</pre> * * <pre> -R <double> * The lambda regularization constant (default = 0.0001)</pre> * * <pre> -E <integer> * The number of epochs to perform (batch learning only, default = 500)</pre> * * <pre> -N * Don't normalize the data</pre> * * <pre> -M * Don't replace missing values</pre> * * <!-- options-end --> * * @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz) * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) * @version $Revision: 7057 $ * */ public class SGD extends RandomizableClassifier implements UpdateableClassifier, OptionHandler { /** * For serialization */ private static final long serialVersionUID = -3732968666673530290L; /** * Replace missing values */ protected ReplaceMissingValues m_replaceMissing; /** * Convert nominal attributes to numerically coded binary ones. Uses * supervised NominalToBinary in the batch learning case */ protected Filter m_nominalToBinary; /** * Normalize the training data */ protected Normalize m_normalize; /** * The regularization parameter */ protected double m_lambda = 0.0001; /** * The learning rate */ protected double m_learningRate = 0.01; /** * Stores the weights (+ bias in the last element) */ protected double[] m_weights; /** * Holds the current iteration number */ protected double m_t; /** * The number of training instances */ protected double m_numInstances; /** * The number of epochs to perform (batch learning). Total iterations is * m_epochs * num instances */ protected int m_epochs = 500; /** * Turn off normalization of the input data. This option gets forced for * incremental training. */ protected boolean m_dontNormalize = false; /** * Turn off global replacement of missing values. Missing values will be * ignored instead. This option gets forced for incremental training. */ protected boolean m_dontReplaceMissing = false; /** * Holds the header of the training data */ protected Instances m_data; /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); //attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class if (m_loss == SQUAREDLOSS) { result.enable(Capability.NUMERIC_CLASS); } else { result.enable(Capability.BINARY_CLASS); } result.enable(Capability.MISSING_CLASS_VALUES); // instances result.setMinimumNumberInstances(0); return result; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String lambdaTipText() { return "The regularization constant. (default = 0.0001)"; } /** * Set the value of lambda to use * * @param lambda the value of lambda to use */ public void setLambda(double lambda) { m_lambda = lambda; } /** * Get the current value of lambda * * @return the current value of lambda */ public double getLambda() { return m_lambda; } /** * Set the learning rate. * * @param lr the learning rate to use. */ public void setLearningRate(double lr) { m_learningRate = lr; } /** * Get the learning rate. * * @return the learning rate */ public double getLearningRate() { return m_learningRate; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String learningRateTipText() { return "The learning rate. If normalization is turned off " + "(as it is automatically for streaming data), then" + "the default learning rate will need to be reduced (" + "try 0.0001)."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String epochsTipText() { return "The number of epochs to perform (batch learning). " + "The total number of iterations is epochs * num" + " instances."; } /** * Set the number of epochs to use * * @param e the number of epochs to use */ public void setEpochs(int e) { m_epochs = e; } /** * Get current number of epochs * * @return the current number of epochs */ public int getEpochs() { return m_epochs; } /** * Turn normalization off/on. * * @param m true if normalization is to be disabled. */ public void setDontNormalize(boolean m) { m_dontNormalize = m; } /** * Get whether normalization has been turned off. * * @return true if normalization has been disabled. */ public boolean getDontNormalize() { return m_dontNormalize; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String dontNormalizeTipText() { return "Turn normalization off"; } /** * Turn global replacement of missing values off/on. If turned off, then * missing values are effectively ignored. * * @param m true if global replacement of missing values is to be turned * off. */ public void setDontReplaceMissing(boolean m) { m_dontReplaceMissing = m; } /** * Get whether global replacement of missing values has been disabled. * * @return true if global replacement of missing values has been turned off */ public boolean getDontReplaceMissing() { return m_dontReplaceMissing; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String dontReplaceMissingTipText() { return "Turn off global replacement of missing values"; } /** * Set the loss function to use. * * @param function the loss function to use. */ public void setLossFunction(SelectedTag function) { if (function.getTags() == TAGS_SELECTION) { m_loss = function.getSelectedTag().getID(); } } /** * Get the current loss function. * * @return the current loss function. */ public SelectedTag getLossFunction() { return new SelectedTag(m_loss, TAGS_SELECTION); } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String lossFunctionTipText() { return "The loss function to use. Hinge loss (SVM), " + "log loss (logistic regression) or " + "squared loss (regression)."; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration<Option> listOptions() { Vector<Option> newVector = new Vector<>(); newVector.add(new Option("\tSet the loss function to minimize. 0 = " + "hinge loss (SVM), 1 = log loss (logistic regression),\n\t" + "2 = squared loss (regression).\n\t(default = 0)", "F", 1, "-F")); newVector.add(new Option( "\tThe learning rate. If normalization is\n" + "\tturned off (as it is automatically for streaming data), then the\n\t" + "default learning rate will need to be reduced " + "(try 0.0001).\n\t(default = 0.01).", "L", 1, "-L")); newVector.add( new Option("\tThe lambda regularization constant " + "(default = 0.0001)", "R", 1, "-R <double>")); newVector.add(new Option("\tThe number of epochs to perform (" + "batch learning only, default = 500)", "E", 1, "-E <integer>")); newVector.add(new Option("\tDon't normalize the data", "N", 0, "-N")); newVector.add(new Option("\tDon't replace missing values", "M", 0, "-M")); return newVector.elements(); } /** * * Parses a given list of options. * <p/> * * <!-- options-start --> * Valid options are: * <p/> * * <pre> -F * Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), * 2 = squared loss (regression). * (default = 0)</pre> * * <pre> -L * The learning rate. If normalization is * turned off (as it is automatically for streaming data), then the * default learning rate will need to be reduced (try 0.0001). * (default = 0.01).</pre> * * <pre> -R <double> * The lambda regularization constant (default = 0.0001)</pre> * * <pre> -E <integer> * The number of epochs to perform (batch learning only, default = 500)</pre> * * <pre> -N * Don't normalize the data</pre> * * <pre> -M * Don't replace missing values</pre> * * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ @Override public void setOptions(String[] options) throws Exception { reset(); super.setOptions(options); String lossString = Utils.getOption('F', options); if (lossString.length() != 0) { setLossFunction(new SelectedTag(Integer.parseInt(lossString), TAGS_SELECTION)); } String lambdaString = Utils.getOption('R', options); if (lambdaString.length() > 0) { setLambda(Double.parseDouble(lambdaString)); } String learningRateString = Utils.getOption('L', options); if (learningRateString.length() > 0) { setLearningRate(Double.parseDouble(learningRateString)); } String epochsString = Utils.getOption("E", options); if (epochsString.length() > 0) { setEpochs(Integer.parseInt(epochsString)); } setDontNormalize(Utils.getFlag("N", options)); setDontReplaceMissing(Utils.getFlag('M', options)); } /** * Gets the current settings of the classifier. * * @return an array of strings suitable for passing to setOptions */ @Override public String[] getOptions() { ArrayList<String> options = new ArrayList<>(); options.add("-F"); options.add("" + getLossFunction().getSelectedTag().getID()); options.add("-L"); options.add("" + getLearningRate()); options.add("-R"); options.add("" + getLambda()); options.add("-E"); options.add("" + getEpochs()); if (getDontNormalize()) { options.add("-N"); } if (getDontReplaceMissing()) { options.add("-M"); } return options.toArray(new String[1]); } /** * Returns a string describing classifier * * @return a description suitable for displaying in the * explorer/experimenter gui */ public String globalInfo() { return "Implements stochastic gradient descent for learning" + " various linear models (binary class SVM, binary class" + " logistic regression and linear" + " regression)." + " Globally replaces all missing values and transforms nominal" + " attributes into binary ones. It also normalizes all attributes," + " so the coefficients in the output are based on the normalized" + " data.\n" + "For numeric class attributes, the squared loss function " + "(" + SQUAREDLOSS + ") must be used."; } /** * Reset the classifier. */ public void reset() { m_t = 1; m_weights = null; } /** * Method for building the classifier. * * @param data the set of training instances. * @throws Exception if the classifier can't be built successfully. */ @Override public void buildClassifier(Instances data) throws Exception { reset(); // can classifier handle the data? getCapabilities().testWithFail(data); data = new Instances(data); data.deleteWithMissingClass(); if (data.numInstances() > 0 && !m_dontReplaceMissing) { m_replaceMissing = new ReplaceMissingValues(); m_replaceMissing.setInputFormat(data); data = Filter.useFilter(data, m_replaceMissing); } // check for only numeric attributes boolean onlyNumeric = true; for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { if (!data.attribute(i).isNumeric()) { onlyNumeric = false; break; } } } if (!onlyNumeric) { if (data.numInstances() > 0) { m_nominalToBinary = new weka.filters.supervised.attribute.NominalToBinary(); } else { m_nominalToBinary = new weka.filters.unsupervised.attribute.NominalToBinary(); } m_nominalToBinary.setInputFormat(data); data = Filter.useFilter(data, m_nominalToBinary); } if (!m_dontNormalize && data.numInstances() > 0) { m_normalize = new Normalize(); m_normalize.setInputFormat(data); data = Filter.useFilter(data, m_normalize); } m_numInstances = data.numInstances(); m_weights = new double[data.numAttributes() + 1]; m_data = new Instances(data, 0); if (data.numInstances() > 0) { data.randomize(new Random(getSeed())); // randomize the data train(data); } } /** * the hinge loss function. */ public static final int HINGE = 0; /** * the log loss function. */ public static final int LOGLOSS = 1; /** * the squared loss function. */ public static final int SQUAREDLOSS = 2; /** * The current loss function to minimize */ protected int m_loss = HINGE; /** * Loss functions to choose from */ public static final Tag[] TAGS_SELECTION = { new Tag(HINGE, "Hinge loss (SVM)"), new Tag(LOGLOSS, "Log loss (logistic regression)"), new Tag(SQUAREDLOSS, "Squared loss (regression)") }; protected double dloss(double z) { if (m_loss == HINGE) { return (z < 1) ? 1 : 0; } if (m_loss == LOGLOSS) { // log loss if (z < 0) { return 1.0 / (Math.exp(z) + 1.0); } else { double t = Math.exp(-z); return t / (t + 1); } } // squared loss return z; } private void train(Instances data) throws Exception { for (int e = 0; e < m_epochs; e++) { for (int i = 0; i < data.numInstances(); i++) { updateClassifier(data.instance(i)); } } } protected static double dotProd(Instance inst1, double[] weights, int classIndex) { double result = 0; int n1 = inst1.numValues(); int n2 = weights.length - 1; for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) { int ind1 = inst1.index(p1); int ind2 = p2; if (ind1 == ind2) { if (ind1 != classIndex && !inst1.isMissingSparse(p1)) { result += inst1.valueSparse(p1) * weights[p2]; } p1++; p2++; } else if (ind1 > ind2) { p2++; } else { p1++; } } return (result); } /** * Updates the classifier with the given instance. * * @param instance the new training instance to include in the model * @exception Exception if the instance could not be incorporated in the * model. */ @Override public void updateClassifier(Instance instance) throws Exception { if (!instance.classIsMissing()) { double wx = dotProd(instance, m_weights, instance.classIndex()); double y; double z; if (instance.classAttribute().isNominal()) { y = (instance.classValue() == 0) ? -1 : 1; z = y * (wx + m_weights[m_weights.length - 1]); } else { y = instance.classValue(); z = y - (wx + m_weights[m_weights.length - 1]); y = 1; } // Compute multiplier for weight decay double multiplier = 1.0; if (m_numInstances == 0) { multiplier = 1.0 - (m_learningRate * m_lambda) / m_t; } else { multiplier = 1.0 - (m_learningRate * m_lambda) / m_numInstances; } for (int i = 0; i < m_weights.length - 1; i++) { m_weights[i] *= multiplier; } // Only need to do the following if the loss is non-zero if (m_loss != HINGE || (z < 1)) { // Compute Factor for updates double factor = m_learningRate * y * dloss(z); // Update coefficients for attributes int n1 = instance.numValues(); for (int p1 = 0; p1 < n1; p1++) { int indS = instance.index(p1); if (indS != instance.classIndex() && !instance.isMissingSparse(p1)) { m_weights[indS] += factor * instance.valueSparse(p1); } } // update the bias m_weights[m_weights.length - 1] += factor; } m_t++; } } /** * Computes the distribution for a given instance * * @param instance the instance for which distribution is computed * @return the distribution * @throws Exception if the distribution can't be computed successfully */ @Override public double[] distributionForInstance(Instance inst) throws Exception { double[] result = (inst.classAttribute().isNominal()) ? new double[2] : new double[1]; if (m_replaceMissing != null) { m_replaceMissing.input(inst); inst = m_replaceMissing.output(); } if (m_nominalToBinary != null) { m_nominalToBinary.input(inst); inst = m_nominalToBinary.output(); } if (m_normalize != null) { m_normalize.input(inst); inst = m_normalize.output(); } double wx = dotProd(inst, m_weights, inst.classIndex());// * m_wScale; double z = (wx + m_weights[m_weights.length - 1]); if (inst.classAttribute().isNumeric()) { result[0] = z; return result; } if (z <= 0) { // z = 0; if (m_loss == LOGLOSS) { result[0] = 1.0 / (1.0 + Math.exp(z)); result[1] = 1.0 - result[0]; } else { result[0] = 1; } } else { if (m_loss == LOGLOSS) { result[1] = 1.0 / (1.0 + Math.exp(-z)); result[0] = 1.0 - result[1]; } else { result[1] = 1; } } return result; } /** * Prints out the classifier. * * @return a description of the classifier as a string */ @Override public String toString() { if (m_weights == null) { return "SGD: No model built yet.\n"; } StringBuilder buff = new StringBuilder(); buff.append("Loss function: "); if (m_loss == HINGE) { buff.append("Hinge loss (SVM)\n\n"); } else if (m_loss == LOGLOSS) { buff.append("Log loss (logistic regression)\n\n"); } else { buff.append("Squared loss (linear regression)\n\n"); } buff.append(m_data.classAttribute().name()).append(" = \n\n"); int printed = 0; for (int i = 0; i < m_weights.length - 1; i++) { if (i != m_data.classIndex()) { if (printed > 0) { buff.append(" + "); } else { buff.append(" "); } buff.append(Utils.doubleToString(m_weights[i], 12, 4)).append(" ") .append((m_normalize != null) ? "(normalized) " : "").append(m_data.attribute(i).name()) .append("\n"); printed++; } } if (m_weights[m_weights.length - 1] > 0) { buff.append(" + ").append(Utils.doubleToString(m_weights[m_weights.length - 1], 12, 4)); } else { buff.append(" - ").append(Utils.doubleToString(-m_weights[m_weights.length - 1], 12, 4)); } return buff.toString(); } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 7057 $"); } /** * Main method for testing this class. * @param args */ public static void main(String[] args) { runClassifier(new SGD(), args); } }