Java tutorial
/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ /* * SimpleLogistic.java * Copyright (C) 2003-2012 University of Waikato, Hamilton, New Zealand * */ package weka.classifiers.functions; import java.util.Collections; import java.util.Enumeration; import java.util.Vector; import weka.classifiers.AbstractClassifier; import weka.classifiers.trees.lmt.LogisticBase; import weka.core.*; import weka.core.Capabilities.Capability; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import weka.filters.Filter; import weka.filters.unsupervised.attribute.NominalToBinary; import weka.filters.unsupervised.attribute.ReplaceMissingValues; /** <!-- globalinfo-start --> * Classifier for building linear logistic regression * models. LogitBoost with simple regression functions as base learners is used * for fitting the logistic models. The optimal number of LogitBoost iterations * to perform is cross-validated, which leads to automatic attribute selection. * For more information see:<br/> * Niels Landwehr, Mark Hall, Eibe Frank (2005). Logistic Model Trees.<br/> * <br/> * Marc Sumner, Eibe Frank, Mark Hall: Speeding up Logistic Model Tree * Induction. In: 9th European Conference on Principles and Practice of * Knowledge Discovery in Databases, 675-683, 2005. * <p/> <!-- globalinfo-end --> * <!-- technical-bibtex-start --> * BibTeX: * * <pre> * @article{Landwehr2005, * author = {Niels Landwehr and Mark Hall and Eibe Frank}, * booktitle = {Machine Learning}, * number = {1-2}, * pages = {161-205}, * title = {Logistic Model Trees}, * volume = {95}, * year = {2005} * } * * @inproceedings{Sumner2005, * author = {Marc Sumner and Eibe Frank and Mark Hall}, * booktitle = {9th European Conference on Principles and Practice of Knowledge Discovery in Databases}, * pages = {675-683}, * publisher = {Springer}, * title = {Speeding up Logistic Model Tree Induction}, * year = {2005} * } * </pre> * <p/> * <!-- technical-bibtex-end --> * * <!-- options-start --> Valid options are: * <p/> * * <pre> * -I <iterations> * Set fixed number of iterations for LogitBoost * </pre> * * <pre> * -S * Use stopping criterion on training set (instead of * cross-validation) * </pre> * * <pre> * -P * Use error on probabilities (rmse) instead of * misclassification error for stopping criterion * </pre> * * <pre> * -M <iterations> * Set maximum number of boosting iterations * </pre> * * <pre> * -H <iterations> * Set parameter for heuristic for early stopping of * LogitBoost. * If enabled, the minimum is selected greedily, stopping * if the current minimum has not changed for iter iterations. * By default, heuristic is enabled with value 50. Set to * zero to disable heuristic. * </pre> * * <pre> * -W <beta> * Set beta for weight trimming for LogitBoost. Set to 0 for no weight trimming. * </pre> * * <pre> * -A * The AIC is used to choose the best iteration (instead of CV or training error). * </pre> * <!-- options-end --> * * @author Niels Landwehr * @author Marc Sumner * @version $Revision$ */ public class SimpleLogistic extends AbstractClassifier implements OptionHandler, AdditionalMeasureProducer, WeightedInstancesHandler, TechnicalInformationHandler { /** for serialization */ static final long serialVersionUID = 7397710626304705059L; /** The actual logistic regression model */ protected LogisticBase m_boostedModel; /** Filter for converting nominal attributes to binary ones */ protected NominalToBinary m_NominalToBinary = null; /** Filter for replacing missing values */ protected ReplaceMissingValues m_ReplaceMissingValues = null; /** If non-negative, use this as fixed number of LogitBoost iterations */ protected int m_numBoostingIterations; /** Maximum number of iterations for LogitBoost */ protected int m_maxBoostingIterations = 500; /** Parameter for the heuristic for early stopping of LogitBoost */ protected int m_heuristicStop = 50; /** If true, cross-validate number of LogitBoost iterations */ protected boolean m_useCrossValidation; /** * If true, use minimize error on probabilities instead of misclassification * error */ protected boolean m_errorOnProbabilities; /** * Threshold for trimming weights. Instances with a weight lower than this (as * a percentage of total weights) are not included in the regression fit. */ protected double m_weightTrimBeta = 0; /** If true, the AIC is used to choose the best iteration */ private boolean m_useAIC = false; /** * Constructor for creating SimpleLogistic object with standard options. */ public SimpleLogistic() { m_numBoostingIterations = 0; m_useCrossValidation = true; m_errorOnProbabilities = false; m_weightTrimBeta = 0; m_useAIC = false; } /** * Constructor for creating SimpleLogistic object. * * @param numBoostingIterations if non-negative, use this as fixed number of * iterations for LogitBoost * @param useCrossValidation cross-validate number of LogitBoost iterations. * @param errorOnProbabilities minimize error on probabilities instead of * misclassification error */ public SimpleLogistic(int numBoostingIterations, boolean useCrossValidation, boolean errorOnProbabilities) { m_numBoostingIterations = numBoostingIterations; m_useCrossValidation = useCrossValidation; m_errorOnProbabilities = errorOnProbabilities; m_weightTrimBeta = 0; m_useAIC = false; } /** * Main method for testing this class * * @param argv commandline options */ public static void main(String[] argv) { runClassifier(new SimpleLogistic(), argv); } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * Builds the logistic regression using LogitBoost. * * @param data the training data * @throws Exception if something goes wrong */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); // replace missing values m_ReplaceMissingValues = new ReplaceMissingValues(); m_ReplaceMissingValues.setInputFormat(data); data = Filter.useFilter(data, m_ReplaceMissingValues); // convert nominal attributes m_NominalToBinary = new NominalToBinary(); m_NominalToBinary.setInputFormat(data); data = Filter.useFilter(data, m_NominalToBinary); // create actual logistic model m_boostedModel = new LogisticBase(m_numBoostingIterations, m_useCrossValidation, m_errorOnProbabilities); m_boostedModel.setMaxIterations(m_maxBoostingIterations); m_boostedModel.setHeuristicStop(m_heuristicStop); m_boostedModel.setWeightTrimBeta(m_weightTrimBeta); m_boostedModel.setUseAIC(m_useAIC); m_boostedModel.setNumDecimalPlaces(m_numDecimalPlaces); // build logistic model m_boostedModel.buildClassifier(data); } /** * Returns class probabilities for an instance. * * @param inst the instance to compute the probabilities for * @return the probabilities * @throws Exception if distribution can't be computed successfully */ public double[] distributionForInstance(Instance inst) throws Exception { // replace missing values / convert nominal atts m_ReplaceMissingValues.input(inst); inst = m_ReplaceMissingValues.output(); m_NominalToBinary.input(inst); inst = m_NominalToBinary.output(); // obtain probs from logistic model return m_boostedModel.distributionForInstance(inst); } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration<Option> listOptions() { Vector<Option> newVector = new Vector<Option>(); newVector.addElement( new Option("\tSet fixed number of iterations for LogitBoost", "I", 1, "-I <iterations>")); newVector.addElement(new Option( "\tUse stopping criterion on training set (instead of\n" + "\tcross-validation)", "S", 0, "-S")); newVector.addElement(new Option("\tUse error on probabilities (rmse) instead of\n" + "\tmisclassification error for stopping criterion", "P", 0, "-P")); newVector.addElement(new Option("\tSet maximum number of boosting iterations", "M", 1, "-M <iterations>")); newVector.addElement(new Option("\tSet parameter for heuristic for early stopping of\n" + "\tLogitBoost.\n" + "\tIf enabled, the minimum is selected greedily, stopping\n" + "\tif the current minimum has not changed for iter iterations.\n" + "\tBy default, heuristic is enabled with value 50. Set to\n" + "\tzero to disable heuristic.", "H", 1, "-H <iterations>")); newVector.addElement( new Option("\tSet beta for weight trimming for LogitBoost. Set to 0 for no weight trimming.\n", "W", 1, "-W <beta>")); newVector.addElement( new Option("\tThe AIC is used to choose the best iteration (instead of CV or training error).\n", "A", 0, "-A")); newVector.addAll(Collections.list(super.listOptions())); return newVector.elements(); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String[] getOptions() { Vector<String> options = new Vector<String>(); options.add("-I"); options.add("" + getNumBoostingIterations()); if (!getUseCrossValidation()) { options.add("-S"); } if (getErrorOnProbabilities()) { options.add("-P"); } options.add("-M"); options.add("" + getMaxBoostingIterations()); options.add("-H"); options.add("" + getHeuristicStop()); options.add("-W"); options.add("" + getWeightTrimBeta()); if (getUseAIC()) { options.add("-A"); } Collections.addAll(options, super.getOptions()); return options.toArray(new String[0]); } /** * Parses a given list of options. * <p/> * <!-- options-start --> * Valid options are: * <p/> * * <pre> * -I <iterations> * Set fixed number of iterations for LogitBoost * </pre> * * <pre> * -S * Use stopping criterion on training set (instead of * cross-validation) * </pre> * * <pre> * -P * Use error on probabilities (rmse) instead of * misclassification error for stopping criterion * </pre> * * <pre> * -M <iterations> * Set maximum number of boosting iterations * </pre> * * <pre> * -H <iterations> * Set parameter for heuristic for early stopping of * LogitBoost. * If enabled, the minimum is selected greedily, stopping * if the current minimum has not changed for iter iterations. * By default, heuristic is enabled with value 50. Set to * zero to disable heuristic. * </pre> * * <pre> * -W <beta> * Set beta for weight trimming for LogitBoost. Set to 0 for no weight trimming. * </pre> * * <pre> * -A * The AIC is used to choose the best iteration (instead of CV or training error). * </pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String optionString = Utils.getOption('I', options); if (optionString.length() != 0) { setNumBoostingIterations((new Integer(optionString)).intValue()); } setUseCrossValidation(!Utils.getFlag('S', options)); setErrorOnProbabilities(Utils.getFlag('P', options)); optionString = Utils.getOption('M', options); if (optionString.length() != 0) { setMaxBoostingIterations((new Integer(optionString)).intValue()); } optionString = Utils.getOption('H', options); if (optionString.length() != 0) { setHeuristicStop((new Integer(optionString)).intValue()); } optionString = Utils.getOption('W', options); if (optionString.length() != 0) { setWeightTrimBeta((new Double(optionString)).doubleValue()); } setUseAIC(Utils.getFlag('A', options)); super.setOptions(options); Utils.checkForRemainingOptions(options); } /** * Get the value of numBoostingIterations. * * @return the number of boosting iterations */ public int getNumBoostingIterations() { return m_numBoostingIterations; } /** * Set the value of numBoostingIterations. * * @param n the number of boosting iterations */ public void setNumBoostingIterations(int n) { m_numBoostingIterations = n; } /** * Get the value of useCrossValidation. * * @return true if cross-validation is used */ public boolean getUseCrossValidation() { return m_useCrossValidation; } /** * Set the value of useCrossValidation. * * @param l whether to use cross-validation */ public void setUseCrossValidation(boolean l) { m_useCrossValidation = l; } /** * Get the value of errorOnProbabilities. * * @return If true, use minimize error on probabilities instead of * misclassification error */ public boolean getErrorOnProbabilities() { return m_errorOnProbabilities; } /** * Set the value of errorOnProbabilities. * * @param l If true, use minimize error on probabilities instead of * misclassification error */ public void setErrorOnProbabilities(boolean l) { m_errorOnProbabilities = l; } /** * Get the value of maxBoostingIterations. * * @return the maximum number of boosting iterations */ public int getMaxBoostingIterations() { return m_maxBoostingIterations; } /** * Set the value of maxBoostingIterations. * * @param n the maximum number of boosting iterations */ public void setMaxBoostingIterations(int n) { m_maxBoostingIterations = n; } /** * Get the value of heuristicStop. * * @return the value of heuristicStop */ public int getHeuristicStop() { return m_heuristicStop; } /** * Set the value of heuristicStop. * * @param n the value of heuristicStop */ public void setHeuristicStop(int n) { if (n == 0) m_heuristicStop = m_maxBoostingIterations; else m_heuristicStop = n; } /** * Get the value of weightTrimBeta. */ public double getWeightTrimBeta() { return m_weightTrimBeta; } /** * Set the value of weightTrimBeta. */ public void setWeightTrimBeta(double n) { m_weightTrimBeta = n; } /** * Get the value of useAIC. * * @return Value of useAIC. */ public boolean getUseAIC() { return m_useAIC; } /** * Set the value of useAIC. * * @param c Value to assign to useAIC. */ public void setUseAIC(boolean c) { m_useAIC = c; } /** * Get the number of LogitBoost iterations performed (= the number of * regression functions fit by LogitBoost). * * @return the number of LogitBoost iterations performed */ public int getNumRegressions() { return m_boostedModel.getNumRegressions(); } /** * Returns a description of the logistic model (attributes/coefficients). * * @return the model as string */ public String toString() { if (m_boostedModel == null) return "No model built"; return "SimpleLogistic:\n" + m_boostedModel.toString(); } /** * Returns the fraction of all attributes in the data that are used in the * logistic model (in percent). An attribute is used in the model if it is * used in any of the models for the different classes. * * @return percentage of attributes used in the model */ public double measureAttributesUsed() { return m_boostedModel.percentAttributesUsed(); } /** * Returns an enumeration of the additional measure names * * @return an enumeration of the measure names */ public Enumeration<String> enumerateMeasures() { Vector<String> newVector = new Vector<String>(3); newVector.addElement("measureAttributesUsed"); newVector.addElement("measureNumIterations"); return newVector.elements(); } /** * Returns the value of the named measure * * @param additionalMeasureName the name of the measure to query for its value * @return the value of the named measure * @throws IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.compareToIgnoreCase("measureAttributesUsed") == 0) { return measureAttributesUsed(); } else if (additionalMeasureName.compareToIgnoreCase("measureNumIterations") == 0) { return getNumRegressions(); } else { throw new IllegalArgumentException(additionalMeasureName + " not supported (SimpleLogistic)"); } } /** * Returns a string describing classifier * * @return a description suitable for displaying in the explorer/experimenter * gui */ public String globalInfo() { return "Classifier for building linear logistic regression models. LogitBoost with simple regression " + "functions as base learners is used for fitting the logistic models. The optimal number of LogitBoost " + "iterations to perform is cross-validated, which leads to automatic attribute selection. " + "For more information see:\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing detailed * information about the technical background of this class, e.g., paper * reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; TechnicalInformation additional; result = new TechnicalInformation(Type.ARTICLE); result.setValue(Field.AUTHOR, "Niels Landwehr and Mark Hall and Eibe Frank"); result.setValue(Field.TITLE, "Logistic Model Trees"); result.setValue(Field.BOOKTITLE, "Machine Learning"); result.setValue(Field.YEAR, "2005"); result.setValue(Field.VOLUME, "95"); result.setValue(Field.PAGES, "161-205"); result.setValue(Field.NUMBER, "1-2"); additional = result.add(Type.INPROCEEDINGS); additional.setValue(Field.AUTHOR, "Marc Sumner and Eibe Frank and Mark Hall"); additional.setValue(Field.TITLE, "Speeding up Logistic Model Tree Induction"); additional.setValue(Field.BOOKTITLE, "9th European Conference on Principles and Practice of Knowledge Discovery in Databases"); additional.setValue(Field.YEAR, "2005"); additional.setValue(Field.PAGES, "675-683"); additional.setValue(Field.PUBLISHER, "Springer"); return result; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String numBoostingIterationsTipText() { return "Set fixed number of iterations for LogitBoost. If >= 0, this sets the number of LogitBoost iterations " + "to perform. If < 0, the number is cross-validated or a stopping criterion on the training set is used " + "(depending on the value of useCrossValidation)."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String useCrossValidationTipText() { return "Sets whether the number of LogitBoost iterations is to be cross-validated or the stopping criterion " + "on the training set should be used. If not set (and no fixed number of iterations was given), " + "the number of LogitBoost iterations is used that minimizes the error on the training set " + "(misclassification error or error on probabilities depending on errorOnProbabilities)."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String errorOnProbabilitiesTipText() { return "Use error on the probabilties as error measure when determining the best number of LogitBoost iterations. " + "If set, the number of LogitBoost iterations is chosen that minimizes the root mean squared error " + "(either on the training set or in the cross-validation, depending on useCrossValidation)."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String maxBoostingIterationsTipText() { return "Sets the maximum number of iterations for LogitBoost. Default value is 500, for very small/large " + "datasets a lower/higher value might be preferable."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String heuristicStopTipText() { return "If heuristicStop > 0, the heuristic for greedy stopping while cross-validating the number of " + "LogitBoost iterations is enabled. This means LogitBoost is stopped if no new error minimum " + "has been reached in the last heuristicStop iterations. It is recommended to use this heuristic, " + "it gives a large speed-up especially on small datasets. The default value is 50."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String weightTrimBetaTipText() { return "Set the beta value used for weight trimming in LogitBoost. " + "Only instances carrying (1 - beta)% of the weight from previous iteration " + "are used in the next iteration. Set to 0 for no weight trimming. " + "The default value is 0."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String useAICTipText() { return "The AIC is used to determine when to stop LogitBoost iterations " + "(instead of cross-validation or training error)."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String numDecimalPlacesTipText() { return "The number of decimal places to be used for the output of coefficients."; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision$"); } }