Java tutorial
/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ /* * BFTree.java * Copyright (C) 2007 Haijian Shi * */ package weka.classifiers.trees; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Enumeration; import java.util.Random; import java.util.Vector; import weka.classifiers.Evaluation; import weka.classifiers.RandomizableClassifier; import weka.core.AdditionalMeasureProducer; import weka.core.Attribute; import weka.core.Capabilities; import weka.core.Capabilities.Capability; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.RevisionUtils; import weka.core.SelectedTag; import weka.core.Tag; import weka.core.TechnicalInformation; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import weka.core.TechnicalInformationHandler; import weka.core.Utils; import weka.core.matrix.Matrix; /** * <!-- globalinfo-start --> Class for building a best-first decision tree * classifier. This class uses binary split for both nominal and numeric * attributes. For missing values, the method of 'fractional' instances is used.<br/> * <br/> * For more information, see:<br/> * <br/> * Haijian Shi (2007). Best-first decision tree learning. Hamilton, NZ.<br/> * <br/> * Jerome Friedman, Trevor Hastie, Robert Tibshirani (2000). Additive logistic * regression : A statistical view of boosting. Annals of statistics. * 28(2):337-407. * <p/> * <!-- globalinfo-end --> * * <!-- technical-bibtex-start --> BibTeX: * * <pre> * @mastersthesis{Shi2007, * address = {Hamilton, NZ}, * author = {Haijian Shi}, * note = {COMP594}, * school = {University of Waikato}, * title = {Best-first decision tree learning}, * year = {2007} * } * * @article{Friedman2000, * author = {Jerome Friedman and Trevor Hastie and Robert Tibshirani}, * journal = {Annals of statistics}, * number = {2}, * pages = {337-407}, * title = {Additive logistic regression : A statistical view of boosting}, * volume = {28}, * year = {2000}, * ISSN = {0090-5364} * } * </pre> * <p/> * <!-- technical-bibtex-end --> * * <!-- options-start --> Valid options are: * <p/> * * <pre> * -S <num> * Random number seed. * (default 1) * </pre> * * <pre> * -D * If set, classifier is run in debug mode and * may output additional info to the console * </pre> * * <pre> * -P <UNPRUNED|POSTPRUNED|PREPRUNED> * The pruning strategy. * (default: POSTPRUNED) * </pre> * * <pre> * -M <min no> * The minimal number of instances at the terminal nodes. * (default 2) * </pre> * * <pre> * -N <num folds> * The number of folds used in the pruning. * (default 5) * </pre> * * <pre> * -H * Don't use heuristic search for nominal attributes in multi-class * problem (default yes). * </pre> * * <pre> * -G * Don't use Gini index for splitting (default yes), * if not information is used. * </pre> * * <pre> * -R * Don't use error rate in internal cross-validation (default yes), * but root mean squared error. * </pre> * * <pre> * -A * Use the 1 SE rule to make pruning decision. * (default no). * </pre> * * <pre> * -C * Percentage of training data size (0-1] * (default 1). * </pre> * * <!-- options-end --> * * @author Haijian Shi (hs69@cs.waikato.ac.nz) * @version $Revision$ */ public class BFTree extends RandomizableClassifier implements AdditionalMeasureProducer, TechnicalInformationHandler { /** For serialization. */ private static final long serialVersionUID = -7035607375962528217L; /** pruning strategy: un-pruned */ public static final int PRUNING_UNPRUNED = 0; /** pruning strategy: post-pruning */ public static final int PRUNING_POSTPRUNING = 1; /** pruning strategy: pre-pruning */ public static final int PRUNING_PREPRUNING = 2; /** pruning strategy */ public static final Tag[] TAGS_PRUNING = { new Tag(PRUNING_UNPRUNED, "unpruned", "Un-pruned"), new Tag(PRUNING_POSTPRUNING, "postpruned", "Post-pruning"), new Tag(PRUNING_PREPRUNING, "prepruned", "Pre-pruning") }; /** the pruning strategy */ protected int m_PruningStrategy = PRUNING_POSTPRUNING; /** Successor nodes. */ protected BFTree[] m_Successors; /** Attribute used for splitting. */ protected Attribute m_Attribute; /** Split point (for numeric attributes). */ protected double m_SplitValue; /** Split subset (for nominal attributes). */ protected String m_SplitString; /** Class value for a node. */ protected double m_ClassValue; /** Class attribute of a dataset. */ protected Attribute m_ClassAttribute; /** Minimum number of instances at leaf nodes. */ protected int m_minNumObj = 2; /** Number of folds for the pruning. */ protected int m_numFoldsPruning = 5; /** If the ndoe is leaf node. */ protected boolean m_isLeaf; /** Number of expansions. */ protected static int m_Expansion; /** * Fixed number of expansions (if no pruning method is used, its value is -1. * Otherwise, its value is gotten from internal cross-validation). */ protected int m_FixedExpansion = -1; /** * If use huristic search for binary split (default true). Note even if its * value is true, it is only used when the number of values of a nominal * attribute is larger than 4. */ protected boolean m_Heuristic = true; /** * If use Gini index as the splitting criterion - default (if not, information * is used). */ protected boolean m_UseGini = true; /** * If use error rate in internal cross-validation to fix the number of * expansions - default (if not, root mean squared error is used). */ protected boolean m_UseErrorRate = true; /** If use the 1SE rule to make the decision. */ protected boolean m_UseOneSE = false; /** Class distributions. */ protected double[] m_Distribution; /** Branch proportions. */ protected double[] m_Props; /** Sorted indices. */ protected int[][] m_SortedIndices; /** Sorted weights. */ protected double[][] m_Weights; /** Distributions of each attribute for two successor nodes. */ protected double[][][] m_Dists; /** Class probabilities. */ protected double[] m_ClassProbs; /** Total weights. */ protected double m_TotalWeight; /** The training data size (0-1). Default 1. */ protected double m_SizePer = 1; /** * Returns a string describing classifier * * @return a description suitable for displaying in the explorer/experimenter * gui */ public String globalInfo() { return "Class for building a best-first decision tree classifier. " + "This class uses binary split for both nominal and numeric attributes. " + "For missing values, the method of 'fractional' instances is used.\n\n" + "For more information, see:\n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing detailed * information about the technical background of this class, e.g., paper * reference or book this class is based on. * * @return the technical information about this class */ @Override public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; TechnicalInformation additional; result = new TechnicalInformation(Type.MASTERSTHESIS); result.setValue(Field.AUTHOR, "Haijian Shi"); result.setValue(Field.YEAR, "2007"); result.setValue(Field.TITLE, "Best-first decision tree learning"); result.setValue(Field.SCHOOL, "University of Waikato"); result.setValue(Field.ADDRESS, "Hamilton, NZ"); result.setValue(Field.NOTE, "COMP594"); additional = result.add(Type.ARTICLE); additional.setValue(Field.AUTHOR, "Jerome Friedman and Trevor Hastie and Robert Tibshirani"); additional.setValue(Field.YEAR, "2000"); additional.setValue(Field.TITLE, "Additive logistic regression : A statistical view of boosting"); additional.setValue(Field.JOURNAL, "Annals of statistics"); additional.setValue(Field.VOLUME, "28"); additional.setValue(Field.NUMBER, "2"); additional.setValue(Field.PAGES, "337-407"); additional.setValue(Field.ISSN, "0090-5364"); return result; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); return result; } /** * Method for building a BestFirst decision tree classifier. * * @param data set of instances serving as training data * @throws Exception if decision tree cannot be built successfully */ @Override public void buildClassifier(Instances data) throws Exception { getCapabilities().testWithFail(data); data = new Instances(data); data.deleteWithMissingClass(); // build an unpruned tree if (m_PruningStrategy == PRUNING_UNPRUNED) { // calculate sorted indices, weights and initial class probabilities int[][] sortedIndices = new int[data.numAttributes()][0]; double[][] weights = new double[data.numAttributes()][0]; double[] classProbs = new double[data.numClasses()]; double totalWeight = computeSortedInfo(data, sortedIndices, weights, classProbs); // Compute information of the best split for this node (include split // attribute, // split value and gini gain (or information gain)). At the same time, // compute // variables dists, props and totalSubsetWeights. double[][][] dists = new double[data.numAttributes()][2][data.numClasses()]; double[][] props = new double[data.numAttributes()][2]; double[][] totalSubsetWeights = new double[data.numAttributes()][2]; ArrayList<Object> nodeInfo = computeSplitInfo(this, data, sortedIndices, weights, dists, props, totalSubsetWeights, m_Heuristic, m_UseGini); // add the node (with all split info) into BestFirstElements ArrayList<ArrayList<Object>> BestFirstElements = new ArrayList<ArrayList<Object>>(); BestFirstElements.add(nodeInfo); // Make the best-first decision tree. int attIndex = ((Attribute) nodeInfo.get(1)).index(); m_Expansion = 0; makeTree(BestFirstElements, data, sortedIndices, weights, dists, classProbs, totalWeight, props[attIndex], m_minNumObj, m_Heuristic, m_UseGini, m_FixedExpansion); return; } // the following code is for pre-pruning and post-pruning methods // Compute train data, test data, sorted indices, sorted weights, total // weights, // class probabilities, class distributions, branch proportions and total // subset // weights for root nodes of each fold for prepruning and postpruning. int expansion = 0; Random random = new Random(m_Seed); Instances cvData = new Instances(data); cvData.randomize(random); cvData = new Instances(cvData, 0, (int) (cvData.numInstances() * m_SizePer) - 1); cvData.stratify(m_numFoldsPruning); Instances[] train = new Instances[m_numFoldsPruning]; Instances[] test = new Instances[m_numFoldsPruning]; @SuppressWarnings("unchecked") ArrayList<ArrayList<Object>>[] parallelBFElements = new ArrayList[m_numFoldsPruning]; BFTree[] m_roots = new BFTree[m_numFoldsPruning]; int[][][] sortedIndices = new int[m_numFoldsPruning][data.numAttributes()][0]; double[][][] weights = new double[m_numFoldsPruning][data.numAttributes()][0]; double[][] classProbs = new double[m_numFoldsPruning][data.numClasses()]; double[] totalWeight = new double[m_numFoldsPruning]; double[][][][] dists = new double[m_numFoldsPruning][data.numAttributes()][2][data.numClasses()]; double[][][] props = new double[m_numFoldsPruning][data.numAttributes()][2]; double[][][] totalSubsetWeights = new double[m_numFoldsPruning][data.numAttributes()][2]; @SuppressWarnings("unchecked") ArrayList<Object>[] nodeInfo = new ArrayList[m_numFoldsPruning]; for (int i = 0; i < m_numFoldsPruning; i++) { train[i] = cvData.trainCV(m_numFoldsPruning, i); test[i] = cvData.testCV(m_numFoldsPruning, i); parallelBFElements[i] = new ArrayList<ArrayList<Object>>(); m_roots[i] = new BFTree(); // calculate sorted indices, weights, initial class counts and total // weights for each training data totalWeight[i] = computeSortedInfo(train[i], sortedIndices[i], weights[i], classProbs[i]); // compute information of the best split for this node (include split // attribute, // split value and gini gain (or information gain)) in this fold nodeInfo[i] = computeSplitInfo(m_roots[i], train[i], sortedIndices[i], weights[i], dists[i], props[i], totalSubsetWeights[i], m_Heuristic, m_UseGini); // compute information for root nodes int attIndex = ((Attribute) nodeInfo[i].get(1)).index(); m_roots[i].m_SortedIndices = new int[sortedIndices[i].length][0]; m_roots[i].m_Weights = new double[weights[i].length][0]; m_roots[i].m_Dists = new double[dists[i].length][0][0]; m_roots[i].m_ClassProbs = new double[classProbs[i].length]; m_roots[i].m_Distribution = new double[classProbs[i].length]; m_roots[i].m_Props = new double[2]; for (int j = 0; j < m_roots[i].m_SortedIndices.length; j++) { m_roots[i].m_SortedIndices[j] = sortedIndices[i][j]; m_roots[i].m_Weights[j] = weights[i][j]; m_roots[i].m_Dists[j] = dists[i][j]; } System.arraycopy(classProbs[i], 0, m_roots[i].m_ClassProbs, 0, classProbs[i].length); if (Utils.sum(m_roots[i].m_ClassProbs) != 0) { Utils.normalize(m_roots[i].m_ClassProbs); } System.arraycopy(classProbs[i], 0, m_roots[i].m_Distribution, 0, classProbs[i].length); System.arraycopy(props[i][attIndex], 0, m_roots[i].m_Props, 0, props[i][attIndex].length); m_roots[i].m_TotalWeight = totalWeight[i]; parallelBFElements[i].add(nodeInfo[i]); } // build a pre-pruned tree if (m_PruningStrategy == PRUNING_PREPRUNING) { double previousError = Double.MAX_VALUE; double currentError = previousError; double minError = Double.MAX_VALUE; ArrayList<Double> errorList = new ArrayList<Double>(); while (true) { // compute average error double expansionError = 0; int count = 0; for (int i = 0; i < m_numFoldsPruning; i++) { Evaluation eval; // calculate error rate if only root node if (expansion == 0) { m_roots[i].m_isLeaf = true; eval = new Evaluation(test[i]); eval.evaluateModel(m_roots[i], test[i]); if (m_UseErrorRate) { expansionError += eval.errorRate(); } else { expansionError += eval.rootMeanSquaredError(); } count++; } // make tree - expand one node at a time else { if (m_roots[i] == null) { continue; // if the tree cannot be expanded, go to next fold } m_roots[i].m_isLeaf = false; BFTree nodeToSplit = (BFTree) ((parallelBFElements[i].get(0)).get(0)); if (!m_roots[i].makeTree(parallelBFElements[i], m_roots[i], train[i], nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights, nodeToSplit.m_Dists, nodeToSplit.m_ClassProbs, nodeToSplit.m_TotalWeight, nodeToSplit.m_Props, m_minNumObj, m_Heuristic, m_UseGini)) { m_roots[i] = null; // cannot be expanded continue; } eval = new Evaluation(test[i]); eval.evaluateModel(m_roots[i], test[i]); if (m_UseErrorRate) { expansionError += eval.errorRate(); } else { expansionError += eval.rootMeanSquaredError(); } count++; } } // no tree can be expanded any more if (count == 0) { break; } expansionError /= count; errorList.add(new Double(expansionError)); currentError = expansionError; if (!m_UseOneSE) { if (currentError > previousError) { break; } } else { if (expansionError < minError) { minError = expansionError; } if (currentError > previousError) { double oneSE = Math.sqrt(minError * (1 - minError) / data.numInstances()); if (currentError > minError + oneSE) { break; } } } expansion++; previousError = currentError; } if (!m_UseOneSE) { expansion = expansion - 1; } else { double oneSE = Math.sqrt(minError * (1 - minError) / data.numInstances()); for (int i = 0; i < errorList.size(); i++) { double error = (errorList.get(i)).doubleValue(); if (error <= minError + oneSE) { // && counts[i]>=m_numFoldsPruning/2) // { expansion = i; break; } } } } // build a postpruned tree else { @SuppressWarnings("unchecked") ArrayList<Double>[] modelError = new ArrayList[m_numFoldsPruning]; // calculate error of each expansion for each fold for (int i = 0; i < m_numFoldsPruning; i++) { modelError[i] = new ArrayList<Double>(); m_roots[i].m_isLeaf = true; Evaluation eval = new Evaluation(test[i]); eval.evaluateModel(m_roots[i], test[i]); double error; if (m_UseErrorRate) { error = eval.errorRate(); } else { error = eval.rootMeanSquaredError(); } modelError[i].add(new Double(error)); m_roots[i].m_isLeaf = false; BFTree nodeToSplit = (BFTree) ((parallelBFElements[i].get(0)).get(0)); m_roots[i].makeTree(parallelBFElements[i], m_roots[i], train[i], test[i], modelError[i], nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights, nodeToSplit.m_Dists, nodeToSplit.m_ClassProbs, nodeToSplit.m_TotalWeight, nodeToSplit.m_Props, m_minNumObj, m_Heuristic, m_UseGini, m_UseErrorRate); m_roots[i] = null; } // find the expansion with minimal error rate double minError = Double.MAX_VALUE; int maxExpansion = modelError[0].size(); for (int i = 1; i < modelError.length; i++) { if (modelError[i].size() > maxExpansion) { maxExpansion = modelError[i].size(); } } double[] error = new double[maxExpansion]; int[] counts = new int[maxExpansion]; for (int i = 0; i < maxExpansion; i++) { counts[i] = 0; error[i] = 0; for (int j = 0; j < m_numFoldsPruning; j++) { if (i < modelError[j].size()) { error[i] += modelError[j].get(i).doubleValue(); counts[i]++; } } error[i] = error[i] / counts[i]; // average error for each expansion if (error[i] < minError) {// && counts[i]>=m_numFoldsPruning/2) { minError = error[i]; expansion = i; } } // the 1 SE rule choosen if (m_UseOneSE) { double oneSE = Math.sqrt(minError * (1 - minError) / data.numInstances()); for (int i = 0; i < maxExpansion; i++) { if (error[i] <= minError + oneSE) { // && // counts[i]>=m_numFoldsPruning/2) // { expansion = i; break; } } } } // make tree on all data based on the expansion caculated // from cross-validation // calculate sorted indices, weights and initial class counts int[][] prune_sortedIndices = new int[data.numAttributes()][0]; double[][] prune_weights = new double[data.numAttributes()][0]; double[] prune_classProbs = new double[data.numClasses()]; double prune_totalWeight = computeSortedInfo(data, prune_sortedIndices, prune_weights, prune_classProbs); // compute information of the best split for this node (include split // attribute, // split value and gini gain) double[][][] prune_dists = new double[data.numAttributes()][2][data.numClasses()]; double[][] prune_props = new double[data.numAttributes()][2]; double[][] prune_totalSubsetWeights = new double[data.numAttributes()][2]; ArrayList<Object> prune_nodeInfo = computeSplitInfo(this, data, prune_sortedIndices, prune_weights, prune_dists, prune_props, prune_totalSubsetWeights, m_Heuristic, m_UseGini); // add the root node (with its split info) to BestFirstElements ArrayList<ArrayList<Object>> BestFirstElements = new ArrayList<ArrayList<Object>>(); BestFirstElements.add(prune_nodeInfo); int attIndex = ((Attribute) prune_nodeInfo.get(1)).index(); m_Expansion = 0; makeTree(BestFirstElements, data, prune_sortedIndices, prune_weights, prune_dists, prune_classProbs, prune_totalWeight, prune_props[attIndex], m_minNumObj, m_Heuristic, m_UseGini, expansion); } /** * Recursively build a best-first decision tree. Method for building a * Best-First tree for a given number of expansions. preExpasion is -1 means * that no expansion is specified (just for a tree without any pruning * method). Pre-pruning and post-pruning methods also use this method to build * the final tree on all training data based on the expansion calculated from * internal cross-validation. * * @param BestFirstElements list to store BFTree nodes * @param data training data * @param sortedIndices sorted indices of the instances * @param weights weights of the instances * @param dists class distributions for each attribute * @param classProbs class probabilities of this node * @param totalWeight total weight of this node (note if the node can not * split, this value is not calculated.) * @param branchProps proportions of two subbranches * @param minNumObj minimal number of instances at leaf nodes * @param useHeuristic if use heuristic search for nominal attributes in * multi-class problem * @param useGini if use Gini index as splitting criterion * @param preExpansion the number of expansions the tree to be expanded * @throws Exception if something goes wrong */ protected void makeTree(ArrayList<ArrayList<Object>> BestFirstElements, Instances data, int[][] sortedIndices, double[][] weights, double[][][] dists, double[] classProbs, double totalWeight, double[] branchProps, int minNumObj, boolean useHeuristic, boolean useGini, int preExpansion) throws Exception { if (BestFirstElements.size() == 0) { return; } // ///////////////////////////////////////////////////////////////////// // All information about the node to split (the first BestFirst object in // BestFirstElements) ArrayList<Object> firstElement = BestFirstElements.get(0); // split attribute Attribute att = (Attribute) firstElement.get(1); // info of split value or split string double splitValue = Double.NaN; String splitStr = null; if (att.isNumeric()) { splitValue = ((Double) firstElement.get(2)).doubleValue(); } else { splitStr = ((String) firstElement.get(2)).toString(); } // the best gini gain or information gain of this node double gain = ((Double) firstElement.get(3)).doubleValue(); // ///////////////////////////////////////////////////////////////////// if (m_ClassProbs == null) { m_SortedIndices = new int[sortedIndices.length][0]; m_Weights = new double[weights.length][0]; m_Dists = new double[dists.length][0][0]; m_ClassProbs = new double[classProbs.length]; m_Distribution = new double[classProbs.length]; m_Props = new double[2]; for (int i = 0; i < m_SortedIndices.length; i++) { m_SortedIndices[i] = sortedIndices[i]; m_Weights[i] = weights[i]; m_Dists[i] = dists[i]; } System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length); System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length); System.arraycopy(branchProps, 0, m_Props, 0, m_Props.length); m_TotalWeight = totalWeight; if (Utils.sum(m_ClassProbs) != 0) { Utils.normalize(m_ClassProbs); } } // If no enough data or this node can not be split, find next node to split. if (totalWeight < 2 * minNumObj || branchProps[0] == 0 || branchProps[1] == 0) { // remove the first element BestFirstElements.remove(0); makeLeaf(data); if (BestFirstElements.size() != 0) { ArrayList<Object> nextSplitElement = BestFirstElements.get(0); BFTree nextSplitNode = (BFTree) nextSplitElement.get(0); nextSplitNode.makeTree(BestFirstElements, data, nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights, nextSplitNode.m_Dists, nextSplitNode.m_ClassProbs, nextSplitNode.m_TotalWeight, nextSplitNode.m_Props, minNumObj, useHeuristic, useGini, preExpansion); } return; } // If gini gain or information gain is 0, make all nodes in the // BestFirstElements leaf nodes // because these nodes are sorted descendingly according to gini gain or // information gain. // (namely, gini gain or information gain of all nodes in BestFirstEelements // is 0). if (gain == 0 || preExpansion == m_Expansion) { for (int i = 0; i < BestFirstElements.size(); i++) { ArrayList<Object> element = BestFirstElements.get(i); BFTree node = (BFTree) element.get(0); node.makeLeaf(data); } BestFirstElements.clear(); } // gain is not 0 else { // remove the first element BestFirstElements.remove(0); m_Attribute = att; if (m_Attribute.isNumeric()) { m_SplitValue = splitValue; } else { m_SplitString = splitStr; } int[][][] subsetIndices = new int[2][data.numAttributes()][0]; double[][][] subsetWeights = new double[2][data.numAttributes()][0]; splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue, m_SplitString, sortedIndices, weights, data); // If split will generate node(s) which has total weights less than // m_minNumObj, // do not split. int attIndex = att.index(); if (subsetIndices[0][attIndex].length < minNumObj || subsetIndices[1][attIndex].length < minNumObj) { makeLeaf(data); } // split the node else { m_isLeaf = false; m_Attribute = att; // if expansion is specified (if pruning method used) if ((m_PruningStrategy == PRUNING_PREPRUNING) || (m_PruningStrategy == PRUNING_POSTPRUNING) || (preExpansion != -1)) { m_Expansion++; } makeSuccessors(BestFirstElements, data, subsetIndices, subsetWeights, dists, att, useHeuristic, useGini); } // choose next node to split if (BestFirstElements.size() != 0) { ArrayList<Object> nextSplitElement = BestFirstElements.get(0); BFTree nextSplitNode = (BFTree) nextSplitElement.get(0); nextSplitNode.makeTree(BestFirstElements, data, nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights, nextSplitNode.m_Dists, nextSplitNode.m_ClassProbs, nextSplitNode.m_TotalWeight, nextSplitNode.m_Props, minNumObj, useHeuristic, useGini, preExpansion); } } } /** * This method is to find the number of expansions based on internal * cross-validation for just pre-pruning. It expands the first BestFirst node * in the BestFirstElements if it is expansible, otherwise it looks for next * exapansible node. If it finds a node is expansibel, expand the node, then * return true. (note it just expands one node at a time). * * @param BestFirstElements list to store BFTree nodes * @param root root node of tree in each fold * @param train training data * @param sortedIndices sorted indices of the instances * @param weights weights of the instances * @param dists class distributions for each attribute * @param classProbs class probabilities of this node * @param totalWeight total weight of this node (note if the node can not * split, this value is not calculated.) * @param branchProps proportions of two subbranches * @param minNumObj minimal number of instances at leaf nodes * @param useHeuristic if use heuristic search for nominal attributes in * multi-class problem * @param useGini if use Gini index as splitting criterion * @return true if expand successfully, otherwise return false (all nodes in * BestFirstElements cannot be expanded). * @throws Exception if something goes wrong */ protected boolean makeTree(ArrayList<ArrayList<Object>> BestFirstElements, BFTree root, Instances train, int[][] sortedIndices, double[][] weights, double[][][] dists, double[] classProbs, double totalWeight, double[] branchProps, int minNumObj, boolean useHeuristic, boolean useGini) throws Exception { if (BestFirstElements.size() == 0) { return false; } // ///////////////////////////////////////////////////////////////////// // All information about the node to split (first BestFirst object in // BestFirstElements) ArrayList<Object> firstElement = BestFirstElements.get(0); // node to split BFTree nodeToSplit = (BFTree) firstElement.get(0); // split attribute Attribute att = (Attribute) firstElement.get(1); // info of split value or split string double splitValue = Double.NaN; String splitStr = null; if (att.isNumeric()) { splitValue = ((Double) firstElement.get(2)).doubleValue(); } else { splitStr = ((String) firstElement.get(2)).toString(); } // the best gini gain or information gain of this node double gain = ((Double) firstElement.get(3)).doubleValue(); // ///////////////////////////////////////////////////////////////////// // If no enough data to split for this node or this node can not be split // find next node to split. if (totalWeight < 2 * minNumObj || branchProps[0] == 0 || branchProps[1] == 0) { // remove the first element BestFirstElements.remove(0); nodeToSplit.makeLeaf(train); if (BestFirstElements.size() == 0) { return false; } BFTree nextNode = (BFTree) BestFirstElements.get(0).get(0); return root.makeTree(BestFirstElements, root, train, nextNode.m_SortedIndices, nextNode.m_Weights, nextNode.m_Dists, nextNode.m_ClassProbs, nextNode.m_TotalWeight, nextNode.m_Props, minNumObj, useHeuristic, useGini); } // If gini gain or information is 0, make all nodes in the BestFirstElements // leaf nodes // because these node sorted descendingly according to gini gain or // information gain. // (namely, gini gain or information gain of all nodes in BestFirstEelements // is 0). if (gain == 0) { for (int i = 0; i < BestFirstElements.size(); i++) { ArrayList<Object> element = BestFirstElements.get(i); BFTree node = (BFTree) element.get(0); node.makeLeaf(train); } BestFirstElements.clear(); return false; } else { // remove the first element BestFirstElements.remove(0); nodeToSplit.m_Attribute = att; if (att.isNumeric()) { nodeToSplit.m_SplitValue = splitValue; } else { nodeToSplit.m_SplitString = splitStr; } int[][][] subsetIndices = new int[2][train.numAttributes()][0]; double[][][] subsetWeights = new double[2][train.numAttributes()][0]; splitData(subsetIndices, subsetWeights, nodeToSplit.m_Attribute, nodeToSplit.m_SplitValue, nodeToSplit.m_SplitString, nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights, train); // if split will generate node(s) which has total weights less than // m_minNumObj, // do not split int attIndex = att.index(); if (subsetIndices[0][attIndex].length < minNumObj || subsetIndices[1][attIndex].length < minNumObj) { nodeToSplit.makeLeaf(train); BFTree nextNode = (BFTree) BestFirstElements.get(0).get(0); return root.makeTree(BestFirstElements, root, train, nextNode.m_SortedIndices, nextNode.m_Weights, nextNode.m_Dists, nextNode.m_ClassProbs, nextNode.m_TotalWeight, nextNode.m_Props, minNumObj, useHeuristic, useGini); } // split the node else { nodeToSplit.m_isLeaf = false; nodeToSplit.m_Attribute = att; nodeToSplit.makeSuccessors(BestFirstElements, train, subsetIndices, subsetWeights, dists, nodeToSplit.m_Attribute, useHeuristic, useGini); for (int i = 0; i < 2; i++) { nodeToSplit.m_Successors[i].makeLeaf(train); } return true; } } } /** * This method is to find the number of expansions based on internal * cross-validation for just post-pruning. It expands the first BestFirst node * in the BestFirstElements until no node can be split. When building the * tree, stroe error for each temporary tree, namely for each expansion. * * @param BestFirstElements list to store BFTree nodes * @param root root node of tree in each fold * @param train training data in each fold * @param test test data in each fold * @param modelError list to store error for each expansion in each fold * @param sortedIndices sorted indices of the instances * @param weights weights of the instances * @param dists class distributions for each attribute * @param classProbs class probabilities of this node * @param totalWeight total weight of this node (note if the node can not * split, this value is not calculated.) * @param branchProps proportions of two subbranches * @param minNumObj minimal number of instances at leaf nodes * @param useHeuristic if use heuristic search for nominal attributes in * multi-class problem * @param useGini if use Gini index as splitting criterion * @param useErrorRate if use error rate in internal cross-validation * @throws Exception if something goes wrong */ protected void makeTree(ArrayList<ArrayList<Object>> BestFirstElements, BFTree root, Instances train, Instances test, ArrayList<Double> modelError, int[][] sortedIndices, double[][] weights, double[][][] dists, double[] classProbs, double totalWeight, double[] branchProps, int minNumObj, boolean useHeuristic, boolean useGini, boolean useErrorRate) throws Exception { if (BestFirstElements.size() == 0) { return; } // ///////////////////////////////////////////////////////////////////// // All information about the node to split (first BestFirst object in // BestFirstElements) ArrayList<Object> firstElement = BestFirstElements.get(0); // node to split // BFTree nodeToSplit = (BFTree)firstElement.elementAt(0); // split attribute Attribute att = (Attribute) firstElement.get(1); // info of split value or split string double splitValue = Double.NaN; String splitStr = null; if (att.isNumeric()) { splitValue = ((Double) firstElement.get(2)).doubleValue(); } else { splitStr = ((String) firstElement.get(2)).toString(); } // the best gini gain or information of this node double gain = ((Double) firstElement.get(3)).doubleValue(); // ///////////////////////////////////////////////////////////////////// if (totalWeight < 2 * minNumObj || branchProps[0] == 0 || branchProps[1] == 0) { // remove the first element BestFirstElements.remove(0); makeLeaf(train); if (BestFirstElements.size() == 0) { return; } BFTree nextSplitNode = (BFTree) BestFirstElements.get(0).get(0); nextSplitNode.makeTree(BestFirstElements, root, train, test, modelError, nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights, nextSplitNode.m_Dists, nextSplitNode.m_ClassProbs, nextSplitNode.m_TotalWeight, nextSplitNode.m_Props, minNumObj, useHeuristic, useGini, useErrorRate); return; } // If gini gain or information gain is 0, make all nodes in the // BestFirstElements leaf nodes // because these node sorted descendingly according to gini gain or // information gain. // (namely, gini gain or information gain of all nodes in BestFirstEelements // is 0). if (gain == 0) { for (int i = 0; i < BestFirstElements.size(); i++) { ArrayList<Object> element = BestFirstElements.get(i); BFTree node = (BFTree) element.get(0); node.makeLeaf(train); } BestFirstElements.clear(); } // gini gain or information gain is not 0 else { // remove the first element BestFirstElements.remove(0); m_Attribute = att; if (att.isNumeric()) { m_SplitValue = splitValue; } else { m_SplitString = splitStr; } int[][][] subsetIndices = new int[2][train.numAttributes()][0]; double[][][] subsetWeights = new double[2][train.numAttributes()][0]; splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue, m_SplitString, sortedIndices, weights, train); // if split will generate node(s) which has total weights less than // m_minNumObj, // do not split int attIndex = att.index(); if (subsetIndices[0][attIndex].length < minNumObj || subsetIndices[1][attIndex].length < minNumObj) { makeLeaf(train); } // split the node and cauculate error rate of this temporary tree else { m_isLeaf = false; m_Attribute = att; makeSuccessors(BestFirstElements, train, subsetIndices, subsetWeights, dists, m_Attribute, useHeuristic, useGini); for (int i = 0; i < 2; i++) { m_Successors[i].makeLeaf(train); } Evaluation eval = new Evaluation(test); eval.evaluateModel(root, test); double error; if (useErrorRate) { error = eval.errorRate(); } else { error = eval.rootMeanSquaredError(); } modelError.add(new Double(error)); } if (BestFirstElements.size() != 0) { ArrayList<Object> nextSplitElement = BestFirstElements.get(0); BFTree nextSplitNode = (BFTree) nextSplitElement.get(0); nextSplitNode.makeTree(BestFirstElements, root, train, test, modelError, nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights, nextSplitNode.m_Dists, nextSplitNode.m_ClassProbs, nextSplitNode.m_TotalWeight, nextSplitNode.m_Props, minNumObj, useHeuristic, useGini, useErrorRate); } } } /** * Generate successor nodes for a node and put them into BestFirstElements * according to gini gain or information gain in a descending order. * * @param BestFirstElements list to store BestFirst nodes * @param data training instance * @param subsetSortedIndices sorted indices of instances of successor nodes * @param subsetWeights weights of instances of successor nodes * @param dists class distributions of successor nodes * @param att attribute used to split the node * @param useHeuristic if use heuristic search for nominal attributes in * multi-class problem * @param useGini if use Gini index as splitting criterion * @throws Exception if something goes wrong */ protected void makeSuccessors(ArrayList<ArrayList<Object>> BestFirstElements, Instances data, int[][][] subsetSortedIndices, double[][][] subsetWeights, double[][][] dists, Attribute att, boolean useHeuristic, boolean useGini) throws Exception { m_Successors = new BFTree[2]; for (int i = 0; i < 2; i++) { m_Successors[i] = new BFTree(); m_Successors[i].m_isLeaf = true; // class probability and distribution for this successor node m_Successors[i].m_ClassProbs = new double[data.numClasses()]; m_Successors[i].m_Distribution = new double[data.numClasses()]; System.arraycopy(dists[att.index()][i], 0, m_Successors[i].m_ClassProbs, 0, m_Successors[i].m_ClassProbs.length); System.arraycopy(dists[att.index()][i], 0, m_Successors[i].m_Distribution, 0, m_Successors[i].m_Distribution.length); if (Utils.sum(m_Successors[i].m_ClassProbs) != 0) { Utils.normalize(m_Successors[i].m_ClassProbs); } // split information for this successor node double[][] props = new double[data.numAttributes()][2]; double[][][] subDists = new double[data.numAttributes()][2][data.numClasses()]; double[][] totalSubsetWeights = new double[data.numAttributes()][2]; ArrayList<Object> splitInfo = m_Successors[i].computeSplitInfo(m_Successors[i], data, subsetSortedIndices[i], subsetWeights[i], subDists, props, totalSubsetWeights, useHeuristic, useGini); // branch proportion for this successor node int splitIndex = ((Attribute) splitInfo.get(1)).index(); m_Successors[i].m_Props = new double[2]; System.arraycopy(props[splitIndex], 0, m_Successors[i].m_Props, 0, m_Successors[i].m_Props.length); // sorted indices and weights of each attribute for this successor node m_Successors[i].m_SortedIndices = new int[data.numAttributes()][0]; m_Successors[i].m_Weights = new double[data.numAttributes()][0]; for (int j = 0; j < m_Successors[i].m_SortedIndices.length; j++) { m_Successors[i].m_SortedIndices[j] = subsetSortedIndices[i][j]; m_Successors[i].m_Weights[j] = subsetWeights[i][j]; } // distribution of each attribute for this successor node m_Successors[i].m_Dists = new double[data.numAttributes()][2][data.numClasses()]; for (int j = 0; j < subDists.length; j++) { m_Successors[i].m_Dists[j] = subDists[j]; } // total weights for this successor node. m_Successors[i].m_TotalWeight = Utils.sum(totalSubsetWeights[splitIndex]); // insert this successor node into BestFirstElements according to gini // gain or information gain // descendingly if (BestFirstElements.size() == 0) { BestFirstElements.add(splitInfo); } else { double gGain = ((Double) (splitInfo.get(3))).doubleValue(); int vectorSize = BestFirstElements.size(); ArrayList<Object> lastNode = BestFirstElements.get(vectorSize - 1); // If gini gain is less than that of last node in FastVector if (gGain < ((Double) (lastNode.get(3))).doubleValue()) { BestFirstElements.add(vectorSize, splitInfo); } else { for (int j = 0; j < vectorSize; j++) { ArrayList<Object> node = BestFirstElements.get(j); double nodeGain = ((Double) (node.get(3))).doubleValue(); if (gGain >= nodeGain) { BestFirstElements.add(j, splitInfo); break; } } } } } } /** * Compute sorted indices, weights and class probabilities for a given * dataset. Return total weights of the data at the node. * * @param data training data * @param sortedIndices sorted indices of instances at the node * @param weights weights of instances at the node * @param classProbs class probabilities at the node * @return total weights of instances at the node * @throws Exception if something goes wrong */ protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights, double[] classProbs) throws Exception { // Create array of sorted indices and weights double[] vals = new double[data.numInstances()]; for (int j = 0; j < data.numAttributes(); j++) { if (j == data.classIndex()) { continue; } weights[j] = new double[data.numInstances()]; if (data.attribute(j).isNominal()) { // Handling nominal attributes. Putting indices of // instances with missing values at the end. sortedIndices[j] = new int[data.numInstances()]; int count = 0; for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); if (!inst.isMissing(j)) { sortedIndices[j][count] = i; weights[j][count] = inst.weight(); count++; } } for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); if (inst.isMissing(j)) { sortedIndices[j][count] = i; weights[j][count] = inst.weight(); count++; } } } else { // Sorted indices are computed for numeric attributes // missing values instances are put to end (through Utils.sort() method) for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); vals[i] = inst.value(j); } sortedIndices[j] = Utils.sort(vals); for (int i = 0; i < data.numInstances(); i++) { weights[j][i] = data.instance(sortedIndices[j][i]).weight(); } } } // Compute initial class counts and total weight double totalWeight = 0; for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); classProbs[(int) inst.classValue()] += inst.weight(); totalWeight += inst.weight(); } return totalWeight; } /** * Compute the best splitting attribute, split point or subset and the best * gini gain or iformation gain for a given dataset. * * @param node node to be split * @param data training data * @param sortedIndices sorted indices of the instances * @param weights weights of the instances * @param dists class distributions for each attribute * @param props proportions of two branches * @param totalSubsetWeights total weight of two subsets * @param useHeuristic if use heuristic search for nominal attributes in * multi-class problem * @param useGini if use Gini index as splitting criterion * @return split information about the node * @throws Exception if something is wrong */ protected ArrayList<Object> computeSplitInfo(BFTree node, Instances data, int[][] sortedIndices, double[][] weights, double[][][] dists, double[][] props, double[][] totalSubsetWeights, boolean useHeuristic, boolean useGini) throws Exception { double[] splits = new double[data.numAttributes()]; String[] splitString = new String[data.numAttributes()]; double[] gains = new double[data.numAttributes()]; for (int i = 0; i < data.numAttributes(); i++) { if (i == data.classIndex()) { continue; } Attribute att = data.attribute(i); if (att.isNumeric()) { // numeric attribute splits[i] = numericDistribution(props, dists, att, sortedIndices[i], weights[i], totalSubsetWeights, gains, data, useGini); } else { // nominal attribute splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i], weights[i], totalSubsetWeights, gains, data, useHeuristic, useGini); } } int index = Utils.maxIndex(gains); double mBestGain = gains[index]; Attribute att = data.attribute(index); double mValue = Double.NaN; String mString = null; if (att.isNumeric()) { mValue = splits[index]; } else { mString = splitString[index]; if (mString == null) { mString = ""; } } // split information ArrayList<Object> splitInfo = new ArrayList<Object>(); splitInfo.add(node); splitInfo.add(att); if (att.isNumeric()) { splitInfo.add(new Double(mValue)); } else { splitInfo.add(mString); } splitInfo.add(new Double(mBestGain)); return splitInfo; } /** * Compute distributions, proportions and total weights of two successor nodes * for a given numeric attribute. * * @param props proportions of each two branches for each attribute * @param dists class distributions of two branches for each attribute * @param att numeric att split on * @param sortedIndices sorted indices of instances for the attirubte * @param weights weights of instances for the attirbute * @param subsetWeights total weight of two branches split based on the * attribute * @param gains Gini gains or information gains for each attribute * @param data training instances * @param useGini if use Gini index as splitting criterion * @return Gini gain or information gain for the given attribute * @throws Exception if something goes wrong */ protected double numericDistribution(double[][] props, double[][][] dists, Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights, double[] gains, Instances data, boolean useGini) throws Exception { double splitPoint = Double.NaN; double[][] dist = null; int numClasses = data.numClasses(); int i; // differ instances with or without missing values double[][] currDist = new double[2][numClasses]; dist = new double[2][numClasses]; // Move all instances without missing values into second subset double[] parentDist = new double[numClasses]; int missingStart = 0; for (int j = 0; j < sortedIndices.length; j++) { Instance inst = data.instance(sortedIndices[j]); if (!inst.isMissing(att)) { missingStart++; currDist[1][(int) inst.classValue()] += weights[j]; } parentDist[(int) inst.classValue()] += weights[j]; } System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length); // Try all possible split points double currSplit = data.instance(sortedIndices[0]).value(att); double currGain; double bestGain = -Double.MAX_VALUE; for (i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } if (inst.value(att) > currSplit) { double[][] tempDist = new double[2][numClasses]; for (int k = 0; k < 2; k++) { // tempDist[k] = currDist[k]; System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length); } double[] tempProps = new double[2]; for (int k = 0; k < 2; k++) { tempProps[k] = Utils.sum(tempDist[k]); } if (Utils.sum(tempProps) != 0) { Utils.normalize(tempProps); } // split missing values int index = missingStart; while (index < sortedIndices.length) { Instance insta = data.instance(sortedIndices[index]); for (int j = 0; j < 2; j++) { tempDist[j][(int) insta.classValue()] += tempProps[j] * weights[index]; } index++; } if (useGini) { currGain = computeGiniGain(parentDist, tempDist); } else { currGain = computeInfoGain(parentDist, tempDist); } if (currGain > bestGain) { bestGain = currGain; // clean split point splitPoint = Math.rint((inst.value(att) + currSplit) / 2.0 * 100000) / 100000.0; for (int j = 0; j < currDist.length; j++) { System.arraycopy(tempDist[j], 0, dist[j], 0, dist[j].length); } } } currSplit = inst.value(att); currDist[0][(int) inst.classValue()] += weights[i]; currDist[1][(int) inst.classValue()] -= weights[i]; } // Compute weights int attIndex = att.index(); props[attIndex] = new double[2]; for (int k = 0; k < 2; k++) { props[attIndex][k] = Utils.sum(dist[k]); } if (Utils.sum(props[attIndex]) != 0) { Utils.normalize(props[attIndex]); } // Compute subset weights subsetWeights[attIndex] = new double[2]; for (int j = 0; j < 2; j++) { subsetWeights[attIndex][j] += Utils.sum(dist[j]); } // clean gain gains[attIndex] = Math.rint(bestGain * 10000000) / 10000000.0; dists[attIndex] = dist; return splitPoint; } /** * Compute distributions, proportions and total weights of two successor nodes * for a given nominal attribute. * * @param props proportions of each two branches for each attribute * @param dists class distributions of two branches for each attribute * @param att numeric att split on * @param sortedIndices sorted indices of instances for the attirubte * @param weights weights of instances for the attirbute * @param subsetWeights total weight of two branches split based on the * attribute * @param gains Gini gains for each attribute * @param data training instances * @param useHeuristic if use heuristic search * @param useGini if use Gini index as splitting criterion * @return Gini gain for the given attribute * @throws Exception if something goes wrong */ protected String nominalDistribution(double[][] props, double[][][] dists, Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights, double[] gains, Instances data, boolean useHeuristic, boolean useGini) throws Exception { String[] values = new String[att.numValues()]; int numCat = values.length; // number of values of the attribute int numClasses = data.numClasses(); String bestSplitString = ""; double bestGain = -Double.MAX_VALUE; // class frequency for each value int[] classFreq = new int[numCat]; for (int j = 0; j < numCat; j++) { classFreq[j] = 0; } double[] parentDist = new double[numClasses]; double[][] currDist = new double[2][numClasses]; double[][] dist = new double[2][numClasses]; int missingStart = 0; for (int i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (!inst.isMissing(att)) { missingStart++; classFreq[(int) inst.value(att)]++; } parentDist[(int) inst.classValue()] += weights[i]; } // count the number of values that class frequency is not 0 int nonEmpty = 0; for (int j = 0; j < numCat; j++) { if (classFreq[j] != 0) { nonEmpty++; } } // attribute values which class frequency is not 0 String[] nonEmptyValues = new String[nonEmpty]; int nonEmptyIndex = 0; for (int j = 0; j < numCat; j++) { if (classFreq[j] != 0) { nonEmptyValues[nonEmptyIndex] = att.value(j); nonEmptyIndex++; } } // attribute values which class frequency is 0 int empty = numCat - nonEmpty; String[] emptyValues = new String[empty]; int emptyIndex = 0; for (int j = 0; j < numCat; j++) { if (classFreq[j] == 0) { emptyValues[emptyIndex] = att.value(j); emptyIndex++; } } if (nonEmpty <= 1) { gains[att.index()] = 0; return ""; } // for tow-class probloms if (data.numClasses() == 2) { // // Firstly, for attribute values which class frequency is not zero // probability of class 0 for each attribute value double[] pClass0 = new double[nonEmpty]; // class distribution for each attribute value double[][] valDist = new double[nonEmpty][2]; for (int j = 0; j < nonEmpty; j++) { for (int k = 0; k < 2; k++) { valDist[j][k] = 0; } } for (int sortedIndice : sortedIndices) { Instance inst = data.instance(sortedIndice); if (inst.isMissing(att)) { break; } for (int j = 0; j < nonEmpty; j++) { if (att.value((int) inst.value(att)).compareTo(nonEmptyValues[j]) == 0) { valDist[j][(int) inst.classValue()] += inst.weight(); break; } } } for (int j = 0; j < nonEmpty; j++) { double distSum = Utils.sum(valDist[j]); if (distSum == 0) { pClass0[j] = 0; } else { pClass0[j] = valDist[j][0] / distSum; } } // sort category according to the probability of class 0.0 String[] sortedValues = new String[nonEmpty]; for (int j = 0; j < nonEmpty; j++) { sortedValues[j] = nonEmptyValues[Utils.minIndex(pClass0)]; pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE; } // Find a subset of attribute values that maximize impurity decrease // for the attribute values that class frequency is not 0 String tempStr = ""; for (int j = 0; j < nonEmpty - 1; j++) { currDist = new double[2][numClasses]; if (tempStr == "") { tempStr = "(" + sortedValues[j] + ")"; } else { tempStr += "|" + "(" + sortedValues[j] + ")"; } // System.out.println(sortedValues[j]); for (int i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } if (tempStr.indexOf("(" + att.value((int) inst.value(att)) + ")") != -1) { currDist[0][(int) inst.classValue()] += weights[i]; } else { currDist[1][(int) inst.classValue()] += weights[i]; } } double[][] tempDist = new double[2][numClasses]; for (int kk = 0; kk < 2; kk++) { tempDist[kk] = currDist[kk]; } double[] tempProps = new double[2]; for (int kk = 0; kk < 2; kk++) { tempProps[kk] = Utils.sum(tempDist[kk]); } if (Utils.sum(tempProps) != 0) { Utils.normalize(tempProps); } // split missing values int mstart = missingStart; while (mstart < sortedIndices.length) { Instance insta = data.instance(sortedIndices[mstart]); for (int jj = 0; jj < 2; jj++) { tempDist[jj][(int) insta.classValue()] += tempProps[jj] * weights[mstart]; } mstart++; } double currGain; if (useGini) { currGain = computeGiniGain(parentDist, tempDist); } else { currGain = computeInfoGain(parentDist, tempDist); } if (currGain > bestGain) { bestGain = currGain; bestSplitString = tempStr; for (int jj = 0; jj < 2; jj++) { System.arraycopy(tempDist[jj], 0, dist[jj], 0, dist[jj].length); } } } } // multi-class problems (exhaustive search) else if (!useHeuristic || nonEmpty <= 4) { // else if (!useHeuristic || nonEmpty==2) { // Firstly, for attribute values which class frequency is not zero for (int i = 0; i < (int) Math.pow(2, nonEmpty - 1); i++) { String tempStr = ""; currDist = new double[2][numClasses]; int mod; int bit10 = i; for (int j = nonEmpty - 1; j >= 0; j--) { mod = bit10 % 2; // convert from 10bit to 2bit if (mod == 1) { if (tempStr == "") { tempStr = "(" + nonEmptyValues[j] + ")"; } else { tempStr += "|" + "(" + nonEmptyValues[j] + ")"; } } bit10 = bit10 / 2; } for (int j = 0; j < sortedIndices.length; j++) { Instance inst = data.instance(sortedIndices[j]); if (inst.isMissing(att)) { break; } if (tempStr.indexOf("(" + att.value((int) inst.value(att)) + ")") != -1) { currDist[0][(int) inst.classValue()] += weights[j]; } else { currDist[1][(int) inst.classValue()] += weights[j]; } } double[][] tempDist = new double[2][numClasses]; for (int k = 0; k < 2; k++) { tempDist[k] = currDist[k]; } double[] tempProps = new double[2]; for (int k = 0; k < 2; k++) { tempProps[k] = Utils.sum(tempDist[k]); } if (Utils.sum(tempProps) != 0) { Utils.normalize(tempProps); } // split missing values int index = missingStart; while (index < sortedIndices.length) { Instance insta = data.instance(sortedIndices[index]); for (int j = 0; j < 2; j++) { tempDist[j][(int) insta.classValue()] += tempProps[j] * weights[index]; } index++; } double currGain; if (useGini) { currGain = computeGiniGain(parentDist, tempDist); } else { currGain = computeInfoGain(parentDist, tempDist); } if (currGain > bestGain) { bestGain = currGain; bestSplitString = tempStr; for (int j = 0; j < 2; j++) { // dist[jj] = new double[currDist[jj].length]; System.arraycopy(tempDist[j], 0, dist[j], 0, dist[j].length); } } } } // huristic method to solve multi-classes problems else { // Firstly, for attribute values which class frequency is not zero int n = nonEmpty; int k = data.numClasses(); // number of classes of the data double[][] P = new double[n][k]; // class probability matrix int[] numInstancesValue = new int[n]; // number of instances for an // attribute value double[] meanClass = new double[k]; // vector of mean class probability int numInstances = data.numInstances(); // total number of instances // initialize the vector of mean class probability for (int j = 0; j < meanClass.length; j++) { meanClass[j] = 0; } for (int j = 0; j < numInstances; j++) { Instance inst = data.instance(j); int valueIndex = 0; // attribute value index in nonEmptyValues for (int i = 0; i < nonEmpty; i++) { if (att.value((int) inst.value(att)).compareToIgnoreCase(nonEmptyValues[i]) == 0) { valueIndex = i; break; } } P[valueIndex][(int) inst.classValue()]++; numInstancesValue[valueIndex]++; meanClass[(int) inst.classValue()]++; } // calculate the class probability matrix for (int i = 0; i < P.length; i++) { for (int j = 0; j < P[0].length; j++) { if (numInstancesValue[i] == 0) { P[i][j] = 0; } else { P[i][j] /= numInstancesValue[i]; } } } // calculate the vector of mean class probability for (int i = 0; i < meanClass.length; i++) { meanClass[i] /= numInstances; } // calculate the covariance matrix double[][] covariance = new double[k][k]; for (int i1 = 0; i1 < k; i1++) { for (int i2 = 0; i2 < k; i2++) { double element = 0; for (int j = 0; j < n; j++) { element += (P[j][i2] - meanClass[i2]) * (P[j][i1] - meanClass[i1]) * numInstancesValue[j]; } covariance[i1][i2] = element; } } Matrix matrix = new Matrix(covariance); weka.core.matrix.EigenvalueDecomposition eigen = new weka.core.matrix.EigenvalueDecomposition(matrix); double[] eigenValues = eigen.getRealEigenvalues(); // find index of the largest eigenvalue int index = 0; double largest = eigenValues[0]; for (int i = 1; i < eigenValues.length; i++) { if (eigenValues[i] > largest) { index = i; largest = eigenValues[i]; } } // calculate the first principle component double[] FPC = new double[k]; Matrix eigenVector = eigen.getV(); double[][] vectorArray = eigenVector.getArray(); for (int i = 0; i < FPC.length; i++) { FPC[i] = vectorArray[i][index]; } // calculate the first principle component scores double[] Sa = new double[n]; for (int i = 0; i < Sa.length; i++) { Sa[i] = 0; for (int j = 0; j < k; j++) { Sa[i] += FPC[j] * P[i][j]; } } // sort category according to Sa(s) double[] pCopy = new double[n]; System.arraycopy(Sa, 0, pCopy, 0, n); String[] sortedValues = new String[n]; Arrays.sort(Sa); for (int j = 0; j < n; j++) { sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)]; pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE; } // for the attribute values that class frequency is not 0 String tempStr = ""; for (int j = 0; j < nonEmpty - 1; j++) { currDist = new double[2][numClasses]; if (tempStr == "") { tempStr = "(" + sortedValues[j] + ")"; } else { tempStr += "|" + "(" + sortedValues[j] + ")"; } for (int i = 0; i < sortedIndices.length; i++) { Instance inst = data.instance(sortedIndices[i]); if (inst.isMissing(att)) { break; } if (tempStr.indexOf("(" + att.value((int) inst.value(att)) + ")") != -1) { currDist[0][(int) inst.classValue()] += weights[i]; } else { currDist[1][(int) inst.classValue()] += weights[i]; } } double[][] tempDist = new double[2][numClasses]; for (int kk = 0; kk < 2; kk++) { tempDist[kk] = currDist[kk]; } double[] tempProps = new double[2]; for (int kk = 0; kk < 2; kk++) { tempProps[kk] = Utils.sum(tempDist[kk]); } if (Utils.sum(tempProps) != 0) { Utils.normalize(tempProps); } // split missing values int mstart = missingStart; while (mstart < sortedIndices.length) { Instance insta = data.instance(sortedIndices[mstart]); for (int jj = 0; jj < 2; jj++) { tempDist[jj][(int) insta.classValue()] += tempProps[jj] * weights[mstart]; } mstart++; } double currGain; if (useGini) { currGain = computeGiniGain(parentDist, tempDist); } else { currGain = computeInfoGain(parentDist, tempDist); } if (currGain > bestGain) { bestGain = currGain; bestSplitString = tempStr; for (int jj = 0; jj < 2; jj++) { // dist[jj] = new double[currDist[jj].length]; System.arraycopy(tempDist[jj], 0, dist[jj], 0, dist[jj].length); } } } } // Compute weights int attIndex = att.index(); props[attIndex] = new double[2]; for (int k = 0; k < 2; k++) { props[attIndex][k] = Utils.sum(dist[k]); } if (!(Utils.sum(props[attIndex]) > 0)) { for (int k = 0; k < props[attIndex].length; k++) { props[attIndex][k] = 1.0 / props[attIndex].length; } } else { Utils.normalize(props[attIndex]); } // Compute subset weights subsetWeights[attIndex] = new double[2]; for (int j = 0; j < 2; j++) { subsetWeights[attIndex][j] += Utils.sum(dist[j]); } // Then, for the attribute values that class frequency is 0, split it into // the // most frequent branch for (int j = 0; j < empty; j++) { if (props[attIndex][0] >= props[attIndex][1]) { if (bestSplitString == "") { bestSplitString = "(" + emptyValues[j] + ")"; } else { bestSplitString += "|" + "(" + emptyValues[j] + ")"; } } } // clean gain gains[attIndex] = Math.rint(bestGain * 10000000) / 10000000.0; dists[attIndex] = dist; return bestSplitString; } /** * Split data into two subsets and store sorted indices and weights for two * successor nodes. * * @param subsetIndices sorted indecis of instances for each attribute for two * successor node * @param subsetWeights weights of instances for each attribute for two * successor node * @param att attribute the split based on * @param splitPoint split point the split based on if att is numeric * @param splitStr split subset the split based on if att is nominal * @param sortedIndices sorted indices of the instances to be split * @param weights weights of the instances to bes split * @param data training data * @throws Exception if something goes wrong */ protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights, Attribute att, double splitPoint, String splitStr, int[][] sortedIndices, double[][] weights, Instances data) throws Exception { int j; // For each attribute for (int i = 0; i < data.numAttributes(); i++) { if (i == data.classIndex()) { continue; } int[] num = new int[2]; for (int k = 0; k < 2; k++) { subsetIndices[k][i] = new int[sortedIndices[i].length]; subsetWeights[k][i] = new double[weights[i].length]; } for (j = 0; j < sortedIndices[i].length; j++) { Instance inst = data.instance(sortedIndices[i][j]); if (inst.isMissing(att)) { // Split instance up for (int k = 0; k < 2; k++) { if (m_Props[k] > 0) { subsetIndices[k][i][num[k]] = sortedIndices[i][j]; subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j]; num[k]++; } } } else { int subset; if (att.isNumeric()) { subset = (inst.value(att) < splitPoint) ? 0 : 1; } else { // nominal attribute if (splitStr.indexOf("(" + att.value((int) inst.value(att.index())) + ")") != -1) { subset = 0; } else { subset = 1; } } subsetIndices[subset][i][num[subset]] = sortedIndices[i][j]; subsetWeights[subset][i][num[subset]] = weights[i][j]; num[subset]++; } } // Trim arrays for (int k = 0; k < 2; k++) { int[] copy = new int[num[k]]; System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]); subsetIndices[k][i] = copy; double[] copyWeights = new double[num[k]]; System.arraycopy(subsetWeights[k][i], 0, copyWeights, 0, num[k]); subsetWeights[k][i] = copyWeights; } } } /** * Compute and return gini gain for given distributions of a node and its * successor nodes. * * @param parentDist class distributions of parent node * @param childDist class distributions of successor nodes * @return Gini gain computed */ protected double computeGiniGain(double[] parentDist, double[][] childDist) { double totalWeight = Utils.sum(parentDist); if (totalWeight == 0) { return 0; } double leftWeight = Utils.sum(childDist[0]); double rightWeight = Utils.sum(childDist[1]); double parentGini = computeGini(parentDist, totalWeight); double leftGini = computeGini(childDist[0], leftWeight); double rightGini = computeGini(childDist[1], rightWeight); return parentGini - leftWeight / totalWeight * leftGini - rightWeight / totalWeight * rightGini; } /** * Compute and return gini index for a given distribution of a node. * * @param dist class distributions * @param total class distributions * @return Gini index of the class distributions */ protected double computeGini(double[] dist, double total) { if (total == 0) { return 0; } double val = 0; for (double element : dist) { val += (element / total) * (element / total); } return 1 - val; } /** * Compute and return information gain for given distributions of a node and * its successor nodes. * * @param parentDist class distributions of parent node * @param childDist class distributions of successor nodes * @return information gain computed */ protected double computeInfoGain(double[] parentDist, double[][] childDist) { double totalWeight = Utils.sum(parentDist); if (totalWeight == 0) { return 0; } double leftWeight = Utils.sum(childDist[0]); double rightWeight = Utils.sum(childDist[1]); double parentInfo = computeEntropy(parentDist, totalWeight); double leftInfo = computeEntropy(childDist[0], leftWeight); double rightInfo = computeEntropy(childDist[1], rightWeight); return parentInfo - leftWeight / totalWeight * leftInfo - rightWeight / totalWeight * rightInfo; } /** * Compute and return entropy for a given distribution of a node. * * @param dist class distributions * @param total class distributions * @return entropy of the class distributions */ protected double computeEntropy(double[] dist, double total) { if (total == 0) { return 0; } double entropy = 0; for (double element : dist) { if (element != 0) { entropy -= element / total * Utils.log2(element / total); } } return entropy; } /** * Make the node leaf node. * * @param data training data */ protected void makeLeaf(Instances data) { m_Attribute = null; m_isLeaf = true; m_ClassValue = Utils.maxIndex(m_ClassProbs); m_ClassAttribute = data.classAttribute(); } /** * Computes class probabilities for instance using the decision tree. * * @param instance the instance for which class probabilities is to be * computed * @return the class probabilities for the given instance * @throws Exception if something goes wrong */ @Override public double[] distributionForInstance(Instance instance) throws Exception { if (!m_isLeaf) { // value of split attribute is missing if (instance.isMissing(m_Attribute)) { double[] returnedDist = new double[m_ClassProbs.length]; for (int i = 0; i < m_Successors.length; i++) { double[] help = m_Successors[i].distributionForInstance(instance); if (help != null) { for (int j = 0; j < help.length; j++) { returnedDist[j] += m_Props[i] * help[j]; } } } return returnedDist; } // split attribute is nonimal else if (m_Attribute.isNominal()) { if (m_SplitString.indexOf("(" + m_Attribute.value((int) instance.value(m_Attribute)) + ")") != -1) { return m_Successors[0].distributionForInstance(instance); } else { return m_Successors[1].distributionForInstance(instance); } } // split attribute is numeric else { if (instance.value(m_Attribute) < m_SplitValue) { return m_Successors[0].distributionForInstance(instance); } else { return m_Successors[1].distributionForInstance(instance); } } } else { return m_ClassProbs; } } /** * Prints the decision tree using the protected toString method from below. * * @return a textual description of the classifier */ @Override public String toString() { if ((m_Distribution == null) && (m_Successors == null)) { return "Best-First: No model built yet."; } return "Best-First Decision Tree\n" + toString(0) + "\n\n" + "Size of the Tree: " + numNodes() + "\n\n" + "Number of Leaf Nodes: " + numLeaves(); } /** * Outputs a tree at a certain level. * * @param level the level at which the tree is to be printed * @return a tree at a certain level. */ protected String toString(int level) { StringBuffer text = new StringBuffer(); // if leaf nodes if (m_Attribute == null) { if (Utils.isMissingValue(m_ClassValue)) { text.append(": null"); } else { double correctNum = Math.rint(m_Distribution[Utils.maxIndex(m_Distribution)] * 100) / 100.0; double wrongNum = Math.rint( (Utils.sum(m_Distribution) - m_Distribution[Utils.maxIndex(m_Distribution)]) * 100) / 100.0; String str = "(" + correctNum + "/" + wrongNum + ")"; text.append(": " + m_ClassAttribute.value((int) m_ClassValue) + str); } } else { for (int j = 0; j < 2; j++) { text.append("\n"); for (int i = 0; i < level; i++) { text.append("| "); } if (j == 0) { if (m_Attribute.isNumeric()) { text.append(m_Attribute.name() + " < " + m_SplitValue); } else { text.append(m_Attribute.name() + "=" + m_SplitString); } } else { if (m_Attribute.isNumeric()) { text.append(m_Attribute.name() + " >= " + m_SplitValue); } else { text.append(m_Attribute.name() + "!=" + m_SplitString); } } text.append(m_Successors[j].toString(level + 1)); } } return text.toString(); } /** * Compute size of the tree. * * @return size of the tree */ public int numNodes() { if (m_isLeaf) { return 1; } else { int size = 1; for (BFTree m_Successor : m_Successors) { size += m_Successor.numNodes(); } return size; } } /** * Compute number of leaf nodes. * * @return number of leaf nodes */ public int numLeaves() { if (m_isLeaf) { return 1; } else { int size = 0; for (BFTree m_Successor : m_Successors) { size += m_Successor.numLeaves(); } return size; } } /** * Returns an enumeration describing the available options. * * @return an enumeration describing the available options. */ @Override public Enumeration<Option> listOptions() { Vector<Option> result = new Vector<Option>(); result.addElement( new Option( "\tThe pruning strategy.\n" + "\t(default: " + new SelectedTag(PRUNING_POSTPRUNING, TAGS_PRUNING) + ")", "P", 1, "-P " + Tag.toOptionList(TAGS_PRUNING))); result.addElement(new Option("\tThe minimal number of instances at the terminal nodes.\n" + "\t(default 2)", "M", 1, "-M <min no>")); result.addElement(new Option("\tThe number of folds used in the pruning.\n" + "\t(default 5)", "N", 5, "-N <num folds>")); result.addElement(new Option("\tDon't use heuristic search for nominal attributes in multi-class\n" + "\tproblem (default yes).\n", "H", 0, "-H")); result.addElement(new Option( "\tDon't use Gini index for splitting (default yes),\n" + "\tif not information is used.", "G", 0, "-G")); result.addElement(new Option("\tDon't use error rate in internal cross-validation (default yes), \n" + "\tbut root mean squared error.", "R", 0, "-R")); result.addElement( new Option("\tUse the 1 SE rule to make pruning decision.\n" + "\t(default no).", "A", 0, "-A")); result.addElement( new Option("\tPercentage of training data size (0-1]\n" + "\t(default 1).", "C", 0, "-C")); result.addAll(Collections.list(super.listOptions())); return result.elements(); } /** * Parses the options for this object. * <p/> * * <!-- options-start --> Valid options are: * <p/> * * <pre> * -S <num> * Random number seed. * (default 1) * </pre> * * <pre> * -D * If set, classifier is run in debug mode and * may output additional info to the console * </pre> * * <pre> * -P <UNPRUNED|POSTPRUNED|PREPRUNED> * The pruning strategy. * (default: POSTPRUNED) * </pre> * * <pre> * -M <min no> * The minimal number of instances at the terminal nodes. * (default 2) * </pre> * * <pre> * -N <num folds> * The number of folds used in the pruning. * (default 5) * </pre> * * <pre> * -H * Don't use heuristic search for nominal attributes in multi-class * problem (default yes). * </pre> * * <pre> * -G * Don't use Gini index for splitting (default yes), * if not information is used. * </pre> * * <pre> * -R * Don't use error rate in internal cross-validation (default yes), * but root mean squared error. * </pre> * * <pre> * -A * Use the 1 SE rule to make pruning decision. * (default no). * </pre> * * <pre> * -C * Percentage of training data size (0-1] * (default 1). * </pre> * * <!-- options-end --> * * @param options the options to use * @throws Exception if setting of options fails */ @Override public void setOptions(String[] options) throws Exception { String tmpStr; tmpStr = Utils.getOption('M', options); if (tmpStr.length() != 0) { setMinNumObj(Integer.parseInt(tmpStr)); } else { setMinNumObj(2); } tmpStr = Utils.getOption('N', options); if (tmpStr.length() != 0) { setNumFoldsPruning(Integer.parseInt(tmpStr)); } else { setNumFoldsPruning(5); } tmpStr = Utils.getOption('C', options); if (tmpStr.length() != 0) { setSizePer(Double.parseDouble(tmpStr)); } else { setSizePer(1); } tmpStr = Utils.getOption('P', options); if (tmpStr.length() != 0) { setPruningStrategy(new SelectedTag(tmpStr, TAGS_PRUNING)); } else { setPruningStrategy(new SelectedTag(PRUNING_POSTPRUNING, TAGS_PRUNING)); } setHeuristic(!Utils.getFlag('H', options)); setUseGini(!Utils.getFlag('G', options)); setUseErrorRate(!Utils.getFlag('R', options)); setUseOneSE(Utils.getFlag('A', options)); super.setOptions(options); Utils.checkForRemainingOptions(options); } /** * Gets the current settings of the Classifier. * * @return the current settings of the Classifier */ @Override public String[] getOptions() { Vector<String> result = new Vector<String>(); result.add("-M"); result.add("" + getMinNumObj()); result.add("-N"); result.add("" + getNumFoldsPruning()); if (!getHeuristic()) { result.add("-H"); } if (!getUseGini()) { result.add("-G"); } if (!getUseErrorRate()) { result.add("-R"); } if (getUseOneSE()) { result.add("-A"); } result.add("-C"); result.add("" + getSizePer()); result.add("-P"); result.add("" + getPruningStrategy()); Collections.addAll(result, super.getOptions()); return result.toArray(new String[result.size()]); } /** * Return an enumeration of the measure names. * * @return an enumeration of the measure names */ @Override public Enumeration<String> enumerateMeasures() { Vector<String> result = new Vector<String>(); result.addElement("measureTreeSize"); return result.elements(); } /** * Return number of tree size. * * @return number of tree size */ public double measureTreeSize() { return numNodes(); } /** * Returns the value of the named measure * * @param additionalMeasureName the name of the measure to query for its value * @return the value of the named measure * @throws IllegalArgumentException if the named measure is not supported */ @Override public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) { return measureTreeSize(); } else { throw new IllegalArgumentException(additionalMeasureName + " not supported (Best-First)"); } } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String pruningStrategyTipText() { return "Sets the pruning strategy."; } /** * Sets the pruning strategy. * * @param value the strategy */ public void setPruningStrategy(SelectedTag value) { if (value.getTags() == TAGS_PRUNING) { m_PruningStrategy = value.getSelectedTag().getID(); } } /** * Gets the pruning strategy. * * @return the current strategy. */ public SelectedTag getPruningStrategy() { return new SelectedTag(m_PruningStrategy, TAGS_PRUNING); } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String minNumObjTipText() { return "Set minimal number of instances at the terminal nodes."; } /** * Set minimal number of instances at the terminal nodes. * * @param value minimal number of instances at the terminal nodes */ public void setMinNumObj(int value) { m_minNumObj = value; } /** * Get minimal number of instances at the terminal nodes. * * @return minimal number of instances at the terminal nodes */ public int getMinNumObj() { return m_minNumObj; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String numFoldsPruningTipText() { return "Number of folds in internal cross-validation."; } /** * Set number of folds in internal cross-validation. * * @param value the number of folds */ public void setNumFoldsPruning(int value) { m_numFoldsPruning = value; } /** * Set number of folds in internal cross-validation. * * @return number of folds in internal cross-validation */ public int getNumFoldsPruning() { return m_numFoldsPruning; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui. */ public String heuristicTipText() { return "If heuristic search is used for binary split for nominal attributes."; } /** * Set if use heuristic search for nominal attributes in multi-class problems. * * @param value if use heuristic search for nominal attributes in multi-class * problems */ public void setHeuristic(boolean value) { m_Heuristic = value; } /** * Get if use heuristic search for nominal attributes in multi-class problems. * * @return if use heuristic search for nominal attributes in multi-class * problems */ public boolean getHeuristic() { return m_Heuristic; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui. */ public String useGiniTipText() { return "If true the Gini index is used for splitting criterion, otherwise the information is used."; } /** * Set if use Gini index as splitting criterion. * * @param value if use Gini index splitting criterion */ public void setUseGini(boolean value) { m_UseGini = value; } /** * Get if use Gini index as splitting criterion. * * @return if use Gini index as splitting criterion */ public boolean getUseGini() { return m_UseGini; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui. */ public String useErrorRateTipText() { return "If error rate is used as error estimate. if not, root mean squared error is used."; } /** * Set if use error rate in internal cross-validation. * * @param value if use error rate in internal cross-validation */ public void setUseErrorRate(boolean value) { m_UseErrorRate = value; } /** * Get if use error rate in internal cross-validation. * * @return if use error rate in internal cross-validation. */ public boolean getUseErrorRate() { return m_UseErrorRate; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui. */ public String useOneSETipText() { return "Use the 1SE rule to make pruning decision."; } /** * Set if use the 1SE rule to choose final model. * * @param value if use the 1SE rule to choose final model */ public void setUseOneSE(boolean value) { m_UseOneSE = value; } /** * Get if use the 1SE rule to choose final model. * * @return if use the 1SE rule to choose final model */ public boolean getUseOneSE() { return m_UseOneSE; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui. */ public String sizePerTipText() { return "The percentage of the training set size (0-1, 0 not included)."; } /** * Set training set size. * * @param value training set size */ public void setSizePer(double value) { if ((value <= 0) || (value > 1)) { System.err.println("The percentage of the training set size must be in range 0 to 1 " + "(0 not included) - ignored!"); } else { m_SizePer = value; } } /** * Get training set size. * * @return training set size */ public double getSizePer() { return m_SizePer; } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision$"); } /** * Main method. * * @param args the options for the classifier */ public static void main(String[] args) { runClassifier(new BFTree(), args); } }