ml.shifu.dtrain.DTrainRequestProcessor.java Source code

Java tutorial

Introduction

Here is the source code for ml.shifu.dtrain.DTrainRequestProcessor.java

Source

/**
 * Copyright [2012-2014] eBay Software Foundation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ml.shifu.dtrain;

import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.Thread.UncaughtExceptionHandler;
import java.net.URL;
import java.net.URLDecoder;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import ml.shifu.core.container.fieldMeta.Field;
import ml.shifu.core.container.fieldMeta.FieldBasics.OpType;
import ml.shifu.core.container.fieldMeta.FieldControl.UsageType;
import ml.shifu.core.container.fieldMeta.FieldMeta;
import ml.shifu.core.container.request.Params;
import ml.shifu.core.container.request.ShifuRequest;
import ml.shifu.core.di.spi.ShifuRequestProcessor;
import ml.shifu.core.util.Constants;
import ml.shifu.core.util.Environment;
import ml.shifu.core.util.JSONUtils;
import ml.shifu.dtrain.util.HDFSUtils;
import ml.shifu.guagua.GuaguaConstants;
import ml.shifu.guagua.mapreduce.GuaguaMapReduceClient;
import ml.shifu.guagua.mapreduce.GuaguaMapReduceConstants;
import ml.shifu.shifu.container.obj.ColumnBinning;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ColumnConfig.ColumnFlag;
import ml.shifu.shifu.container.obj.ColumnConfig.ColumnType;
import ml.shifu.shifu.container.obj.ColumnStats;
import ml.shifu.shifu.container.obj.ModelBasicConf;
import ml.shifu.shifu.container.obj.ModelBasicConf.RunMode;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.ModelSourceDataConf;
import ml.shifu.shifu.container.obj.ModelTrainConf;
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
import ml.shifu.train.TrainStep;

import org.apache.commons.collections.ListUtils;
import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IOUtils;
import org.apache.zookeeper.ZooKeeper;
import org.encog.ml.data.MLDataSet;
import org.jboss.netty.bootstrap.ServerBootstrap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Splitter;

/**
 * A processor to process distributed training request on neural network algorithm.
 * 
 * <p>
 * {@link #runDistributedNNTrain(Params, String)} calls guagua client to invoke d-training on neural network. Before
 * that, all parameters should be prepared well in input json configurations files.
 * 
 * <p>
 * Guagua and MapReduce parameters can be configured in input json file or shifuconfig env file. Configurations in json
 * will override configurations in shifuconfig.
 * 
 * <p>
 * To call this request processor, only hadoop distributed mode are supported, suggest using our own bash train.sh when
 * you build a tar file by using 'mvn package'.
 * 
 * <p>
 * Models are written into HDFS setting output folders. After training jobs finished, remote models in HDFS are copied
 * into local folder.
 * 
 * <p>
 * 'shifu.dtrain.zkservers' is option, if zookeeper servers are not set, by default an embedded zookeeper server in
 * master process will be started. THe feature is started from guagua 0.6.0.
 * 
 * <p>
 * Tree ensemble model is added by call open source Shifu API. By constructing ModelConfig and ColumnConfig in open
 * source Shifu, call GBT training smoothly.
 * 
 * @author Zhang David (pengzhang@paypal.com)
 */
public class DTrainRequestProcessor implements ShifuRequestProcessor {

    private static final Logger LOG = LoggerFactory.getLogger(DTrainRequestProcessor.class);

    @Override
    public int exec(ShifuRequest req) throws Exception {
        Params params = req.getParams();

        Configuration conf = new Configuration();

        FileSystem fs = FileSystem.get(conf);
        String pathModelsHdfs = params.getString(DtrainConstants.SHIFU_DTRAIN_MODEL_OUTPUT);
        String pathModelsLocal = params.getString("pathModelsLocal", pathModelsHdfs);
        Path hdfsModels = new Path(pathModelsHdfs);
        // Path localModels = new Path(pathModelsLocal);

        if (fs.exists(hdfsModels)) {
            fs.delete(hdfsModels, true);
        }

        File localModels = new File(pathModelsLocal);
        if (localModels.exists()) {
            localModels.delete();
        }

        String pathFieldMeta = params.getRequiredString("pathFieldMeta");
        FieldMeta fieldMeta = JSONUtils.readValue(new File(pathFieldMeta), FieldMeta.class);
        params.put(DtrainConstants.SHIFU_DTRAIN_NN_INPUT_NODES, fieldMeta.getActiveFields().size());
        LOG.info("    - # Active Fields: " + fieldMeta.getActiveFields().size());

        params.put(DtrainConstants.SHIFU_DTRAIN_NN_OUTPUT_NODES, fieldMeta.getTargetFields().size());
        LOG.info("    - # Target Fields: " + fieldMeta.getTargetFields().size());

        String name = req.getName();

        // if not set, by default it is nn model
        String algorithm = params.getString(DtrainConstants.SHIFU_DTRAIN_ALGORITHM, "nn");

        if ("nn".equalsIgnoreCase(algorithm)) {
            validateNNParams(params);
            runDistributedNNTrain(params, name);
            copyModelsToLocal(pathModelsHdfs, pathModelsLocal);
        } else if ("gbt".equalsIgnoreCase(algorithm) || "rf".equalsIgnoreCase(algorithm)) {
            validateTreeParams(params);

            ModelConfig modelConfig = buildModelConfig(params, name, algorithm);
            List<ColumnConfig> columnConfigList = buildColumnConfig(fieldMeta);
            Map<String, Object> otherConfigs = buildOtherConfigs(params);
            TrainStep step = new TrainStep(modelConfig, columnConfigList, otherConfigs);
            step.process();

            copyModelsToLocal(step.getPathFinder().getModelsPath(), pathModelsLocal);
            LOG.info("Local models can be found in {}", pathModelsLocal);

        } else {
            throw new IllegalArgumentException(
                    "Only 'nn' or 'gbt' training algorithms are supported in 'shifu.dtrain.algorithm'.");
        }

        return 0;
    }

    /**
     * Build Shifu ModelConfig and
     */
    private ModelConfig buildModelConfig(Params params, String name, String algorithm) {
        ModelConfig modelConfig = new ModelConfig();

        ModelBasicConf basic = new ModelBasicConf();
        basic.setName(name);
        basic.setRunMode(RunMode.DIST);
        basic.setVersion("2.0");
        basic.setAuthor(System.getProperty("user.name"));
        modelConfig.setBasic(basic);

        ModelSourceDataConf dataSet = new ModelSourceDataConf();
        dataSet.setDataPath(params.getString(DtrainConstants.SHIFU_DTRAIN_INPUT));
        dataSet.setDataDelimiter(params.getString(DtrainConstants.SHIFU_DTRAIN_INPUT_DELIMETER));
        dataSet.setSource(SourceType.HDFS);
        dataSet.setTargetColumnName(params.getString(DtrainConstants.SHIFU_DTRAIN_TARGET_COLUMN_NAME));

        dataSet.setPosTags(getTagList(params.getString(DtrainConstants.SHIFU_DTRAIN_POSITIVE_TAGS)));
        dataSet.setNegTags(getTagList(params.getString(DtrainConstants.SHIFU_DTRAIN_NEGATIVE_TAGS)));
        modelConfig.setDataSet(dataSet);

        ModelTrainConf trainConf = new ModelTrainConf();
        trainConf.setAlgorithm(algorithm);
        trainConf.setBaggingNum(Integer.parseInt(params.getString(DtrainConstants.SHIFU_DTRAIN_BAGGING_NUM, "1")));
        trainConf.setBaggingSampleRate(
                Double.parseDouble(params.getString(DtrainConstants.SHIFU_DTRAIN_BAGGING_SAMPLE_RATE)));
        trainConf.setBaggingWithReplacement(
                Boolean.parseBoolean(params.getString(DtrainConstants.SHIFU_DTRAIN_IS_BAGGING_WITH_REPLACEMENT)));
        trainConf.setValidSetRate(
                Double.parseDouble(params.getString(DtrainConstants.SHIFU_DTRAIN_CROSS_VALIDATION_RATE)));

        trainConf.setNumTrainEpochs(Integer.parseInt(params.get(DtrainConstants.SHIFU_DTRAIN_EPOCH).toString()));
        // TODO hard code
        trainConf.setIsContinuous(false);
        trainConf.setWorkerThreadCount(4);

        Map<String, Object> trainParams = new HashMap<String, Object>();
        trainParams.put("TreeNum", Integer.parseInt(params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_TREENUM)));
        trainParams.put("FeatureSubsetStrategy",
                params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_FEATURESUBSETSTRATEGY));
        trainParams.put("MaxDepth", Integer.parseInt(params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_MAXDEPTH)));
        trainParams.put("Impurity", params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_IMPURITY));
        trainParams.put("LearningRate",
                Double.parseDouble(params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_LEARNINGRATE)));
        trainParams.put("MinInstancesPerNode",
                Integer.parseInt(params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_MININSTANCESPERNODE)));
        trainParams.put("MinInfoGain",
                Double.parseDouble(params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_MININFOGAIN)));
        trainParams.put("Loss", params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_LOSS));

        trainConf.setParams(trainParams);

        modelConfig.setTrain(trainConf);
        return modelConfig;
    }

    private List<String> getTagList(String str) {
        List<String> list = new ArrayList<String>();
        for (String tag : str.split(",")) {
            list.add(tag.trim());
        }
        return list;
    }

    /**
     * Other configs like pig scripts folder and jar folder
     */
    private Map<String, Object> buildOtherConfigs(Params params) {
        return new HashMap<String, Object>();
    }

    /**
     * Build ColumnConfig list from FieldMeta.
     */
    private List<ColumnConfig> buildColumnConfig(FieldMeta fieldMeta) {
        List<ColumnConfig> columnConfigList = new ArrayList<ColumnConfig>();
        List<Field> fields = fieldMeta.getFields();
        for (Field field : fields) {
            ColumnConfig columnConfig = new ColumnConfig();
            columnConfig.setColumnNum(field.getFieldBasics().getNum() - 1);
            columnConfig.setColumnName(field.getFieldBasics().getName());

            ColumnBinning binning = new ColumnBinning();
            OpType opType = field.getFieldBasics().getOpType();
            switch (opType) {
            case CATEGORICAL:
                columnConfig.setColumnType(ColumnType.C);
                if (field.getFieldControl().getUsageType() == UsageType.ACTIVE) {
                    List<String> categories = field.getFieldStats().getDiscreteStats().getCategories();
                    // first one is missing value category, remove it
                    List<String> subList = categories.subList(1, categories.size());
                    binning.setBinCategory(subList);
                    binning.setLength(subList.size() + 1);
                }

                break;
            case CONTINUOUS:
            default:
                columnConfig.setColumnType(ColumnType.N);
                if (field.getFieldControl().getUsageType() == UsageType.ACTIVE) {
                    List<Double> binBoundaries = field.getFieldStats().getContinuousStats().getBinBoundaries();
                    // first one is missing value bin, no need set it
                    binning.setBinBoundary(binBoundaries.subList(1, binBoundaries.size()));
                    binning.setLength(binning.getBinBoundary().size() + 1);
                }

                break;
            }

            columnConfig.setColumnFlag(
                    field.getFieldControl().getUsageType() == UsageType.TARGET ? ColumnFlag.Target : null);

            columnConfig.setFinalSelect(field.getFieldControl().getUsageType() == UsageType.ACTIVE);

            columnConfig.setColumnBinning(binning);

            ColumnStats stats = new ColumnStats();
            if (field.getFieldControl().getUsageType() == UsageType.ACTIVE) {
                stats.setMean(field.getFieldStats().getContinuousStats().getMean());
                stats.setMean(stats.getMean() == null ? 0d : stats.getMean());
            } else {
                stats.setMean(0d);
            }
            columnConfig.setColumnStats(stats);

            columnConfigList.add(columnConfig);
        }
        return columnConfigList;
    }

    private void copyModelsToLocal(String fromHdfs, String toLocal) throws IOException {
        FileSystem fs = FileSystem.get(new Configuration());

        Path hdfsModels = new Path(fromHdfs);

        File localModels = new File(toLocal);
        // delete recursiveley
        FileUtils.deleteQuietly(localModels);

        if (fs.exists(hdfsModels)) {
            fs.copyToLocalFile(hdfsModels, new Path(toLocal));
            LOG.info("Copying models to local: " + toLocal);
        } else {
            LOG.error("Models not found on HDFS: " + toLocal);
        }
    }

    private void validateTreeParams(Params params) {
        validateInt(params, DtrainConstants.SHIFU_DTRAIN_TREE_TREENUM);
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_TREENUM);

        // TODO validate SHIFU_DTRAIN_TREE_FEATURESUBSETSTRATEGY in list of choices
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_FEATURESUBSETSTRATEGY);

        validateInt(params, DtrainConstants.SHIFU_DTRAIN_TREE_MAXDEPTH);
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_MAXDEPTH);

        // TODO validate SHIFU_DTRAIN_TREE_IMPURITY in list of choices
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_IMPURITY);

        validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_TREE_LEARNINGRATE, Double.NEGATIVE_INFINITY,
                Double.POSITIVE_INFINITY);
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_LEARNINGRATE);

        validateInt(params, DtrainConstants.SHIFU_DTRAIN_TREE_MININSTANCESPERNODE);
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_MININSTANCESPERNODE);

        validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_TREE_MININFOGAIN, Double.NEGATIVE_INFINITY,
                Double.POSITIVE_INFINITY);
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_MININFOGAIN);

        // TODO validate SHIFU_DTRAIN_TREE_LOSS in list of choices
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_LOSS);

        // TODO validate SHIFU_DTRAIN_TARGET_COLUMN_NAME in list of choices
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TARGET_COLUMN_NAME);

        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_NEGATIVE_TAGS);
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_POSITIVE_TAGS);

        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_EPOCH);
        validateInt(params, DtrainConstants.SHIFU_DTRAIN_EPOCH);

        validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_IS_TRAIN_ON_DISK);
        validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_IS_BAGGING_WITH_REPLACEMENT);
        validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_IS_FIX_INITIAL_INPUT);
        validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_PARALLEL);

        validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_CROSS_VALIDATION_RATE, 0.0d, 1.0d);
        validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_BAGGING_SAMPLE_RATE, 0.0d, 1.0d);

        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_INPUT);
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_INPUT_DELIMETER);
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_ALGORITHM);
    }

    private void validateNNParams(Params params) {
        validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_IS_TRAIN_ON_DISK);
        validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_IS_BAGGING_WITH_REPLACEMENT);
        validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_IS_FIX_INITIAL_INPUT);
        validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_PARALLEL);

        validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_CROSS_VALIDATION_RATE, 0.0d, 1.0d);
        validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_BAGGING_SAMPLE_RATE, 0.0d, 1.0d);
        validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_NN_LEARNING_RATE, 0.0d, Double.MAX_VALUE);

        validateInt(params, DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_LAYERS);
        validateInt(params, DtrainConstants.SHIFU_DTRAIN_NN_INPUT_NODES);
        validateInt(params, DtrainConstants.SHIFU_DTRAIN_NN_OUTPUT_NODES);

        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_EPOCH);
        validateInt(params, DtrainConstants.SHIFU_DTRAIN_EPOCH);

        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_BAGGING_NUM);
        validateInt(params, DtrainConstants.SHIFU_DTRAIN_BAGGING_NUM);

        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_INPUT);

        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_MODEL_OUTPUT);

        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TMP_MODEL_OUTPUT);

        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_NN_PROPAGATION);
        String propagation = params.get(DtrainConstants.SHIFU_DTRAIN_NN_PROPAGATION).toString();
        if (!"Q".equalsIgnoreCase(propagation) && !"B".equalsIgnoreCase(propagation)
                && !"R".equalsIgnoreCase(propagation) && !"M".equalsIgnoreCase(propagation)
                && !"S".equalsIgnoreCase(propagation)) {
            throw new IllegalArgumentException(
                    String.format("%s, should one of 'Q R M S B'", DtrainConstants.SHIFU_DTRAIN_NN_PROPAGATION));
        }

        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_NN_ACT_FUNCS);
        validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_NODES);
        String actFuncs = params.get(DtrainConstants.SHIFU_DTRAIN_NN_ACT_FUNCS).toString();
        String hiddenNodes = params.get(DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_NODES).toString();
        if (actFuncs.split(",").length != hiddenNodes.split(",").length) {
            throw new IllegalArgumentException(String.format("%s and %s should have same length.",
                    DtrainConstants.SHIFU_DTRAIN_NN_ACT_FUNCS, DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_NODES));
        }
    }

    private void validateBoolean(Params params, String str) {
        Object object = params.get(str);
        if (object != null) {
            try {
                Boolean.parseBoolean(object.toString());
            } catch (NumberFormatException e) {
                throw new IllegalArgumentException(String.format("%s should be a boolean field.", str));
            }
        }
    }

    private void validateDoubleAndRange(Params params, String str, Double start, Double end) {
        Object object = params.get(str);
        if (object != null) {
            try {
                Double dValue = Double.valueOf(object.toString());
                if (Double.compare(dValue, start) < 0 || Double.compare(dValue, end) > 0) {
                    throw new IllegalArgumentException(
                            String.format("%s should be in range (%s, %s).", str, start, end));
                }
            } catch (NumberFormatException e) {
                throw new IllegalArgumentException(String.format("%s should be a double field.", str));
            }
        }
    }

    private void validateInt(Params params, String str) {
        Object object = params.get(str);
        if (object != null) {
            try {
                Integer.parseInt(object.toString());
            } catch (NumberFormatException e) {
                throw new IllegalArgumentException(String.format("%s should be a int field.", str));
            }
        }
    }

    private void validateNotNull(Params params, String str) {
        Object object = params.get(str);
        if (object == null) {
            throw new IllegalArgumentException(String.format("%s should not be null.", str));
        }
    }

    /**
     * Find a jar that contains a class of the same name, if any. It will return a jar file, even if that is not the
     * first thing on the class path that has a class with the same name.
     * 
     * @param myClass
     *            the class to find
     * @return a jar file that contains the class, or null
     */
    @SuppressWarnings("rawtypes")
    private static String findContainingJar(Class myClass) {
        ClassLoader loader = myClass.getClassLoader();
        String classFile = myClass.getName().replaceAll("\\.", "/") + ".class";
        try {
            for (Enumeration itr = loader.getResources(classFile); itr.hasMoreElements();) {
                URL url = (URL) itr.nextElement();
                if ("jar".equals(url.getProtocol())) {
                    String toReturn = url.getPath();
                    if (toReturn.startsWith("file:")) {
                        toReturn = toReturn.substring("file:".length());
                    }
                    // URLDecoder is a misnamed class, since it actually decodes
                    // x-www-form-urlencoded MIME type rather than actual
                    // URL encoding (which the file path has). Therefore it
                    // would
                    // decode +s to ' 's which is incorrect (spaces are actually
                    // either unencoded or encoded as "%20"). Replace +s first,
                    // so
                    // that they are kept sacred during the decoding process.
                    toReturn = toReturn.replaceAll("\\+", "%2B");
                    toReturn = URLDecoder.decode(toReturn, "UTF-8");
                    return toReturn.replaceAll("!.*$", "");
                } else if ("file".equals(url.getProtocol())) {
                    String toReturn = url.getPath();
                    toReturn = toReturn.replaceAll("\\+", "%2B");
                    toReturn = URLDecoder.decode(toReturn, "UTF-8");
                    return toReturn.replaceAll("!.*$", "");
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return null;
    }

    /**
     * GuaguaOptionsParser doesn't to support *.jar currently. We need to find all jars we used in guagua MapReduce
     * jobs.
     */
    private void addRuntimeJars(final List<String> args) {
        List<String> jars = new ArrayList<String>(16);
        // jackson-databind-*.jar
        jars.add(findContainingJar(ObjectMapper.class));
        // jackson-core-*.jar
        jars.add(findContainingJar(JsonParser.class));
        // jackson-annotations-*.jar
        jars.add(findContainingJar(JsonIgnore.class));
        // commons-compress-*.jar
        jars.add(findContainingJar(BZip2CompressorInputStream.class));
        // commons-lang-*.jar
        jars.add(findContainingJar(StringUtils.class));
        // commons-collections-*.jar
        jars.add(findContainingJar(ListUtils.class));
        // common-io-*.jar
        jars.add(findContainingJar(org.apache.commons.io.IOUtils.class));
        // guava-*.jar
        jars.add(findContainingJar(Splitter.class));
        // encog-core-*.jar
        jars.add(findContainingJar(MLDataSet.class));
        // shifu-*.jar
        jars.add(findContainingJar(getClass()));
        // guagua-core-*.jar
        jars.add(findContainingJar(GuaguaConstants.class));
        // guagua-mapreduce-*.jar
        jars.add(findContainingJar(GuaguaMapReduceConstants.class));
        // zookeeper-*.jar
        jars.add(findContainingJar(ZooKeeper.class));
        // netty-*.jar
        jars.add(findContainingJar(ServerBootstrap.class));

        args.add(StringUtils.join(jars, DtrainConstants.LIB_JAR_SEPARATOR));
    }

    /**
     * Reading common parameters and parameters from shifuconfig file.
     */
    private void prepareCommonParams(final List<String> args, final Params params) {
        args.add("-libjars");
        addRuntimeJars(args);

        args.add("-i");
        args.add(HDFSUtils.getFS()
                .makeQualified(new Path(params.get(DtrainConstants.SHIFU_DTRAIN_INPUT).toString())).toString());

        Object zkServersObj = params.get(DtrainConstants.SHIFU_DTRAIN_ZKSERVERS);
        String zkServers = zkServersObj == null ? "" : zkServersObj.toString();
        if (StringUtils.isEmpty(zkServers)) {
            zkServers = Environment.getConfig(Environment.KEY_SHIFU_ZKSERVERS);
            if (StringUtils.isEmpty(zkServers)) {
                LOG.warn(
                        "No specified zookeeper settings from zookeeperServers in shifuConfig file, Guagua will set "
                                + "embedded zookeeper server in client process or master process.");
            } else {
                args.add("-z");
                args.add(zkServers);
            }
        } else {
            args.add("-z");
            args.add(zkServers);
        }

        args.add("-w");
        args.add(NNWorker.class.getName());

        args.add("-m");
        args.add(NNMaster.class.getName());

        args.add("-c");
        // the reason to add 1 is that the first iteration in D-NN
        // implementation is used for training preparation.
        int numTrainEpochs = Integer.parseInt(params.get(DtrainConstants.SHIFU_DTRAIN_EPOCH).toString()) + 1;
        args.add(String.valueOf(numTrainEpochs));

        args.add("-mr");
        args.add(NNParams.class.getName());

        args.add("-wr");
        args.add(NNParams.class.getName());

        // test
        // args.add("-Dguagua.worker.number=10");
        // test

        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.MAPRED_JOB_QUEUE_NAME,
                params.getString("queueName", Constants.DEFAULT_JOB_QUEUE)));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_MASTER_INTERCEPTERS,
                NNOutput.class.getName()));
        // hard code set computation threshold for 40s, set it in shifuconfig or
        // json can override this value.
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT,
                GuaguaConstants.GUAGUA_COMPUTATION_TIME_THRESHOLD, 60 * 1000L));
        setHeapSizeAndSplitSize(args);
        args.add("-Ddfs.mapred.max.split.size=268435456");

        // special tuning parameters for shifu, 0.99 means each iteation master
        // wait for 99% workers and then can go to
        // next iteration.
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_MIN_WORKERS_RATIO,
                0.99));
        // 20 seconds if waiting over 20, consider 99% workers
        // these two can be overrided in shifuconfig
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_MIN_WORKERS_TIMEOUT,
                20 * 1000L));

        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_NODES,
                params.get(DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_NODES)));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_ACT_FUNCS,
                params.get(DtrainConstants.SHIFU_DTRAIN_NN_ACT_FUNCS)));
        args.add(
                String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_LAYERS,
                        params.get(DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_LAYERS)));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_OUTPUT_NODES,
                params.get(DtrainConstants.SHIFU_DTRAIN_NN_OUTPUT_NODES)));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_INPUT_NODES,
                params.get(DtrainConstants.SHIFU_DTRAIN_NN_INPUT_NODES)));
        args.add(
                String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_LEARNING_RATE,
                        params.get(DtrainConstants.SHIFU_DTRAIN_NN_LEARNING_RATE)));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_PROPAGATION,
                params.get(DtrainConstants.SHIFU_DTRAIN_NN_PROPAGATION)));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT,
                DtrainConstants.SHIFU_DTRAIN_CROSS_VALIDATION_RATE,
                params.get(DtrainConstants.SHIFU_DTRAIN_CROSS_VALIDATION_RATE)));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT,
                DtrainConstants.SHIFU_DTRAIN_IS_BAGGING_WITH_REPLACEMENT,
                params.get(DtrainConstants.SHIFU_DTRAIN_IS_BAGGING_WITH_REPLACEMENT)));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT,
                DtrainConstants.SHIFU_DTRAIN_IS_FIX_INITIAL_INPUT,
                params.get(DtrainConstants.SHIFU_DTRAIN_IS_FIX_INITIAL_INPUT)));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT,
                DtrainConstants.SHIFU_DTRAIN_BAGGING_SAMPLE_RATE,
                params.get(DtrainConstants.SHIFU_DTRAIN_BAGGING_SAMPLE_RATE)));
        args.add(
                String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_IS_TRAIN_ON_DISK,
                        params.get(DtrainConstants.SHIFU_DTRAIN_IS_TRAIN_ON_DISK)));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT,
                GuaguaConstants.GUAGUA_ZK_EMBEDBED_IS_IN_CLIENT, "false"));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.NN_DATA_DELIMITER,
                params.getString("delimiter", "|")));

        // one can set guagua conf in shifuconfig
        for (Map.Entry<Object, Object> entry : Environment.getConfig().entrySet()) {
            if (entry.getKey().toString().startsWith("nn") || entry.getKey().toString().startsWith("guagua")
                    || entry.getKey().toString().startsWith("mapred")) {
                args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, entry.getKey().toString(),
                        entry.getValue().toString()));
            }
        }

        // one can set guagua conf in current configuration fiels
        for (Map.Entry<String, Object> entry : params.entrySet()) {
            if (isHadoopConfigurationInjected(entry.getKey())) {
                args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, entry.getKey(),
                        entry.getValue().toString()));
            }
        }
    }

    public static boolean isHadoopConfigurationInjected(String key) {
        return key.startsWith("nn") || key.startsWith("guagua") || key.startsWith("shifu")
                || key.startsWith("mapred") || key.startsWith("io") || key.startsWith("hadoop")
                || key.startsWith("yarn");
    }

    private void setHeapSizeAndSplitSize(final List<String> args) {
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT,
                GuaguaMapReduceConstants.MAPRED_CHILD_JAVA_OPTS, "-Xmn128m -Xms1G -Xmx1G"));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_SPLIT_COMBINABLE,
                Environment.getConfig(GuaguaConstants.GUAGUA_SPLIT_COMBINABLE, "true")));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT,
                GuaguaConstants.GUAGUA_SPLIT_MAX_COMBINED_SPLIT_SIZE,
                Environment.getConfig(GuaguaConstants.GUAGUA_SPLIT_MAX_COMBINED_SPLIT_SIZE, "536870912")));
    }

    /**
     * Transforming json configuration parameters to guagua configurations and then run guagua jobs to invoke
     * distributed neural network training.
     */
    protected void runDistributedNNTrain(Params params, String name)
            throws IOException, InterruptedException, ClassNotFoundException {
        LOG.info("Start distributed training.");

        final List<String> args = new ArrayList<String>();

        prepareCommonParams(args, params);

        // add tmp models folder to config
        Path tmpModelsPath = HDFSUtils.getFS().makeQualified(
                new Path(params.get(DtrainConstants.SHIFU_DTRAIN_TMP_MODEL_OUTPUT, "tmp").toString()));
        args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.NN_TMP_MODELS_FOLDER,
                tmpModelsPath.toString()));
        int baggingNum = Integer.parseInt(params.get(DtrainConstants.SHIFU_DTRAIN_BAGGING_NUM).toString());

        long start = System.currentTimeMillis();
        LOG.info("Distributed training with baggingNum: {}", baggingNum);
        boolean isParallel = params.getBoolean(DtrainConstants.SHIFU_DTRAIN_PARALLEL, true);
        GuaguaMapReduceClient guaguaClient = new GuaguaMapReduceClient();
        List<String> progressLogList = new ArrayList<String>(baggingNum);
        for (int i = 0; i < baggingNum; i++) {
            List<String> localArgs = new ArrayList<String>(args);
            // set name for each bagging job.
            localArgs.add("-n");
            localArgs.add(String.format("Shifu Master-Workers NN Iteration: %s id:%s", name, i + 1));
            LOG.info("Start trainer with id: {}", i + 1);
            String modelName = getModelName(i + 1);
            Path modelPath = HDFSUtils.getFS().makeQualified(
                    new Path(params.get(DtrainConstants.SHIFU_DTRAIN_MODEL_OUTPUT, "tmp").toString(), modelName));
            localArgs.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.GUAGUA_NN_OUTPUT,
                    modelPath.toString()));
            localArgs.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.NN_TRAINER_ID,
                    String.valueOf(i + 1)));
            final String progressLogFile = getProgressLogFile(i + 1);
            progressLogList.add(progressLogFile);
            localArgs.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.NN_PROGRESS_FILE,
                    progressLogFile));

            if (isParallel) {
                guaguaClient.addJob(localArgs.toArray(new String[0]));
            } else {
                TailThread tailThread = startTailThread(new String[] { progressLogFile });
                guaguaClient.createJob(localArgs.toArray(new String[0])).waitForCompletion(true);
                stopTailThread(tailThread);
            }
        }

        if (isParallel) {
            TailThread tailThread = startTailThread(progressLogList.toArray(new String[0]));
            guaguaClient.run();
            stopTailThread(tailThread);
        }

        LOG.info("Distributed training finished in {}ms.", System.currentTimeMillis() - start);
    }

    private void stopTailThread(TailThread thread) throws IOException {
        thread.interrupt();
        try {
            thread.join(DtrainConstants.DEFAULT_JOIN_TIME);
        } catch (InterruptedException e) {
            LOG.error("Thread stopped!", e);
            Thread.currentThread().interrupt();
        }
        // delete progress file at last
        thread.deleteProgressFiles();
    }

    private TailThread startTailThread(final String[] progressLog) {
        TailThread thread = new TailThread(progressLog);
        thread.setName("Tail Progress Thread");
        thread.setDaemon(true);
        thread.setUncaughtExceptionHandler(new UncaughtExceptionHandler() {
            @Override
            public void uncaughtException(Thread t, Throwable e) {
                LOG.warn(String.format("Error in thread %s: %s", t.getName(), e.getMessage()));
            }
        });
        thread.start();
        return thread;
    }

    private String getProgressLogFile(int i) {
        return String.format("tmp/%s_%s.log", System.currentTimeMillis(), i);
    }

    /**
     * Get NN model name
     * 
     * @param i
     *            index for model name
     */
    private static String getModelName(int i) {
        return String.format("model%s.nn", i);
    }

    /**
     * A thread used to tail progress log from hdfs log file.
     */
    private static class TailThread extends Thread {
        private long[] offset;
        private String[] progressLogs;

        public TailThread(String[] progressLogs) {
            this.progressLogs = progressLogs;
            this.offset = new long[this.progressLogs.length];
            for (String progressLog : progressLogs) {
                try {
                    // delete it firstly, it will be updated from master
                    HDFSUtils.getFS().delete(new Path(progressLog), true);
                } catch (IOException e) {
                    LOG.error("Error in delete progressLog", e);
                }
            }
        }

        @Override
        public void run() {
            while (!Thread.currentThread().isInterrupted()) {
                for (int i = 0; i < this.progressLogs.length; i++) {
                    try {
                        this.offset[i] = dumpFromOffset(new Path(this.progressLogs[i]), this.offset[i]);
                    } catch (FileNotFoundException e) {
                        // ignore because of not created in worker.
                    } catch (IOException e) {
                        LOG.warn(String.format("Error in dump progress log %s: %s", getName(), e.getMessage()));
                    }
                }

                try {
                    Thread.sleep(2000);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
            LOG.debug("DEBUG: Exit from tail thread.");
        }

        private long dumpFromOffset(Path item, long offset) throws IOException {
            FSDataInputStream in = HDFSUtils.getFS().open(item);
            ByteArrayOutputStream out = null;
            DataOutputStream dataOut = null;
            try {
                out = new ByteArrayOutputStream();
                dataOut = new DataOutputStream(out);
                in.seek(offset);
                // use conf so the system configured io block size is used
                IOUtils.copyBytes(in, out, HDFSUtils.getFS().getConf(), false);
                String msgs = new String(out.toByteArray(), Charset.forName("UTF-8")).trim();
                if (StringUtils.isNotEmpty(msgs)) {
                    for (String msg : Splitter.on('\n').split(msgs)) {
                        LOG.info(msg.trim());
                    }
                }
                offset = in.getPos();
            } catch (IOException e) {
                if (!e.getMessage().contains("Cannot seek after EOF")) {
                    throw e;
                } else {
                    LOG.warn(e.getMessage());
                }
            } finally {
                IOUtils.closeStream(in);
                IOUtils.closeStream(dataOut);
            }
            return offset;
        }

        public void deleteProgressFiles() throws IOException {
            for (String progressFile : this.progressLogs) {
                HDFSUtils.getFS().delete(new Path(progressFile), true);
            }
        }
    }

}