org.jpmml.postgresql.PMMLUtil.java Source code

Java tutorial

Introduction

Here is the source code for org.jpmml.postgresql.PMMLUtil.java

Source

/*
 * Copyright (c) 2014 Villu Ruusmann
 *
 * This file is part of JPMML-PostgreSQL
 *
 * JPMML-PostgreSQL is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-PostgreSQL is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-PostgreSQL.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.jpmml.postgresql;

import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;

import com.google.common.cache.CacheBuilder;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.runtime.ModelEvaluatorCache;

public class PMMLUtil {

    private PMMLUtil() {
    }

    static public Object evaluateSimple(Class<?> clazz, Object request) throws Exception {

        if (request == null) {
            return null;
        }

        Evaluator evaluator = getEvaluator(clazz);

        Map<FieldName, ?> arguments = loadArguments(evaluator, request);

        Map<FieldName, ?> result = evaluator.evaluate(arguments);

        Object targetValue = result.get(evaluator.getTargetField());

        return EvaluatorUtil.decode(targetValue);
    }

    static public boolean evaluateComplex(Class<?> clazz, Object request, ResultSet response) throws Exception {

        if (request == null) {
            return false;
        }

        Evaluator evaluator = getEvaluator(clazz);

        Map<FieldName, ?> arguments = loadArguments(evaluator, request);

        Map<FieldName, ?> result = evaluator.evaluate(arguments);

        storeResult(evaluator, result, response);

        return true;
    }

    static private Map<FieldName, FieldValue> loadArguments(Evaluator evaluator, Object request) throws Exception {

        if (request instanceof ResultSet) {
            return loadStruct(evaluator, (ResultSet) request);
        }

        return loadScalarList(evaluator, (List<?>) request);
    }

    static private Map<FieldName, FieldValue> loadStruct(Evaluator evaluator, ResultSet request)
            throws SQLException {
        Map<FieldName, FieldValue> result = Maps.newLinkedHashMap();

        Map<String, Integer> columns = parseColumns(request);

        Iterable<FieldName> fields = evaluator.getActiveFields();
        for (FieldName field : fields) {
            String label = normalize(field.getValue());

            Integer column = columns.get(label);
            if (column == null) {
                continue;
            }

            FieldValue value = EvaluatorUtil.prepare(evaluator, field, request.getObject(column));

            result.put(field, value);
        }

        return result;
    }

    static private Map<FieldName, FieldValue> loadScalarList(Evaluator evaluator, List<?> request) {
        Map<FieldName, FieldValue> result = Maps.newLinkedHashMap();

        List<FieldName> fields = evaluator.getActiveFields();
        if (fields.size() != request.size()) {
            throw new IllegalArgumentException("Invalid number of arguments");
        }

        int i = 0;

        for (FieldName field : fields) {
            FieldValue value = EvaluatorUtil.prepare(evaluator, field, request.get(i));

            result.put(field, value);

            i++;
        }

        return result;
    }

    static private void storeResult(Evaluator evaluator, Map<FieldName, ?> result, ResultSet response)
            throws SQLException {
        Map<String, Integer> columns = parseColumns(response);

        Iterable<FieldName> fields = Iterables.concat(evaluator.getTargetFields(), evaluator.getOutputFields());
        for (FieldName field : fields) {
            String label = normalize(field.getValue());

            Integer column = columns.get(label);
            if (column == null) {
                continue;
            }

            Object value = EvaluatorUtil.decode(result.get(field));
            if (value != null) {
                response.updateObject(column, value);
            } else

            {
                response.updateNull(column);
            }
        }
    }

    static private Map<String, Integer> parseColumns(ResultSet resultSet) throws SQLException {
        Map<String, Integer> result = Maps.newLinkedHashMap();

        ResultSetMetaData metaData = resultSet.getMetaData();

        for (int column = 1; column <= metaData.getColumnCount(); column++) {
            String label = normalize(metaData.getColumnLabel(column));

            Integer previousColumn = result.put(label, column);
            if (previousColumn != null) {
                throw new SQLException("Duplicate column label \"" + label + "\"");
            }
        }

        return result;
    }

    static private String normalize(String label) {
        return label.toLowerCase();
    }

    static private Evaluator getEvaluator(Class<?> clazz) throws Exception {
        return PMMLUtil.evaluatorCache.get(clazz);
    }

    private static final ModelEvaluatorCache evaluatorCache = new ModelEvaluatorCache(CacheBuilder.newBuilder());
}