org.jpmml.evaluator.TargetUtil.java Source code

Java tutorial

Introduction

Here is the source code for org.jpmml.evaluator.TargetUtil.java

Source

/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.jpmml.evaluator;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.Target;
import org.dmg.pmml.TargetValue;
import org.dmg.pmml.Targets;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

public class TargetUtil {

    private TargetUtil() {
    }

    static public Map<FieldName, ? extends Number> evaluateRegression(Double value,
            ModelEvaluationContext context) {
        ModelEvaluator<?> modelEvaluator = context.getModelEvaluator();

        return evaluateRegression(Collections.singletonMap(modelEvaluator.getTargetField(), value), context);
    }

    static public Map<FieldName, ? extends Number> evaluateRegressionDefault(ModelEvaluationContext context) {
        return evaluateRegression((Double) null, context);
    }

    /**
     * Evaluates the {@link Targets} element for {@link MiningFunctionType#REGRESSION regression} models.
     */
    static public Map<FieldName, ? extends Number> evaluateRegression(Map<FieldName, ? extends Number> predictions,
            ModelEvaluationContext context) {
        ModelEvaluator<?> modelEvaluator = context.getModelEvaluator();

        Targets targets = modelEvaluator.getTargets();
        if (targets == null || Iterables.isEmpty(targets)) {
            return predictions;
        }

        Map<FieldName, Number> result = Maps.newLinkedHashMap();

        Collection<? extends Map.Entry<FieldName, ? extends Number>> entries = predictions.entrySet();
        for (Map.Entry<FieldName, ? extends Number> entry : entries) {
            FieldName key = entry.getKey();
            Number value = entry.getValue();

            Target target = modelEvaluator.getTarget(key);
            if (target != null) {

                if (value == null) {
                    value = getDefaultValue(target);
                } // End if

                if (value != null) {
                    value = processValue(target, value);
                }
            }

            result.put(key, value);
        }

        return result;
    }

    static public Map<FieldName, ? extends ClassificationMap<?>> evaluateClassification(ClassificationMap<?> value,
            ModelEvaluationContext context) {
        ModelEvaluator<?> modelEvaluator = context.getModelEvaluator();

        return evaluateClassification(Collections.singletonMap(modelEvaluator.getTargetField(), value), context);
    }

    static public Map<FieldName, ? extends ClassificationMap<?>> evaluateClassificationDefault(
            ModelEvaluationContext context) {
        return evaluateClassification((ClassificationMap<?>) null, context);
    }

    /**
     * Evaluates the {@link Targets} element for {@link MiningFunctionType#CLASSIFICATION classification} models.
     */
    static public Map<FieldName, ? extends ClassificationMap<?>> evaluateClassification(
            Map<FieldName, ? extends ClassificationMap<?>> predictions, ModelEvaluationContext context) {
        ModelEvaluator<?> modelEvaluator = context.getModelEvaluator();

        Targets targets = modelEvaluator.getTargets();
        if (targets == null || Iterables.isEmpty(targets)) {
            return predictions;
        }

        Map<FieldName, ClassificationMap<?>> result = Maps.newLinkedHashMap();

        Collection<? extends Map.Entry<FieldName, ? extends ClassificationMap<?>>> entries = predictions.entrySet();
        for (Map.Entry<FieldName, ? extends ClassificationMap<?>> entry : entries) {
            FieldName key = entry.getKey();
            ClassificationMap<?> value = entry.getValue();

            Target target = modelEvaluator.getTarget(key);
            if (target != null) {

                if (value == null) {
                    value = getPriorProbabilities(target);
                }
            }

            result.put(key, value);
        }

        return result;
    }

    static public Number processValue(Target target, Number value) {
        double result = value.doubleValue();

        Double min = target.getMin();
        if (min != null) {
            result = Math.max(result, min.doubleValue());
        }

        Double max = target.getMax();
        if (max != null) {
            result = Math.min(result, max.doubleValue());
        }

        result = (result * target.getRescaleFactor()) + target.getRescaleConstant();

        Target.CastInteger castInteger = target.getCastInteger();
        if (castInteger == null) {
            return result;
        }

        switch (castInteger) {
        case ROUND:
            return (int) Math.round(result);
        case CEILING:
            return (int) Math.ceil(result);
        case FLOOR:
            return (int) Math.floor(result);
        default:
            throw new UnsupportedFeatureException(target, castInteger);
        }
    }

    static public TargetValue getTargetValue(Target target, Object value) {
        DataType dataType = TypeUtil.getDataType(value);

        List<TargetValue> targetValues = target.getTargetValues();
        for (TargetValue targetValue : targetValues) {

            if (TypeUtil.equals(dataType, value, TypeUtil.parseOrCast(dataType, targetValue.getValue()))) {
                return targetValue;
            }
        }

        return null;
    }

    static private Double getDefaultValue(Target target) {
        List<TargetValue> values = target.getTargetValues();

        if (values.isEmpty()) {
            return null;
        } // End if

        if (values.size() != 1) {
            throw new InvalidFeatureException(target);
        }

        TargetValue value = values.get(0);

        // "Attributes value and priorProbability are used only if the optype of the field is categorical or ordinal"
        if (value.getValue() != null || value.getPriorProbability() != null) {
            throw new InvalidFeatureException(value);
        }

        return value.getDefaultValue();
    }

    static private ProbabilityClassificationMap<String> getPriorProbabilities(Target target) {
        ProbabilityClassificationMap<String> result = new ProbabilityClassificationMap<String>();

        List<TargetValue> values = target.getTargetValues();
        for (TargetValue value : values) {

            // "The attribute defaultValue is used only if the optype of the field is continuous"
            if (value.getDefaultValue() != null) {
                throw new InvalidFeatureException(value);
            }

            String targetCategory = value.getValue();
            Double probability = value.getPriorProbability();

            if (targetCategory == null || probability == null) {
                continue;
            }

            result.put(targetCategory, probability);
        }

        if (result.isEmpty()) {
            return null;
        }

        return result;
    }
}