com.cloudera.oryx.rdf.common.tree.DecisionForest.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.oryx.rdf.common.tree.DecisionForest.java

Source

/*
 * Copyright (c) 2013, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.oryx.rdf.common.tree;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.math.IntMath;
import com.typesafe.config.Config;
import org.apache.commons.math3.util.FastMath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import com.cloudera.oryx.common.iterator.ArrayIterator;
import com.cloudera.oryx.common.parallel.ExecutorUtils;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.rdf.common.eval.Evaluation;
import com.cloudera.oryx.rdf.common.eval.WeightedPrediction;
import com.cloudera.oryx.rdf.common.example.Example;
import com.cloudera.oryx.rdf.common.example.ExampleSet;
import com.cloudera.oryx.rdf.common.rule.Prediction;

/**
 * An ensemble classifier based on many {@link DecisionTree}s.
 *
 * @author Sean Owen
 * @see DecisionTree
 */
public final class DecisionForest implements Iterable<DecisionTree>, TreeBasedClassifier {

    private static final Logger log = LoggerFactory.getLogger(DecisionForest.class);

    private final DecisionTree[] trees;
    private final double[] weights;
    private final double[] evaluations;
    private final double[] featureImportances;

    public static DecisionForest fromExamplesWithDefault(List<Example> examples) {
        Config config = ConfigUtils.getDefaultConfig();
        int numTrees = config.getInt("model.num-trees");
        double fractionOfFeaturesToTry = config.getDouble("model.fraction-features-to-try");
        int minNodeSize = config.getInt("model.min-node-size");
        double minInfoGainNats = config.getDouble("model.min-info-gain-nats");
        int suggestedMaxSplitCandidates = config.getInt("model.max-split-candidates");
        int maxDepth = config.getInt("model.max-depth");
        double sampleRate = config.getDouble("model.sample-rate");
        ExampleSet exampleSet = new ExampleSet(examples);
        return new DecisionForest(numTrees, fractionOfFeaturesToTry, minNodeSize, minInfoGainNats,
                suggestedMaxSplitCandidates, maxDepth, sampleRate, exampleSet);
    }

    public DecisionForest(final int numTrees, double fractionOfFeaturesToTry, final int minNodeSize,
            final double minInfoGainNats, final int suggestedMaxSplitCandidates, final int maxDepth,
            final double sampleRate, final ExampleSet examples) {
        Preconditions.checkArgument(numTrees > 1);
        final int numFeatures = examples.getNumFeatures();
        Preconditions.checkArgument(fractionOfFeaturesToTry > 0.0 && fractionOfFeaturesToTry <= 1.0);
        final int featuresToTry = FastMath.max(1, (int) (fractionOfFeaturesToTry * numFeatures));
        Preconditions.checkArgument(numFeatures >= 1);
        Preconditions.checkArgument(minNodeSize >= 1);
        Preconditions.checkArgument(minInfoGainNats >= 0.0);
        Preconditions.checkArgument(suggestedMaxSplitCandidates >= 1);
        Preconditions.checkArgument(maxDepth >= 1);
        Preconditions.checkArgument(sampleRate > 0.0 && sampleRate <= 1.0);

        weights = new double[numTrees];
        Arrays.fill(weights, 1.0);
        evaluations = new double[numTrees];
        Arrays.fill(evaluations, Double.NaN);
        final double[][] perTreeFeatureImportances = new double[numTrees][];

        // Going to set an arbitrary upper bound on the training size of about 90%
        int maxFolds = FastMath.min(numTrees - 1, (int) (0.9 * numTrees));
        // Going to set an arbitrary lower bound on the CV size of about 10%
        int minFolds = FastMath.max(1, (int) (0.1 * numTrees));
        final int folds = FastMath.min(maxFolds, FastMath.max(minFolds, (int) (sampleRate * numTrees)));

        trees = new DecisionTree[numTrees];

        ExecutorService executor = Executors.newFixedThreadPool(determineParallelism(trees.length));
        try {
            Collection<Future<Object>> futures = Lists.newArrayListWithCapacity(trees.length);
            for (int i = 0; i < numTrees; i++) {
                final int treeID = i;
                futures.add(executor.submit(new Callable<Object>() {
                    @Override
                    public Void call() throws Exception {
                        Collection<Example> allExamples = examples.getExamples();
                        int totalExamples = allExamples.size();
                        int expectedTrainingSize = (int) (totalExamples * sampleRate);
                        int expectedCVSize = totalExamples - expectedTrainingSize;
                        List<Example> trainingExamples = Lists.newArrayListWithExpectedSize(expectedTrainingSize);
                        List<Example> cvExamples = Lists.newArrayListWithExpectedSize(expectedCVSize);
                        for (Example example : allExamples) {
                            if (IntMath.mod(IntMath.mod(example.hashCode(), numTrees) - treeID, numTrees) < folds) {
                                trainingExamples.add(example);
                            } else {
                                cvExamples.add(example);
                            }
                        }

                        Preconditions.checkState(!trainingExamples.isEmpty(), "No training examples sampled?");
                        Preconditions.checkState(!cvExamples.isEmpty(), "No CV examples sampled?");

                        trees[treeID] = new DecisionTree(numFeatures, featuresToTry, minNodeSize, minInfoGainNats,
                                suggestedMaxSplitCandidates, maxDepth, examples.subset(trainingExamples));
                        log.info("Finished tree {}", treeID);
                        ExampleSet cvExampleSet = examples.subset(cvExamples);
                        double[] weightEval = Evaluation.evaluateToWeight(trees[treeID], cvExampleSet);
                        weights[treeID] = weightEval[0];
                        evaluations[treeID] = weightEval[1];
                        perTreeFeatureImportances[treeID] = trees[treeID].featureImportance(cvExampleSet);
                        log.info("Tree {} eval: {}", treeID, weightEval[1]);
                        return null;
                    }
                }));
            }
            ExecutorUtils.checkExceptions(futures);
        } finally {
            ExecutorUtils.shutdownNowAndAwait(executor);
        }

        featureImportances = new double[numFeatures];
        for (double[] perTreeFeatureImporatance : perTreeFeatureImportances) {
            for (int i = 0; i < numFeatures; i++) {
                featureImportances[i] += perTreeFeatureImporatance[i];
            }
        }
        for (int i = 0; i < numFeatures; i++) {
            featureImportances[i] /= numTrees;
        }
    }

    public DecisionForest(DecisionTree[] trees, double[] weights, double[] featureImportances) {
        this.trees = trees;
        this.weights = weights;
        this.evaluations = new double[weights.length];
        this.featureImportances = featureImportances;
    }

    @Override
    public Iterator<DecisionTree> iterator() {
        return ArrayIterator.forArray(trees);
    }

    /**
     * @return {@link DecisionTree}s in the ensemble forest
     */
    public DecisionTree[] getTrees() {
        return trees;
    }

    public double[] getWeights() {
        return weights;
    }

    public double[] getEvaluations() {
        return evaluations;
    }

    public double[] getFeatureImportances() {
        return featureImportances;
    }

    @Override
    public Prediction classify(Example test) {
        return WeightedPrediction
                .voteOnFeature(Lists.transform(Arrays.asList(trees), new TreeToPredictionFunction(test)), weights);
    }

    @Override
    public void update(Example train) {
        for (TreeBasedClassifier tree : trees) {
            tree.update(train);
        }
    }

    private static int determineParallelism(int numTrees) {
        int numCores = ExecutorUtils.getParallelism();
        if (numCores >= numTrees) {
            return numTrees;
        }
        // Try to round up threads so trees is a multiple of it
        int numThreads = numCores;
        while (numTrees % numThreads != 0 && numThreads < 2 * numCores) {
            numThreads++;
        }
        return numThreads;
    }

    @Override
    public String toString() {
        StringBuilder result = new StringBuilder();
        for (DecisionTree tree : trees) {
            result.append(tree).append('\n');
        }
        return result.toString();
    }

    private static final class TreeToPredictionFunction implements Function<DecisionTree, Prediction> {
        private final Example test;

        TreeToPredictionFunction(Example test) {
            this.test = test;
        }

        @Override
        public Prediction apply(DecisionTree tree) {
            return tree.classify(test);
        }
    }
}