Java tutorial
/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ /* * RotationForest.java * Copyright (C) 2008 Juan Jose Rodriguez * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand * */ package weka.classifiers.meta; import weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer; import weka.core.Attribute; import weka.core.FastVector; import weka.core.Instance; import weka.core.DenseInstance; import weka.core.Instances; import weka.core.Option; import weka.core.OptionHandler; import weka.core.Randomizable; import weka.core.RevisionUtils; import weka.core.TechnicalInformation; import weka.core.TechnicalInformationHandler; import weka.core.Utils; import weka.core.WeightedInstancesHandler; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Normalize; import weka.filters.unsupervised.attribute.PrincipalComponents; import weka.filters.unsupervised.attribute.RemoveUseless; import weka.filters.unsupervised.instance.RemovePercentage; import java.util.Enumeration; import java.util.Random; import java.util.Vector; /** <!-- globalinfo-start --> * Class for construction a Rotation Forest. Can do classification and regression depending on the base learner. <br/> * <br/> * For more information, see<br/> * <br/> * Juan J. Rodriguez, Ludmila I. Kuncheva, Carlos J. Alonso (2006). Rotation Forest: A new classifier ensemble method. IEEE Transactions on Pattern Analysis and Machine Intelligence. 28(10):1619-1630. URL http://doi.ieeecomputersociety.org/10.1109/TPAMI.2006.211. * <p/> <!-- globalinfo-end --> * <!-- technical-bibtex-start --> * BibTeX: * <pre> * @article{Rodriguez2006, * author = {Juan J. Rodriguez and Ludmila I. Kuncheva and Carlos J. Alonso}, * journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, * number = {10}, * pages = {1619-1630}, * title = {Rotation Forest: A new classifier ensemble method}, * volume = {28}, * year = {2006}, * ISSN = {0162-8828}, * URL = {http://doi.ieeecomputersociety.org/10.1109/TPAMI.2006.211} * } * </pre> * <p/> <!-- technical-bibtex-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -N * Whether minGroup (-G) and maxGroup (-H) refer to * the number of groups or their size. * (default: false)</pre> * * <pre> -G <num> * Minimum size of a group of attributes: * if numberOfGroups is true, the minimum number * of groups. * (default: 3)</pre> * * <pre> -H <num> * Maximum size of a group of attributes: * if numberOfGroups is true, the maximum number * of groups. * (default: 3)</pre> * * <pre> -P <num> * Percentage of instances to be removed. * (default: 50)</pre> * * <pre> -F <filter specification> * Full class name of filter to use, followed * by filter options. * eg: "weka.filters.unsupervised.attribute.PrincipalComponents-R 1.0"</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -I <num> * Number of iterations. * (default 10)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -W * Full name of base classifier. * (default: weka.classifiers.trees.J48)</pre> * * <pre> * Options specific to classifier weka.classifiers.trees.J48: * </pre> * * <pre> -U * Use unpruned tree.</pre> * * <pre> -C <pruning confidence> * Set confidence threshold for pruning. * (default 0.25)</pre> * * <pre> -M <minimum number of instances> * Set minimum number of instances per leaf. * (default 2)</pre> * * <pre> -R * Use reduced error pruning.</pre> * * <pre> -N <number of folds> * Set number of folds for reduced error * pruning. One fold is used as pruning set. * (default 3)</pre> * * <pre> -B * Use binary splits only.</pre> * * <pre> -S * Don't perform subtree raising.</pre> * * <pre> -L * Do not clean up after the tree has been built.</pre> * * <pre> -A * Laplace smoothing for predicted probabilities.</pre> * * <pre> -Q <seed> * Seed for random data shuffling (default 1).</pre> * <!-- options-end --> * * @author Juan Jose Rodriguez (jjrodriguez@ubu.es) * @version $Revision$ */ public class RotationForest extends RandomizableParallelIteratedSingleClassifierEnhancer implements WeightedInstancesHandler, TechnicalInformationHandler { // It implements WeightedInstancesHandler because the base classifier // can implement this interface, but in this method the weights are // not used /** for serialization */ static final long serialVersionUID = -3255631880798499936L; /** The minimum size of a group */ protected int m_MinGroup = 3; /** The maximum size of a group */ protected int m_MaxGroup = 3; /** * Whether minGroup and maxGroup refer to the number of groups or their * size */ protected boolean m_NumberOfGroups = false; /** The percentage of instances to be removed */ protected int m_RemovedPercentage = 50; /** The attributes of each group */ protected int[][][] m_Groups = null; /** The type of projection filter */ protected Filter m_ProjectionFilter = null; /** The projection filters */ protected Filter[][] m_ProjectionFilters = null; /** Headers of the transformed dataset */ protected Instances[] m_Headers = null; /** Headers of the reduced datasets */ protected Instances[][] m_ReducedHeaders = null; /** Filter that remove useless attributes */ protected RemoveUseless m_RemoveUseless = null; /** Filter that normalized the attributes */ protected Normalize m_Normalize = null; /** Training data */ protected Instances m_data; protected Instances[] m_instancesOfClasses; protected Random m_random; /** * Constructor. */ public RotationForest() { m_Classifier = new weka.classifiers.trees.J48(); m_ProjectionFilter = defaultFilter(); } /** * Default projection method. */ protected Filter defaultFilter() { PrincipalComponents filter = new PrincipalComponents(); // filter.setNormalize(false); filter.setVarianceCovered(1.0); return filter; } /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for construction a Rotation Forest. Can do classification " + "and regression depending on the base learner. \n\n" + "For more information, see\n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.ARTICLE); result.setValue(Field.AUTHOR, "Juan J. Rodriguez and Ludmila I. Kuncheva and Carlos J. Alonso"); result.setValue(Field.YEAR, "2006"); result.setValue(Field.TITLE, "Rotation Forest: A new classifier ensemble method"); result.setValue(Field.JOURNAL, "IEEE Transactions on Pattern Analysis and Machine Intelligence"); result.setValue(Field.VOLUME, "28"); result.setValue(Field.NUMBER, "10"); result.setValue(Field.PAGES, "1619-1630"); result.setValue(Field.ISSN, "0162-8828"); result.setValue(Field.URL, "http://doi.ieeecomputersociety.org/10.1109/TPAMI.2006.211"); return result; } /** * String describing default classifier. * * @return the default classifier classname */ protected String defaultClassifierString() { return "weka.classifiers.trees.J48"; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(5); newVector.addElement(new Option("\tWhether minGroup (-G) and maxGroup (-H) refer to" + "\n\tthe number of groups or their size." + "\n\t(default: false)", "N", 0, "-N")); newVector.addElement(new Option("\tMinimum size of a group of attributes:" + "\n\t\tif numberOfGroups is true, the minimum number" + "\n\t\tof groups." + "\n\t\t(default: 3)", "G", 1, "-G <num>")); newVector.addElement(new Option("\tMaximum size of a group of attributes:" + "\n\t\tif numberOfGroups is true, the maximum number" + "\n\t\tof groups." + "\n\t\t(default: 3)", "H", 1, "-H <num>")); newVector.addElement( new Option("\tPercentage of instances to be removed." + "\n\t\t(default: 50)", "P", 1, "-P <num>")); newVector.addElement(new Option( "\tFull class name of filter to use, followed\n" + "\tby filter options.\n" + "\teg: \"weka.filters.unsupervised.attribute.PrincipalComponents-R 1.0\"", "F", 1, "-F <filter specification>")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -N * Whether minGroup (-G) and maxGroup (-H) refer to * the number of groups or their size. * (default: false)</pre> * * <pre> -G <num> * Minimum size of a group of attributes: * if numberOfGroups is true, the minimum number * of groups. * (default: 3)</pre> * * <pre> -H <num> * Maximum size of a group of attributes: * if numberOfGroups is true, the maximum number * of groups. * (default: 3)</pre> * * <pre> -P <num> * Percentage of instances to be removed. * (default: 50)</pre> * * <pre> -F <filter specification> * Full class name of filter to use, followed * by filter options. * eg: "weka.filters.unsupervised.attribute.PrincipalComponents-R 1.0"</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -I <num> * Number of iterations. * (default 10)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -W * Full name of base classifier. * (default: weka.classifiers.trees.J48)</pre> * * <pre> * Options specific to classifier weka.classifiers.trees.J48: * </pre> * * <pre> -U * Use unpruned tree.</pre> * * <pre> -C <pruning confidence> * Set confidence threshold for pruning. * (default 0.25)</pre> * * <pre> -M <minimum number of instances> * Set minimum number of instances per leaf. * (default 2)</pre> * * <pre> -R * Use reduced error pruning.</pre> * * <pre> -N <number of folds> * Set number of folds for reduced error * pruning. One fold is used as pruning set. * (default 3)</pre> * * <pre> -B * Use binary splits only.</pre> * * <pre> -S * Don't perform subtree raising.</pre> * * <pre> -L * Do not clean up after the tree has been built.</pre> * * <pre> -A * Laplace smoothing for predicted probabilities.</pre> * * <pre> -Q <seed> * Seed for random data shuffling (default 1).</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { /* Taken from FilteredClassifier */ String filterString = Utils.getOption('F', options); if (filterString.length() > 0) { String[] filterSpec = Utils.splitOptions(filterString); if (filterSpec.length == 0) { throw new IllegalArgumentException("Invalid filter specification string"); } String filterName = filterSpec[0]; filterSpec[0] = ""; setProjectionFilter((Filter) Utils.forName(Filter.class, filterName, filterSpec)); } else { setProjectionFilter(defaultFilter()); } String tmpStr; tmpStr = Utils.getOption('G', options); if (tmpStr.length() != 0) setMinGroup(Integer.parseInt(tmpStr)); else setMinGroup(3); tmpStr = Utils.getOption('H', options); if (tmpStr.length() != 0) setMaxGroup(Integer.parseInt(tmpStr)); else setMaxGroup(3); tmpStr = Utils.getOption('P', options); if (tmpStr.length() != 0) setRemovedPercentage(Integer.parseInt(tmpStr)); else setRemovedPercentage(50); setNumberOfGroups(Utils.getFlag('N', options)); super.setOptions(options); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String[] getOptions() { String[] superOptions = super.getOptions(); String[] options = new String[superOptions.length + 9]; int current = 0; if (getNumberOfGroups()) { options[current++] = "-N"; } options[current++] = "-G"; options[current++] = "" + getMinGroup(); options[current++] = "-H"; options[current++] = "" + getMaxGroup(); options[current++] = "-P"; options[current++] = "" + getRemovedPercentage(); options[current++] = "-F"; options[current++] = getProjectionFilterSpec(); System.arraycopy(superOptions, 0, options, current, superOptions.length); current += superOptions.length; while (current < options.length) { options[current++] = ""; } return options; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numberOfGroupsTipText() { return "Whether minGroup and maxGroup refer to the number of groups or their size."; } /** * Set whether minGroup and maxGroup refer to the number of groups or their * size * * @param numberOfGroups whether minGroup and maxGroup refer to the number * of groups or their size */ public void setNumberOfGroups(boolean numberOfGroups) { m_NumberOfGroups = numberOfGroups; } /** * Get whether minGroup and maxGroup refer to the number of groups or their * size * * @return whether minGroup and maxGroup refer to the number of groups or * their size */ public boolean getNumberOfGroups() { return m_NumberOfGroups; } /** * Returns the tip text for this property * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String minGroupTipText() { return "Minimum size of a group (if numberOfGrups is true, the minimum number of groups."; } /** * Sets the minimum size of a group. * * @param minGroup the minimum value. * of attributes. */ public void setMinGroup(int minGroup) throws IllegalArgumentException { if (minGroup <= 0) throw new IllegalArgumentException("MinGroup has to be positive."); m_MinGroup = minGroup; } /** * Gets the minimum size of a group. * * @return the minimum value. */ public int getMinGroup() { return m_MinGroup; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String maxGroupTipText() { return "Maximum size of a group (if numberOfGrups is true, the maximum number of groups."; } /** * Sets the maximum size of a group. * * @param maxGroup the maximum value. * of attributes. */ public void setMaxGroup(int maxGroup) throws IllegalArgumentException { if (maxGroup <= 0) throw new IllegalArgumentException("MaxGroup has to be positive."); m_MaxGroup = maxGroup; } /** * Gets the maximum size of a group. * * @return the maximum value. */ public int getMaxGroup() { return m_MaxGroup; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String removedPercentageTipText() { return "The percentage of instances to be removed."; } /** * Sets the percentage of instance to be removed * * @param removedPercentage the percentage. */ public void setRemovedPercentage(int removedPercentage) throws IllegalArgumentException { if (removedPercentage < 0) throw new IllegalArgumentException("RemovedPercentage has to be >=0."); if (removedPercentage >= 100) throw new IllegalArgumentException("RemovedPercentage has to be <100."); m_RemovedPercentage = removedPercentage; } /** * Gets the percentage of instances to be removed * * @return the percentage. */ public int getRemovedPercentage() { return m_RemovedPercentage; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String projectionFilterTipText() { return "The filter used to project the data (e.g., PrincipalComponents)."; } /** * Sets the filter used to project the data. * * @param projectionFilter the filter. */ public void setProjectionFilter(Filter projectionFilter) { m_ProjectionFilter = projectionFilter; } /** * Gets the filter used to project the data. * * @return the filter. */ public Filter getProjectionFilter() { return m_ProjectionFilter; } /** * Gets the filter specification string, which contains the class name of * the filter and any options to the filter * * @return the filter string. */ /* Taken from FilteredClassifier */ protected String getProjectionFilterSpec() { Filter c = getProjectionFilter(); if (c instanceof OptionHandler) { return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler) c).getOptions()); } return c.getClass().getName(); } /** * Returns description of the Rotation Forest classifier. * * @return description of the Rotation Forest classifier as a string */ public String toString() { if (m_Classifiers == null) { return "RotationForest: No model built yet."; } StringBuffer text = new StringBuffer(); text.append("All the base classifiers: \n\n"); for (int i = 0; i < m_Classifiers.length; i++) text.append(m_Classifiers[i].toString() + "\n\n"); return text.toString(); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision$"); } protected class ClassifierWrapper extends weka.classifiers.AbstractClassifier { /** For serialization */ private static final long serialVersionUID = 2327175798869994435L; protected weka.classifiers.Classifier m_wrappedClassifier; protected int m_classifierNumber; public ClassifierWrapper(weka.classifiers.Classifier classifier, int classifierNumber) { super(); m_wrappedClassifier = classifier; m_classifierNumber = classifierNumber; } public void buildClassifier(Instances data) throws Exception { m_ReducedHeaders[m_classifierNumber] = new Instances[m_Groups[m_classifierNumber].length]; FastVector transformedAttributes = new FastVector(m_data.numAttributes()); // Construction of the dataset for each group of attributes for (int j = 0; j < m_Groups[m_classifierNumber].length; j++) { FastVector fv = new FastVector(m_Groups[m_classifierNumber][j].length + 1); for (int k = 0; k < m_Groups[m_classifierNumber][j].length; k++) { String newName = m_data.attribute(m_Groups[m_classifierNumber][j][k]).name() + "_" + k; fv.addElement(m_data.attribute(m_Groups[m_classifierNumber][j][k]).copy(newName)); } fv.addElement(m_data.classAttribute().copy()); Instances dataSubSet = new Instances("rotated-" + m_classifierNumber + "-" + j + "-", fv, 0); dataSubSet.setClassIndex(dataSubSet.numAttributes() - 1); // Select instances for the dataset m_ReducedHeaders[m_classifierNumber][j] = new Instances(dataSubSet, 0); boolean[] selectedClasses = selectClasses(m_instancesOfClasses.length, m_random); for (int c = 0; c < selectedClasses.length; c++) { if (!selectedClasses[c]) continue; Enumeration enu = m_instancesOfClasses[c].enumerateInstances(); while (enu.hasMoreElements()) { Instance instance = (Instance) enu.nextElement(); Instance newInstance = new DenseInstance(dataSubSet.numAttributes()); newInstance.setDataset(dataSubSet); for (int k = 0; k < m_Groups[m_classifierNumber][j].length; k++) { newInstance.setValue(k, instance.value(m_Groups[m_classifierNumber][j][k])); } newInstance.setClassValue(instance.classValue()); dataSubSet.add(newInstance); } } dataSubSet.randomize(m_random); // Remove a percentage of the instances Instances originalDataSubSet = dataSubSet; dataSubSet.randomize(m_random); RemovePercentage rp = new RemovePercentage(); rp.setPercentage(m_RemovedPercentage); rp.setInputFormat(dataSubSet); dataSubSet = Filter.useFilter(dataSubSet, rp); if (dataSubSet.numInstances() < 2) { dataSubSet = originalDataSubSet; } // Project de data m_ProjectionFilters[m_classifierNumber][j].setInputFormat(dataSubSet); Instances projectedData = null; do { try { projectedData = Filter.useFilter(dataSubSet, m_ProjectionFilters[m_classifierNumber][j]); } catch (Exception e) { // The data could not be projected, we add some random instances addRandomInstances(dataSubSet, 10, m_random); } } while (projectedData == null); // Include the projected attributes in the attributes of the // transformed dataset for (int a = 0; a < projectedData.numAttributes() - 1; a++) { String newName = projectedData.attribute(a).name() + "_" + j; transformedAttributes.addElement(projectedData.attribute(a).copy(newName)); } } transformedAttributes.addElement(m_data.classAttribute().copy()); Instances transformedData = new Instances("rotated-" + m_classifierNumber + "-", transformedAttributes, 0); transformedData.setClassIndex(transformedData.numAttributes() - 1); m_Headers[m_classifierNumber] = new Instances(transformedData, 0); // Project all the training data Enumeration enu = m_data.enumerateInstances(); while (enu.hasMoreElements()) { Instance instance = (Instance) enu.nextElement(); Instance newInstance = convertInstance(instance, m_classifierNumber); transformedData.add(newInstance); } // Build the base classifier if (m_wrappedClassifier instanceof Randomizable) { ((Randomizable) m_wrappedClassifier).setSeed(m_random.nextInt()); } m_wrappedClassifier.buildClassifier(transformedData); } public double classifierInstance(Instance instance) throws Exception { return m_wrappedClassifier.classifyInstance(instance); } public double[] distributionForInstance(Instance instance) throws Exception { return m_wrappedClassifier.distributionForInstance(instance); } public String toString() { return m_wrappedClassifier.toString(); } } protected Instances getTrainingSet(int iteration) throws Exception { // The wrapped base classifiers' buildClassifier method creates the // transformed training data return m_data; } /** * builds the classifier. * * @param data the training data to be used for generating the * classifier. * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); m_data = new Instances(data); super.buildClassifier(m_data); // Wrap up the base classifiers for (int i = 0; i < m_Classifiers.length; i++) { ClassifierWrapper cw = new ClassifierWrapper(m_Classifiers[i], i); m_Classifiers[i] = cw; } checkMinMax(m_data); if (m_data.numInstances() > 0) { // This function fails if there are 0 instances m_random = m_data.getRandomNumberGenerator(m_Seed); } else { m_random = new Random(m_Seed); } m_RemoveUseless = new RemoveUseless(); m_RemoveUseless.setInputFormat(m_data); m_data = Filter.useFilter(data, m_RemoveUseless); m_Normalize = new Normalize(); m_Normalize.setInputFormat(m_data); m_data = Filter.useFilter(m_data, m_Normalize); if (m_NumberOfGroups) { generateGroupsFromNumbers(m_data, m_random); } else { generateGroupsFromSizes(m_data, m_random); } m_ProjectionFilters = new Filter[m_Groups.length][]; for (int i = 0; i < m_ProjectionFilters.length; i++) { m_ProjectionFilters[i] = Filter.makeCopies(m_ProjectionFilter, m_Groups[i].length); } int numClasses = m_data.numClasses(); m_instancesOfClasses = new Instances[numClasses + 1]; if (m_data.classAttribute().isNumeric()) { m_instancesOfClasses = new Instances[numClasses]; m_instancesOfClasses[0] = m_data; } else { m_instancesOfClasses = new Instances[numClasses + 1]; for (int i = 0; i < m_instancesOfClasses.length; i++) { m_instancesOfClasses[i] = new Instances(m_data, 0); } Enumeration enu = m_data.enumerateInstances(); while (enu.hasMoreElements()) { Instance instance = (Instance) enu.nextElement(); if (instance.classIsMissing()) { m_instancesOfClasses[numClasses].add(instance); } else { int c = (int) instance.classValue(); m_instancesOfClasses[c].add(instance); } } // If there are not instances with a missing class, we do not need to // consider them if (m_instancesOfClasses[numClasses].numInstances() == 0) { Instances[] tmp = m_instancesOfClasses; m_instancesOfClasses = new Instances[numClasses]; System.arraycopy(tmp, 0, m_instancesOfClasses, 0, numClasses); } } // These arrays keep the information of the transformed data set m_Headers = new Instances[m_Classifiers.length]; m_ReducedHeaders = new Instances[m_Classifiers.length][]; buildClassifiers(); if (m_Debug) { printGroups(); } // save memory m_data = null; m_instancesOfClasses = null; m_random = null; } /** * Adds random instances to the dataset. * * @param dataset the dataset * @param numInstances the number of instances * @param random a random number generator */ protected void addRandomInstances(Instances dataset, int numInstances, Random random) { int n = dataset.numAttributes(); double[] v = new double[n]; for (int i = 0; i < numInstances; i++) { for (int j = 0; j < n; j++) { Attribute att = dataset.attribute(j); if (att.isNumeric()) { v[j] = random.nextDouble(); } else if (att.isNominal()) { v[j] = random.nextInt(att.numValues()); } } dataset.add(new DenseInstance(1, v)); } } /** * Checks m_MinGroup and m_MaxGroup * * @param data the dataset */ protected void checkMinMax(Instances data) { if (m_MinGroup > m_MaxGroup) { int tmp = m_MaxGroup; m_MaxGroup = m_MinGroup; m_MinGroup = tmp; } int n = data.numAttributes(); if (m_MaxGroup >= n) m_MaxGroup = n - 1; if (m_MinGroup >= n) m_MinGroup = n - 1; } /** * Selects a non-empty subset of the classes * * @param numClasses the number of classes * @param random the random number generator. * @return a random subset of classes */ protected boolean[] selectClasses(int numClasses, Random random) { int numSelected = 0; boolean selected[] = new boolean[numClasses]; for (int i = 0; i < selected.length; i++) { if (random.nextBoolean()) { selected[i] = true; numSelected++; } } if (numSelected == 0) { selected[random.nextInt(selected.length)] = true; } return selected; } /** * generates the groups of attributes, given their minimum and maximum * sizes. * * @param data the training data to be used for generating the * groups. * @param random the random number generator. */ protected void generateGroupsFromSizes(Instances data, Random random) { m_Groups = new int[m_Classifiers.length][][]; for (int i = 0; i < m_Classifiers.length; i++) { int[] permutation = attributesPermutation(data.numAttributes(), data.classIndex(), random); // The number of groups that have a given size int[] numGroupsOfSize = new int[m_MaxGroup - m_MinGroup + 1]; int numAttributes = 0; int numGroups; // Select the size of each group for (numGroups = 0; numAttributes < permutation.length; numGroups++) { int n = random.nextInt(numGroupsOfSize.length); numGroupsOfSize[n]++; numAttributes += m_MinGroup + n; } m_Groups[i] = new int[numGroups][]; int currentAttribute = 0; int currentSize = 0; for (int j = 0; j < numGroups; j++) { while (numGroupsOfSize[currentSize] == 0) currentSize++; numGroupsOfSize[currentSize]--; int n = m_MinGroup + currentSize; m_Groups[i][j] = new int[n]; for (int k = 0; k < n; k++) { if (currentAttribute < permutation.length) m_Groups[i][j][k] = permutation[currentAttribute]; else // For the last group, it can be necessary to reuse some attributes m_Groups[i][j][k] = permutation[random.nextInt(permutation.length)]; currentAttribute++; } } } } /** * generates the groups of attributes, given their minimum and maximum * numbers. * * @param data the training data to be used for generating the * groups. * @param random the random number generator. */ protected void generateGroupsFromNumbers(Instances data, Random random) { m_Groups = new int[m_Classifiers.length][][]; for (int i = 0; i < m_Classifiers.length; i++) { int[] permutation = attributesPermutation(data.numAttributes(), data.classIndex(), random); int numGroups = m_MinGroup + random.nextInt(m_MaxGroup - m_MinGroup + 1); m_Groups[i] = new int[numGroups][]; int groupSize = permutation.length / numGroups; // Some groups will have an additional attribute int numBiggerGroups = permutation.length % numGroups; // Distribute the attributes in the groups int currentAttribute = 0; for (int j = 0; j < numGroups; j++) { if (j < numBiggerGroups) { m_Groups[i][j] = new int[groupSize + 1]; } else { m_Groups[i][j] = new int[groupSize]; } for (int k = 0; k < m_Groups[i][j].length; k++) { m_Groups[i][j][k] = permutation[currentAttribute++]; } } } } /** * generates a permutation of the attributes. * * @param numAttributes the number of attributes. * @param classAttributes the index of the class attribute. * @param random the random number generator. * @return a permutation of the attributes */ protected int[] attributesPermutation(int numAttributes, int classAttribute, Random random) { int[] permutation = new int[numAttributes - 1]; int i = 0; for (; i < classAttribute; i++) { permutation[i] = i; } for (; i < permutation.length; i++) { permutation[i] = i + 1; } permute(permutation, random); return permutation; } /** * permutes the elements of a given array. * * @param v the array to permute * @param random the random number generator. */ protected void permute(int v[], Random random) { for (int i = v.length - 1; i > 0; i--) { int j = random.nextInt(i + 1); if (i != j) { int tmp = v[i]; v[i] = v[j]; v[j] = tmp; } } } /** * prints the groups. */ protected void printGroups() { for (int i = 0; i < m_Groups.length; i++) { for (int j = 0; j < m_Groups[i].length; j++) { System.err.print("( "); for (int k = 0; k < m_Groups[i][j].length; k++) { System.err.print(m_Groups[i][j][k]); System.err.print(" "); } System.err.print(") "); } System.err.println(); } } /** * Transforms an instance for the i-th classifier. * * @param instance the instance to be transformed * @param i the base classifier number * @return the transformed instance * @throws Exception if the instance can't be converted successfully */ protected Instance convertInstance(Instance instance, int i) throws Exception { Instance newInstance = new DenseInstance(m_Headers[i].numAttributes()); newInstance.setWeight(instance.weight()); newInstance.setDataset(m_Headers[i]); int currentAttribute = 0; // Project the data for each group for (int j = 0; j < m_Groups[i].length; j++) { Instance auxInstance = new DenseInstance(m_Groups[i][j].length + 1); int k; for (k = 0; k < m_Groups[i][j].length; k++) { auxInstance.setValue(k, instance.value(m_Groups[i][j][k])); } auxInstance.setValue(k, instance.classValue()); auxInstance.setDataset(m_ReducedHeaders[i][j]); m_ProjectionFilters[i][j].input(auxInstance); auxInstance = m_ProjectionFilters[i][j].output(); m_ProjectionFilters[i][j].batchFinished(); for (int a = 0; a < auxInstance.numAttributes() - 1; a++) { newInstance.setValue(currentAttribute++, auxInstance.value(a)); } } newInstance.setClassValue(instance.classValue()); return newInstance; } /** * Calculates the class membership probabilities for the given test * instance. * * @param instance the instance to be classified * @return preedicted class probability distribution * @throws Exception if distribution can't be computed successfully */ public double[] distributionForInstance(Instance instance) throws Exception { m_RemoveUseless.input(instance); instance = m_RemoveUseless.output(); m_RemoveUseless.batchFinished(); m_Normalize.input(instance); instance = m_Normalize.output(); m_Normalize.batchFinished(); double[] sums = new double[instance.numClasses()], newProbs; for (int i = 0; i < m_Classifiers.length; i++) { Instance convertedInstance = convertInstance(instance, i); if (instance.classAttribute().isNumeric() == true) { sums[0] += m_Classifiers[i].classifyInstance(convertedInstance); } else { newProbs = m_Classifiers[i].distributionForInstance(convertedInstance); for (int j = 0; j < newProbs.length; j++) sums[j] += newProbs[j]; } } if (instance.classAttribute().isNumeric() == true) { sums[0] /= (double) m_NumIterations; return sums; } else if (Utils.eq(Utils.sum(sums), 0)) { return sums; } else { Utils.normalize(sums); return sums; } } /** * Main method for testing this class. * * @param argv the options */ public static void main(String[] argv) { runClassifier(new RotationForest(), argv); } }