Java tutorial
/* * Copyright [2013-2015] eBay Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.dtrain; import java.io.File; import java.io.IOException; import java.util.List; import java.util.Properties; import java.util.regex.Matcher; import java.util.regex.Pattern; import org.apache.commons.io.FileUtils; import org.encog.neural.networks.BasicNetwork; import org.encog.persist.EncogDirectoryPersistence; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; import org.testng.annotations.AfterTest; import org.testng.annotations.Test; import ml.shifu.dtrain.NNMaster; import ml.shifu.dtrain.NNWorker; import ml.shifu.dtrain.NNParams; import ml.shifu.guagua.GuaguaConstants; import ml.shifu.guagua.hadoop.GuaguaMRUnitDriver; import ml.shifu.guagua.unit.GuaguaUnitDriver; /** * Class for NN trianer test. * @author xiaobzheng (zheng.xiaobin.roubao@gmail.com) * */ public class NNTest { private static final Logger LOG = LoggerFactory.getLogger(NNTest.class); private static final String NN_TEST = System.getProperty("user.dir") + File.separator + "nn_test_tmp"; private static final String OUTPUT = NN_TEST + File.separator + "nn_test_output"; private static final String PROGRESS_FILE_STRING = NN_TEST + File.separator + "nn_test_progress"; private static final String TMP_MODELS_FOLDER = NN_TEST + File.separator + "nn_tmp_models"; @Test public void testNNApp() throws IOException { Properties props = new Properties(); LOG.info("Set property for Guagua driver"); props.setProperty(GuaguaConstants.MASTER_COMPUTABLE_CLASS, NNMaster.class.getName()); props.setProperty(GuaguaConstants.WORKER_COMPUTABLE_CLASS, NNWorker.class.getName()); props.setProperty(GuaguaConstants.GUAGUA_ITERATION_COUNT, "30"); props.setProperty(GuaguaConstants.GUAGUA_MASTER_RESULT_CLASS, NNParams.class.getName()); props.setProperty(GuaguaConstants.GUAGUA_WORKER_RESULT_CLASS, NNParams.class.getName()); props.setProperty(GuaguaConstants.GUAGUA_INPUT_DIR, getClass().getResource("/data/wdbc/wdbc.normalized").toString()); props.setProperty(GuaguaConstants.GUAGUA_MASTER_INTERCEPTERS, NNOutput.class.getName()); props.setProperty(DtrainConstants.GUAGUA_NN_OUTPUT, OUTPUT); props.setProperty(DtrainConstants.NN_PROGRESS_FILE, PROGRESS_FILE_STRING); props.setProperty(DtrainConstants.NN_TRAINER_ID, "#1"); props.setProperty(DtrainConstants.NN_TMP_MODELS_FOLDER, TMP_MODELS_FOLDER); // Since many parameter setting in NNMaster/NNWorker lack default value. So we have to // specify all these parameters, or the master/worker won't work properly. Settings // below contains all indispensable and optional parameters. LOG.info("Set property for NN trainer"); props.setProperty(DtrainConstants.NN_DATA_DELIMITER, ","); props.setProperty(DtrainConstants.SHIFU_DTRAIN_BAGGING_NUM, "1"); props.setProperty(DtrainConstants.SHIFU_DTRAIN_IS_TRAIN_ON_DISK, "false"); props.setProperty(DtrainConstants.SHIFU_DTRAIN_BAGGING_SAMPLE_RATE, "1.0"); props.setProperty(DtrainConstants.SHIFU_DTRAIN_CROSS_VALIDATION_RATE, "0.2"); props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_INPUT_NODES, "30"); props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_OUTPUT_NODES, "1"); props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_LAYERS, "2"); props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_NODES, "30,20"); props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_ACT_FUNCS, DtrainConstants.NN_SIGMOID + "," + DtrainConstants.NN_SIGMOID); props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_PROPAGATION, NNUtils.QUICK_PROPAGATION); props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_LEARNING_RATE, "0.2"); props.setProperty(DtrainConstants.SHIFU_DTRAIN_PARALLEL, "true"); GuaguaUnitDriver<NNParams, NNParams> driver = new GuaguaMRUnitDriver<NNParams, NNParams>(props); driver.run(); // Check output files exist. File finalModel = new File(NN_TEST); Assert.assertTrue(finalModel.exists()); File progressFile = new File(PROGRESS_FILE_STRING); Assert.assertTrue(progressFile.exists()); // Check final output error less than threshold. List<String> errorList = FileUtils.readLines(progressFile); String errorLine = errorList.get(errorList.size() - 1); Pattern resultPattern = Pattern .compile("Train\\s+Error:(\\d+\\.\\d+)\\s+Validation\\s+Error:(\\d+\\.\\d+)"); Matcher errorMatcher = resultPattern.matcher(errorLine); Assert.assertTrue(errorMatcher.find()); double trainErr = Double.parseDouble(errorMatcher.group(1)); double testErr = Double.parseDouble(errorMatcher.group(2)); double threshold = 0.2; Assert.assertTrue(Double.compare((trainErr + testErr) / 2, threshold) <= 0); // Check final model. // Here only simply check the output weight size. BasicNetwork model = (BasicNetwork) EncogDirectoryPersistence.loadObject(new File(OUTPUT)); Assert.assertEquals(model.getFlat().getWeights().length, 31 * 30 + 31 * 20 + 21); } @AfterTest public void clearUp() throws IOException { FileUtils.deleteQuietly(new File(NN_TEST)); } }