com.github.tteofili.calabrize.impl.RNN.java Source code

Java tutorial

Introduction

Here is the source code for com.github.tteofili.calabrize.impl.RNN.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package com.github.tteofili.calabrize.impl;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.commons.math3.distribution.EnumeratedDistribution;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.SetRange;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/**
 * A min char/word-level vanilla RNN model, based on Andrej Karpathy's python code.
 * See also:
 *
 * @see <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness">The Unreasonable Effectiveness of Recurrent Neural Networks</a>
 * @see <a href="https://gist.github.com/karpathy/d4dee566867f8291f086">Minimal character-level language model with a Vanilla Recurrent Neural Network, in Python/numpy</a>
 */
public class RNN {

    // hyperparameters
    protected float learningRate;
    protected final int seqLength; // no. of steps to unroll the RNN for
    protected final int hiddenLayerSize;
    protected final int epochs;
    protected final boolean useChars;
    protected final int batch;
    protected final int vocabSize;
    protected final Map<String, Integer> charToIx;
    protected final Map<Integer, String> ixToChar;
    protected final List<String> data;
    private final static double eps = 1e-8;
    private final static double decay = 0.9;
    private final V2HCalabrianEncoder encoder = new V2HCalabrianEncoder();

    // model parameters
    private final INDArray wxh; // input to hidden
    private final INDArray whh; // hidden to hidden
    private final INDArray why; // hidden to output
    private final INDArray bh; // hidden bias
    private final INDArray by; // output bias

    private INDArray hPrev = null; // memory state

    public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) {
        this(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, true);
    }

    public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, int batch,
            boolean useChars) {
        this.learningRate = learningRate;
        this.seqLength = seqLength;
        this.hiddenLayerSize = hiddenLayerSize;
        this.epochs = epochs;
        this.batch = batch;
        this.useChars = useChars;

        data = Arrays.asList(useChars ? toStrings(text.toCharArray()) : text.split(" "));
        for (String d : data) {
            data.add(encoder.encode(d));
        }

        //    data = new LinkedList<>();
        //    Collections.addAll(data, textTokens);
        Set<String> tokens = new HashSet<>(data);
        vocabSize = tokens.size();

        System.out.printf("data has %d tokens, %d unique.\n", data.size(), vocabSize);
        charToIx = new HashMap<>();
        ixToChar = new HashMap<>();
        int i = 0;
        for (String c : tokens) {
            charToIx.put(c, i);
            ixToChar.put(i, c);
            i++;
        }

        wxh = Nd4j.randn(hiddenLayerSize, vocabSize).mul(0.01);
        whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).mul(0.01);
        why = Nd4j.randn(vocabSize, hiddenLayerSize).mul(0.01);
        bh = Nd4j.zeros(hiddenLayerSize, 1);
        by = Nd4j.zeros(vocabSize, 1);
    }

    private String[] toStrings(char[] chars) {
        String[] strings = new String[chars.length];
        for (int i = 0; i < chars.length; i++) {
            strings[i] = String.valueOf(chars[i]);
        }
        return strings;
    }

    public void learn() {

        int currentEpoch = 0;

        int n = 0;
        int p = 0;

        // memory variables for Adagrad
        INDArray mWxh = Nd4j.zerosLike(wxh);
        INDArray mWhh = Nd4j.zerosLike(whh);
        INDArray mWhy = Nd4j.zerosLike(why);

        INDArray mbh = Nd4j.zerosLike(bh);
        INDArray mby = Nd4j.zerosLike(by);

        // loss at iteration 0
        double smoothLoss = -Math.log(1.0 / vocabSize) * seqLength;

        while (true) {
            // prepare inputs (we're sweeping from left to right in steps seqLength long)
            if (p + seqLength + 1 >= data.size() || n == 0) {
                hPrev = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
                p = 0; // go from start of data
                currentEpoch++;
                if (currentEpoch == epochs) {
                    System.out.println("training finished: e:" + epochs + ", l: " + smoothLoss + ", h:("
                            + learningRate + ", " + seqLength + ", " + hiddenLayerSize + ")");
                    break;
                }
            }

            INDArray inputs = getSequence(p, true);
            INDArray targets = getSequence(p + 1, false);

            // sample from the model every now and then
            if (n % 1000 == 0 && n > 0) {
                String txt = sample(inputs.getInt(0));
                System.out.printf("\n---\n %s \n----\n", txt);
            }

            INDArray dWxh = Nd4j.zerosLike(wxh);
            INDArray dWhh = Nd4j.zerosLike(whh);
            INDArray dWhy = Nd4j.zerosLike(why);

            INDArray dbh = Nd4j.zerosLike(bh);
            INDArray dby = Nd4j.zerosLike(by);

            // forward seqLength characters through the net and fetch gradient
            double loss = lossFun(inputs, targets, dWxh, dWhh, dWhy, dbh, dby);
            smoothLoss = smoothLoss * 0.999 + loss * 0.001;
            if (Double.isNaN(smoothLoss)) {
                System.out.println("loss is NaN (over/underflow occured, try adjusting hyperparameters)");
                break;
            }
            if (n % 100 == 0) {
                System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress
            }

            if (n % batch == 0) {

                // perform parameter update with RMSprop
                mWxh = mWxh.mul(decay).add(1 - decay).mul((dWxh).mul(dWxh));
                wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));

                mWhh = mWhh.mul(decay).add(1 - decay).mul((dWhh).mul(dWhh));
                whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps)));

                mWhy = mWhy.mul(decay).add(1 - decay).mul((dWhy).mul(dWhy));
                why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy).add(eps)));

                mbh = mbh.mul(decay).add(1 - decay).mul((dbh).mul(dbh));
                bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps)));

                mby = mby.mul(decay).add(1 - decay).mul((dby).mul(dby));
                by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
            }

            p += seqLength; // move data pointer
            n++; // iteration counter
        }
    }

    private INDArray getSequence(int p, boolean translate) {
        INDArray inputs = Nd4j.create(seqLength);
        int c = 0;
        for (String ch : data.subList(p, p + seqLength)) {
            if (translate) {
                ch = encoder.encode(ch);
            }
            Integer ix = charToIx.get(ch);
            inputs.putScalar(c, ix);
            c++;
        }
        return inputs;
    }

    /**
     * inputs, targets are both list of integers
     * hprev is Hx1 array of initial hidden state
     * returns the modified loss, gradients on model parameters
     */
    private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWhy,
            INDArray dbh, INDArray dby) {

        INDArray xs = Nd4j.zeros(inputs.length(), vocabSize);
        INDArray hs = null;
        INDArray ys = null;
        INDArray ps = null;

        INDArray hs1 = Nd4j.create(hPrev.shape());
        Nd4j.copy(hPrev, hs1);

        double loss = 0;

        // forward pass
        for (int t = 0; t < inputs.length(); t++) {
            int tIndex = inputs.getScalar(t).getInt(0);
            xs.putScalar(t, tIndex, 1); // encode in 1-of-k representation
            INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
            INDArray hst = Transforms.tanh(wxh.mmul(xs.getRow(t).transpose()).add(whh.mmul(hsRow)).add(bh)); // hidden state
            if (hs == null) {
                hs = init(inputs.length(), hst.shape());
            }
            hs.putRow(t, hst);

            INDArray yst = (why.mmul(hst)).add(by); // unnormalized log probabilities for next chars
            if (ys == null) {
                ys = init(inputs.length(), yst.shape());
            }
            ys.putRow(t, yst);
            INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); // probabilities for next chars
            if (ps == null) {
                ps = init(inputs.length(), pst.shape());
            }
            ps.putRow(t, pst);
            loss += -Math.log(pst.getDouble(targets.getInt(t), 0)); // softmax (cross-entropy loss)
        }

        // backward pass: compute gradients going backwards
        INDArray dhNext = Nd4j.zerosLike(hPrev);
        for (int t = inputs.length() - 1; t >= 0; t--) {
            INDArray dy = ps.getRow(t);
            dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // backprop into y
            INDArray hst = hs.getRow(t);
            dWhy.addi(dy.mmul(hst.transpose())); // derivative of hy layer
            dby.addi(dy);
            INDArray dh = why.transpose().mmul(dy).add(dhNext); // backprop into h
            INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst.mul(hst))).mul(dh); // backprop through tanh nonlinearity
            dbh.addi(dhraw);
            dWxh.addi(dhraw.mmul(xs.getRow(t)));
            INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
            dWhh.addi(dhraw.mmul(hsRow.transpose()));
            dhNext = whh.transpose().mmul(dhraw);
        }

        this.hPrev = hs.getRow(inputs.length() - 1);

        return loss;
    }

    protected INDArray init(int t, int[] aShape) {
        INDArray as;
        int[] shape = new int[1 + aShape.length];
        shape[0] = t;
        System.arraycopy(aShape, 0, shape, 1, aShape.length);
        as = Nd4j.create(shape);
        return as;
    }

    /**
     * sample a sequence of integers from the model, using current (hPrev) memory state, seedIx is seed letter for first time step
     */
    public String sample(int seedIx) {

        INDArray x = Nd4j.zeros(vocabSize, 1);
        x.putScalar(seedIx, 1);
        int sampleSize = 144;
        INDArray ixes = Nd4j.create(sampleSize);

        INDArray h = hPrev.dup();

        for (int t = 0; t < sampleSize; t++) {
            h = Transforms.tanh(wxh.mmul(x).add(whh.mmul(h)).add(bh));
            INDArray y = (why.mmul(h)).add(by);
            INDArray pm = Nd4j.getExecutioner().execAndReturn(new SoftMax(y)).ravel();

            List<Pair<Integer, Double>> d = new LinkedList<>();
            for (int pi = 0; pi < vocabSize; pi++) {
                d.add(new Pair<>(pi, pm.getDouble(0, pi)));
            }
            try {
                EnumeratedDistribution<Integer> distribution = new EnumeratedDistribution<>(d);

                int ix = distribution.sample();

                x = Nd4j.zeros(vocabSize, 1);
                x.putScalar(ix, 1);
                ixes.putScalar(t, ix);
            } catch (Exception e) {
            }
        }

        return getSampleString(ixes);
    }

    protected String getSampleString(INDArray ixes) {
        StringBuilder txt = new StringBuilder();

        NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape());
        while (ndIndexIterator.hasNext()) {
            int[] next = ndIndexIterator.next();
            if (!useChars && txt.length() > 0) {
                txt.append(' ');
            }
            txt.append(ixToChar.get(ixes.getInt(next)));
        }
        return txt.toString();
    }

    public int getVocabSize() {
        return vocabSize;
    }

    @Override
    public String toString() {
        return getClass().getName() + "{" + "learningRate=" + learningRate + ", seqLength=" + seqLength
                + ", hiddenLayerSize=" + hiddenLayerSize + ", epochs=" + epochs + ", vocabSize=" + vocabSize
                + ", useChars=" + useChars + ", batch=" + batch + '}';
    }

    public void serialize(String prefix) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(
                new FileWriter(new File(prefix + new Date().toString() + ".txt")));
        bufferedWriter.write("wxh");
        bufferedWriter.write(wxh.toString());
        bufferedWriter.write("whh");
        bufferedWriter.write(whh.toString());
        bufferedWriter.write("why");
        bufferedWriter.write(why.toString());
        bufferedWriter.write("bh");
        bufferedWriter.write(bh.toString());
        bufferedWriter.write("by");
        bufferedWriter.write(by.toString());
        bufferedWriter.flush();
        bufferedWriter.close();
    }
}