org.jpmml.evaluator.OutputUtil.java Source code

Java tutorial

Introduction

Here is the source code for org.jpmml.evaluator.OutputUtil.java

Source

/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.jpmml.evaluator;

import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;

import com.google.common.base.Joiner;
import com.google.common.collect.BiMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.dmg.pmml.AssociationRule;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Item;
import org.dmg.pmml.ItemRef;
import org.dmg.pmml.Itemset;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.ResultFeatureType;
import org.dmg.pmml.RuleFeatureType;
import org.dmg.pmml.Target;
import org.dmg.pmml.TargetValue;
import org.dmg.pmml.Value;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

public class OutputUtil {

    private OutputUtil() {
    }

    /**
     * Evaluates the {@link Output} element.
     *
     * @param predictions Map of {@link Evaluator#getTargetFields() target field} values.
     *
     * @return Map of {@link Evaluator#getTargetFields() target field} values together with {@link Evaluator#getOutputFields() output field} values.
     */
    @SuppressWarnings(value = { "fallthrough" })
    static public Map<FieldName, ?> evaluate(Map<FieldName, ?> predictions, ModelEvaluationContext context) {
        ModelEvaluator<?> modelEvaluator = context.getModelEvaluator();

        Output output = modelEvaluator.getOutput();
        if (output == null) {
            return predictions;
        }

        Map<FieldName, Object> result = Maps.newLinkedHashMap(predictions);

        List<OutputField> outputFields = output.getOutputFields();

        outputFields: for (OutputField outputField : outputFields) {
            Map<FieldName, ?> segmentPredictions = predictions;

            String segmentId = outputField.getSegmentId();
            if (segmentId != null) {
                MiningModelEvaluationContext miningModelContext = (MiningModelEvaluationContext) context;

                segmentPredictions = miningModelContext.getResult(segmentId);
            } // End if

            // "If there is no Segment matching segmentId or if the predicate of the matching Segment evaluated to false, then the result delivered by this OutputField is missing"
            if (segmentPredictions == null) {
                continue outputFields;
            }

            // "Attribute targetField is required in case the model has multiple target fields"
            FieldName targetField = outputField.getTargetField();
            if (targetField == null) {
                targetField = modelEvaluator.getTargetField();
            }

            Object value = null;

            ResultFeatureType resultFeature = getResultFeatureType(outputField);

            // Load the mining result
            switch (resultFeature) {
            case ENTITY_ID: {
                if (isSegmentId(segmentPredictions, outputField)) {
                    break;
                }
            }
            // Falls through
            case PREDICTED_VALUE:
            case PREDICTED_DISPLAY_VALUE:
            case PROBABILITY:
            case RESIDUAL:
            case CLUSTER_ID:
            case AFFINITY:
            case ENTITY_AFFINITY:
            case CLUSTER_AFFINITY:
            case REASON_CODE:
            case RULE_VALUE:
            case ANTECEDENT:
            case CONSEQUENT:
            case RULE:
            case RULE_ID:
            case CONFIDENCE:
            case SUPPORT:
            case LIFT:
            case LEVERAGE: {
                if (!segmentPredictions.containsKey(targetField)) {
                    throw new MissingFieldException(targetField, outputField);
                }

                // A target value could be either simple or complex values
                value = segmentPredictions.get(targetField);

                // If the target value is missing, then the result delivered by this OutputField is missing
                if (value == null) {
                    continue outputFields;
                }
            }
                break;
            default:
                break;
            } // End switch

            // Perform the requested computation on the mining result
            switch (resultFeature) {
            case PREDICTED_VALUE: {
                value = getPredictedValue(value);
            }
                break;
            case PREDICTED_DISPLAY_VALUE: {
                Target target = modelEvaluator.getTarget(targetField);

                DataField dataField = modelEvaluator.getDataField(targetField);

                value = getPredictedDisplayValue(value, target, dataField);
            }
                break;
            case TRANSFORMED_VALUE:
            case DECISION: {
                if (segmentId != null) {
                    throw new UnsupportedFeatureException(outputField);
                }

                Expression expression = outputField.getExpression();
                if (expression == null) {
                    throw new InvalidFeatureException(outputField);
                }

                value = FieldValueUtil.getValue(ExpressionUtil.evaluate(expression, context));
            }
                break;
            case PROBABILITY: {
                value = getProbability(value, outputField);
            }
                break;
            case RESIDUAL: {
                FieldValue expectedValue = context.getField(targetField);
                if (expectedValue == null) {
                    throw new MissingFieldException(targetField, outputField);
                }

                DataField dataField = modelEvaluator.getDataField(targetField);

                OpType opType = dataField.getOptype();
                switch (opType) {
                case CONTINUOUS:
                    value = getContinuousResidual(value, expectedValue);
                    break;
                case CATEGORICAL:
                    value = getCategoricalResidual(value, expectedValue);
                    break;
                default:
                    throw new UnsupportedFeatureException(dataField, opType);
                }
            }
                break;
            case ENTITY_ID: {
                // "Result feature entityId returns the id of the winning segment"
                if (isSegmentId(segmentPredictions, outputField)) {
                    SegmentResultMap segmentResult = (SegmentResultMap) segmentPredictions;

                    value = segmentResult.getId();

                    break;
                }

                value = getEntityId(value, outputField);
            }
                break;
            case CLUSTER_ID: {
                value = getClusterId(value);
            }
                break;
            case AFFINITY:
            case ENTITY_AFFINITY: {
                value = getAffinity(value, outputField);
            }
                break;
            case CLUSTER_AFFINITY: {
                value = getClusterAffinity(value);
            }
                break;
            case REASON_CODE: {
                value = getReasonCode(value, outputField);
            }
                break;
            case RULE_VALUE: {
                value = getRuleValue(value, outputField);
            }
                break;
            case ANTECEDENT: {
                value = getRuleValue(value, outputField, RuleFeatureType.ANTECEDENT);
            }
                break;
            case CONSEQUENT: {
                value = getRuleValue(value, outputField, RuleFeatureType.CONSEQUENT);
            }
                break;
            case RULE: {
                value = getRuleValue(value, outputField, RuleFeatureType.RULE);
            }
                break;
            case RULE_ID: {
                value = getRuleValue(value, outputField, RuleFeatureType.RULE_ID);
            }
                break;
            case CONFIDENCE: {
                value = getRuleValue(value, outputField, RuleFeatureType.CONFIDENCE);
            }
                break;
            case SUPPORT: {
                value = getRuleValue(value, outputField, RuleFeatureType.SUPPORT);
            }
                break;
            case LIFT: {
                value = getRuleValue(value, outputField, RuleFeatureType.LIFT);
            }
                break;
            case LEVERAGE: {
                value = getRuleValue(value, outputField, RuleFeatureType.LEVERAGE);
            }
                break;
            case WARNING: {
                value = context.getWarnings();
            }
                break;
            default:
                throw new UnsupportedFeatureException(outputField, resultFeature);
            }

            // The result of one output field becomes available to other other output fields
            context.declare(outputField.getName(), FieldValueUtil.create(outputField, value));

            result.put(outputField.getName(), value);
        }

        return result;
    }

    static public DataType getDataType(OutputField outputField, ModelEvaluator<?> modelEvaluator) {
        DataType dataType = outputField.getDataType();

        if (dataType != null) {
            return dataType;
        }

        String segmentId = outputField.getSegmentId();
        if (segmentId != null) {
            throw new TypeAnalysisException(outputField);
        }

        ResultFeatureType resultFeature = getResultFeatureType(outputField);

        switch (resultFeature) {
        case PREDICTED_VALUE: {
            FieldName targetField = outputField.getTargetField();
            if (targetField == null) {
                targetField = modelEvaluator.getTargetField();
            }

            DataField dataField = modelEvaluator.getDataField(targetField);

            return dataField.getDataType();
        }
        case PREDICTED_DISPLAY_VALUE: {
            return DataType.STRING; // XXX
        }
        case TRANSFORMED_VALUE:
        case DECISION: {
            throw new TypeAnalysisException(outputField);
        }
        case PROBABILITY:
        case RESIDUAL:
        case STANDARD_ERROR: {
            return DataType.DOUBLE;
        }
        case ENTITY_ID:
        case CLUSTER_ID: {
            return DataType.STRING;
        }
        case AFFINITY:
        case ENTITY_AFFINITY:
        case CLUSTER_AFFINITY: {
            return DataType.DOUBLE;
        }
        case REASON_CODE: {
            return DataType.STRING;
        }
        case RULE_VALUE: {
            return getRuleDataType(outputField);
        }
        case ANTECEDENT: {
            return getRuleDataType(outputField, RuleFeatureType.ANTECEDENT);
        }
        case CONSEQUENT: {
            return getRuleDataType(outputField, RuleFeatureType.CONSEQUENT);
        }
        case RULE: {
            return getRuleDataType(outputField, RuleFeatureType.RULE);
        }
        case RULE_ID: {
            return getRuleDataType(outputField, RuleFeatureType.RULE_ID);
        }
        case SUPPORT: {
            return getRuleDataType(outputField, RuleFeatureType.SUPPORT);
        }
        case CONFIDENCE: {
            return getRuleDataType(outputField, RuleFeatureType.CONFIDENCE);
        }
        case LIFT: {
            return getRuleDataType(outputField, RuleFeatureType.LIFT);
        }
        case LEVERAGE: {
            return getRuleDataType(outputField, RuleFeatureType.LEVERAGE);
        }
        case WARNING: {
            throw new TypeAnalysisException(outputField);
        }
        default:
            throw new UnsupportedFeatureException(outputField, resultFeature);
        }
    }

    static private ResultFeatureType getResultFeatureType(OutputField outputField) {
        ResultFeatureType resultFeature = outputField.getFeature();

        // "If the attribute feature is not specified, then the output value is a copy of the target field value"
        if (resultFeature == null) {
            resultFeature = ResultFeatureType.PREDICTED_VALUE;
        }

        return resultFeature;
    }

    static private boolean isSegmentId(Map<FieldName, ?> predictions, OutputField outputField) {
        FieldName targetField = outputField.getTargetField();

        if (targetField == null) {
            return (predictions instanceof SegmentResultMap);
        }

        return false;
    }

    static private Object getPredictedValue(Object object) {
        return EvaluatorUtil.decode(object);
    }

    static private Object getPredictedDisplayValue(Object object, Target target, DataField dataField) {

        if (object instanceof HasDisplayValue) {
            HasDisplayValue hasDisplayValue = asResultFeature(HasDisplayValue.class, object);

            return hasDisplayValue.getDisplayValue();
        }

        object = getPredictedValue(object);

        if (target != null) {
            TargetValue targetValue = TargetUtil.getTargetValue(target, object);

            if (targetValue != null) {
                String displayValue = targetValue.getDisplayValue();

                if (displayValue != null) {
                    return displayValue;
                }
            }
        }

        OpType opType = dataField.getOptype();

        if ((OpType.CATEGORICAL).equals(opType) || (OpType.ORDINAL).equals(opType)) {
            Value value = ArgumentUtil.getValidValue(dataField, object);

            if (value != null) {
                String displayValue = value.getDisplayValue();

                if (displayValue != null) {
                    return displayValue;
                }
            }
        }

        // "If the display value is not specified explicitly, then the raw predicted value is used by default"
        return object;
    }

    static private Double getProbability(Object object, OutputField outputField) {
        HasProbability hasProbability = asResultFeature(HasProbability.class, object);

        String value = getCategoryValue(object, outputField);

        return hasProbability.getProbability(value);
    }

    static private String getCategoryValue(Object object, OutputField outputField) {
        String value = outputField.getValue();

        // "If the attribute value is not specified, then the predicted categorical value should be returned as a result"
        if (value == null) {
            return TypeUtil.format(getPredictedValue(object));
        }

        return value;
    }

    static private Double getContinuousResidual(Object object, FieldValue expectedObject) {
        Number value = (Number) getPredictedValue(object);
        Number expectedValue = (Number) FieldValueUtil.getValue(expectedObject);

        return Double.valueOf(expectedValue.doubleValue() - value.doubleValue());
    }

    static public Double getCategoricalResidual(Object object, FieldValue expectedObject) {
        HasProbability hasProbability = asResultFeature(HasProbability.class, object);

        String value = TypeUtil.format(getPredictedValue(object));
        String expectedValue = TypeUtil.format(FieldValueUtil.getValue(expectedObject));

        boolean equals = TypeUtil.equals(DataType.STRING, value, expectedValue);

        return Double.valueOf((equals ? 1d : 0d) - hasProbability.getProbability(value));
    }

    static private String getEntityId(Object object, OutputField outputField) {
        HasEntityId hasEntityId = asResultFeature(HasEntityId.class, object);

        int rank = outputField.getRank();
        if (rank <= 0) {
            throw new InvalidFeatureException(outputField);
        }

        if (rank > 1) {
            HasEntityIdRanking hasEntityIdRanking = asResultFeature(HasEntityIdRanking.class, object);

            return getElement(hasEntityIdRanking.getEntityIdRanking(), rank);
        }

        return hasEntityId.getEntityId();
    }

    static private String getClusterId(Object object) {
        HasClusterId hasClusterId = asResultFeature(HasClusterId.class, object);

        return hasClusterId.getClusterId();
    }

    static public Double getAffinity(Object object, OutputField outputField) {
        HasAffinity hasAffinity = asResultFeature(HasAffinity.class, object);

        int rank = outputField.getRank();
        if (rank <= 0) {
            throw new InvalidFeatureException(outputField);
        }

        if (rank > 1) {
            HasAffinityRanking hasAffinityRanking = asResultFeature(HasAffinityRanking.class, object);

            return getElement(hasAffinityRanking.getAffinityRanking(), rank);
        }

        String value = getCategoryValue(object, outputField);

        return hasAffinity.getAffinity(value);
    }

    static public Double getClusterAffinity(Object object) {
        HasClusterAffinity hasClusterAffinity = asResultFeature(HasClusterAffinity.class, object);

        return hasClusterAffinity.getClusterAffinity();
    }

    static public String getReasonCode(Object object, OutputField outputField) {
        HasReasonCodeRanking hasReasonCodeRanking = asResultFeature(HasReasonCodeRanking.class, object);

        int rank = outputField.getRank();
        if (rank <= 0) {
            throw new InvalidFeatureException(outputField);
        }

        return getElement(hasReasonCodeRanking.getReasonCodeRanking(), rank);
    }

    static public Object getRuleValue(Object object, OutputField outputField, RuleFeatureType ruleFeature) {
        HasRuleValues hasRuleValues = asResultFeature(HasRuleValues.class, object);

        List<AssociationRule> associationRules = hasRuleValues.getRuleValues(outputField.getAlgorithm());
        sortRules(associationRules, outputField);

        String isMultiValued = outputField.getIsMultiValued();
        if (!("0").equals(isMultiValued)) {
            throw new UnsupportedFeatureException(outputField);
        }

        int rank = outputField.getRank();
        if (rank <= 0) {
            throw new InvalidFeatureException(outputField);
        }

        AssociationRule associationRule = getElement(associationRules, rank);
        if (associationRule != null) {
            return getRuleFeature(hasRuleValues, associationRule, outputField, ruleFeature);
        }

        return null;
    }

    static public Object getRuleValue(Object object, OutputField outputField) {
        HasRuleValues hasRuleValues = asResultFeature(HasRuleValues.class, object);

        List<AssociationRule> associationRules = hasRuleValues.getRuleValues(outputField.getAlgorithm());
        sortRules(associationRules, outputField);

        String isMultiValued = outputField.getIsMultiValued();

        // Return a single result
        if (("0").equals(isMultiValued)) {
            int rank = outputField.getRank();
            if (rank <= 0) {
                throw new InvalidFeatureException(outputField);
            }

            AssociationRule associationRule = getElement(associationRules, rank);
            if (associationRule != null) {
                return getRuleFeature(hasRuleValues, associationRule, outputField);
            }

            return null;
        } else

        // Return multiple results
        if (("1").equals(isMultiValued)) {
            int size;

            int rank = outputField.getRank();
            if (rank < 0) {
                throw new InvalidFeatureException(outputField);
            } else

            // "A zero value indicates that all output values are to be returned"
            if (rank == 0) {
                size = associationRules.size();
            } else

            // "A positive value indicates the number of output values to be returned"
            {
                size = Math.min(rank, associationRules.size());
            }

            associationRules = associationRules.subList(0, size);

            List<Object> result = Lists.newArrayList();

            for (AssociationRule associationRule : associationRules) {
                result.add(getRuleFeature(hasRuleValues, associationRule, outputField));
            }

            return result;
        } else

        {
            throw new InvalidFeatureException(outputField);
        }
    }

    static private void sortRules(List<AssociationRule> associationRules, final OutputField outputField) {
        Comparator<AssociationRule> comparator = new Comparator<AssociationRule>() {

            private OutputField.RankBasis rankBasis = outputField.getRankBasis();

            private OutputField.RankOrder rankOrder = outputField.getRankOrder();

            @Override
            public int compare(AssociationRule left, AssociationRule right) {
                int order;

                switch (this.rankBasis) {
                case CONFIDENCE:
                    order = (getConfidence(left)).compareTo(getConfidence(right));
                    break;
                case SUPPORT:
                    order = (getSupport(left)).compareTo(getSupport(right));
                    break;
                case LIFT:
                    order = (getLift(left)).compareTo(getLift(right));
                    break;
                case LEVERAGE:
                    order = (getLeverage(left)).compareTo(getLeverage(right));
                    break;
                case AFFINITY:
                    order = (getAffinity(left)).compareTo(getAffinity(right));
                    break;
                default:
                    throw new UnsupportedFeatureException(outputField, this.rankBasis);
                } // End switch

                switch (this.rankOrder) {
                case ASCENDING:
                    return order;
                case DESCENDING:
                    return -order;
                default:
                    throw new UnsupportedFeatureException(outputField, this.rankOrder);
                }
            }

            private Double getConfidence(AssociationRule rule) {
                return checkRuleFeature(rule, rule.getConfidence());
            }

            private Double getSupport(AssociationRule rule) {
                return checkRuleFeature(rule, rule.getSupport());
            }

            private Double getLift(AssociationRule rule) {
                return checkRuleFeature(rule, rule.getLift());
            }

            private Double getLeverage(AssociationRule rule) {
                return checkRuleFeature(rule, rule.getLeverage());
            }

            private Double getAffinity(AssociationRule rule) {
                return checkRuleFeature(rule, rule.getAffinity());
            }

            private <V> V checkRuleFeature(AssociationRule rule, V value) {

                if (value == null) {
                    throw new InvalidFeatureException(rule);
                }

                return value;
            }
        };
        Collections.sort(associationRules, comparator);
    }

    static private Object getRuleFeature(HasRuleValues hasRuleValues, AssociationRule associationRule,
            OutputField outputField) {
        return getRuleFeature(hasRuleValues, associationRule, outputField, outputField.getRuleFeature());
    }

    static private Object getRuleFeature(HasRuleValues hasRuleValues, AssociationRule associationRule,
            PMMLObject element, RuleFeatureType ruleFeature) {

        switch (ruleFeature) {
        case ANTECEDENT:
            return getItemValues(hasRuleValues, associationRule.getAntecedent());
        case CONSEQUENT:
            return getItemValues(hasRuleValues, associationRule.getConsequent());
        case RULE: {
            Joiner joiner = Joiner.on(',');

            StringBuilder sb = new StringBuilder();

            String left = joiner.join(getItemValues(hasRuleValues, associationRule.getAntecedent()));
            sb.append('{').append(left).append('}');

            sb.append("->");

            String right = joiner.join(getItemValues(hasRuleValues, associationRule.getConsequent()));
            sb.append('{').append(right).append('}');

            return sb.toString();
        }
        case RULE_ID: {
            String id = associationRule.getId();
            if (id == null) {
                BiMap<String, AssociationRule> associationRuleRegistry = hasRuleValues.getAssociationRuleRegistry();

                id = (associationRuleRegistry.inverse()).get(associationRule);
            }

            return id;
        }
        case CONFIDENCE:
            return associationRule.getConfidence();
        case SUPPORT:
            return associationRule.getSupport();
        case LIFT:
            return associationRule.getLift();
        case LEVERAGE:
            return associationRule.getLeverage();
        case AFFINITY:
            return associationRule.getAffinity();
        default:
            throw new UnsupportedFeatureException(element, ruleFeature);
        }
    }

    static private DataType getRuleDataType(OutputField outputField) {
        return getRuleDataType(outputField, outputField.getRuleFeature());
    }

    static private DataType getRuleDataType(OutputField outputField, RuleFeatureType ruleFeature) {
        String isMultiValued = outputField.getIsMultiValued();
        if (!("0").equals(isMultiValued)) {
            throw new TypeAnalysisException(outputField);
        }

        switch (ruleFeature) {
        case ANTECEDENT:
        case CONSEQUENT:
            throw new TypeAnalysisException(outputField);
        case RULE:
        case RULE_ID:
            return DataType.STRING;
        case SUPPORT:
        case CONFIDENCE:
        case LIFT:
        case LEVERAGE:
        case AFFINITY:
            return DataType.DOUBLE;
        default:
            throw new UnsupportedFeatureException(outputField, ruleFeature);
        }
    }

    static private List<String> getItemValues(HasRuleValues hasRuleValues, String id) {
        List<String> result = Lists.newArrayList();

        BiMap<String, Item> itemRegistry = hasRuleValues.getItemRegistry();
        BiMap<String, Itemset> itemsetRegistry = hasRuleValues.getItemsetRegistry();

        Itemset itemset = itemsetRegistry.get(id);

        List<ItemRef> itemRefs = itemset.getItemRefs();
        for (ItemRef itemRef : itemRefs) {
            Item item = itemRegistry.get(itemRef.getItemRef());

            result.add(item.getValue());
        }

        return result;
    }

    static private <E extends ResultFeature> E asResultFeature(Class<? extends E> clazz, Object object) {

        if (!clazz.isInstance(object)) {
            throw new TypeCheckException(clazz, object);
        }

        return clazz.cast(object);
    }

    static private <E> E getElement(List<E> elements, int rank) {
        int index = (rank - 1);

        if (index < elements.size()) {
            return elements.get(index);
        }

        return null;
    }
}