ml.shifu.shifu.core.pmml.PMMLVerifySuit.java Source code

Java tutorial

Introduction

Here is the source code for ml.shifu.shifu.core.pmml.PMMLVerifySuit.java

Source

/*
 * Copyright [2013-2015] PayPal Software Foundation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ml.shifu.shifu.core.pmml;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.neural_network.NeuralNetworkEvaluator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ml.shifu.shifu.ShifuCLI;
import ml.shifu.shifu.container.obj.ModelTrainConf;
import ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator;
import ml.shifu.shifu.core.pmml.builder.impl.NNSpecifCreator;
import ml.shifu.shifu.core.processor.ExportModelProcessor;

/**
 * Created by zhanhu on 7/15/16.
 */
public class PMMLVerifySuit {

    private static Logger logger = LoggerFactory.getLogger(PMMLVerifySuit.class);

    private String modelName;
    private String modelConfigPath;
    private String columnConfigpath;
    private String modelsPath;
    private ModelTrainConf.ALGORITHM algorithm = ModelTrainConf.ALGORITHM.NN;
    private int modelCnt;
    private String evalSetName;
    private String evalDataPath;
    private String delimiter;
    private double scoreDiff;
    private boolean isConcisePmml;

    public PMMLVerifySuit(String modelName, String modelConfigPath, String columnConfigpath, String modelsPath,
            int modelCnt, String evalSetName, String evalDataPath, String delimiter, double scoreDiff,
            boolean isConcisePmml) {
        this.modelName = modelName;
        this.modelConfigPath = modelConfigPath;
        this.columnConfigpath = columnConfigpath;
        this.modelsPath = modelsPath;
        this.modelCnt = modelCnt;
        this.evalSetName = evalSetName;
        this.evalDataPath = evalDataPath;
        this.delimiter = delimiter;
        this.scoreDiff = scoreDiff;
        this.isConcisePmml = isConcisePmml;
    }

    public PMMLVerifySuit(String modelName, String modelConfigPath, String columnConfigpath, String modelsPath,
            ModelTrainConf.ALGORITHM algorithm, int modelCnt, String evalSetName, String evalDataPath,
            String delimiter, double scoreDiff, boolean isConcisePmml) {
        this(modelName, modelConfigPath, columnConfigpath, modelsPath, modelCnt, evalSetName, evalDataPath,
                delimiter, scoreDiff, isConcisePmml);
        this.algorithm = algorithm;
    }

    public boolean doVerification() throws Exception {
        // Step 1. Eval the scores using SHIFU
        File originModel = new File(this.modelConfigPath);
        File tmpModel = new File("ModelConfig.json");

        File originColumn = new File(this.columnConfigpath);
        File tmpColumn = new File("ColumnConfig.json");

        File modelsDir = new File(this.modelsPath);
        File tmpModelsDir = new File("models");

        FileUtils.copyFile(originModel, tmpModel);
        FileUtils.copyFile(originColumn, tmpColumn);
        FileUtils.copyDirectory(modelsDir, tmpModelsDir);

        // run evaluation set
        ShifuCLI.runEvalScore(this.evalSetName, null);
        File evalScore = new File("evals" + File.separator + this.evalSetName + File.separator + "EvalScore");

        Map<String, Object> params = new HashMap<String, Object>();
        params.put(ExportModelProcessor.IS_CONCISE, this.isConcisePmml);
        ShifuCLI.exportModel(null, params);

        // Step 2. Eval the scores using PMML and compare it with SHIFU output

        String DataPath = this.evalDataPath;
        String OutPath = "./pmml_out.dat";
        for (int index = 0; index < modelCnt; index++) {
            String num = Integer.toString(index);
            String pmmlPath = "pmmls" + File.separator + this.modelName + num + ".pmml";

            if (ModelTrainConf.ALGORITHM.NN.equals(algorithm)) {
                evalNNPmml(pmmlPath, DataPath, OutPath, this.delimiter, "model" + num);
            } else if (ModelTrainConf.ALGORITHM.LR.equals(algorithm)) {
                evalLRPmml(pmmlPath, DataPath, OutPath, this.delimiter, "model" + num);
            } else {
                logger.error("The algorithm - {} is not supported yet.", algorithm);
                return false;
            }

            boolean status = compareScore(evalScore, new File(OutPath), "model" + num, "\\|", this.scoreDiff);
            if (!status) {
                return status;
            }

            FileUtils.deleteQuietly(new File(OutPath));
        }

        FileUtils.deleteQuietly(tmpModel);
        FileUtils.deleteQuietly(tmpColumn);
        FileUtils.deleteDirectory(tmpModelsDir);

        FileUtils.deleteQuietly(new File("./pmmls"));
        FileUtils.deleteQuietly(new File("evals"));

        return true;
    }

    private boolean compareScore(File test, File control, String scoreName, String sep, Double errorRange)
            throws Exception {
        List<String> testData = FileUtils.readLines(test);
        List<String> controlData = FileUtils.readLines(control);
        String[] testSchema = testData.get(0).trim().split(sep);
        String[] controlSchema = controlData.get(0).trim().split(sep);

        for (int row = 1; row < controlData.size(); row++) {
            Map<String, Object> ctx = new HashMap<String, Object>();
            Map<String, Object> controlCtx = new HashMap<String, Object>();

            String[] testRowValue = testData.get(row).split(sep, testSchema.length);
            for (int index = 0; index < testSchema.length; index++) {
                ctx.put(testSchema[index], testRowValue[index]);
            }
            String[] controlRowValue = controlData.get(row).split(sep, controlSchema.length);

            for (int index = 0; index < controlSchema.length; index++) {
                controlCtx.put(controlSchema[index], controlRowValue[index]);
            }

            for (String realScoreName : controlSchema) {
                Double controlScore = Double.valueOf((String) controlCtx.get(realScoreName));
                Double testScore = Double.valueOf((String) ctx.get(realScoreName));
                if (Math.abs(controlScore - testScore) > errorRange) {
                    System.out.println(
                            "Row " + row + " - The score doens't match " + controlScore + " vs " + testScore + ".");
                    return false;
                }
            }
        }

        return true;
    }

    @SuppressWarnings("unchecked")
    private void evalNNPmml(String pmmlPath, String DataPath, String OutPath, String sep, String scoreName)
            throws Exception {
        PMML pmml = PMMLUtils.loadPMML(pmmlPath);
        NeuralNetworkEvaluator evaluator = new NeuralNetworkEvaluator(pmml);

        PrintWriter writer = new PrintWriter(OutPath, "UTF-8");
        List<TargetField> targetFields = evaluator.getTargetFields();
        if (targetFields.size() == 1) {
            writer.println(scoreName);
        } else {
            for (int i = 0; i < targetFields.size(); i++) {
                if (i > 0) {
                    writer.print("|");
                }
                writer.print(scoreName + "_tag_" + i);
            }
            writer.println();
        }

        List<Map<FieldName, FieldValue>> input = CsvUtil.load(evaluator, DataPath, sep);

        for (Map<FieldName, FieldValue> maps : input) {
            switch (evaluator.getModel().getMiningFunction()) {
            case REGRESSION:
                if (targetFields.size() == 1) {
                    Map<FieldName, Double> regressionTerm = (Map<FieldName, Double>) evaluator.evaluate(maps);
                    writer.println(
                            regressionTerm.get(new FieldName(AbstractSpecifCreator.FINAL_RESULT)).intValue());
                } else {
                    Map<FieldName, Double> regressionTerm = (Map<FieldName, Double>) evaluator.evaluate(maps);
                    List<FieldName> outputFieldList = new ArrayList<FieldName>(regressionTerm.keySet());
                    Collections.sort(outputFieldList, new Comparator<FieldName>() {
                        @Override
                        public int compare(FieldName a, FieldName b) {
                            return a.getValue().compareTo(b.getValue());
                        }
                    });
                    int j = 0;
                    for (int i = 0; i < outputFieldList.size(); i++) {
                        FieldName fieldName = outputFieldList.get(i);
                        if (fieldName.getValue().startsWith(AbstractSpecifCreator.FINAL_RESULT)) {
                            if (j++ > 0) {
                                writer.print("|");
                            }
                            writer.print(regressionTerm.get(fieldName));
                        }
                    }
                    writer.println();
                }
                break;
            case CLASSIFICATION:
                Map<FieldName, Classification<Double>> classificationTerm = (Map<FieldName, Classification<Double>>) evaluator
                        .evaluate(maps);
                for (Classification<Double> cMap : classificationTerm.values())
                    for (Map.Entry<String, Value<Double>> entry : cMap.getValues().entrySet())
                        System.out.println(entry.getValue().getValue() * 1000);
                break;
            default:
                break;
            }
        }

        IOUtils.closeQuietly(writer);
    }

    @SuppressWarnings("unchecked")
    private void evalLRPmml(String pmmlPath, String DataPath, String OutPath, String sep, String scoreName)
            throws Exception {
        PMML pmml = PMMLUtils.loadPMML(pmmlPath);
        Model m = pmml.getModels().get(0);
        ModelEvaluator<?> evaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml, m);
        PrintWriter writer = new PrintWriter(OutPath, "UTF-8");
        writer.println(scoreName);
        List<Map<FieldName, FieldValue>> input = CsvUtil.load(evaluator, DataPath, sep);

        for (Map<FieldName, FieldValue> maps : input) {
            Map<FieldName, Double> regressionTerm = (Map<FieldName, Double>) evaluator.evaluate(maps);
            writer.println(regressionTerm.get(new FieldName(NNSpecifCreator.FINAL_RESULT)).intValue());
        }
        IOUtils.closeQuietly(writer);
    }
}