org.jpmml.evaluator.ScorecardEvaluator.java Source code

Java tutorial

Introduction

Here is the source code for org.jpmml.evaluator.ScorecardEvaluator.java

Source

/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.jpmml.evaluator;

import java.util.Collections;
import java.util.List;
import java.util.Map;

import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import org.dmg.pmml.Attribute;
import org.dmg.pmml.Characteristic;
import org.dmg.pmml.Characteristics;
import org.dmg.pmml.ComplexPartialScore;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.Scorecard;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

public class ScorecardEvaluator extends ModelEvaluator<Scorecard> {

    public ScorecardEvaluator(PMML pmml) {
        this(pmml, find(pmml.getModels(), Scorecard.class));
    }

    public ScorecardEvaluator(PMML pmml, Scorecard scorecard) {
        super(pmml, scorecard);
    }

    @Override
    public String getSummary() {
        return "Scorecard";
    }

    @Override
    public Map<FieldName, ?> evaluate(ModelEvaluationContext context) {
        Scorecard scorecard = getModel();
        if (!scorecard.isScorable()) {
            throw new InvalidResultException(scorecard);
        }

        Map<FieldName, ?> predictions;

        MiningFunctionType miningFunction = scorecard.getFunctionName();
        switch (miningFunction) {
        case REGRESSION:
            predictions = evaluateRegression(context);
            break;
        default:
            throw new UnsupportedFeatureException(scorecard, miningFunction);
        }

        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext context) {
        Scorecard scorecard = getModel();

        double score = scorecard.getInitialScore();

        boolean useReasonCodes = scorecard.isUseReasonCodes();

        VoteCounter<String> reasonCodePoints = new VoteCounter<String>();

        Characteristics characteristics = scorecard.getCharacteristics();
        for (Characteristic characteristic : characteristics) {
            Double baselineScore = characteristic.getBaselineScore();
            if (baselineScore == null) {
                baselineScore = scorecard.getBaselineScore();
            } // End if

            if (useReasonCodes) {

                if (baselineScore == null) {
                    throw new InvalidFeatureException(characteristic);
                }
            }

            boolean hasTrueAttribute = false;

            List<Attribute> attributes = characteristic.getAttributes();
            for (Attribute attribute : attributes) {
                Predicate predicate = attribute.getPredicate();
                if (predicate == null) {
                    throw new InvalidFeatureException(attribute);
                }

                Boolean status = PredicateUtil.evaluate(predicate, context);
                if (status == null || !status.booleanValue()) {
                    continue;
                }

                Double partialScore = null;

                ComplexPartialScore complexPartialScore = attribute.getComplexPartialScore();
                if (complexPartialScore != null) {
                    Expression expression = complexPartialScore.getExpression();
                    if (expression == null) {
                        throw new InvalidFeatureException(complexPartialScore);
                    }

                    FieldValue computedValue = ExpressionUtil.evaluate(expression, context);
                    if (computedValue == null) {
                        return TargetUtil.evaluateRegressionDefault(context);
                    }

                    partialScore = (computedValue.asNumber()).doubleValue();
                } else

                {
                    partialScore = attribute.getPartialScore();
                } // End if

                if (partialScore == null) {
                    throw new InvalidFeatureException(attribute);
                }

                score += partialScore.doubleValue();

                String reasonCode = attribute.getReasonCode();
                if (reasonCode == null) {
                    reasonCode = characteristic.getReasonCode();
                } // End if

                if (useReasonCodes) {

                    if (reasonCode == null) {
                        throw new InvalidFeatureException(attribute);
                    }

                    Double difference;

                    Scorecard.ReasonCodeAlgorithm reasonCodeAlgorithm = scorecard.getReasonCodeAlgorithm();
                    switch (reasonCodeAlgorithm) {
                    case POINTS_ABOVE:
                        difference = (partialScore - baselineScore);
                        break;
                    case POINTS_BELOW:
                        difference = (baselineScore - partialScore);
                        break;
                    default:
                        throw new UnsupportedFeatureException(scorecard, reasonCodeAlgorithm);
                    }

                    reasonCodePoints.increment(reasonCode, difference);
                }

                hasTrueAttribute = true;

                break;
            }

            // "If not even a single Attribute evaluates to "true" for a given Characteristic, the scorecard as a whole returns an invalid value"
            if (!hasTrueAttribute) {
                throw new InvalidResultException(characteristic);
            }
        }

        Map<FieldName, ? extends Number> result = TargetUtil.evaluateRegression(score, context);

        if (useReasonCodes) {
            Map.Entry<FieldName, ? extends Number> resultEntry = Iterables.getOnlyElement(result.entrySet());

            return Collections.singletonMap(resultEntry.getKey(),
                    createScoreMap(resultEntry.getValue(), reasonCodePoints));
        }

        return result;
    }

    static private ScoreClassificationMap createScoreMap(Number value, Map<String, Double> reasonCodePoints) {
        ScoreClassificationMap result = new ScoreClassificationMap(value);

        // Filter out meaningless (ie. negative values) explanations
        com.google.common.base.Predicate<Map.Entry<String, Double>> predicate = new com.google.common.base.Predicate<Map.Entry<String, Double>>() {

            @Override
            public boolean apply(Map.Entry<String, Double> entry) {
                return Double.compare(entry.getValue(), 0) >= 0;
            }
        };
        result.putAll(Maps.filterEntries(reasonCodePoints, predicate));

        return result;
    }
}