org.jpmml.sparkml.ConverterUtil.java Source code

Java tutorial

Introduction

Here is the source code for org.jpmml.sparkml.ConverterUtil.java

Source

/*
 * Copyright (c) 2016 Villu Ruusmann
 *
 * This file is part of JPMML-SparkML
 *
 * JPMML-SparkML is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-SparkML is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-SparkML.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.jpmml.sparkml;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import javax.xml.bind.JAXBException;

import com.google.common.collect.Iterables;
import org.apache.commons.io.output.ByteArrayOutputStream;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.feature.Binarizer;
import org.apache.spark.ml.feature.Bucketizer;
import org.apache.spark.ml.feature.ChiSqSelectorModel;
import org.apache.spark.ml.feature.ColumnPruner;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.MinMaxScalerModel;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.ml.feature.PCAModel;
import org.apache.spark.ml.feature.RFormulaModel;
import org.apache.spark.ml.feature.StandardScalerModel;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.feature.VectorAttributeRewriter;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.feature.VectorSlicer;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.GBTRegressionModel;
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.model.MetroJAXBUtil;
import org.jpmml.sparkml.feature.BinarizerConverter;
import org.jpmml.sparkml.feature.BucketizerConverter;
import org.jpmml.sparkml.feature.ChiSqSelectorModelConverter;
import org.jpmml.sparkml.feature.ColumnPrunerConverter;
import org.jpmml.sparkml.feature.IndexToStringConverter;
import org.jpmml.sparkml.feature.MinMaxScalerModelConverter;
import org.jpmml.sparkml.feature.OneHotEncoderConverter;
import org.jpmml.sparkml.feature.PCAModelConverter;
import org.jpmml.sparkml.feature.RFormulaModelConverter;
import org.jpmml.sparkml.feature.StandardScalerModelConverter;
import org.jpmml.sparkml.feature.StringIndexerModelConverter;
import org.jpmml.sparkml.feature.VectorAssemblerConverter;
import org.jpmml.sparkml.feature.VectorAttributeRewriterConverter;
import org.jpmml.sparkml.feature.VectorIndexerModelConverter;
import org.jpmml.sparkml.feature.VectorSlicerConverter;
import org.jpmml.sparkml.model.DecisionTreeClassificationModelConverter;
import org.jpmml.sparkml.model.DecisionTreeRegressionModelConverter;
import org.jpmml.sparkml.model.GBTClassificationModelConverter;
import org.jpmml.sparkml.model.GBTRegressionModelConverter;
import org.jpmml.sparkml.model.GeneralizedLinearRegressionModelConverter;
import org.jpmml.sparkml.model.KMeansModelConverter;
import org.jpmml.sparkml.model.LinearRegressionModelConverter;
import org.jpmml.sparkml.model.LogisticRegressionModelConverter;
import org.jpmml.sparkml.model.MultilayerPerceptronClassificationModelConverter;
import org.jpmml.sparkml.model.RandomForestClassificationModelConverter;
import org.jpmml.sparkml.model.RandomForestRegressionModelConverter;

public class ConverterUtil {

    private ConverterUtil() {
    }

    static public PMML toPMML(StructType schema, PipelineModel pipelineModel) {
        FeatureMapper featureMapper = new FeatureMapper(schema);

        Map<String, org.dmg.pmml.Model> models = new LinkedHashMap<>();

        Transformer[] stages = pipelineModel.stages();
        for (Transformer stage : stages) {
            TransformerConverter<?> converter = ConverterUtil.createConverter(stage);

            if (converter instanceof FeatureConverter) {
                FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;

                featureMapper.append(featureConverter);
            } else

            if (converter instanceof ModelConverter) {
                ModelConverter<?> modelConverter = (ModelConverter<?>) converter;

                Schema featureSchema = featureMapper.createSchema(modelConverter);

                org.dmg.pmml.Model model = modelConverter.encodeModel(featureSchema);

                featureMapper.append(modelConverter);

                HasPredictionCol hasPredictionCol = (HasPredictionCol) stage;

                models.put(hasPredictionCol.getPredictionCol(), model);
            } else

            {
                throw new IllegalArgumentException();
            }
        }

        org.dmg.pmml.Model rootModel;

        if (models.size() == 1) {
            rootModel = Iterables.getOnlyElement(models.values());
        } else

        if (models.size() >= 2) {
            List<MiningField> targetMiningFields = new ArrayList<>();

            List<Map.Entry<String, org.dmg.pmml.Model>> entries = new ArrayList<>(models.entrySet());
            for (Iterator<Map.Entry<String, org.dmg.pmml.Model>> entryIt = entries.iterator(); entryIt.hasNext();) {
                Map.Entry<String, org.dmg.pmml.Model> entry = entryIt.next();

                String predictionCol = entry.getKey();
                org.dmg.pmml.Model model = entry.getValue();

                MiningSchema miningSchema = model.getMiningSchema();

                List<MiningField> miningFields = miningSchema.getMiningFields();
                for (Iterator<MiningField> miningFieldIt = miningFields.iterator(); miningFieldIt.hasNext();) {
                    MiningField miningField = miningFieldIt.next();

                    MiningField.UsageType usageType = miningField.getUsageType();
                    switch (usageType) {
                    case PREDICTED:
                    case TARGET:
                        targetMiningFields.add(miningField);
                        break;
                    default:
                        break;
                    }
                }

                if (!entryIt.hasNext()) {
                    break;
                }

                FieldName name = FieldName.create(predictionCol);

                DataField dataField = featureMapper.getDataField(name);
                if (dataField == null) {
                    throw new IllegalArgumentException();
                }

                featureMapper.removeDataField(name);

                Output output = model.getOutput();
                if (output == null) {
                    output = new Output();

                    model.setOutput(output);
                }

                OutputField outputField = new OutputField(name, dataField.getDataType())
                        .setOpType(dataField.getOpType()).setResultFeature(ResultFeature.PREDICTED_VALUE);

                output.addOutputFields(outputField);
            }

            MiningSchema miningSchema = new MiningSchema(targetMiningFields);

            List<org.dmg.pmml.Model> memberModels = new ArrayList<>(models.values());

            MiningModel miningModel = MiningModelUtil
                    .createModelChain(null, Collections.<FieldName>emptyList(), memberModels)
                    .setMiningSchema(miningSchema);

            rootModel = miningModel;
        } else

        {
            throw new IllegalArgumentException();
        }

        PMML pmml = featureMapper.encodePMML(rootModel);

        return pmml;
    }

    static public byte[] toPMMLByteArray(StructType schema, PipelineModel pipelineModel) {
        PMML pmml = toPMML(schema, pipelineModel);

        ByteArrayOutputStream os = new ByteArrayOutputStream(1024 * 1024);

        try {
            MetroJAXBUtil.marshalPMML(pmml, os);
        } catch (JAXBException je) {
            throw new RuntimeException(je);
        }

        return os.toByteArray();
    }

    static public FeatureConverter<?> createFeatureConverter(Transformer transformer) {
        return (FeatureConverter<?>) createConverter(transformer);
    }

    static public ModelConverter<?> createModelConverter(Transformer transformer) {
        return (ModelConverter<?>) createConverter(transformer);
    }

    static public <T extends Transformer> TransformerConverter<T> createConverter(T transformer) {
        Class<? extends Transformer> clazz = transformer.getClass();

        Class<? extends TransformerConverter> converterClazz = getConverterClazz(clazz);
        if (converterClazz == null) {
            throw new IllegalArgumentException("Transformer class " + clazz.getName() + " is not supported");
        }

        try {
            Constructor<?> constructor = converterClazz.getDeclaredConstructor(clazz);

            return (TransformerConverter) constructor.newInstance(transformer);
        } catch (Exception e) {
            throw new IllegalArgumentException(e);
        }
    }

    static public Class<? extends TransformerConverter> getConverterClazz(Class<? extends Transformer> clazz) {
        return ConverterUtil.converters.get(clazz);
    }

    static public void putConverterClazz(Class<? extends Transformer> clazz,
            Class<? extends TransformerConverter<?>> converterClazz) {
        ConverterUtil.converters.put(clazz, converterClazz);
    }

    private static final Map<Class<? extends Transformer>, Class<? extends TransformerConverter>> converters = new LinkedHashMap<>();

    static {
        // Features
        converters.put(Binarizer.class, BinarizerConverter.class);
        converters.put(Bucketizer.class, BucketizerConverter.class);
        converters.put(ChiSqSelectorModel.class, ChiSqSelectorModelConverter.class);
        converters.put(ColumnPruner.class, ColumnPrunerConverter.class);
        converters.put(IndexToString.class, IndexToStringConverter.class);
        converters.put(MinMaxScalerModel.class, MinMaxScalerModelConverter.class);
        converters.put(OneHotEncoder.class, OneHotEncoderConverter.class);
        converters.put(PCAModel.class, PCAModelConverter.class);
        converters.put(RFormulaModel.class, RFormulaModelConverter.class);
        converters.put(StandardScalerModel.class, StandardScalerModelConverter.class);
        converters.put(StringIndexerModel.class, StringIndexerModelConverter.class);
        converters.put(VectorAssembler.class, VectorAssemblerConverter.class);
        converters.put(VectorAttributeRewriter.class, VectorAttributeRewriterConverter.class);
        converters.put(VectorIndexerModel.class, VectorIndexerModelConverter.class);
        converters.put(VectorSlicer.class, VectorSlicerConverter.class);

        // Models
        converters.put(DecisionTreeClassificationModel.class, DecisionTreeClassificationModelConverter.class);
        converters.put(DecisionTreeRegressionModel.class, DecisionTreeRegressionModelConverter.class);
        converters.put(GBTClassificationModel.class, GBTClassificationModelConverter.class);
        converters.put(GBTRegressionModel.class, GBTRegressionModelConverter.class);
        converters.put(GeneralizedLinearRegressionModel.class, GeneralizedLinearRegressionModelConverter.class);
        converters.put(KMeansModel.class, KMeansModelConverter.class);
        converters.put(LinearRegressionModel.class, LinearRegressionModelConverter.class);
        converters.put(LogisticRegressionModel.class, LogisticRegressionModelConverter.class);
        converters.put(MultilayerPerceptronClassificationModel.class,
                MultilayerPerceptronClassificationModelConverter.class);
        converters.put(RandomForestClassificationModel.class, RandomForestClassificationModelConverter.class);
        converters.put(RandomForestRegressionModel.class, RandomForestRegressionModelConverter.class);
    }
}