Java tutorial
/******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.examples.multigpu.rnn; import org.apache.commons.io.FileUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.IterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.parallelism.ParallelWrapper; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; import java.io.IOException; import java.net.URL; import java.nio.charset.Charset; import java.util.Random; //import org.nd4j.jita.conf.CudaEnvironment; /**LSTM Character modelling example * @author Alex Black Example: Train a LSTM RNN to generates text, one character at a time. This example is somewhat inspired by Andrej Karpathy's blog post, "The Unreasonable Effectiveness of Recurrent Neural Networks" http://karpathy.github.io/2015/05/21/rnn-effectiveness/ This example is set up to train on the Complete Works of William Shakespeare, downloaded from Project Gutenberg. Training on other text sources should be relatively easy to implement. For more details on RNNs in DL4J, see the following: http://deeplearning4j.org/usingrnns http://deeplearning4j.org/lstm http://deeplearning4j.org/recurrentnetwork */ public class LSTMCharModellingExample { public static void main(String[] args) throws Exception { int seed = 12345; int lstmLayerSize = 200; //Number of units in each LSTM layer int miniBatchSize = 32; //Size of mini batch to use when training int exampleLength = 1000; //Length of each training example sequence to use. This could certainly be increased int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters int nSamplesToGenerate = 4; //Number of samples to generate after each training epoch int nCharactersToSample = 300; //Length of each sample to generate String generationInitialization = null; //Optional character initialization; a random character is used if null // Above is Used to 'prime' the LSTM with a character sequence to continue/complete. // Initialization characters must all be in CharacterIterator.getMinimalCharacterSet() by default Random rng = new Random(seed); //Get a DataSetIterator that handles vectorization of text into something we can use to train // our LSTM network. CharacterIterator iter = getShakespeareIterator(miniBatchSize, exampleLength); int nOut = iter.totalOutcomes(); //Set up network configuration: ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l2(0.0001) .weightInit(WeightInit.XAVIER).updater(new Adam(0.005)).graphBuilder().addInputs("input") //Give the input a name. For a ComputationGraph with multiple inputs, this also defines the input array orders //First layer: name "first", with inputs from the input called "input" .addLayer("first", new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize).activation(Activation.TANH) .build(), "input") //Second layer, name "second", with inputs from the layer called "first" .addLayer("second", new LSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize).activation(Activation.TANH) .build(), "first") //Output layer, name "outputlayer" with inputs from the two layers called "first" and "second" .addLayer("outputLayer", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX) .nIn(2 * lstmLayerSize).nOut(nOut).build(), "first", "second") .setOutputs("outputLayer") //List the output. For a ComputationGraph with multiple outputs, this also defines the input array orders .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength) .tBPTTBackwardLength(tbpttLength).build(); ComputationGraph net = new ComputationGraph(conf); net.init(); net.setListeners(new ScoreIterationListener(1), new IterationListener() { @Override public void iterationDone(Model model, int iteration, int epoch) { if (iteration % 20 == 0) { System.out.println("--------------------"); System.out.println("Sampling characters from network given initialization \"" + (generationInitialization == null ? "" : generationInitialization) + "\""); String[] samples = sampleCharactersFromNetwork(generationInitialization, (ComputationGraph) model, iter, rng, nCharactersToSample, nSamplesToGenerate); for (int j = 0; j < samples.length; j++) { System.out.println("----- Sample " + j + " -----"); System.out.println(samples[j]); System.out.println(); } } } }); //Print the number of parameters in the network (and for each layer) Layer[] layers = net.getLayers(); long totalNumParams = 0; for (int i = 0; i < layers.length; i++) { long nParams = layers[i].numParams(); System.out.println("Number of parameters in layer " + i + ": " + nParams); totalNumParams += nParams; } System.out.println("Total number of network parameters: " + totalNumParams); // ParallelWrapper will take care of load balancing between GPUs. ParallelWrapper wrapper = new ParallelWrapper.Builder(net) // DataSets prefetching options. Set this value with respect to number of actual devices .prefetchBuffer(24) // set number of workers equal to number of available devices. x1-x2 are good values to start with .workers(2) .build(); wrapper.fit(iter); System.out.println("\n\nExample complete"); } /** Downloads Shakespeare training data and stores it locally (temp directory). Then set up and return a simple * DataSetIterator that does vectorization based on the text. * @param miniBatchSize Number of text segments in each training mini-batch * @param sequenceLength Number of characters in each text segment. */ public static CharacterIterator getShakespeareIterator(int miniBatchSize, int sequenceLength) throws Exception { //The Complete Works of William Shakespeare //5.3MB file in UTF-8 Encoding, ~5.4 million characters //https://www.gutenberg.org/ebooks/100 String url = "https://s3.amazonaws.com/dl4j-distribution/pg100.txt"; String tempDir = System.getProperty("java.io.tmpdir"); String fileLocation = tempDir + "/Shakespeare.txt"; //Storage location from downloaded file File f = new File(fileLocation); if (!f.exists()) { FileUtils.copyURLToFile(new URL(url), f); System.out.println("File downloaded to " + f.getAbsolutePath()); } else { System.out.println("Using existing text file at " + f.getAbsolutePath()); } if (!f.exists()) throw new IOException("File does not exist: " + fileLocation); //Download problem? char[] validCharacters = CharacterIterator.getMinimalCharacterSet(); //Which characters are allowed? Others will be removed return new CharacterIterator(fileLocation, Charset.forName("UTF-8"), miniBatchSize, sequenceLength, validCharacters, new Random(12345)); } /** Generate a sample from the network, given an (optional, possibly null) initialization. Initialization * can be used to 'prime' the RNN with a sequence you want to extend/continue.<br> * Note that the initalization is used for all samples * @param initialization String, may be null. If null, select a random character as initialization for all samples * @param charactersToSample Number of characters to sample from network (excluding initialization) * @param net MultiLayerNetwork with one or more LSTM/RNN layers and a softmax output layer * @param iter CharacterIterator. Used for going from indexes back to characters */ private static String[] sampleCharactersFromNetwork(String initialization, ComputationGraph net, CharacterIterator iter, Random rng, int charactersToSample, int numSamples) { //Set up initialization. If no initialization: use a random character if (initialization == null) { initialization = String.valueOf(iter.getRandomCharacter()); } //Create input for initialization INDArray initializationInput = Nd4j.zeros(numSamples, iter.inputColumns(), initialization.length()); char[] init = initialization.toCharArray(); for (int i = 0; i < init.length; i++) { int idx = iter.convertCharacterToIndex(init[i]); for (int j = 0; j < numSamples; j++) { initializationInput.putScalar(new int[] { j, idx, i }, 1.0f); } } StringBuilder[] sb = new StringBuilder[numSamples]; for (int i = 0; i < numSamples; i++) sb[i] = new StringBuilder(initialization); //Sample from network (and feed samples back into input) one character at a time (for all samples) //Sampling is done in parallel here net.rnnClearPreviousState(); INDArray output = net.rnnTimeStep(initializationInput)[0]; output = output.tensorAlongDimension((int) output.size(2) - 1, 1, 0); //Gets the last time step output for (int i = 0; i < charactersToSample; i++) { //Set up next input (single time step) by sampling from previous output INDArray nextInput = Nd4j.zeros(numSamples, iter.inputColumns()); //Output is a probability distribution. Sample from this for each example we want to generate, and add it to the new input for (int s = 0; s < numSamples; s++) { double[] outputProbDistribution = new double[iter.totalOutcomes()]; for (int j = 0; j < outputProbDistribution.length; j++) outputProbDistribution[j] = output.getDouble(s, j); int sampledCharacterIdx = sampleFromDistribution(outputProbDistribution, rng); nextInput.putScalar(new int[] { s, sampledCharacterIdx }, 1.0f); //Prepare next time step input sb[s].append(iter.convertIndexToCharacter(sampledCharacterIdx)); //Add sampled character to StringBuilder (human readable output) } output = net.rnnTimeStep(nextInput)[0]; //Do one time step of forward pass } String[] out = new String[numSamples]; for (int i = 0; i < numSamples; i++) out[i] = sb[i].toString(); return out; } /** Given a probability distribution over discrete classes, sample from the distribution * and return the generated class index. * @param distribution Probability distribution over classes. Must sum to 1.0 */ public static int sampleFromDistribution(double[] distribution, Random rng) { double d = rng.nextDouble(); double sum = 0.0; for (int i = 0; i < distribution.length; i++) { sum += distribution[i]; if (d <= sum) return i; } //Should never happen if distribution is a valid probability distribution throw new IllegalArgumentException("Distribution is invalid? d=" + d + ", sum=" + sum); } }