imp.lstm.main.Driver.java Source code

Java tutorial

Introduction

Here is the source code for imp.lstm.main.Driver.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package imp.lstm.main;

import imp.lstm.architecture.NetworkConnectomeLoader;
import java.io.File;
import imp.lstm.io.leadsheet.LeadSheetDataSequence;
import imp.lstm.architecture.FragmentedNeuralQueue;
import imp.lstm.architecture.FullyConnectedLayer;
import imp.lstm.architecture.LSTM;
import imp.lstm.architecture.Loadable;
import imp.lstm.architecture.poex.ProductCompressingAutoencoder;
import imp.lstm.encoding.EncodingParameters;
import imp.lstm.encoding.Group;
import imp.lstm.filters.GroupedSoftMaxSampler;
import imp.lstm.filters.Operations;
import imp.lstm.io.leadsheet.LeadSheetIO;
import java.io.FileNotFoundException;
import java.io.FilenameFilter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import imp.lstm.architecture.InvalidParametersException;
import mikera.arrayz.INDArray;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import org.apache.commons.configuration2.Configuration;
import org.apache.commons.configuration2.PropertiesConfiguration;
import org.apache.commons.configuration2.builder.FileBasedConfigurationBuilder;
import org.apache.commons.configuration2.builder.fluent.Parameters;
import org.apache.commons.configuration2.convert.DefaultListDelimiterHandler;
import org.apache.commons.configuration2.ex.ConfigurationException;

/**
 * Class Driver is an implementation test for CompressingAutoEncoder which reads
 * a LeadSheet file and produces an equivalent length LeadSheet file
 *
 * @author Nicholas Weintraut
 */
public class Driver {

    private static final boolean advanceDecoding = false; //should we start decoding as soon as possible?

    public static void main(String[] args)
            throws FileNotFoundException, IOException, ConfigurationException, InvalidParametersException {
        FileBasedConfigurationBuilder<PropertiesConfiguration> builder = new FileBasedConfigurationBuilder<>(
                PropertiesConfiguration.class).configure(
                        new Parameters().properties().setFileName(args[0]).setThrowExceptionOnMissing(true)
                                .setListDelimiterHandler(new DefaultListDelimiterHandler(';'))
                                .setIncludesAllowed(false));
        Configuration config = builder.getConfiguration();

        String inputSongPath = config.getString("input_song");
        String outputFolderPath = config.getString("output_folder");
        String autoEncoderParamsPath = config.getString("auto_encoder_params");
        String nameGeneratorParamsPath = config.getString("name_generator_params");
        String queueFolderPath = config.getString("queue_folder");
        String referenceQueuePath = config.getString("reference_queue", "nil");
        String inputCorpusFolder = config.getString("input_corpus_folder");
        boolean shouldWriteQueue = config.getBoolean("should_write_generated_queue");
        boolean frankensteinTest = config.getBoolean("queue_tests_frankenstein");
        boolean interpolateTest = config.getBoolean("queue_tests_interpolation");
        boolean iterateOverCorpus = config.getBoolean("iterate_over_corpus", false);
        boolean shouldGenerateSongTitle = config.getBoolean("generate_song_title");
        boolean shouldGenerateSong = config.getBoolean("generate_leadsheet");

        LogTimer.initStartTime(); //start our logging timer to keep track of our execution time
        LogTimer.log("Creating name generator...");

        //here is just silly code for generating name based on an LSTM lol $wag
        LSTM lstm = new LSTM();
        FullyConnectedLayer fullLayer = new FullyConnectedLayer(Operations.None);
        Loadable titleNetLoader = new Loadable() {
            @Override
            public boolean load(INDArray array, String path) {
                String car = pathCar(path);
                String cdr = pathCdr(path);
                switch (car) {
                case "full":
                    return fullLayer.load(array, cdr);
                case "lstm":
                    return lstm.load(array, cdr);
                default:
                    return false;
                }
            }
        };

        LogTimer.log("Packing name generator from files...");
        (new NetworkConnectomeLoader()).load(nameGeneratorParamsPath, titleNetLoader);

        String characterString = " !\"'[],-.01245679:?ABCDEFGHIJKLMNOPQRSTUVWYZabcdefghijklmnopqrstuvwxyz";

        //Initialization
        LogTimer.log("Creating autoencoder...");
        int inputSize = 34;
        int outputSize = EncodingParameters.noteEncoder.getNoteLength();
        int featureVectorSize = 100;
        ProductCompressingAutoencoder autoencoder = new ProductCompressingAutoencoder(24, 48, 84 + 1, false); //create our network

        int numInterpolationDivisions = 5;

        //"pack" the network from weights and biases file directory
        LogTimer.log("Packing autoencoder from files");
        (new NetworkConnectomeLoader()).load(autoEncoderParamsPath, autoencoder);

        File[] songFiles;
        if (iterateOverCorpus) {
            songFiles = new File(inputCorpusFolder).listFiles();
        } else {
            songFiles = new File[] { new File(inputSongPath) };
        }
        for (File inputFile : songFiles) {
            (new NetworkConnectomeLoader()).refresh(autoEncoderParamsPath, autoencoder, "initialstate");
            String songTitle;
            if (shouldGenerateSong) {
                Random rand = new Random();
                AVector charOut = Vector.createLength(characterString.length());
                GroupedSoftMaxSampler sampler = new GroupedSoftMaxSampler(
                        new Group[] { new Group(0, characterString.length(), true) });
                songTitle = "";
                for (int i = 0; i < 50; i++) {
                    charOut = fullLayer.forward(lstm.step(charOut));
                    charOut = sampler.filter(charOut);
                    int charIndex = 0;
                    for (; charIndex < charOut.length(); charIndex++) {
                        if (charOut.get(charIndex) == 1.0) {
                            break;
                        }
                    }
                    songTitle += characterString.substring(charIndex, charIndex + 1);
                }
                songTitle = songTitle.trim();

                LogTimer.log("Generated song name: " + songTitle);
            } else {
                songTitle = "The Song We Never Name";
            }
            LogTimer.log("Reading file...");
            LeadSheetDataSequence inputSequence = LeadSheetIO.readLeadSheet(inputFile); //read our leadsheet to get a data vessel as retrieved in rbm-provisor
            LeadSheetDataSequence outputSequence = inputSequence.copy();

            outputSequence.clearMelody();
            if (interpolateTest) {
                LeadSheetDataSequence additionalOutput = outputSequence.copy();
                for (int i = 0; i < numInterpolationDivisions; i++) {
                    outputSequence.concat(additionalOutput.copy());
                }
            }
            LeadSheetDataSequence decoderInputSequence = outputSequence.copy();

            LogTimer.startLog("Encoding data...");
            //TradingTimer.initStart(); //start our trading timer to keep track our our generation versus realtime play
            while (inputSequence.hasNext()) { //iterate through time steps in input data
                //TradingTimer.waitForNextTimedInput();
                autoencoder.encodeStep(inputSequence.retrieve()); //feed the resultant input vector into the network
                if (advanceDecoding) { //if we are using advance decoding (we start decoding as soon as we can)
                    if (autoencoder.canDecode()) { //if queue has enough data to decode from
                        outputSequence.pushStep(null, null,
                                autoencoder.decodeStep(decoderInputSequence.retrieve())); //take sampled data for a timestep from autoencoder
                        //TradingTimer.logTimestep(); //log our time to TradingTimer so we can know how far ahead of realtime we are
                    }
                }
            }
            LogTimer.endLog();

            if (shouldWriteQueue) {
                String queueFilePath = queueFolderPath + java.io.File.separator
                        + inputFile.getName().replace(".ls", ".q");
                FragmentedNeuralQueue currQueue = autoencoder.getQueue();
                currQueue.writeToFile(queueFilePath);
                LogTimer.log("Wrote queue " + inputFile.getName().replace(".ls", ".q") + " to file...");
            }
            if (shouldGenerateSong) {
                if (interpolateTest) {

                    FragmentedNeuralQueue refQueue = new FragmentedNeuralQueue();
                    refQueue.initFromFile(referenceQueuePath);

                    FragmentedNeuralQueue currQueue = autoencoder.getQueue();
                    //currQueue.writeToFile(queueFilePath);

                    autoencoder.setQueue(currQueue.copy());
                    while (autoencoder.hasDataStepsLeft()) { //we are done encoding all time steps, so just finish decoding!{
                        outputSequence.pushStep(null, null,
                                autoencoder.decodeStep(decoderInputSequence.retrieve())); //take sampled data for a timestep from autoencoder
                        //TradingTimer.logTimestep(); //log our time to TradingTimer so we can know how far ahead of realtime we are       
                    }

                    for (int i = 1; i <= numInterpolationDivisions; i++) {
                        System.out.println("Starting interpolation " + ((1.0 / numInterpolationDivisions) * (i)));
                        (new NetworkConnectomeLoader()).refresh(autoEncoderParamsPath, autoencoder, "initialstate");
                        FragmentedNeuralQueue currCopy = currQueue.copy();
                        currCopy.basicInterpolate(refQueue, (1.0 / numInterpolationDivisions) * (i));
                        autoencoder.setQueue(currCopy);
                        int timeStep = 0;
                        while (autoencoder.hasDataStepsLeft()) { //we are done encoding all time steps, so just finish decoding!{
                            System.out.println("interpolation " + i + " step " + ++timeStep);
                            outputSequence.pushStep(null, null,
                                    autoencoder.decodeStep(decoderInputSequence.retrieve())); //take sampled data for a timestep from autoencoder
                            //TradingTimer.logTimestep(); //log our time to TradingTimer so we can know how far ahead of realtime we are       
                        }
                    }

                }
                if (frankensteinTest) {
                    LogTimer.startLog("Loading queues");
                    File queueFolder = new File(queueFolderPath);
                    int numComponents = config.getInt("frankenstein_num_components", 5);
                    int numCombinations = config.getInt("frankenstein_num_combinations", 6);
                    double interpolationMagnitude = config.getDouble("frankenstein_magnitude", 2.0);
                    if (queueFolder.isDirectory()) {
                        File[] queueFiles = queueFolder.listFiles(new FilenameFilter() {
                            @Override
                            public boolean accept(File dir, String name) {
                                return name.contains(".q");
                            }
                        });

                        List<File> fileList = new ArrayList<>();
                        for (File file : queueFiles) {
                            fileList.add(file);
                        }
                        Collections.shuffle(fileList);
                        int numSelectedFiles = (numComponents > queueFiles.length) ? queueFiles.length
                                : numComponents;

                        for (int i = 0; i < queueFiles.length - numSelectedFiles; i++) {
                            fileList.remove(fileList.size() - 1);
                        }
                        List<FragmentedNeuralQueue> queuePopulation = new ArrayList<>(fileList.size());
                        songTitle += " - a mix of ";
                        for (File file : fileList) {
                            FragmentedNeuralQueue newQueue = new FragmentedNeuralQueue();
                            newQueue.initFromFile(file.getPath());
                            queuePopulation.add(newQueue);
                            songTitle += file.getName().replaceAll(".ls", "") + ", ";
                        }
                        LogTimer.endLog();

                        LeadSheetDataSequence additionalOutput = outputSequence.copy();
                        for (int i = 1; i < numCombinations; i++) {
                            outputSequence.concat(additionalOutput.copy());
                        }
                        decoderInputSequence = outputSequence.copy();

                        FragmentedNeuralQueue origQueue = autoencoder.getQueue();

                        for (int i = 0; i < numCombinations; i++) {

                            LogTimer.startLog("Performing queue interpolation...");
                            AVector combinationStrengths = Vector.createLength(queuePopulation.size());
                            Random vectorRand = new Random(i);
                            for (int j = 0; j < combinationStrengths.length(); j++) {
                                combinationStrengths.set(j, vectorRand.nextDouble());
                            }
                            combinationStrengths.divide(combinationStrengths.elementSum());
                            FragmentedNeuralQueue currQueue = origQueue.copy();
                            for (int k = 0; k < combinationStrengths.length(); k++) {
                                currQueue.basicInterpolate(queuePopulation.get(k),
                                        combinationStrengths.get(k) * interpolationMagnitude);
                            }
                            LogTimer.endLog();
                            autoencoder.setQueue(currQueue);
                            LogTimer.startLog("Refreshing autoencoder state...");
                            (new NetworkConnectomeLoader()).refresh(autoEncoderParamsPath, autoencoder,
                                    "initialstate");
                            LogTimer.endLog();
                            LogTimer.startLog("Decoding segment...");
                            while (autoencoder.hasDataStepsLeft()) { //we are done encoding all time steps, so just finish decoding!{
                                outputSequence.pushStep(null, null,
                                        autoencoder.decodeStep(decoderInputSequence.retrieve())); //take sampled data for a timestep from autoencoder
                                //TradingTimer.logTimestep(); //log our time to TradingTimer so we can know how far ahead of realtime we are       
                            }
                            LogTimer.endLog();
                        }

                    }
                }

                while (autoencoder.hasDataStepsLeft()) { //we are done encoding all time steps, so just finish decoding!{
                    outputSequence.pushStep(null, null, autoencoder.decodeStep(decoderInputSequence.retrieve())); //take sampled data for a timestep from autoencoder
                    //TradingTimer.logTimestep(); //log our time to TradingTimer so we can know how far ahead of realtime we are       
                }
                LogTimer.log("Writing file...");

                String outputFilename = outputFolderPath + java.io.File.separator
                        + inputFile.getName().replace(".ls", "_Output"); //we'll write our generated file with the same name plus "_Output"
                LeadSheetIO.writeLeadSheet(outputSequence, outputFilename, songTitle);
                System.out.println(outputFilename);
            } else {
                autoencoder.setQueue(new FragmentedNeuralQueue());
            }
        }
        LogTimer.log("Process finished"); //Done!

    }
}