com.cloudera.oryx.rdf.common.pmml.DecisionForestPMML.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.oryx.rdf.common.pmml.DecisionForestPMML.java

Source

/*
 * Copyright (c) 2013, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.oryx.rdf.common.pmml;

import com.cloudera.oryx.common.pmml.PMMLUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.Lists;
import org.apache.commons.math3.util.Pair;
import org.dmg.pmml.Array;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.IOUtil;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.MissingValueStrategyType;
import org.dmg.pmml.Model;
import org.dmg.pmml.MultipleModelMethodType;
import org.dmg.pmml.Node;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.Segment;
import org.dmg.pmml.Segmentation;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.SimpleSetPredicate;
import org.dmg.pmml.TreeModel;
import org.dmg.pmml.True;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import javax.xml.transform.stream.StreamResult;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.Writer;
import java.util.ArrayDeque;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Queue;

import com.cloudera.oryx.common.collection.BitSet;
import com.cloudera.oryx.common.io.DelimitedDataUtils;
import com.cloudera.oryx.common.io.IOUtils;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.common.settings.InboundSettings;
import com.cloudera.oryx.rdf.common.example.FeatureType;
import com.cloudera.oryx.rdf.common.rule.CategoricalDecision;
import com.cloudera.oryx.rdf.common.rule.CategoricalPrediction;
import com.cloudera.oryx.rdf.common.rule.Decision;
import com.cloudera.oryx.rdf.common.rule.NumericDecision;
import com.cloudera.oryx.rdf.common.rule.NumericPrediction;
import com.cloudera.oryx.rdf.common.rule.Prediction;
import com.cloudera.oryx.rdf.common.tree.DecisionForest;
import com.cloudera.oryx.rdf.common.tree.DecisionNode;
import com.cloudera.oryx.rdf.common.tree.DecisionTree;
import com.cloudera.oryx.rdf.common.tree.TerminalNode;
import com.cloudera.oryx.rdf.common.tree.TreeNode;

/**
 * Contains utility methods for writing a {@link DecisionTree} as a PMML file, and reading it back.
 *
 * @author Sean Owen
 */
public final class DecisionForestPMML {

    private DecisionForestPMML() {
    }

    // Write PMML

    /**
     * Writes to a {@link File} instead of {@link Writer}.
     *
     * @see #write(Writer, DecisionForest, Map)
     */
    public static void write(File pmmlFile, DecisionForest forest,
            Map<Integer, BiMap<String, Integer>> columnToCategoryNameToIDMapping) throws IOException {
        Writer pmmlOut = IOUtils.buildGZIPWriter(pmmlFile);
        try {
            write(pmmlOut, forest, columnToCategoryNameToIDMapping);
        } finally {
            pmmlOut.close();
        }
    }

    /**
     * @param pmmlOut stream to write PMML representation to
     * @param forest {@link DecisionForest} to encode as PMML
     * @param columnToCategoryNameToIDMapping {@link Map} from column number in the input, to a {@link BiMap}
     *  mapping between category value names and category value IDs (for categorical feature columns only). This
     *  is necessary because the {@link DecisionForest} operates in terms of value IDs, but the PMML encoding
     *  should encode the names of these category values -- {@code female}, not {@code 2}
     */
    public static void write(Writer pmmlOut, DecisionForest forest,
            Map<Integer, BiMap<String, Integer>> columnToCategoryNameToIDMapping) throws IOException {

        InboundSettings inboundSettings = InboundSettings.create(ConfigUtils.getDefaultConfig());
        int targetColumn = inboundSettings.getTargetColumn();
        boolean classificationTask = inboundSettings.isCategorical(targetColumn);

        MiningFunctionType miningFunctionType = classificationTask ? MiningFunctionType.CLASSIFICATION
                : MiningFunctionType.REGRESSION;
        MiningSchema miningSchema = PMMLUtils.buildMiningSchema(inboundSettings, forest.getFeatureImportances());
        MiningModel miningModel = new MiningModel(miningSchema, miningFunctionType);
        MultipleModelMethodType multipleModelMethodType = classificationTask
                ? MultipleModelMethodType.WEIGHTED_MAJORITY_VOTE
                : MultipleModelMethodType.WEIGHTED_AVERAGE;
        Segmentation segmentation = new Segmentation(multipleModelMethodType);
        miningModel.setSegmentation(segmentation);

        int treeID = 0;
        for (DecisionTree tree : forest) {
            Segment segment = buildTreeModel(forest, columnToCategoryNameToIDMapping, miningFunctionType,
                    miningSchema, treeID, tree, inboundSettings);
            segmentation.getSegments().add(segment);
            treeID++;
        }

        DataDictionary dictionary = PMMLUtils.buildDataDictionary(inboundSettings, columnToCategoryNameToIDMapping);
        PMML pmml = new PMML(null, dictionary, "4.1");
        pmml.getModels().add(miningModel);

        try {
            IOUtil.marshal(pmml, new StreamResult(pmmlOut));
        } catch (JAXBException jaxbe) {
            throw new IOException(jaxbe);
        }
    }

    private static Segment buildTreeModel(DecisionForest forest,
            Map<Integer, BiMap<String, Integer>> columnToCategoryNameToIDMapping,
            MiningFunctionType miningFunctionType, MiningSchema miningSchema, int treeID, DecisionTree tree,
            InboundSettings settings) {

        List<String> columnNames = settings.getColumnNames();
        int targetColumn = settings.getTargetColumn();

        Node root = new Node();
        root.setId("r");

        // Queue<Node> modelNodes = Queues.newArrayDeque();
        Queue<Node> modelNodes = new ArrayDeque<Node>();
        modelNodes.add(root);

        Queue<Pair<TreeNode, Decision>> treeNodes = new ArrayDeque<Pair<TreeNode, Decision>>();
        treeNodes.add(new Pair<TreeNode, Decision>(tree.getRoot(), null));

        while (!treeNodes.isEmpty()) {

            Pair<TreeNode, Decision> treeNodePredicate = treeNodes.remove();
            Node modelNode = modelNodes.remove();

            // This is the decision that got us here from the parent, if any; not the predicate at this node
            Predicate predicate = buildPredicate(treeNodePredicate.getSecond(), columnNames,
                    columnToCategoryNameToIDMapping);
            modelNode.setPredicate(predicate);

            TreeNode treeNode = treeNodePredicate.getFirst();
            if (treeNode.isTerminal()) {

                TerminalNode terminalNode = (TerminalNode) treeNode;
                modelNode.setRecordCount((double) terminalNode.getCount());

                Prediction prediction = terminalNode.getPrediction();

                if (prediction.getFeatureType() == FeatureType.CATEGORICAL) {

                    Map<Integer, String> categoryIDToName = columnToCategoryNameToIDMapping.get(targetColumn)
                            .inverse();
                    CategoricalPrediction categoricalPrediction = (CategoricalPrediction) prediction;
                    int[] categoryCounts = categoricalPrediction.getCategoryCounts();
                    float[] categoryProbabilities = categoricalPrediction.getCategoryProbabilities();
                    for (int categoryID = 0; categoryID < categoryProbabilities.length; categoryID++) {
                        int categoryCount = categoryCounts[categoryID];
                        float probability = categoryProbabilities[categoryID];
                        if (categoryCount > 0 && probability > 0.0f) {
                            String categoryName = categoryIDToName.get(categoryID);
                            ScoreDistribution distribution = new ScoreDistribution(categoryName, categoryCount);
                            distribution.setProbability((double) probability);
                            modelNode.getScoreDistributions().add(distribution);
                        }
                    }

                } else {

                    NumericPrediction numericPrediction = (NumericPrediction) prediction;
                    modelNode.setScore(Float.toString(numericPrediction.getPrediction()));
                }

            } else {

                DecisionNode decisionNode = (DecisionNode) treeNode;
                Decision decision = decisionNode.getDecision();

                Node positiveModelNode = new Node();
                positiveModelNode.setId(modelNode.getId() + '+');
                modelNode.getNodes().add(positiveModelNode);
                Node negativeModelNode = new Node();
                negativeModelNode.setId(modelNode.getId() + '-');
                modelNode.getNodes().add(negativeModelNode);
                modelNode.setDefaultChild(
                        decision.getDefaultDecision() ? positiveModelNode.getId() : negativeModelNode.getId());
                modelNodes.add(positiveModelNode);
                modelNodes.add(negativeModelNode);
                treeNodes.add(new Pair<TreeNode, Decision>(decisionNode.getRight(), decision));
                treeNodes.add(new Pair<TreeNode, Decision>(decisionNode.getLeft(), null));

            }

        }

        TreeModel treeModel = new TreeModel(miningSchema, root, miningFunctionType);
        treeModel.setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        treeModel.setMissingValueStrategy(MissingValueStrategyType.DEFAULT_CHILD);

        Segment segment = new Segment();
        segment.setId(Integer.toString(treeID));
        segment.setPredicate(new True());
        segment.setModel(treeModel);
        segment.setWeight(forest.getWeights()[treeID]);

        return segment;
    }

    private static Predicate buildPredicate(Decision decision, List<String> columnNames,
            Map<Integer, BiMap<String, Integer>> columnToCategoryNameToIDMapping) {
        Predicate predicate;
        if (decision == null) {
            predicate = new True();

        } else {
            int columnNumber = decision.getFeatureNumber();
            FieldName fieldName = new FieldName(columnNames.get(columnNumber));

            if (decision.getType() == FeatureType.CATEGORICAL) {
                CategoricalDecision categoricalDecision = (CategoricalDecision) decision;
                Map<Integer, String> categoryIDToName = columnToCategoryNameToIDMapping.get(columnNumber).inverse();
                BitSet includedCategoryIDs = categoricalDecision.getCategoryIDs();
                List<String> categoryNames = Lists.newArrayList();
                int categoryID = -1;
                while ((categoryID = includedCategoryIDs.nextSetBit(categoryID + 1)) >= 0) {
                    categoryNames.add(categoryIDToName.get(categoryID));
                }
                Array categories = new Array(DelimitedDataUtils.encode(categoryNames, ' '), Array.Type.STRING);
                predicate = new SimpleSetPredicate(categories, fieldName, SimpleSetPredicate.BooleanOperator.IS_IN);

            } else {
                NumericDecision numericDecision = (NumericDecision) decision;
                SimplePredicate numericPredicate = new SimplePredicate(fieldName,
                        SimplePredicate.Operator.GREATER_OR_EQUAL);
                numericPredicate.setValue(Float.toString(numericDecision.getThreshold()));
                predicate = numericPredicate;
            }
        }
        return predicate;
    }

    // Read PMML

    /**
     * @param pmmlFile file to read PMML encoding from
     * @return a {@link DecisionForest} representation of the PMML encoded model
     */
    public static Pair<DecisionForest, Map<Integer, BiMap<String, Integer>>> read(File pmmlFile)
            throws IOException {

        PMML pmml;
        InputStream in = IOUtils.openMaybeDecompressing(pmmlFile);
        try {
            pmml = IOUtil.unmarshal(in);
        } catch (SAXException e) {
            throw new IOException(e);
        } catch (JAXBException e) {
            throw new IOException(e);
        } finally {
            in.close();
        }

        List<Model> models = pmml.getModels();
        Preconditions.checkNotNull(models);
        Preconditions.checkArgument(!models.isEmpty());
        Preconditions.checkArgument(models.get(0) instanceof MiningModel);
        MiningModel miningModel = (MiningModel) models.get(0);

        Segmentation segmentation = miningModel.getSegmentation();
        Preconditions.checkNotNull(segmentation);

        List<Segment> segments = segmentation.getSegments();
        Preconditions.checkNotNull(segments);
        Preconditions.checkArgument(!segments.isEmpty());

        Map<Integer, BiMap<String, Integer>> columnToCategoryNameToIDMapping = PMMLUtils
                .buildColumnCategoryMapping(pmml.getDataDictionary());
        InboundSettings settings = InboundSettings.create(ConfigUtils.getDefaultConfig());
        DecisionTree[] trees = new DecisionTree[segments.size()];
        double[] weights = new double[trees.length];
        for (int i = 0; i < trees.length; i++) {
            Segment segment = segments.get(i);
            weights[i] = segment.getWeight();
            TreeModel treeModel = (TreeModel) segment.getModel();
            TreeNode root = translateFromPMML(treeModel.getNode(), columnToCategoryNameToIDMapping, settings);
            trees[i] = new DecisionTree(root);
        }

        List<String> columnNames = settings.getColumnNames();
        List<MiningField> fields = miningModel.getMiningSchema().getMiningFields();
        double[] featureImportances = new double[fields.size()];
        for (MiningField field : fields) {
            Double importance = field.getImportance();
            if (importance != null) {
                int featureNumber = columnNames.indexOf(field.getName().getValue());
                featureImportances[featureNumber] = importance;
            }
        }

        return new Pair<DecisionForest, Map<Integer, BiMap<String, Integer>>>(
                new DecisionForest(trees, weights, featureImportances), columnToCategoryNameToIDMapping);
    }

    private static TreeNode translateFromPMML(Node root,
            Map<Integer, BiMap<String, Integer>> columnToCategoryNameToIDMapping, InboundSettings settings) {

        List<String> columnNames = settings.getColumnNames();
        int targetColumn = settings.getTargetColumn();

        List<Node> children = root.getNodes();
        if (children.isEmpty()) {
            // Terminal
            Collection<ScoreDistribution> scoreDistributions = root.getScoreDistributions();
            Prediction prediction;
            if (scoreDistributions != null && !scoreDistributions.isEmpty()) {
                // Categorical target
                Map<String, Integer> valueToID = columnToCategoryNameToIDMapping.get(targetColumn);
                int[] categoryCounts = new int[valueToID.size()];
                for (ScoreDistribution dist : scoreDistributions) {
                    int valueID = valueToID.get(dist.getValue());
                    categoryCounts[valueID] = (int) Math.round(dist.getRecordCount());
                }
                prediction = new CategoricalPrediction(categoryCounts);
            } else {
                prediction = new NumericPrediction(Float.parseFloat(root.getScore()),
                        (int) Math.round(root.getRecordCount()));
            }
            return new TerminalNode(prediction);
        }

        Preconditions.checkArgument(children.size() == 2);
        // Decision
        Node child1 = children.get(0);
        Node child2 = children.get(1);
        Node negativeLeftChild;
        Node positiveRightChild;
        if (child1.getPredicate().getClass().equals(True.class)) {
            negativeLeftChild = child1;
            positiveRightChild = child2;
        } else {
            Preconditions.checkArgument(child2.getPredicate().getClass().equals(True.class));
            negativeLeftChild = child2;
            positiveRightChild = child1;
        }

        Decision decision;
        Predicate predicate = positiveRightChild.getPredicate();
        boolean defaultDecision = positiveRightChild.getId().equals(root.getDefaultChild());

        if (predicate instanceof SimplePredicate) {
            // Numeric decision
            SimplePredicate simplePredicate = (SimplePredicate) predicate;
            Preconditions.checkArgument(simplePredicate.getOperator() == SimplePredicate.Operator.GREATER_OR_EQUAL);
            float threshold = Float.parseFloat(simplePredicate.getValue());
            int featureNumber = columnNames.indexOf(simplePredicate.getField().getValue());
            decision = new NumericDecision(featureNumber, threshold, defaultDecision);

        } else {
            // Cateogrical decision
            Preconditions.checkArgument(predicate instanceof SimpleSetPredicate);
            SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) predicate;
            Preconditions.checkArgument(
                    simpleSetPredicate.getBooleanOperator() == SimpleSetPredicate.BooleanOperator.IS_IN);
            int featureNumber = columnNames.indexOf(simpleSetPredicate.getField().getValue());
            Map<String, Integer> categoryNameToID = columnToCategoryNameToIDMapping.get(featureNumber);
            String[] categories = DelimitedDataUtils.decode(simpleSetPredicate.getArray().getValue(), ' ');
            BitSet activeCategories = new BitSet(categoryNameToID.size());
            for (String category : categories) {
                int categoryID = categoryNameToID.get(category);
                activeCategories.set(categoryID);
            }
            decision = new CategoricalDecision(featureNumber, activeCategories, defaultDecision);
        }

        return new DecisionNode(decision,
                translateFromPMML(negativeLeftChild, columnToCategoryNameToIDMapping, settings),
                translateFromPMML(positiveRightChild, columnToCategoryNameToIDMapping, settings));
    }

}