ml.shifu.dtrain.NNTest.java Source code

Java tutorial

Introduction

Here is the source code for ml.shifu.dtrain.NNTest.java

Source

/*
 * Copyright [2013-2015] eBay Software Foundation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ml.shifu.dtrain;

import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.io.FileUtils;
import org.encog.neural.networks.BasicNetwork;
import org.encog.persist.EncogDirectoryPersistence;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.AfterTest;
import org.testng.annotations.Test;

import ml.shifu.dtrain.NNMaster;
import ml.shifu.dtrain.NNWorker;
import ml.shifu.dtrain.NNParams;
import ml.shifu.guagua.GuaguaConstants;
import ml.shifu.guagua.hadoop.GuaguaMRUnitDriver;
import ml.shifu.guagua.unit.GuaguaUnitDriver;

/**
 * Class for NN trianer test.
 * @author xiaobzheng (zheng.xiaobin.roubao@gmail.com)
 *
 */
public class NNTest {

    private static final Logger LOG = LoggerFactory.getLogger(NNTest.class);

    private static final String NN_TEST = System.getProperty("user.dir") + File.separator + "nn_test_tmp";
    private static final String OUTPUT = NN_TEST + File.separator + "nn_test_output";
    private static final String PROGRESS_FILE_STRING = NN_TEST + File.separator + "nn_test_progress";
    private static final String TMP_MODELS_FOLDER = NN_TEST + File.separator + "nn_tmp_models";

    @Test
    public void testNNApp() throws IOException {
        Properties props = new Properties();

        LOG.info("Set property for Guagua driver");
        props.setProperty(GuaguaConstants.MASTER_COMPUTABLE_CLASS, NNMaster.class.getName());
        props.setProperty(GuaguaConstants.WORKER_COMPUTABLE_CLASS, NNWorker.class.getName());
        props.setProperty(GuaguaConstants.GUAGUA_ITERATION_COUNT, "30");
        props.setProperty(GuaguaConstants.GUAGUA_MASTER_RESULT_CLASS, NNParams.class.getName());
        props.setProperty(GuaguaConstants.GUAGUA_WORKER_RESULT_CLASS, NNParams.class.getName());
        props.setProperty(GuaguaConstants.GUAGUA_INPUT_DIR,
                getClass().getResource("/data/wdbc/wdbc.normalized").toString());
        props.setProperty(GuaguaConstants.GUAGUA_MASTER_INTERCEPTERS, NNOutput.class.getName());
        props.setProperty(DtrainConstants.GUAGUA_NN_OUTPUT, OUTPUT);
        props.setProperty(DtrainConstants.NN_PROGRESS_FILE, PROGRESS_FILE_STRING);
        props.setProperty(DtrainConstants.NN_TRAINER_ID, "#1");
        props.setProperty(DtrainConstants.NN_TMP_MODELS_FOLDER, TMP_MODELS_FOLDER);

        // Since many parameter setting in NNMaster/NNWorker lack default value. So we have to 
        // specify all these parameters, or the master/worker won't work properly. Settings
        // below contains all indispensable and optional parameters.
        LOG.info("Set property for NN trainer");
        props.setProperty(DtrainConstants.NN_DATA_DELIMITER, ",");
        props.setProperty(DtrainConstants.SHIFU_DTRAIN_BAGGING_NUM, "1");
        props.setProperty(DtrainConstants.SHIFU_DTRAIN_IS_TRAIN_ON_DISK, "false");
        props.setProperty(DtrainConstants.SHIFU_DTRAIN_BAGGING_SAMPLE_RATE, "1.0");
        props.setProperty(DtrainConstants.SHIFU_DTRAIN_CROSS_VALIDATION_RATE, "0.2");
        props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_INPUT_NODES, "30");
        props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_OUTPUT_NODES, "1");
        props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_LAYERS, "2");
        props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_HIDDEN_NODES, "30,20");
        props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_ACT_FUNCS,
                DtrainConstants.NN_SIGMOID + "," + DtrainConstants.NN_SIGMOID);
        props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_PROPAGATION, NNUtils.QUICK_PROPAGATION);
        props.setProperty(DtrainConstants.SHIFU_DTRAIN_NN_LEARNING_RATE, "0.2");
        props.setProperty(DtrainConstants.SHIFU_DTRAIN_PARALLEL, "true");

        GuaguaUnitDriver<NNParams, NNParams> driver = new GuaguaMRUnitDriver<NNParams, NNParams>(props);
        driver.run();

        // Check output files exist.
        File finalModel = new File(NN_TEST);
        Assert.assertTrue(finalModel.exists());

        File progressFile = new File(PROGRESS_FILE_STRING);
        Assert.assertTrue(progressFile.exists());

        // Check final output error less than threshold.
        List<String> errorList = FileUtils.readLines(progressFile);
        String errorLine = errorList.get(errorList.size() - 1);

        Pattern resultPattern = Pattern
                .compile("Train\\s+Error:(\\d+\\.\\d+)\\s+Validation\\s+Error:(\\d+\\.\\d+)");
        Matcher errorMatcher = resultPattern.matcher(errorLine);
        Assert.assertTrue(errorMatcher.find());

        double trainErr = Double.parseDouble(errorMatcher.group(1));
        double testErr = Double.parseDouble(errorMatcher.group(2));
        double threshold = 0.2;
        Assert.assertTrue(Double.compare((trainErr + testErr) / 2, threshold) <= 0);

        // Check final model.
        // Here only simply check the output weight size.
        BasicNetwork model = (BasicNetwork) EncogDirectoryPersistence.loadObject(new File(OUTPUT));
        Assert.assertEquals(model.getFlat().getWeights().length, 31 * 30 + 31 * 20 + 21);
    }

    @AfterTest
    public void clearUp() throws IOException {
        FileUtils.deleteQuietly(new File(NN_TEST));
    }

}