org.deeplearning4j.patent.LocalTraining.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.patent.LocalTraining.java

Source

/*******************************************************************************
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.patent;

import com.beust.jcommander.Parameter;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.patent.preprocessing.PatentLabelGenerator;
import org.deeplearning4j.patent.utils.JCommanderUtils;
import org.deeplearning4j.patent.utils.data.LoadDataSetsFunction;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.api.loader.Loader;
import org.nd4j.api.loader.LocalFileSourceFactory;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.nio.charset.Charset;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;

/**
 * This is a local (single machine, no Spark) implementation of the example.
 * It's main purpose is to provide a single machine performance comparison on the same network.
 * This allows for comparing performance for a single machine vs. cluster, in terms of "time to accuracy X"
 *
 * Before running this script, you will need to run the Spark preprocessing, and the train/test data (or a subset there-of)
 * to the local machine
 *
 * @author Alex Black
 */
public class LocalTraining {
    private static final Logger log = LoggerFactory.getLogger(LocalTraining.class);
    public static final int MILLISEC_PER_SEC = 1000;

    @Parameter(names = {
            "--outputPath" }, description = "Local output path/directory to write results to", required = true)
    private String outputPath;

    @Parameter(names = {
            "--dataDir" }, description = "Directory containing the training data - locally", required = true)
    private String dataDir;

    @Parameter(names = { "--w2vPath" }, description = "Path to the Word2Vec vectors - locally", required = true)
    private String w2vPath;

    @Parameter(names = { "--numEpochs" }, description = "Number of epochs for training")
    private int numEpochs = 1;

    @Parameter(names = { "--rngSeed" }, description = "Random number generator seed (for repeatability)")
    private int rngSeed = 12345;

    @Parameter(names = {
            "--totalExamplesTrain" }, description = "Total number of examples for training", required = true)
    private int totalExamplesTrain = -1;

    @Parameter(names = { "--totalExamplesTest" }, description = "Total number of examples for testing")
    private int totalExamplesTest = 5000;

    @Parameter(names = {
            "--saveConvergenceNets" }, description = "If true, save networks at each evaluation point during training")
    private boolean saveConvergenceNets = true;

    @Parameter(names = { "--batchSize" }, description = "Batch size for training")
    private int batchSize = 32;

    @Parameter(names = { "--listenerFrequency" }, description = "Listener Frequency")
    private int listenerFrequency = 10;

    @Parameter(names = {
            "--convergenceEvalFrequencyBatches" }, description = "Perform convergence evaluation every N minibatches total")
    private int convergenceEvalFrequencyBatches = 400;

    @Parameter(names = {
            "--maxTrainingTimeMin" }, description = "Maximum time for training, in minutes (or < 0 to use no limit)")
    private int maxTrainingTimeMin = -1;

    /**
     * Main function
     *
     * @param args
     * @throws Exception
     */
    public static void main(String... args) throws Exception {
        new LocalTraining().entryPoint(args);
    }

    /**
     * JCommander entry point
     *
     * @param args
     * @throws Exception
     */
    protected void entryPoint(String[] args) throws Exception {
        JCommanderUtils.parseArgs(this, args);
        File resultFile = new File(outputPath, "results.txt"); //Output will be written here
        Preconditions.checkArgument(convergenceEvalFrequencyBatches > 0,
                "convergenceEvalFrequencyBatches must be positive: got %s", convergenceEvalFrequencyBatches);

        Nd4j.getMemoryManager().setAutoGcWindow(15000);

        // Prepare neural net
        ComputationGraph net = new ComputationGraph(NetworkConfiguration.getConf());
        net.init();
        net.setListeners(new PerformanceListener(listenerFrequency, true));
        log.info("Parameters: {}", net.params().length());

        //Write configuration
        writeConfig();

        // Train neural net
        DateTimeFormatter dtf = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
        int subsetCount = 0;
        boolean firstSave = true;
        long overallStart = System.currentTimeMillis();
        boolean exit = false;
        for (int epoch = 0; epoch < numEpochs; epoch++) {
            // Training
            log.info("epoch {} training begin: {}", epoch + 1, dtf.format(LocalDateTime.now()));
            // Prepare training data. Note we'll get this again for each epoch in case we are using early termination iterator
            // plus randomization. This is to ensure consistency between epochs
            DataSetIterator trainData = getDataIterator(dataDir, true, totalExamplesTrain, batchSize, rngSeed);

            //For convergence purposes: want to split into subsets, time training on each subset, an evaluate
            while (trainData.hasNext()) {
                subsetCount++;
                log.info("Starting training: epoch {} of {}, subset {} ({} minibatches)", (epoch + 1), numEpochs,
                        subsetCount, convergenceEvalFrequencyBatches);
                DataSetIterator subset = new EarlyTerminationDataSetIterator(trainData,
                        convergenceEvalFrequencyBatches);
                int itersBefore = net.getIterationCount();
                long start = System.currentTimeMillis();
                net.fit(subset);
                long end = System.currentTimeMillis();
                int iterAfter = net.getIterationCount();

                //Save model
                if (saveConvergenceNets) {
                    String fileName = "net_" + System.currentTimeMillis() + "_epoch" + epoch + "_subset"
                            + subsetCount + ".zip";
                    String outpath = FilenameUtils.concat(outputPath, "nets/" + fileName);
                    File f = new File(outpath);
                    if (firstSave) {
                        firstSave = false;
                        f.getParentFile().mkdirs();
                    }
                    ModelSerializer.writeModel(net, f, true);
                    log.info("Saved network to {}", outpath);
                }

                DataSetIterator test = getDataIterator(dataDir, false, totalExamplesTrain, batchSize, rngSeed);
                long startEval = System.currentTimeMillis();
                IEvaluation[] evals = net.doEvaluation(test, new Evaluation(), new ROCMultiClass());
                long endEval = System.currentTimeMillis();

                StringBuilder sb = new StringBuilder();
                Evaluation e = (Evaluation) evals[0];
                ROCMultiClass r = (ROCMultiClass) evals[1];
                sb.append("epoch ").append(epoch + 1).append(" of ").append(numEpochs).append(" subset ")
                        .append(subsetCount).append(" subsetMiniBatches ").append(iterAfter - itersBefore) //Note: "end of epoch" effect - may be smaller than subset size
                        .append(" trainMS ").append(end - start).append(" evalMS ").append(endEval - startEval)
                        .append(" accuracy ").append(e.accuracy()).append(" f1 ").append(e.f1()).append(" AvgAUC ")
                        .append(r.calculateAverageAUC()).append(" AvgAUPRC ").append(r.calculateAverageAUCPR())
                        .append("\n");

                FileUtils.writeStringToFile(resultFile, sb.toString(), Charset.forName("UTF-8"), true); //Append new output to file
                saveEvaluation(false, evals);
                log.info("Evaluation: {}", sb.toString());

                if (maxTrainingTimeMin > 0
                        && (System.currentTimeMillis() - overallStart) / 60000 > maxTrainingTimeMin) {
                    log.info("Exceeded max training time of {} minutes - terminating", maxTrainingTimeMin);
                    exit = true;
                    break;
                }
            }
            if (exit)
                break;
        }

        File dir = new File(outputPath, "trainedModel.bin");
        net.save(dir, true);
    }

    private void writeConfig() throws Exception {
        long time = System.currentTimeMillis();

        StringBuilder sb = new StringBuilder();
        sb.append("Output Path: ").append(outputPath).append("\n").append("Time: ").append(time).append("\n")
                .append("RNG Seed: ").append(rngSeed).append("\n").append("Total Examples (Train): ")
                .append(totalExamplesTrain).append("\n").append("Total Examples (Test): ").append(totalExamplesTest)
                .append("\n").append("numEpoch: ").append(numEpochs).append("\n").append("BatchSize: ")
                .append(batchSize).append("\n").append("Listener Frequency: ").append(listenerFrequency)
                .append("\n").append("\n");

        String str = sb.toString();
        log.info(str);

        //Write to file:
        String toWrite = sb.toString();

        String path = FilenameUtils.concat(outputPath, "experimentConfig.txt");
        log.info("Writing experiment config and info to file: {}", path);
        FileUtils.writeStringToFile(new File(outputPath, "experimentConfig.txt"), toWrite,
                Charset.forName("UTF-8"));
    }

    private void saveEvaluation(boolean train, IEvaluation[] evaluations) throws Exception {
        String evalPath = FilenameUtils.concat(outputPath, ("evaluation_" + (train ? "train" : "test")));
        //Write evaluations to disk
        for (int i = 0; i < evaluations.length; i++) {
            String path = FilenameUtils.concat(evalPath, "evaluation_" + i + ".txt");
            FileUtils.writeStringToFile(new File(path), evaluations[i].stats(), Charset.forName("UTF-8"));
        }
    }

    public DataSetIterator getDataIterator(String dataRootDir, boolean train, int totalExamples, int batchSize,
            int seed) {
        File root = new File(dataRootDir, train ? "train" : "test");
        List<String> all = new ArrayList<>();
        File[] files = root.listFiles();
        if (files == null || files.length == 0) {
            throw new IllegalStateException("Did not find files in directory " + root.getAbsolutePath());
        }
        for (File f : files) {
            all.add(f.getAbsolutePath());
        }
        Collections.sort(all);
        int totalBatches = (totalExamples < 0 ? -1 : totalExamples / batchSize);
        if (totalBatches > 0 && totalBatches < all.size()) {
            Random r = new Random(seed);
            int[] order = new int[all.size()];
            for (int i = 0; i < order.length; i++) {
                order[i] = i;
            }
            MathUtils.shuffleArray(order, r);
            List<String> from = all;
            all = new ArrayList<>();
            for (int i = 0; i < totalBatches; i++) {
                all.add(from.get(order[i]));
            }
        }

        Loader<DataSet> loader = new LoadDataSetsFunction(w2vPath, PatentLabelGenerator.classLabelToIndex().size(),
                300);
        return new DataSetLoaderIterator(all, loader, new LocalFileSourceFactory());
    }
}