com.cloudera.oryx.app.pmml.AppPMMLUtils.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.oryx.app.pmml.AppPMMLUtils.java

Source

/*
 * Copyright (c) 2014, Cloudera and Intel, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.oryx.app.pmml;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.google.common.base.Preconditions;
import com.google.common.io.CharStreams;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.dmg.pmml.Array;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Extension;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.Value;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.cloudera.oryx.app.schema.CategoricalValueEncodings;
import com.cloudera.oryx.app.schema.InputSchema;
import com.cloudera.oryx.common.pmml.PMMLUtils;
import com.cloudera.oryx.common.text.TextUtils;

/**
 * General app tier PMML-related utility methods.
 */
public final class AppPMMLUtils {

    private static final Logger log = LoggerFactory.getLogger(AppPMMLUtils.class);

    private AppPMMLUtils() {
    }

    public static String getExtensionValue(PMML pmml, String name) {
        return pmml.getExtensions().stream().filter(extension -> name.equals(extension.getName())).findFirst()
                .map(Extension::getValue).orElse(null);
    }

    /**
     * @param pmml PMML model to query for extensions
     * @param name name of extension to query
     * @return content of the extension, parsed as if it were a PMML {@link Array}:
     *  space-separated values, with PMML quoting rules
     */
    public static List<String> getExtensionContent(PMML pmml, String name) {
        return pmml.getExtensions().stream().filter(extension -> name.equals(extension.getName())).findFirst()
                .map(extension -> {
                    List<?> content = extension.getContent();
                    Preconditions.checkArgument(content.size() <= 1);
                    return content.isEmpty() ? Collections.<String>emptyList()
                            : Arrays.asList(TextUtils.parsePMMLDelimited(content.get(0).toString()));
                }).orElse(null);
    }

    /**
     * @param pmml PMML model to add extension to, with no content. It may possibly duplicate
     *  existing extensions.
     * @param key extension key
     * @param value extension value
     */
    public static void addExtension(PMML pmml, String key, Object value) {
        pmml.addExtensions(new Extension().setName(key).setValue(value.toString()));
    }

    /**
     * @param pmml PMML model to add extension to, with a single {@code String} content and no value.
     *  The content is encoded as if they were being added to a PMML {@link Array} and are
     *  space-separated with PMML quoting rules
     * @param key extension key
     * @param content list of values to add as a {@code String}
     */
    public static void addExtensionContent(PMML pmml, String key, Collection<?> content) {
        if (content.isEmpty()) {
            return;
        }
        String joined = TextUtils.joinPMMLDelimited(content);
        pmml.addExtensions(new Extension().setName(key).addContent(joined));
    }

    /**
     * @param values {@code double} value to make into a PMML {@link Array}
     * @return PMML {@link Array} representation
     */
    public static Array toArray(double... values) {
        List<Double> valueList = new ArrayList<>(values.length);
        for (double value : values) {
            valueList.add(value);
        }
        String arrayValue = TextUtils.joinPMMLDelimitedNumbers(valueList);
        return new Array(Array.Type.REAL, arrayValue).setN(valueList.size());
    }

    /**
     * @param schema {@link InputSchema} whose information should be encoded in PMML
     * @return a {@link MiningSchema} representing the information contained in an
     *  {@link InputSchema}
     */
    public static MiningSchema buildMiningSchema(InputSchema schema) {
        return buildMiningSchema(schema, null);
    }

    /**
     * @param schema {@link InputSchema} whose information should be encoded in PMML
     * @param importances optional feature importances. May be {@code null}, or else the size
     *  of the array must match the number of predictors in the schema, which may be
     *  less than the total number of features.
     * @return a {@link MiningSchema} representing the information contained in an
     *  {@link InputSchema}
     */
    public static MiningSchema buildMiningSchema(InputSchema schema, double[] importances) {
        Preconditions.checkArgument(importances == null || (importances.length == schema.getNumPredictors()));
        List<String> featureNames = schema.getFeatureNames();
        List<MiningField> miningFields = new ArrayList<>();
        for (int featureIndex = 0; featureIndex < featureNames.size(); featureIndex++) {
            String featureName = featureNames.get(featureIndex);
            MiningField field = new MiningField(FieldName.create(featureName));
            if (schema.isNumeric(featureName)) {
                field.setOpType(OpType.CONTINUOUS);
                field.setUsageType(MiningField.UsageType.ACTIVE);
            } else if (schema.isCategorical(featureName)) {
                field.setOpType(OpType.CATEGORICAL);
                field.setUsageType(MiningField.UsageType.ACTIVE);
            } else {
                // ID, or ignored
                field.setUsageType(MiningField.UsageType.SUPPLEMENTARY);
            }
            if (schema.hasTarget() && schema.isTarget(featureName)) {
                // Override to PREDICTED
                field.setUsageType(MiningField.UsageType.PREDICTED);
            }
            // Will be active if and only if it's a predictor
            if (field.getUsageType() == MiningField.UsageType.ACTIVE && importances != null) {
                int predictorIndex = schema.featureToPredictorIndex(featureIndex);
                field.setImportance(importances[predictorIndex]);
            }
            miningFields.add(field);
        }
        return new MiningSchema(miningFields);
    }

    /**
     * @param miningSchema {@link MiningSchema} from a model
     * @return names of features in order
     */
    public static List<String> getFeatureNames(MiningSchema miningSchema) {
        return miningSchema.getMiningFields().stream().map(field -> field.getName().getValue())
                .collect(Collectors.toList());
    }

    /**
     * @param miningSchema {@link MiningSchema} from a model
     * @return index of the {@link MiningField.UsageType#PREDICTED} feature
     */
    public static Integer findTargetIndex(MiningSchema miningSchema) {
        List<MiningField> miningFields = miningSchema.getMiningFields();
        for (int i = 0; i < miningFields.size(); i++) {
            if (miningFields.get(i).getUsageType() == MiningField.UsageType.PREDICTED) {
                return i;
            }
        }
        return null;
    }

    public static DataDictionary buildDataDictionary(InputSchema schema,
            CategoricalValueEncodings categoricalValueEncodings) {
        List<String> featureNames = schema.getFeatureNames();

        List<DataField> dataFields = new ArrayList<>();
        for (int featureIndex = 0; featureIndex < featureNames.size(); featureIndex++) {
            String featureName = featureNames.get(featureIndex);
            OpType opType;
            DataType dataType;
            if (schema.isNumeric(featureName)) {
                opType = OpType.CONTINUOUS;
                dataType = DataType.DOUBLE;
            } else if (schema.isCategorical(featureName)) {
                opType = OpType.CATEGORICAL;
                dataType = DataType.STRING;
            } else {
                // Don't know
                opType = null;
                dataType = null;
            }
            DataField field = new DataField(FieldName.create(featureName), opType, dataType);
            if (schema.isCategorical(featureName)) {
                categoricalValueEncodings.getEncodingValueMap(featureIndex).entrySet().stream()
                        .sorted(Comparator.comparing(Map.Entry::getKey)).map(Map.Entry::getValue)
                        .forEach(value -> field.addValues(new Value(value)));
            }
            dataFields.add(field);
        }

        return new DataDictionary(dataFields).setNumberOfFields(dataFields.size());
    }

    /**
     * @param dictionary {@link DataDictionary} from model
     * @return names of features in order
     */
    public static List<String> getFeatureNames(DataDictionary dictionary) {
        List<DataField> dataFields = dictionary.getDataFields();
        Preconditions.checkArgument(dataFields != null && !dataFields.isEmpty(), "No fields in DataDictionary");
        return dataFields.stream().map(field -> field.getName().getValue()).collect(Collectors.toList());
    }

    public static CategoricalValueEncodings buildCategoricalValueEncodings(DataDictionary dictionary) {
        Map<Integer, Collection<String>> indexToValues = new HashMap<>();
        List<DataField> dataFields = dictionary.getDataFields();
        for (int featureIndex = 0; featureIndex < dataFields.size(); featureIndex++) {
            TypeDefinitionField field = dataFields.get(featureIndex);
            Collection<Value> values = field.getValues();
            if (values != null && !values.isEmpty()) {
                Collection<String> categoricalValues = values.stream().map(Value::getValue)
                        .collect(Collectors.toList());
                indexToValues.put(featureIndex, categoricalValues);
            }
        }
        return new CategoricalValueEncodings(indexToValues);
    }

    public static PMML readPMMLFromUpdateKeyMessage(String key, String message, Configuration hadoopConf)
            throws IOException {
        String pmmlString;
        switch (key) {
        case "MODEL":
            pmmlString = message;
            break;
        case "MODEL-REF":
            // Allowing null is mostly for integration tests
            if (hadoopConf == null) {
                hadoopConf = new Configuration();
            }
            Path messagePath = new Path(message);
            FileSystem fs = FileSystem.get(messagePath.toUri(), hadoopConf);
            try (InputStreamReader in = new InputStreamReader(fs.open(messagePath), StandardCharsets.UTF_8)) {
                pmmlString = CharStreams.toString(in);
            } catch (FileNotFoundException fnfe) {
                log.warn("Unable to load model file at {}; ignoring", messagePath);
                return null;
            }
            break;
        default:
            throw new IllegalArgumentException("Unknown key " + key);
        }
        return PMMLUtils.fromString(pmmlString);
    }

}