Java tutorial
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Suite 500, Boston, MA 02110. */ /* * FastRandomForest.java * Copyright (C) 2001 University of Waikato, Hamilton, NZ (original code, * RandomForest.java ) * Copyright (C) 2009 Fran Supek (adapted code) */ package hr.irb.fastRandomForest; import weka.classifiers.AbstractClassifier; import weka.core.*; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import java.util.Enumeration; import java.util.Vector; /** * Based on the "weka.classifiers.trees.RandomForest" class, revision 1.12, * by Richard Kirkby, with minor modifications: * <p/> * - uses FastRfBagger with FastRandomTree, instead of Bagger with RandomTree. * - stores dataset header (instead of every Tree storing its own header) * - checks if only ZeroR model is possible (instead of each Tree checking) * - added "-threads" option * <p/> * <!-- globalinfo-start --> * Class for constructing a forest of random trees.<br/> * <br/> * For more information see: <br/> * <br/> * Leo Breiman (2001). Random Forests. Machine Learning. 45(1):5-32. * <p/> * <!-- globalinfo-end --> * <p/> * <!-- technical-bibtex-start --> * BibTeX: * <pre> * @article{Breiman2001, * author = {Leo Breiman}, * journal = {Machine Learning}, * number = {1}, * pages = {5-32}, * title = {Random Forests}, * volume = {45}, * year = {2001} * } * </pre> * <p/> * <!-- technical-bibtex-end --> * <p/> * <!-- options-start --> * Valid options are: <p/> * <p/> * <pre> -I <number of trees> * Number of trees to build.</pre> * <p/> * <pre> -K <number of features> * Number of features to consider (<1=int(logM+1)).</pre> * <p/> * <pre> -S * Seed for random number generator. * (default 1)</pre> * <p/> * <pre> -depth <num> * The maximum depth of the trees, 0 for unlimited. * (default 0)</pre> * <p/> * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <p/> * <!-- options-end --> * * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) - original code * @author Fran Supek (fran.supek[AT]irb.hr) - adapted code * @version $Revision: 0.99$ */ public class FastRandomForest extends AbstractClassifier implements OptionHandler, Randomizable, WeightedInstancesHandler, AdditionalMeasureProducer, TechnicalInformationHandler { /** for serialization */ static final long serialVersionUID = 4216839470751428700L; /** Number of trees in forest. */ protected int m_numTrees = 100; /** * Number of features to consider in random feature selection. * If less than 1 will use int(logM+1) ) */ protected int m_numFeatures = 0; /** The random seed. */ protected int m_randomSeed = 1; /** Final number of features that were considered in last build. */ protected int m_KValue = 0; /** Number of simultaneous threads to use in computation (0 = autodetect). */ protected int m_NumThreads = 0; /** The bagger. */ protected FastRfBagging m_bagger = null; /** The maximum depth of the trees (0 = unlimited) */ protected int m_MaxDepth = 0; /** The header information. */ protected Instances m_Info = null; /** a ZeroR model in case no model can be built from the data */ protected AbstractClassifier m_ZeroR; /** * Returns a string describing classifier * * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for constructing a forest of random trees.\n\n" + "For more information see: \n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.ARTICLE); result.setValue(Field.AUTHOR, "Leo Breiman"); result.setValue(Field.YEAR, "2001"); result.setValue(Field.TITLE, "Random Forests"); result.setValue(Field.JOURNAL, "Machine Learning"); result.setValue(Field.VOLUME, "45"); result.setValue(Field.NUMBER, "1"); result.setValue(Field.PAGES, "5-32"); return result; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numTreesTipText() { return "The number of trees to be generated."; } /** * Get the value of numTrees. * * @return Value of numTrees. */ public int getNumTrees() { return m_numTrees; } /** * Set the value of numTrees. * * @param newNumTrees Value to assign to numTrees. */ public void setNumTrees(int newNumTrees) { m_numTrees = newNumTrees; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numFeaturesTipText() { return "The number of attributes to be used in random selection (see RandomTree2)."; } /** * Get the number of features used in random selection. * * @return Value of numFeatures. */ public int getNumFeatures() { return m_numFeatures; } /** * Set the number of features to use in random selection. * * @param newNumFeatures Value to assign to numFeatures. */ public void setNumFeatures(int newNumFeatures) { m_numFeatures = newNumFeatures; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String seedTipText() { return "The random number seed to be used."; } /** * Set the seed for random number generation. * * @param seed the seed */ public void setSeed(int seed) { m_randomSeed = seed; } /** * Gets the seed for the random number generations * * @return the seed for the random number generation */ public int getSeed() { return m_randomSeed; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String maxDepthTipText() { return "The maximum depth of the trees, 0 for unlimited."; } /** * Get the maximum depth of trh tree, 0 for unlimited. * * @return the maximum depth. */ public int getMaxDepth() { return m_MaxDepth; } /** * Set the maximum depth of the tree, 0 for unlimited. * * @param value the maximum depth. */ public void setMaxDepth(int value) { m_MaxDepth = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numThreadsTipText() { return "Number of simultaneous threads to use in computation (0 = autodetect)."; } /** * Get the number of simultaneous threads used in training, 0 for autodetect. * * @return the maximum depth. */ public int getNumThreads() { return m_NumThreads; } /** * Set the number of simultaneous threads used in training, 0 for autodetect. * * @param value the maximum depth. */ public void setNumThreads(int value) { m_NumThreads = value; } //////////////////////////// // Feature importances stuff //////////////////////////// /** * The value of the features importances. */ private double[] m_FeatureImportances; /** * Whether to compute the importances or not. */ private boolean m_computeImportances = false; /** * @return compute feature importances? */ public boolean getComputeImportances() { return m_computeImportances; } /** * @param computeImportances compute feature importances? */ public void setComputeImportances(boolean computeImportances) { m_computeImportances = computeImportances; } /** * Gets the out of bag error that was calculated as the classifier was built. * * @return the out of bag error */ public double measureOutOfBagError() { if (m_bagger != null) { return m_bagger.measureOutOfBagError(); } else return Double.NaN; } /** * Returns an enumeration of the additional measure names. * * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(1); newVector.addElement("measureOutOfBagError"); return newVector.elements(); } /** * Returns the value of the named measure. * * @param additionalMeasureName the name of the measure to query for its value * * @return the value of the named measure * * @throws IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.equalsIgnoreCase("measureOutOfBagError")) { return measureOutOfBagError(); } else { throw new IllegalArgumentException(additionalMeasureName + " not supported (FastRandomForest)"); } } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(); newVector.addElement(new Option("\tNumber of trees to build.", "I", 1, "-I <number of trees>")); newVector.addElement(new Option("\tNumber of features to consider (<1=int(logM+1)).", "K", 1, "-K <number of features>")); newVector.addElement(new Option("\tSeed for random number generator.\n" + "\t(default 1)", "S", 1, "-S")); newVector.addElement(new Option("\tThe maximum depth of the trees, 0 for unlimited.\n" + "\t(default 0)", "depth", 1, "-depth <num>")); newVector.addElement( new Option("\tThe number of simultaneous threads to use for computation, 0 for autodetect.\n" + "\t(default 0)", "threads", 1, "-threads <num>")); newVector.addElement(new Option("\tWhether to compute feature importances.\n", "import", 0, "-import")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Gets the current settings of the forest. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions() { Vector result; String[] options; int i; result = new Vector(); result.add("-I"); result.add("" + getNumTrees()); result.add("-K"); result.add("" + getNumFeatures()); result.add("-S"); result.add("" + getSeed()); if (getMaxDepth() > 0) { result.add("-depth"); result.add("" + getMaxDepth()); } if (getNumThreads() > 0) { result.add("-threads"); result.add("" + getNumThreads()); } if (getComputeImportances()) { result.add("-import"); } options = super.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); return (String[]) result.toArray(new String[result.size()]); } /** * Parses a given list of options. <p/> * <p/> * <!-- options-start --> * Valid options are: <p/> * <p/> * <pre> -I <number of trees> * Number of trees to build.</pre> * <p/> * <pre> -K <number of features> * Number of features to consider (<1=int(logM+1)).</pre> * <p/> * <pre> -S * Seed for random number generator. * (default 1)</pre> * <p/> * <pre> -depth <num> * The maximum depth of the trees, 0 for unlimited. * (default 0)</pre> * <p/> * <pre> -threads * Number of simultaneous threads to use. * (default 0 = autodetect number of available cores)</pre> * <p/> * <pre> -import * Compute and output RF feature importances (slow).</pre> * <p/> * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <p/> * <!-- options-end --> * * @param options the list of options as an array of strings * * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String tmpStr; tmpStr = Utils.getOption('I', options); if (tmpStr.length() != 0) { m_numTrees = Integer.parseInt(tmpStr); } else { m_numTrees = 10; } tmpStr = Utils.getOption('K', options); if (tmpStr.length() != 0) { m_numFeatures = Integer.parseInt(tmpStr); } else { m_numFeatures = 0; } tmpStr = Utils.getOption('S', options); if (tmpStr.length() != 0) { setSeed(Integer.parseInt(tmpStr)); } else { setSeed(1); } tmpStr = Utils.getOption("depth", options); if (tmpStr.length() != 0) { setMaxDepth(Integer.parseInt(tmpStr)); } else { setMaxDepth(0); } tmpStr = Utils.getOption("threads", options); if (tmpStr.length() != 0) { setNumThreads(Integer.parseInt(tmpStr)); } else { setNumThreads(0); } setComputeImportances(Utils.getFlag("import", options)); super.setOptions(options); Utils.checkForRemainingOptions(options); } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { return new FastRandomTree().getCapabilities(); } /** * Builds a classifier for a set of instances. * * @param data the instances to train the classifier with * * @throws Exception if something goes wrong */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); // only class? -> build ZeroR model if (data.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(data); return; } else { m_ZeroR = null; } /* Save header with attribute info. Can be accessed later by FastRfTrees * through their m_MotherForest field. */ m_Info = new Instances(data, 0); m_bagger = new FastRfBagging(); // Set up the tree options which are held in the motherForest. m_KValue = m_numFeatures; if (m_KValue > data.numAttributes() - 1) m_KValue = data.numAttributes() - 1; if (m_KValue < 1) m_KValue = (int) Utils.log2(data.numAttributes()) + 1; FastRandomTree rTree = new FastRandomTree(); rTree.m_MotherForest = this; // allows to retrieve KValue and MaxDepth // some temporary arrays which need to be separate for every tree, so // that the trees can be trained in parallel in different threads // set up the bagger and build the forest m_bagger.setClassifier(rTree); m_bagger.setSeed(m_randomSeed); m_bagger.setNumIterations(m_numTrees); m_bagger.setCalcOutOfBag(true); m_bagger.setComputeImportances(this.getComputeImportances()); m_bagger.buildClassifier(data, m_NumThreads, this); } /** * Returns the class probability distribution for an instance. * * @param instance the instance to be classified * * @return the distribution the forest generates for the instance * * @throws Exception if computation fails */ public double[] distributionForInstance(Instance instance) throws Exception { if (m_ZeroR != null) { // default model? return m_ZeroR.distributionForInstance(instance); } return m_bagger.distributionForInstance(instance); } /** * Outputs a description of this classifier. * * @return a string containing a description of the classifier */ public String toString() { StringBuilder sb = new StringBuilder(); if (m_bagger == null) sb.append("FastRandomForest not built yet"); else { sb.append("FastRandomForest of " + m_numTrees + " trees, each constructed while considering " + m_KValue + " random feature" + (m_KValue == 1 ? "" : "s") + ".\n" + "Out of bag error: " + Utils.doubleToString(m_bagger.measureOutOfBagError() * 100.0, 3) + "%\n" + (getMaxDepth() > 0 ? ("Max. depth of trees: " + getMaxDepth() + "\n") : ("")) + "\n"); if (getComputeImportances()) { sb.append( "Feature importances - increase in out-of-bag error (as % misclassified instances) after feature permuted:\n"); double[] importances = m_bagger.getFeatureImportances(); for (int i = 0; i < importances.length; i++) { sb.append(String.format("%d\t%s\t%6.4f%%\n", i + 1, this.m_Info.attribute(i).name(), i == m_Info.classIndex() ? Double.NaN : importances[i] * 100.0)); //bagger.getFeatureNames()[i] ); } } } return sb.toString(); } /** * Main method for this class. * * @param argv the options */ public static void main(String[] argv) { runClassifier(new FastRandomForest(), argv); } public String getRevision() { return RevisionUtils.extract("$Revision: 0.99$"); } //////////////////////////// // Feature importances stuff //////////////////////////// /** @return the feature importances or <code>null</code> if the importances haven't been computed */ public double[] getFeatureImportances() { return m_bagger.getFeatureImportances(); } //////////////////////////// // /Feature importances stuff //////////////////////////// }