org.deeplearning4j.examples.recurrent.character.melodl4j.MelodyModelingExample.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.examples.recurrent.character.melodl4j.MelodyModelingExample.java

Source

/*******************************************************************************
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.examples.recurrent.character.melodl4j;

import org.apache.commons.io.FileUtils;
import org.deeplearning4j.examples.recurrent.character.CharacterIterator;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.*;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/**
 * LSTM  Symbolic melody modelling example, to compose music from symbolic melodies extracted from MIDI.
 * Based closely on LSTMCharModellingExample.java.
 * See the README file in this directory for documentation.
 *
 * @author Alex Black, Donald A. Smith.
 */
public class MelodyModelingExample {
    final static String inputSymbolicMelodiesFilename = "bach-melodies-input.txt";
    // Examples:  bach-melodies-input.txt, beatles-melodies-input.txt ,  pop-melodies-input.txt (large)

    final static String tmpDir = System.getProperty("java.io.tmpdir");

    final static String symbolicMelodiesInputFilePath = tmpDir + "/" + inputSymbolicMelodiesFilename; // Point to melodies created by MidiMelodyExtractor.java
    final static String composedMelodiesOutputFilePath = tmpDir + "/composition.txt"; // You can listen to these melodies by running PlayMelodyStrings.java against this file.

    //final static String symbolicMelodiesInputFilePath = "D:/tmp/bach-melodies.txt";
    //final static String composedMelodiesOutputFilePath = tmpDir + "/bach-composition.txt"; // You can listen to these melodies by running PlayMelodyStrings.java against this file.

    //....
    public static void main(String[] args) throws Exception {
        String loadNetworkPath = null; //"/tmp/MelodyModel-bach.zip"; //null;
        String generationInitialization = null; //Optional character initialization; a random character is used if null
        if (args.length == 2) {
            loadNetworkPath = args[0];
            generationInitialization = args[1];
        }

        int lstmLayerSize = 200; //Number of units in each LSTM layer
        int miniBatchSize = 32; //Size of mini batch to use when training
        int exampleLength = 500; //1000;           //Length of each training example sequence to use.
        int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters
        int numEpochs = 50; //Total number of training epochs
        int generateSamplesEveryNMinibatches = 20; //How frequently to generate samples from the network? 1000 characters / 50 tbptt length: 20 parameter updates per minibatch
        int nSamplesToGenerate = 10; //Number of samples to generate after each training epoch
        int nCharactersToSample = 300; //Length of each sample to generate

        // Above is Used to 'prime' the LSTM with a character sequence to continue/complete.
        // Initialization characters must all be in CharacterIterator.getMinimalCharacterSet() by default
        Random rng = new Random(12345);
        long startTime = System.currentTimeMillis();

        System.out.println("Using " + tmpDir + " as the temporary directory");
        //Get a DataSetIterator that handles vectorization of text into something we can use to train
        // our LSTM network.
        CharacterIterator iter = getMidiIterator(miniBatchSize, exampleLength);

        if (loadNetworkPath != null) {
            MultiLayerNetwork net = MultiLayerNetwork.load(new File(loadNetworkPath), true);
            String[] samples = sampleCharactersFromNetwork(generationInitialization, net, iter, rng,
                    nCharactersToSample, nSamplesToGenerate);
            for (String melody : samples) {
                System.out.println(melody);
                PlayMelodyStrings.playMelody(melody, 10);
                System.out.println();
            }
            System.exit(0);
        }

        int nOut = iter.totalOutcomes();

        //Set up network configuration:
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(0.1)).seed(12345)
                .l2(0.001).weightInit(WeightInit.XAVIER).list()
                .layer(0,
                        new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize).activation(Activation.TANH)
                                .build())
                .layer(1,
                        new LSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize).activation(Activation.TANH)
                                .build())
                //            .layer(2, new LSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)
                //                .activation(Activation.TANH).build())
                .layer(2,
                        new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX) //MCXENT + softmax for classification
                                .nIn(lstmLayerSize).nOut(nOut).build())
                .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength)
                .tBPTTBackwardLength(tbpttLength).build();

        learn(miniBatchSize, exampleLength, numEpochs, generateSamplesEveryNMinibatches, nSamplesToGenerate,
                nCharactersToSample, generationInitialization, rng, startTime, iter, conf);
    }

    private static void save(CharacterIterator iter) throws IOException {
        FileOutputStream fos = new FileOutputStream("/tmp/midi-character-iterator.jobj");
        ObjectOutputStream oos = new ObjectOutputStream(fos);
        oos.writeObject(iter);
        oos.close();
    }

    private static void learn(int miniBatchSize, int exampleLength, int numEpochs,
            int generateSamplesEveryNMinibatches, int nSamplesToGenerate, int nCharactersToSample,
            String generationInitialization, Random rng, long startTime, CharacterIterator iter,
            MultiLayerConfiguration conf) throws Exception {
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        //  GradientsListener listener2 = new GradientsListener(net,80);
        net.setListeners(/*listener2,*/ new ScoreIterationListener(100));

        //Print the  number of parameters in the network (and for each layer)
        Layer[] layers = net.getLayers();
        long totalNumParams = 0;
        for (int i = 0; i < layers.length; i++) {
            long nParams = layers[i].numParams();
            System.out.println("Number of parameters in layer " + i + ": " + nParams);
            totalNumParams += nParams;
        }
        System.out.println("Total number of network parameters: " + totalNumParams);

        List<String> melodies = new ArrayList<>(); // Later we print them out in reverse
        // order, so that the best melodies are at the start of the file.
        //Do training, and then generate and print samples from network
        int miniBatchNumber = 0;
        for (int epoch = 0; epoch < numEpochs; epoch++) {
            System.out.println("Starting epoch " + epoch);
            while (iter.hasNext()) {
                DataSet ds = iter.next();
                net.fit(ds);
                if (++miniBatchNumber % generateSamplesEveryNMinibatches == 0) {
                    System.out.println("---------- epoch " + epoch + " --------------");
                    System.out.println("Completed " + miniBatchNumber + " minibatches of size " + miniBatchSize
                            + "x" + exampleLength + " characters");
                    System.out.println("Sampling characters from network given initialization \""
                            + (generationInitialization == null ? "" : generationInitialization) + "\"");
                    String[] samples = sampleCharactersFromNetwork(generationInitialization, net, iter, rng,
                            nCharactersToSample, nSamplesToGenerate);
                    for (int j = 0; j < samples.length; j++) {
                        System.out.println("----- Sample " + j + " ----- of epoch " + epoch);
                        System.out.println(samples[j]);
                        melodies.add(samples[j]);
                        System.out.println();
                    }
                }
                if (miniBatchNumber == 0) {
                    // save(iter); System.exit(0);
                }
            }
            iter.reset(); //Reset iterator for another epoch
            if (melodies.size() > 0) {
                String melody = melodies.get(melodies.size() - 1);
                int seconds = 25;
                System.out.println("\nFirst " + seconds + " seconds of " + melody);
                PlayMelodyStrings.playMelody(melody, seconds);
            }
        }
        int indexOfLastPeriod = inputSymbolicMelodiesFilename.lastIndexOf('.');
        String saveFileName = inputSymbolicMelodiesFilename.substring(0,
                indexOfLastPeriod > 0 ? indexOfLastPeriod : inputSymbolicMelodiesFilename.length());
        ModelSerializer.writeModel(net, "/tmp/" + saveFileName + ".zip", false);

        // Write all melodies to the output file, in reverse order (so that the best melodies are at the start of the file).
        PrintWriter printWriter = new PrintWriter(composedMelodiesOutputFilePath);
        for (int i = melodies.size() - 1; i >= 0; i--) {
            printWriter.println(melodies.get(i));
        }
        printWriter.close();
        double seconds = 0.001 * (System.currentTimeMillis() - startTime);

        System.out.println("\n\nExample complete in " + seconds + " seconds");
        System.exit(0);
    }

    public static void makeSureFileIsInTmpDir(String filename) {
        final File f = new File(tmpDir + "/" + filename);
        if (!f.exists()) {
            URL url = null;
            try {
                url = new URL("http://truthsite.org/music/" + filename);
                FileUtils.copyURLToFile(url, f);
            } catch (Exception exc) {
                System.err.println("Error copying " + url + " to " + f);
                throw new RuntimeException(exc);
            }
            if (!f.exists()) {
                throw new RuntimeException(f.getAbsolutePath() + " does not exist");
            }
            System.out.println("File downloaded to " + f.getAbsolutePath());
        } else {
            System.out.println("Using existing text file at " + f.getAbsolutePath());
        }
    }

    /**
     * Sets up and return a simple DataSetIterator that does vectorization based on the melody sample.
     *
     * @param miniBatchSize  Number of text segments in each training mini-batch
     * @param sequenceLength Number of characters in each text segment.
     */
    public static CharacterIterator getMidiIterator(int miniBatchSize, int sequenceLength) throws Exception {
        makeSureFileIsInTmpDir(inputSymbolicMelodiesFilename);
        final char[] validCharacters = MelodyStrings.allValidCharacters.toCharArray(); //Which characters are allowed? Others will be removed
        return new CharacterIterator(symbolicMelodiesInputFilePath, Charset.forName("UTF-8"), miniBatchSize,
                sequenceLength, validCharacters, new Random(12345), MelodyStrings.COMMENT_STRING);
    }

    /**
     * Generate a sample from the network, given an (optional, possibly null) initialization. Initialization
     * can be used to 'prime' the RNN with a sequence you want to extend/continue.<br>
     * Note that the initalization is used for all samples
     *
     * @param initialization     String, may be null. If null, select a random character as initialization for all samples
     * @param charactersToSample Number of characters to sample from network (excluding initialization)
     * @param net                MultiLayerNetwork with one or more LSTM/RNN layers and a softmax output layer
     * @param iter               CharacterIterator. Used for going from indexes back to characters
     */
    public static String[] sampleCharactersFromNetwork(String initialization, MultiLayerNetwork net,
            CharacterIterator iter, Random rng, int charactersToSample, int numSamples) {
        //Set up initialization. If no initialization: use a random character
        if (initialization == null) {
            initialization = String.valueOf(iter.getRandomCharacter());
        }

        //Create input for initialization
        INDArray initializationInput = Nd4j.zeros(numSamples, iter.inputColumns(), initialization.length());
        char[] init = initialization.toCharArray();
        for (int i = 0; i < init.length; i++) {
            int idx = iter.convertCharacterToIndex(init[i]);
            for (int j = 0; j < numSamples; j++) {
                initializationInput.putScalar(new int[] { j, idx, i }, 1.0f);
            }
        }

        StringBuilder[] sb = new StringBuilder[numSamples];
        for (int i = 0; i < numSamples; i++)
            sb[i] = new StringBuilder(initialization);

        //Sample from network (and feed samples back into input) one character at a time (for all samples)
        //Sampling is done in parallel here
        net.rnnClearPreviousState();
        INDArray output = net.rnnTimeStep(initializationInput);
        output = output.tensorAlongDimension((int) output.size(2) - 1, 1, 0); //Gets the last time step output

        for (int i = 0; i < charactersToSample; i++) {
            //Set up next input (single time step) by sampling from previous output
            INDArray nextInput = Nd4j.zeros(numSamples, iter.inputColumns());
            //Output is a probability distribution. Sample from this for each example we want to generate, and add it to the new input
            for (int s = 0; s < numSamples; s++) {
                double[] outputProbDistribution = new double[iter.totalOutcomes()];
                for (int j = 0; j < outputProbDistribution.length; j++)
                    outputProbDistribution[j] = output.getDouble(s, j);
                int sampledCharacterIdx = sampleFromDistribution(outputProbDistribution, rng);

                nextInput.putScalar(new int[] { s, sampledCharacterIdx }, 1.0f); //Prepare next time step input
                sb[s].append(iter.convertIndexToCharacter(sampledCharacterIdx)); //Add sampled character to StringBuilder (human readable output)
            }

            output = net.rnnTimeStep(nextInput); //Do one time step of forward pass
        }

        String[] out = new String[numSamples];
        for (int i = 0; i < numSamples; i++)
            out[i] = sb[i].toString();
        return out;
    }

    /**
     * Given a probability distribution over discrete classes, sample from the distribution
     * and return the generated class index.
     *
     * @param distribution Probability distribution over classes. Must sum to 1.0
     */
    public static int sampleFromDistribution(double[] distribution, Random rng) {
        double d = 0.0;
        double sum = 0.0;
        for (int t = 0; t < 10; t++) {
            d = rng.nextDouble();
            sum = 0.0;
            for (int i = 0; i < distribution.length; i++) {
                sum += distribution[i];
                if (d <= sum)
                    return i;
            }
            //If we haven't found the right index yet, maybe the sum is slightly
            //lower than 1 due to rounding error, so try again.
        }
        //Should be extremely unlikely to happen if distribution is a valid probability distribution
        throw new IllegalArgumentException("Distribution is invalid? d=" + d + ", sum=" + sum);
    }
}