Java tutorial
/** * Copyright [2012-2014] eBay Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.dtrain; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.lang.Thread.UncaughtExceptionHandler; import java.net.URL; import java.net.URLDecoder; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Enumeration; import java.util.HashMap; import java.util.List; import java.util.Map; import ml.shifu.core.container.fieldMeta.Field; import ml.shifu.core.container.fieldMeta.FieldBasics.OpType; import ml.shifu.core.container.fieldMeta.FieldControl.UsageType; import ml.shifu.core.container.fieldMeta.FieldMeta; import ml.shifu.core.container.request.Params; import ml.shifu.core.container.request.ShifuRequest; import ml.shifu.core.di.spi.ShifuRequestProcessor; import ml.shifu.core.util.Constants; import ml.shifu.core.util.Environment; import ml.shifu.core.util.JSONUtils; import ml.shifu.dtrain.util.HDFSUtils; import ml.shifu.guagua.GuaguaConstants; import ml.shifu.guagua.mapreduce.GuaguaMapReduceClient; import ml.shifu.guagua.mapreduce.GuaguaMapReduceConstants; import ml.shifu.shifu.container.obj.ColumnBinning; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ColumnConfig.ColumnFlag; import ml.shifu.shifu.container.obj.ColumnConfig.ColumnType; import ml.shifu.shifu.container.obj.ColumnStats; import ml.shifu.shifu.container.obj.ModelBasicConf; import ml.shifu.shifu.container.obj.ModelBasicConf.RunMode; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.ModelSourceDataConf; import ml.shifu.shifu.container.obj.ModelTrainConf; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.train.TrainStep; import org.apache.commons.collections.ListUtils; import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream; import org.apache.commons.io.FileUtils; import org.apache.commons.lang.StringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IOUtils; import org.apache.zookeeper.ZooKeeper; import org.encog.ml.data.MLDataSet; import org.jboss.netty.bootstrap.ServerBootstrap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Splitter; /** * A processor to process distributed training request on neural network algorithm. * * <p> * {@link #runDistributedNNTrain(Params, String)} calls guagua client to invoke d-training on neural network. Before * that, all parameters should be prepared well in input json configurations files. * * <p> * Guagua and MapReduce parameters can be configured in input json file or shifuconfig env file. Configurations in json * will override configurations in shifuconfig. * * <p> * To call this request processor, only hadoop distributed mode are supported, suggest using our own bash train.sh when * you build a tar file by using 'mvn package'. * * <p> * Models are written into HDFS setting output folders. After training jobs finished, remote models in HDFS are copied * into local folder. * * <p> * 'shifu.dtrain.zkservers' is option, if zookeeper servers are not set, by default an embedded zookeeper server in * master process will be started. THe feature is started from guagua 0.6.0. * * <p> * Tree ensemble model is added by call open source Shifu API. By constructing ModelConfig and ColumnConfig in open * source Shifu, call GBT training smoothly. * * @author Zhang David (pengzhang@paypal.com) */ public class DTrainRequestProcessor implements ShifuRequestProcessor { private static final Logger LOG = LoggerFactory.getLogger(DTrainRequestProcessor.class); @Override public int exec(ShifuRequest req) throws Exception { Params params = req.getParams(); Configuration conf = new Configuration(); FileSystem fs = FileSystem.get(conf); String pathModelsHdfs = params.getString(DtrainConstants.SHIFU_DTRAIN_MODEL_OUTPUT); String pathModelsLocal = params.getString("pathModelsLocal", pathModelsHdfs); Path hdfsModels = new Path(pathModelsHdfs); // Path localModels = new Path(pathModelsLocal); if (fs.exists(hdfsModels)) { fs.delete(hdfsModels, true); } File localModels = new File(pathModelsLocal); if (localModels.exists()) { localModels.delete(); } String pathFieldMeta = params.getRequiredString("pathFieldMeta"); FieldMeta fieldMeta = JSONUtils.readValue(new File(pathFieldMeta), FieldMeta.class); params.put(DtrainConstants.SHIFU_DTRAIN_NN_INPUT_NODES, fieldMeta.getActiveFields().size()); LOG.info(" - # Active Fields: " + fieldMeta.getActiveFields().size()); params.put(DtrainConstants.SHIFU_DTRAIN_NN_OUTPUT_NODES, fieldMeta.getTargetFields().size()); LOG.info(" - # Target Fields: " + fieldMeta.getTargetFields().size()); String name = req.getName(); // if not set, by default it is nn model String algorithm = params.getString(DtrainConstants.SHIFU_DTRAIN_ALGORITHM, "nn"); if ("nn".equalsIgnoreCase(algorithm)) { validateNNParams(params); runDistributedNNTrain(params, name); copyModelsToLocal(pathModelsHdfs, pathModelsLocal); } else if ("gbt".equalsIgnoreCase(algorithm) || "rf".equalsIgnoreCase(algorithm)) { validateTreeParams(params); ModelConfig modelConfig = buildModelConfig(params, name, algorithm); List<ColumnConfig> columnConfigList = buildColumnConfig(fieldMeta); Map<String, Object> otherConfigs = buildOtherConfigs(params); TrainStep step = new TrainStep(modelConfig, columnConfigList, otherConfigs); step.process(); copyModelsToLocal(step.getPathFinder().getModelsPath(), pathModelsLocal); LOG.info("Local models can be found in {}", pathModelsLocal); } else { throw new IllegalArgumentException( "Only 'nn' or 'gbt' training algorithms are supported in 'shifu.dtrain.algorithm'."); } return 0; } /** * Build Shifu ModelConfig and */ private ModelConfig buildModelConfig(Params params, String name, String algorithm) { ModelConfig modelConfig = new ModelConfig(); ModelBasicConf basic = new ModelBasicConf(); basic.setName(name); basic.setRunMode(RunMode.DIST); basic.setVersion("2.0"); basic.setAuthor(System.getProperty("user.name")); modelConfig.setBasic(basic); ModelSourceDataConf dataSet = new ModelSourceDataConf(); dataSet.setDataPath(params.getString(DtrainConstants.SHIFU_DTRAIN_INPUT)); dataSet.setDataDelimiter(params.getString(DtrainConstants.SHIFU_DTRAIN_INPUT_DELIMETER)); dataSet.setSource(SourceType.HDFS); dataSet.setTargetColumnName(params.getString(DtrainConstants.SHIFU_DTRAIN_TARGET_COLUMN_NAME)); dataSet.setPosTags(getTagList(params.getString(DtrainConstants.SHIFU_DTRAIN_POSITIVE_TAGS))); dataSet.setNegTags(getTagList(params.getString(DtrainConstants.SHIFU_DTRAIN_NEGATIVE_TAGS))); modelConfig.setDataSet(dataSet); ModelTrainConf trainConf = new ModelTrainConf(); trainConf.setAlgorithm(algorithm); trainConf.setBaggingNum(Integer.parseInt(params.getString(DtrainConstants.SHIFU_DTRAIN_BAGGING_NUM, "1"))); trainConf.setBaggingSampleRate( Double.parseDouble(params.getString(DtrainConstants.SHIFU_DTRAIN_BAGGING_SAMPLE_RATE))); trainConf.setBaggingWithReplacement( Boolean.parseBoolean(params.getString(DtrainConstants.SHIFU_DTRAIN_IS_BAGGING_WITH_REPLACEMENT))); trainConf.setValidSetRate( Double.parseDouble(params.getString(DtrainConstants.SHIFU_DTRAIN_CROSS_VALIDATION_RATE))); trainConf.setNumTrainEpochs(Integer.parseInt(params.get(DtrainConstants.SHIFU_DTRAIN_EPOCH).toString())); // TODO hard code trainConf.setIsContinuous(false); trainConf.setWorkerThreadCount(4); Map<String, Object> trainParams = new HashMap<String, Object>(); trainParams.put("TreeNum", Integer.parseInt(params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_TREENUM))); trainParams.put("FeatureSubsetStrategy", params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_FEATURESUBSETSTRATEGY)); trainParams.put("MaxDepth", Integer.parseInt(params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_MAXDEPTH))); trainParams.put("Impurity", params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_IMPURITY)); trainParams.put("LearningRate", Double.parseDouble(params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_LEARNINGRATE))); trainParams.put("MinInstancesPerNode", Integer.parseInt(params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_MININSTANCESPERNODE))); trainParams.put("MinInfoGain", Double.parseDouble(params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_MININFOGAIN))); trainParams.put("Loss", params.getString(DtrainConstants.SHIFU_DTRAIN_TREE_LOSS)); trainConf.setParams(trainParams); modelConfig.setTrain(trainConf); return modelConfig; } private List<String> getTagList(String str) { List<String> list = new ArrayList<String>(); for (String tag : str.split(",")) { list.add(tag.trim()); } return list; } /** * Other configs like pig scripts folder and jar folder */ private Map<String, Object> buildOtherConfigs(Params params) { return new HashMap<String, Object>(); } /** * Build ColumnConfig list from FieldMeta. */ private List<ColumnConfig> buildColumnConfig(FieldMeta fieldMeta) { List<ColumnConfig> columnConfigList = new ArrayList<ColumnConfig>(); List<Field> fields = fieldMeta.getFields(); for (Field field : fields) { ColumnConfig columnConfig = new ColumnConfig(); columnConfig.setColumnNum(field.getFieldBasics().getNum() - 1); columnConfig.setColumnName(field.getFieldBasics().getName()); ColumnBinning binning = new ColumnBinning(); OpType opType = field.getFieldBasics().getOpType(); switch (opType) { case CATEGORICAL: columnConfig.setColumnType(ColumnType.C); if (field.getFieldControl().getUsageType() == UsageType.ACTIVE) { List<String> categories = field.getFieldStats().getDiscreteStats().getCategories(); // first one is missing value category, remove it List<String> subList = categories.subList(1, categories.size()); binning.setBinCategory(subList); binning.setLength(subList.size() + 1); } break; case CONTINUOUS: default: columnConfig.setColumnType(ColumnType.N); if (field.getFieldControl().getUsageType() == UsageType.ACTIVE) { List<Double> binBoundaries = field.getFieldStats().getContinuousStats().getBinBoundaries(); // first one is missing value bin, no need set it binning.setBinBoundary(binBoundaries.subList(1, binBoundaries.size())); binning.setLength(binning.getBinBoundary().size() + 1); } break; } columnConfig.setColumnFlag( field.getFieldControl().getUsageType() == UsageType.TARGET ? ColumnFlag.Target : null); columnConfig.setFinalSelect(field.getFieldControl().getUsageType() == UsageType.ACTIVE); columnConfig.setColumnBinning(binning); ColumnStats stats = new ColumnStats(); if (field.getFieldControl().getUsageType() == UsageType.ACTIVE) { stats.setMean(field.getFieldStats().getContinuousStats().getMean()); stats.setMean(stats.getMean() == null ? 0d : stats.getMean()); } else { stats.setMean(0d); } columnConfig.setColumnStats(stats); columnConfigList.add(columnConfig); } return columnConfigList; } private void copyModelsToLocal(String fromHdfs, String toLocal) throws IOException { FileSystem fs = FileSystem.get(new Configuration()); Path hdfsModels = new Path(fromHdfs); File localModels = new File(toLocal); // delete recursiveley FileUtils.deleteQuietly(localModels); if (fs.exists(hdfsModels)) { fs.copyToLocalFile(hdfsModels, new Path(toLocal)); LOG.info("Copying models to local: " + toLocal); } else { LOG.error("Models not found on HDFS: " + toLocal); } } private void validateTreeParams(Params params) { validateInt(params, DtrainConstants.SHIFU_DTRAIN_TREE_TREENUM); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_TREENUM); // TODO validate SHIFU_DTRAIN_TREE_FEATURESUBSETSTRATEGY in list of choices validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_FEATURESUBSETSTRATEGY); validateInt(params, DtrainConstants.SHIFU_DTRAIN_TREE_MAXDEPTH); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_MAXDEPTH); // TODO validate SHIFU_DTRAIN_TREE_IMPURITY in list of choices validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_IMPURITY); validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_TREE_LEARNINGRATE, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_LEARNINGRATE); validateInt(params, DtrainConstants.SHIFU_DTRAIN_TREE_MININSTANCESPERNODE); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_MININSTANCESPERNODE); validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_TREE_MININFOGAIN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_MININFOGAIN); // TODO validate SHIFU_DTRAIN_TREE_LOSS in list of choices validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TREE_LOSS); // TODO validate SHIFU_DTRAIN_TARGET_COLUMN_NAME in list of choices validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TARGET_COLUMN_NAME); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_NEGATIVE_TAGS); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_POSITIVE_TAGS); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_EPOCH); validateInt(params, DtrainConstants.SHIFU_DTRAIN_EPOCH); validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_IS_TRAIN_ON_DISK); validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_IS_BAGGING_WITH_REPLACEMENT); validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_IS_FIX_INITIAL_INPUT); validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_PARALLEL); validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_CROSS_VALIDATION_RATE, 0.0d, 1.0d); validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_BAGGING_SAMPLE_RATE, 0.0d, 1.0d); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_INPUT); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_INPUT_DELIMETER); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_ALGORITHM); } private void validateNNParams(Params params) { validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_IS_TRAIN_ON_DISK); validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_IS_BAGGING_WITH_REPLACEMENT); validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_IS_FIX_INITIAL_INPUT); validateBoolean(params, DtrainConstants.SHIFU_DTRAIN_PARALLEL); validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_CROSS_VALIDATION_RATE, 0.0d, 1.0d); validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_BAGGING_SAMPLE_RATE, 0.0d, 1.0d); validateDoubleAndRange(params, DtrainConstants.SHIFU_DTRAIN_NN_LEARNING_RATE, 0.0d, Double.MAX_VALUE); validateInt(params, DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_LAYERS); validateInt(params, DtrainConstants.SHIFU_DTRAIN_NN_INPUT_NODES); validateInt(params, DtrainConstants.SHIFU_DTRAIN_NN_OUTPUT_NODES); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_EPOCH); validateInt(params, DtrainConstants.SHIFU_DTRAIN_EPOCH); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_BAGGING_NUM); validateInt(params, DtrainConstants.SHIFU_DTRAIN_BAGGING_NUM); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_INPUT); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_MODEL_OUTPUT); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_TMP_MODEL_OUTPUT); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_NN_PROPAGATION); String propagation = params.get(DtrainConstants.SHIFU_DTRAIN_NN_PROPAGATION).toString(); if (!"Q".equalsIgnoreCase(propagation) && !"B".equalsIgnoreCase(propagation) && !"R".equalsIgnoreCase(propagation) && !"M".equalsIgnoreCase(propagation) && !"S".equalsIgnoreCase(propagation)) { throw new IllegalArgumentException( String.format("%s, should one of 'Q R M S B'", DtrainConstants.SHIFU_DTRAIN_NN_PROPAGATION)); } validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_NN_ACT_FUNCS); validateNotNull(params, DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_NODES); String actFuncs = params.get(DtrainConstants.SHIFU_DTRAIN_NN_ACT_FUNCS).toString(); String hiddenNodes = params.get(DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_NODES).toString(); if (actFuncs.split(",").length != hiddenNodes.split(",").length) { throw new IllegalArgumentException(String.format("%s and %s should have same length.", DtrainConstants.SHIFU_DTRAIN_NN_ACT_FUNCS, DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_NODES)); } } private void validateBoolean(Params params, String str) { Object object = params.get(str); if (object != null) { try { Boolean.parseBoolean(object.toString()); } catch (NumberFormatException e) { throw new IllegalArgumentException(String.format("%s should be a boolean field.", str)); } } } private void validateDoubleAndRange(Params params, String str, Double start, Double end) { Object object = params.get(str); if (object != null) { try { Double dValue = Double.valueOf(object.toString()); if (Double.compare(dValue, start) < 0 || Double.compare(dValue, end) > 0) { throw new IllegalArgumentException( String.format("%s should be in range (%s, %s).", str, start, end)); } } catch (NumberFormatException e) { throw new IllegalArgumentException(String.format("%s should be a double field.", str)); } } } private void validateInt(Params params, String str) { Object object = params.get(str); if (object != null) { try { Integer.parseInt(object.toString()); } catch (NumberFormatException e) { throw new IllegalArgumentException(String.format("%s should be a int field.", str)); } } } private void validateNotNull(Params params, String str) { Object object = params.get(str); if (object == null) { throw new IllegalArgumentException(String.format("%s should not be null.", str)); } } /** * Find a jar that contains a class of the same name, if any. It will return a jar file, even if that is not the * first thing on the class path that has a class with the same name. * * @param myClass * the class to find * @return a jar file that contains the class, or null */ @SuppressWarnings("rawtypes") private static String findContainingJar(Class myClass) { ClassLoader loader = myClass.getClassLoader(); String classFile = myClass.getName().replaceAll("\\.", "/") + ".class"; try { for (Enumeration itr = loader.getResources(classFile); itr.hasMoreElements();) { URL url = (URL) itr.nextElement(); if ("jar".equals(url.getProtocol())) { String toReturn = url.getPath(); if (toReturn.startsWith("file:")) { toReturn = toReturn.substring("file:".length()); } // URLDecoder is a misnamed class, since it actually decodes // x-www-form-urlencoded MIME type rather than actual // URL encoding (which the file path has). Therefore it // would // decode +s to ' 's which is incorrect (spaces are actually // either unencoded or encoded as "%20"). Replace +s first, // so // that they are kept sacred during the decoding process. toReturn = toReturn.replaceAll("\\+", "%2B"); toReturn = URLDecoder.decode(toReturn, "UTF-8"); return toReturn.replaceAll("!.*$", ""); } else if ("file".equals(url.getProtocol())) { String toReturn = url.getPath(); toReturn = toReturn.replaceAll("\\+", "%2B"); toReturn = URLDecoder.decode(toReturn, "UTF-8"); return toReturn.replaceAll("!.*$", ""); } } } catch (IOException e) { throw new RuntimeException(e); } return null; } /** * GuaguaOptionsParser doesn't to support *.jar currently. We need to find all jars we used in guagua MapReduce * jobs. */ private void addRuntimeJars(final List<String> args) { List<String> jars = new ArrayList<String>(16); // jackson-databind-*.jar jars.add(findContainingJar(ObjectMapper.class)); // jackson-core-*.jar jars.add(findContainingJar(JsonParser.class)); // jackson-annotations-*.jar jars.add(findContainingJar(JsonIgnore.class)); // commons-compress-*.jar jars.add(findContainingJar(BZip2CompressorInputStream.class)); // commons-lang-*.jar jars.add(findContainingJar(StringUtils.class)); // commons-collections-*.jar jars.add(findContainingJar(ListUtils.class)); // common-io-*.jar jars.add(findContainingJar(org.apache.commons.io.IOUtils.class)); // guava-*.jar jars.add(findContainingJar(Splitter.class)); // encog-core-*.jar jars.add(findContainingJar(MLDataSet.class)); // shifu-*.jar jars.add(findContainingJar(getClass())); // guagua-core-*.jar jars.add(findContainingJar(GuaguaConstants.class)); // guagua-mapreduce-*.jar jars.add(findContainingJar(GuaguaMapReduceConstants.class)); // zookeeper-*.jar jars.add(findContainingJar(ZooKeeper.class)); // netty-*.jar jars.add(findContainingJar(ServerBootstrap.class)); args.add(StringUtils.join(jars, DtrainConstants.LIB_JAR_SEPARATOR)); } /** * Reading common parameters and parameters from shifuconfig file. */ private void prepareCommonParams(final List<String> args, final Params params) { args.add("-libjars"); addRuntimeJars(args); args.add("-i"); args.add(HDFSUtils.getFS() .makeQualified(new Path(params.get(DtrainConstants.SHIFU_DTRAIN_INPUT).toString())).toString()); Object zkServersObj = params.get(DtrainConstants.SHIFU_DTRAIN_ZKSERVERS); String zkServers = zkServersObj == null ? "" : zkServersObj.toString(); if (StringUtils.isEmpty(zkServers)) { zkServers = Environment.getConfig(Environment.KEY_SHIFU_ZKSERVERS); if (StringUtils.isEmpty(zkServers)) { LOG.warn( "No specified zookeeper settings from zookeeperServers in shifuConfig file, Guagua will set " + "embedded zookeeper server in client process or master process."); } else { args.add("-z"); args.add(zkServers); } } else { args.add("-z"); args.add(zkServers); } args.add("-w"); args.add(NNWorker.class.getName()); args.add("-m"); args.add(NNMaster.class.getName()); args.add("-c"); // the reason to add 1 is that the first iteration in D-NN // implementation is used for training preparation. int numTrainEpochs = Integer.parseInt(params.get(DtrainConstants.SHIFU_DTRAIN_EPOCH).toString()) + 1; args.add(String.valueOf(numTrainEpochs)); args.add("-mr"); args.add(NNParams.class.getName()); args.add("-wr"); args.add(NNParams.class.getName()); // test // args.add("-Dguagua.worker.number=10"); // test args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.MAPRED_JOB_QUEUE_NAME, params.getString("queueName", Constants.DEFAULT_JOB_QUEUE))); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_MASTER_INTERCEPTERS, NNOutput.class.getName())); // hard code set computation threshold for 40s, set it in shifuconfig or // json can override this value. args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_COMPUTATION_TIME_THRESHOLD, 60 * 1000L)); setHeapSizeAndSplitSize(args); args.add("-Ddfs.mapred.max.split.size=268435456"); // special tuning parameters for shifu, 0.99 means each iteation master // wait for 99% workers and then can go to // next iteration. args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_MIN_WORKERS_RATIO, 0.99)); // 20 seconds if waiting over 20, consider 99% workers // these two can be overrided in shifuconfig args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_MIN_WORKERS_TIMEOUT, 20 * 1000L)); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_NODES, params.get(DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_NODES))); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_ACT_FUNCS, params.get(DtrainConstants.SHIFU_DTRAIN_NN_ACT_FUNCS))); args.add( String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_LAYERS, params.get(DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_LAYERS))); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_OUTPUT_NODES, params.get(DtrainConstants.SHIFU_DTRAIN_NN_OUTPUT_NODES))); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_INPUT_NODES, params.get(DtrainConstants.SHIFU_DTRAIN_NN_INPUT_NODES))); args.add( String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_LEARNING_RATE, params.get(DtrainConstants.SHIFU_DTRAIN_NN_LEARNING_RATE))); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_NN_PROPAGATION, params.get(DtrainConstants.SHIFU_DTRAIN_NN_PROPAGATION))); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_CROSS_VALIDATION_RATE, params.get(DtrainConstants.SHIFU_DTRAIN_CROSS_VALIDATION_RATE))); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_IS_BAGGING_WITH_REPLACEMENT, params.get(DtrainConstants.SHIFU_DTRAIN_IS_BAGGING_WITH_REPLACEMENT))); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_IS_FIX_INITIAL_INPUT, params.get(DtrainConstants.SHIFU_DTRAIN_IS_FIX_INITIAL_INPUT))); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_BAGGING_SAMPLE_RATE, params.get(DtrainConstants.SHIFU_DTRAIN_BAGGING_SAMPLE_RATE))); args.add( String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.SHIFU_DTRAIN_IS_TRAIN_ON_DISK, params.get(DtrainConstants.SHIFU_DTRAIN_IS_TRAIN_ON_DISK))); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_ZK_EMBEDBED_IS_IN_CLIENT, "false")); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.NN_DATA_DELIMITER, params.getString("delimiter", "|"))); // one can set guagua conf in shifuconfig for (Map.Entry<Object, Object> entry : Environment.getConfig().entrySet()) { if (entry.getKey().toString().startsWith("nn") || entry.getKey().toString().startsWith("guagua") || entry.getKey().toString().startsWith("mapred")) { args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, entry.getKey().toString(), entry.getValue().toString())); } } // one can set guagua conf in current configuration fiels for (Map.Entry<String, Object> entry : params.entrySet()) { if (isHadoopConfigurationInjected(entry.getKey())) { args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, entry.getKey(), entry.getValue().toString())); } } } public static boolean isHadoopConfigurationInjected(String key) { return key.startsWith("nn") || key.startsWith("guagua") || key.startsWith("shifu") || key.startsWith("mapred") || key.startsWith("io") || key.startsWith("hadoop") || key.startsWith("yarn"); } private void setHeapSizeAndSplitSize(final List<String> args) { args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, GuaguaMapReduceConstants.MAPRED_CHILD_JAVA_OPTS, "-Xmn128m -Xms1G -Xmx1G")); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_SPLIT_COMBINABLE, Environment.getConfig(GuaguaConstants.GUAGUA_SPLIT_COMBINABLE, "true"))); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, GuaguaConstants.GUAGUA_SPLIT_MAX_COMBINED_SPLIT_SIZE, Environment.getConfig(GuaguaConstants.GUAGUA_SPLIT_MAX_COMBINED_SPLIT_SIZE, "536870912"))); } /** * Transforming json configuration parameters to guagua configurations and then run guagua jobs to invoke * distributed neural network training. */ protected void runDistributedNNTrain(Params params, String name) throws IOException, InterruptedException, ClassNotFoundException { LOG.info("Start distributed training."); final List<String> args = new ArrayList<String>(); prepareCommonParams(args, params); // add tmp models folder to config Path tmpModelsPath = HDFSUtils.getFS().makeQualified( new Path(params.get(DtrainConstants.SHIFU_DTRAIN_TMP_MODEL_OUTPUT, "tmp").toString())); args.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.NN_TMP_MODELS_FOLDER, tmpModelsPath.toString())); int baggingNum = Integer.parseInt(params.get(DtrainConstants.SHIFU_DTRAIN_BAGGING_NUM).toString()); long start = System.currentTimeMillis(); LOG.info("Distributed training with baggingNum: {}", baggingNum); boolean isParallel = params.getBoolean(DtrainConstants.SHIFU_DTRAIN_PARALLEL, true); GuaguaMapReduceClient guaguaClient = new GuaguaMapReduceClient(); List<String> progressLogList = new ArrayList<String>(baggingNum); for (int i = 0; i < baggingNum; i++) { List<String> localArgs = new ArrayList<String>(args); // set name for each bagging job. localArgs.add("-n"); localArgs.add(String.format("Shifu Master-Workers NN Iteration: %s id:%s", name, i + 1)); LOG.info("Start trainer with id: {}", i + 1); String modelName = getModelName(i + 1); Path modelPath = HDFSUtils.getFS().makeQualified( new Path(params.get(DtrainConstants.SHIFU_DTRAIN_MODEL_OUTPUT, "tmp").toString(), modelName)); localArgs.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.GUAGUA_NN_OUTPUT, modelPath.toString())); localArgs.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.NN_TRAINER_ID, String.valueOf(i + 1))); final String progressLogFile = getProgressLogFile(i + 1); progressLogList.add(progressLogFile); localArgs.add(String.format(DtrainConstants.MAPREDUCE_PARAM_FORMAT, DtrainConstants.NN_PROGRESS_FILE, progressLogFile)); if (isParallel) { guaguaClient.addJob(localArgs.toArray(new String[0])); } else { TailThread tailThread = startTailThread(new String[] { progressLogFile }); guaguaClient.createJob(localArgs.toArray(new String[0])).waitForCompletion(true); stopTailThread(tailThread); } } if (isParallel) { TailThread tailThread = startTailThread(progressLogList.toArray(new String[0])); guaguaClient.run(); stopTailThread(tailThread); } LOG.info("Distributed training finished in {}ms.", System.currentTimeMillis() - start); } private void stopTailThread(TailThread thread) throws IOException { thread.interrupt(); try { thread.join(DtrainConstants.DEFAULT_JOIN_TIME); } catch (InterruptedException e) { LOG.error("Thread stopped!", e); Thread.currentThread().interrupt(); } // delete progress file at last thread.deleteProgressFiles(); } private TailThread startTailThread(final String[] progressLog) { TailThread thread = new TailThread(progressLog); thread.setName("Tail Progress Thread"); thread.setDaemon(true); thread.setUncaughtExceptionHandler(new UncaughtExceptionHandler() { @Override public void uncaughtException(Thread t, Throwable e) { LOG.warn(String.format("Error in thread %s: %s", t.getName(), e.getMessage())); } }); thread.start(); return thread; } private String getProgressLogFile(int i) { return String.format("tmp/%s_%s.log", System.currentTimeMillis(), i); } /** * Get NN model name * * @param i * index for model name */ private static String getModelName(int i) { return String.format("model%s.nn", i); } /** * A thread used to tail progress log from hdfs log file. */ private static class TailThread extends Thread { private long[] offset; private String[] progressLogs; public TailThread(String[] progressLogs) { this.progressLogs = progressLogs; this.offset = new long[this.progressLogs.length]; for (String progressLog : progressLogs) { try { // delete it firstly, it will be updated from master HDFSUtils.getFS().delete(new Path(progressLog), true); } catch (IOException e) { LOG.error("Error in delete progressLog", e); } } } @Override public void run() { while (!Thread.currentThread().isInterrupted()) { for (int i = 0; i < this.progressLogs.length; i++) { try { this.offset[i] = dumpFromOffset(new Path(this.progressLogs[i]), this.offset[i]); } catch (FileNotFoundException e) { // ignore because of not created in worker. } catch (IOException e) { LOG.warn(String.format("Error in dump progress log %s: %s", getName(), e.getMessage())); } } try { Thread.sleep(2000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } LOG.debug("DEBUG: Exit from tail thread."); } private long dumpFromOffset(Path item, long offset) throws IOException { FSDataInputStream in = HDFSUtils.getFS().open(item); ByteArrayOutputStream out = null; DataOutputStream dataOut = null; try { out = new ByteArrayOutputStream(); dataOut = new DataOutputStream(out); in.seek(offset); // use conf so the system configured io block size is used IOUtils.copyBytes(in, out, HDFSUtils.getFS().getConf(), false); String msgs = new String(out.toByteArray(), Charset.forName("UTF-8")).trim(); if (StringUtils.isNotEmpty(msgs)) { for (String msg : Splitter.on('\n').split(msgs)) { LOG.info(msg.trim()); } } offset = in.getPos(); } catch (IOException e) { if (!e.getMessage().contains("Cannot seek after EOF")) { throw e; } else { LOG.warn(e.getMessage()); } } finally { IOUtils.closeStream(in); IOUtils.closeStream(dataOut); } return offset; } public void deleteProgressFiles() throws IOException { for (String progressFile : this.progressLogs) { HDFSUtils.getFS().delete(new Path(progressFile), true); } } } }