Java tutorial
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * FastRfBagging.java * Copyright (C) 1999 University of Waikato, Hamilton, NZ (original code, * Bagging.java ) * Copyright (C) 2013 Fran Supek (adapted code) */ package com.tum.classifiertest; import weka.classifiers.Classifier; import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer; import weka.core.*; import java.util.*; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import android.util.Log; /** * Based on the "weka.classifiers.meta.Bagging" class, revision 1.39, * by Kirkby, Frank and Trigg, with modifications: * <ul> * <p/> * <li>Instead of Instances, produces DataCaches; consequently, FastRfBagging * is compatible only with FastRandomTree as base classifier * <p/> * <li>The function for resampling the data is removed; this is a responsibility * of the DataCache objects now * <p/> * <li>Not a TechnicalInformationHandler anymore * <p/> * <li>The classifiers are trained in separate "tasks" which are handled by * an ExecutorService (the FixedThreadPool) which runs the tasks in * more threads in parallel. If the number of threads is not specified, * it will be set automatically to the available number of cores. * <p/> * <li>Estimating the out-of-bag (OOB) error is also multithreaded, using * the VotesCollector class * <p/> * <li>OOB estimation in Weka's Bagging is one tree - one vote. In FastRF 0.97 * onwards, some trees will have a heavier weight in the overall vote * depending on the averaged weights of instances that ended in the specific * leaf. * <p/> * </ul> * This class should be used only from within the FastRandomForest classifier. * * @author Eibe Frank (eibe@cs.waikato.ac.nz) - original code * @author Len Trigg (len@reeltwo.com) - original code * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) - original code * @author Fran Supek (fran.supek[AT]irb.hr) - adapted code * @version $Revision: 0.99$ */ class FastRfBagging extends RandomizableIteratedSingleClassifierEnhancer implements WeightedInstancesHandler, AdditionalMeasureProducer { /** * for serialization */ static final long serialVersionUID = -505879962237199702L; /** * Bagging method. Produces DataCache objects with bootstrap samples of * the original data, and feeds them to the base classifier (which can only * be a FastRandomTree). * * @param data The training set to be used for generating the * bagged classifier. * @param numThreads The number of simultaneous threads to use for * computation. Pass zero (0) for autodetection. * @param motherForest A reference to the FastRandomForest object that * invoked this. * * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data, int numThreads, FastRandomForest motherForest) throws Exception { // can classifier handle the vals? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); if (!(m_Classifier instanceof FastRandomTree)) throw new IllegalArgumentException( "The FastRfBagging class accepts " + "only FastRandomTree as its base classifier."); /* We fill the m_Classifiers array by creating lots of trees with new() * because this is much faster than using serialization to deep-copy the * one tree in m_Classifier - this is what the super.buildClassifier(data) * normally does. */ m_Classifiers = new Classifier[m_NumIterations]; for (int i = 0; i < m_Classifiers.length; i++) { FastRandomTree curTree = new FastRandomTree(); // all parameters for training will be looked up in the motherForest (maxDepth, k_Value) curTree.m_MotherForest = motherForest; // 0.99: reference to these arrays will get passed down all nodes so the array can be re-used // 0.99: this array is of size two as now all splits are binary - even categorical ones curTree.tempProps = new double[2]; curTree.tempDists = new double[2][]; curTree.tempDists[0] = new double[data.numClasses()]; curTree.tempDists[1] = new double[data.numClasses()]; curTree.tempDistsOther = new double[2][]; curTree.tempDistsOther[0] = new double[data.numClasses()]; curTree.tempDistsOther[1] = new double[data.numClasses()]; m_Classifiers[i] = curTree; } // this was SLOW.. takes approx 1/2 time as training the forest afterwards (!!!) // super.buildClassifier(data); if (m_CalcOutOfBag && (m_BagSizePercent != 100)) { throw new IllegalArgumentException( "Bag size needs to be 100% if " + "out-of-bag error is to be calculated!"); } // sorting is performed inside this constructor DataCache myData = new DataCache(data); int bagSize = data.numInstances() * m_BagSizePercent / 100; Random random = new Random(m_Seed); boolean[][] inBag = new boolean[m_Classifiers.length][]; // thread management ExecutorService threadPool = Executors .newFixedThreadPool(numThreads > 0 ? numThreads : Runtime.getRuntime().availableProcessors()); List<Future<?>> futures = new ArrayList<Future<?>>(m_Classifiers.length); try { for (int treeIdx = 0; treeIdx < m_Classifiers.length; treeIdx++) { // create the in-bag dataset (and be sure to remember what's in bag) // for computing the out-of-bag error later DataCache bagData = myData.resample(bagSize, random); bagData.reusableRandomGenerator = bagData.getRandomNumberGenerator(random.nextInt()); inBag[treeIdx] = bagData.inBag; // store later for OOB error calculation // build the classifier if (m_Classifiers[treeIdx] instanceof FastRandomTree) { FastRandomTree aTree = (FastRandomTree) m_Classifiers[treeIdx]; aTree.data = bagData; Future<?> future = threadPool.submit(aTree); futures.add(future); } else { throw new IllegalArgumentException( "The FastRfBagging class accepts " + "only FastRandomTree as its base classifier."); } } // make sure all trees have been trained before proceeding for (int treeIdx = 0; treeIdx < m_Classifiers.length; treeIdx++) { futures.get(treeIdx).get(); } // calc OOB error? if (getCalcOutOfBag() || getComputeImportances()) { //m_OutOfBagError = computeOOBError(data, inBag, threadPool); m_OutOfBagError = computeOOBError(myData, inBag, threadPool); } else { m_OutOfBagError = 0; } //calc feature importances m_FeatureImportances = null; //m_FeatureNames = null; if (getComputeImportances()) { m_FeatureImportances = new double[data.numAttributes()]; ///m_FeatureNames = new String[data.numAttributes()]; //Instances dataCopy = new Instances(data); //To scramble //int[] permutation = FastRfUtils.randomPermutation(data.numInstances(), random); for (int j = 0; j < data.numAttributes(); j++) { if (j != data.classIndex()) { //double sError = computeOOBError(FastRfUtils.scramble(data, dataCopy, j, permutation), inBag, threadPool); //double sError = computeOOBError(data, inBag, threadPool, j, 0); float[] unscrambled = myData.scrambleOneAttribute(j, random); double sError = computeOOBError(myData, inBag, threadPool); myData.vals[j] = unscrambled; // restore the original state m_FeatureImportances[j] = sError - m_OutOfBagError; } //m_FeatureNames[j] = data.attribute(j).name(); } } threadPool.shutdown(); } finally { threadPool.shutdownNow(); } } /** * Compute the out-of-bag error for a set of instances. * * @param data the instances * @param inBag numTrees x numInstances indicating out-of-bag instances * @param threadPool the pool of threads * * @return the oob error */ private double computeOOBError(Instances data, boolean[][] inBag, ExecutorService threadPool) throws InterruptedException, ExecutionException { boolean numeric = data.classAttribute().isNumeric(); List<Future<Double>> votes = new ArrayList<Future<Double>>(data.numInstances()); for (int i = 0; i < data.numInstances(); i++) { VotesCollector aCollector = new VotesCollector(m_Classifiers, i, data, inBag); votes.add(threadPool.submit(aCollector)); } double outOfBagCount = 0.0; double errorSum = 0.0; for (int i = 0; i < data.numInstances(); i++) { double vote = votes.get(i).get(); // error for instance outOfBagCount += data.instance(i).weight(); if (numeric) { errorSum += StrictMath.abs(vote - data.instance(i).classValue()) * data.instance(i).weight(); } else { if (vote != data.instance(i).classValue()) errorSum += data.instance(i).weight(); } } return errorSum / outOfBagCount; } /** * Compute the out-of-bag error on the instances in a DataCache. This must * be the datacache used for training the FastRandomForest (this is not * checked in the function!). * * @param data the instances (as a DataCache) * @param inBag numTrees x numInstances indicating out-of-bag instances * @param threadPool the pool of threads * * @return the oob error */ private double computeOOBError(DataCache data, boolean[][] inBag, ExecutorService threadPool) throws InterruptedException, ExecutionException { List<Future<Double>> votes = new ArrayList<Future<Double>>(data.numInstances); for (int i = 0; i < data.numInstances; i++) { VotesCollectorDataCache aCollector = new VotesCollectorDataCache(m_Classifiers, i, data, inBag); votes.add(threadPool.submit(aCollector)); } double outOfBagCount = 0.0; double errorSum = 0.0; for (int i = 0; i < data.numInstances; i++) { double vote = votes.get(i).get(); // error for instance outOfBagCount += data.instWeights[i]; if ((int) vote != data.instClassValues[i]) { errorSum += data.instWeights[i]; } } return errorSum / outOfBagCount; } //////////////////////////// // Feature importances stuff //////////////////////////// /** * The value of the features importances. */ private double[] m_FeatureImportances; /** * Whether compute the importances or not. */ private boolean m_computeImportances = true; /** * @return compute feature importances? */ public boolean getComputeImportances() { return m_computeImportances; } /** * @param computeImportances compute feature importances? */ public void setComputeImportances(boolean computeImportances) { m_computeImportances = computeImportances; } /** * @return unnormalized feature importances */ public double[] getFeatureImportances() { return m_FeatureImportances; } /** Used when displaying feature importances. */ //private String[] m_FeatureNames; /** Available only if feature importances have been computed. */ //public String[] getFeatureNames() { // return m_FeatureNames; //} //////////////////////////// // /Feature importances stuff //////////////////////////// /** * Not supported. */ @Override public void buildClassifier(Instances data) throws Exception { throw new Exception("FastRfBagging can be built only from within a FastRandomForest."); } /** * The size of each bag sample, as a percentage of the training size */ protected int m_BagSizePercent = 100; /** * Whether to calculate the out of bag error */ protected boolean m_CalcOutOfBag = true; /** * The out of bag error that has been calculated */ protected double m_OutOfBagError; /** * Constructor. */ public FastRfBagging() { m_Classifier = new FastRandomTree(); } /** * Returns a string describing classifier * * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for bagging a classifier to reduce variance. Can do classification " + "and regression depending on the base learner. \n\n"; } /** * String describing default classifier. * * @return the default classifier classname */ @Override protected String defaultClassifierString() { return "hr.irb.fastRandomForest.FastRfTree"; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ @Override public Enumeration listOptions() { Vector newVector = new Vector(2); newVector.addElement( new Option("\tSize of each bag, as a percentage of the\n" + "\ttraining set size. (default 100)", "P", 1, "-P")); newVector.addElement(new Option("\tCalculate the out of bag error.", "O", 0, "-O")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Parses a given list of options. <p/> * <p/> * <!-- options-start --> * Valid options are: <p/> * <p/> * <pre> -P * Size of each bag, as a percentage of the * training set size. (default 100)</pre> * <p/> * <pre> -O * Calculate the out of bag error.</pre> * <p/> * <pre> -S <num> * Random number seed. * (default 1)</pre> * <p/> * <pre> -I <num> * Number of iterations. * (default 10)</pre> * <p/> * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <p/> * <pre> -W * Full name of base classifier. * (default: fastRandomForest.classifiers.FastRandomTree)</pre> * <p/> * <!-- options-end --> * <p/> * Options after -- are passed to the designated classifier.<p> * * @param options the list of options as an array of strings * * @throws Exception if an option is not supported */ @Override public void setOptions(String[] options) throws Exception { String bagSize = Utils.getOption('P', options); if (bagSize.length() != 0) { setBagSizePercent(Integer.parseInt(bagSize)); } else { setBagSizePercent(100); } setCalcOutOfBag(Utils.getFlag('O', options)); super.setOptions(options); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ @Override public String[] getOptions() { String[] superOptions = super.getOptions(); String[] options = new String[superOptions.length + 3]; int current = 0; options[current++] = "-P"; options[current++] = "" + getBagSizePercent(); if (getCalcOutOfBag()) { options[current++] = "-O"; } System.arraycopy(superOptions, 0, options, current, superOptions.length); current += superOptions.length; while (current < options.length) { options[current++] = ""; } return options; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String bagSizePercentTipText() { return "Size of each bag, as a percentage of the training set size."; } /** * Gets the size of each bag, as a percentage of the training set size. * * @return the bag size, as a percentage. */ public int getBagSizePercent() { return m_BagSizePercent; } /** * Sets the size of each bag, as a percentage of the training set size. * * @param newBagSizePercent the bag size, as a percentage. */ public void setBagSizePercent(int newBagSizePercent) { m_BagSizePercent = newBagSizePercent; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String calcOutOfBagTipText() { return "Whether the out-of-bag error is calculated."; } /** * Set whether the out of bag error is calculated. * * @param calcOutOfBag whether to calculate the out of bag error */ public void setCalcOutOfBag(boolean calcOutOfBag) { m_CalcOutOfBag = calcOutOfBag; } /** * Get whether the out of bag error is calculated. * * @return whether the out of bag error is calculated */ public boolean getCalcOutOfBag() { return m_CalcOutOfBag; } /** * Gets the out of bag error that was calculated as the classifier * was built. * * @return the out of bag error */ public double measureOutOfBagError() { return m_OutOfBagError; } /** * Returns an enumeration of the additional measure names. * * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(1); newVector.addElement("measureOutOfBagError"); return newVector.elements(); } /** * Returns the value of the named measure. * * @param additionalMeasureName the name of the measure to query for its value * * @return the value of the named measure * * @throws IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.equalsIgnoreCase("measureOutOfBagError")) { return measureOutOfBagError(); } else { throw new IllegalArgumentException(additionalMeasureName + " not supported (Bagging)"); } } /** * Calculates the class membership probabilities for the given test * instance. * * @param instance the instance to be classified * * @return preedicted class probability distribution * * @throws Exception if distribution can't be computed successfully */ @Override public double[] distributionForInstance(Instance instance) throws Exception { double[] sums = new double[instance.numClasses()], newProbs; //Log.i("FastRfBagging", " sums length : " + sums.length);//problem is that it is 1 instead of 2. coz instance is not nominal //Log.i("FastRfBagging", "m_NumIterations : " + m_NumIterations); for (int i = 0; i < m_NumIterations; i++) { if (instance.classAttribute().isNumeric()) { sums[0] += m_Classifiers[i].classifyInstance(instance); //Log.i("FastRfBagging", "yesNumeric , sum value : " + sums[0]); } else { newProbs = m_Classifiers[i].distributionForInstance(instance); //Log.i("FastRfBagging", "notNumeric , newProbs length : " + newProbs.length); for (int j = 0; j < newProbs.length; j++) sums[j] += newProbs[j]; } } if (instance.classAttribute().isNumeric()) { sums[0] /= (double) m_NumIterations; return sums; } else if (Utils.eq(Utils.sum(sums), 0)) { return sums; } else { Utils.normalize(sums); return sums; } } /** * Returns description of the bagged classifier. * * @return description of the bagged classifier as a string */ @Override public String toString() { if (m_Classifiers == null) { return "FastRfBagging: No model built yet."; } StringBuffer text = new StringBuffer(); text.append("All the base classifiers: \n\n"); for (int i = 0; i < m_Classifiers.length; i++) text.append(m_Classifiers[i].toString() + "\n\n"); if (m_CalcOutOfBag) { text.append("Out of bag error: " + Utils.doubleToString(m_OutOfBagError, 4) + "\n\n"); } return text.toString(); } /** * Main method for testing this class. * * @param argv the options */ public static void main(String[] argv) { runClassifier(new FastRfBagging(), argv); } public String getRevision() { return RevisionUtils.extract("$Revision: 0.99$"); } }