org.apache.horn.core.RecurrentLayeredNeuralNetwork.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.horn.core.RecurrentLayeredNeuralNetwork.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.horn.core;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.apache.commons.lang.math.RandomUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hama.Constants;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSPJob;
import org.apache.hama.commons.io.FloatMatrixWritable;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseFloatMatrix;
import org.apache.hama.commons.math.DenseFloatVector;
import org.apache.hama.commons.math.FloatFunction;
import org.apache.hama.commons.math.FloatMatrix;
import org.apache.hama.commons.math.FloatVector;
import org.apache.hama.util.ReflectionUtils;
import org.apache.horn.core.Constants.LearningStyle;
import org.apache.horn.core.Constants.TrainingMethod;
import org.apache.horn.examples.MultiLayerPerceptron.StandardNeuron;
import org.apache.horn.examples.RecurrentDropoutNeuron;
import org.apache.horn.funcs.FunctionFactory;
import org.apache.horn.funcs.IdentityFunction;
import org.apache.horn.funcs.SoftMax;
import org.apache.horn.utils.MathUtils;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;

/**
 * SmallLayeredNeuralNetwork defines the general operations for derivative
 * layered models, include Linear Regression, Logistic Regression, Multilayer
 * Perceptron, Autoencoder, and Restricted Boltzmann Machine, etc. For
 * SmallLayeredNeuralNetwork, the training can be conducted in parallel, but the
 * parameters of the models are assumes to be stored in a single machine.
 * 
 * In general, these models consist of neurons which are aligned in layers.
 * Between layers, for any two adjacent layers, the neurons are connected to
 * form a bipartite weighted graph.
 * 
 */
public class RecurrentLayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {

    private static final Log LOG = LogFactory.getLog(RecurrentLayeredNeuralNetwork.class);

    /* Weights between neurons at adjacent layers */
    protected List<FloatMatrix> weightMatrixList;
    /* Weights between neurons at adjacent layers */
    protected List<List<FloatMatrix>> weightMatrixLists;
    /* Previous weight updates between neurons at adjacent layers */
    protected List<FloatMatrix> prevWeightUpdatesList;
    protected List<List<FloatMatrix>> prevWeightUpdatesLists;
    /* Different layers can have different squashing function */
    protected List<FloatFunction> squashingFunctionList;
    protected List<Class<? extends Neuron>> neuronClassList;
    /* Record the recurrent layer */
    protected List<Boolean> recurrentLayerList;
    /* Recurrent step size */
    protected int recurrentStepSize;
    protected int finalLayerIdx;
    private List<Neuron[]> neurons;
    private List<List<Neuron[]>> neuronLists;
    private float dropRate;
    private long iterations;

    private int numOutCells;

    public RecurrentLayeredNeuralNetwork() {
        this.layerSizeList = Lists.newArrayList();
        this.weightMatrixList = Lists.newArrayList();
        this.prevWeightUpdatesList = Lists.newArrayList();
        this.squashingFunctionList = Lists.newArrayList();
        this.neuronClassList = Lists.newArrayList();
        this.weightMatrixLists = Lists.newArrayList();
        this.prevWeightUpdatesLists = Lists.newArrayList();
        this.neuronLists = Lists.newArrayList();
        this.recurrentLayerList = Lists.newArrayList();
    }

    public RecurrentLayeredNeuralNetwork(HamaConfiguration conf, String modelPath) {
        super(conf, modelPath);
        initializeNeurons(false);
        initializeWeightMatrixLists();
    }

    public RecurrentLayeredNeuralNetwork(HamaConfiguration conf, String modelPath, boolean isTraining) {
        super(conf, modelPath);
        initializeNeurons(isTraining);
        initializeWeightMatrixLists();
    }

    /**
     *  initialize neuron objects
     * @param isTraining
     */
    private void initializeNeurons(boolean isTraining) {
        this.neuronLists = Lists.newArrayListWithExpectedSize(recurrentStepSize);
        for (int stepIdx = 0; stepIdx < this.recurrentStepSize; stepIdx++) {
            neurons = new ArrayList<Neuron[]>();

            int expectedNeuronsSize = this.layerSizeList.size();
            if (stepIdx < this.recurrentStepSize - this.numOutCells) {
                expectedNeuronsSize--;
            }
            for (int neuronLayerIdx = 0; neuronLayerIdx < expectedNeuronsSize; neuronLayerIdx++) {
                int numOfNeurons = layerSizeList.get(neuronLayerIdx);
                // if not final layer and next layer is recurrent
                if (stepIdx > 0 && neuronLayerIdx < layerSizeList.size() - 1
                        && this.recurrentLayerList.get(neuronLayerIdx + 1)) {
                    numOfNeurons = numOfNeurons + layerSizeList.get(neuronLayerIdx + 1) - 1;
                }
                Class<? extends Neuron> neuronClass;
                if (neuronLayerIdx == 0)
                    neuronClass = StandardNeuron.class; // actually doesn't needed
                else
                    neuronClass = neuronClassList.get(neuronLayerIdx - 1);

                Neuron[] tmp = new Neuron[numOfNeurons];
                for (int neuronIdx = 0; neuronIdx < numOfNeurons; neuronIdx++) {
                    Neuron n = newNeuronInstance(neuronClass);
                    if (n instanceof RecurrentDropoutNeuron)
                        ((RecurrentDropoutNeuron) n).setDropRate(dropRate);
                    if (neuronLayerIdx > 0 && neuronIdx < layerSizeList.get(neuronLayerIdx))
                        n.setSquashingFunction(squashingFunctionList.get(neuronLayerIdx - 1));
                    else
                        n.setSquashingFunction(new IdentityFunction());
                    n.setLayerIndex(neuronLayerIdx);
                    n.setNeuronID(neuronIdx);
                    n.setLearningRate(this.learningRate);
                    n.setMomentumWeight(this.momentumWeight);
                    n.setTraining(isTraining);
                    tmp[neuronIdx] = n;
                }
                neurons.add(tmp);
            }
            this.neuronLists.add(neurons);
        }
    }

    /**
     * Initialize WeightMatrixLists
     */
    public void initializeWeightMatrixLists() {
        this.numOutCells = (numOutCells == 0 ? this.recurrentStepSize : numOutCells);
        this.weightMatrixLists.clear();
        this.weightMatrixLists = Lists.newArrayListWithExpectedSize(this.recurrentStepSize);
        this.prevWeightUpdatesLists.clear();
        this.prevWeightUpdatesLists = Lists.newArrayListWithExpectedSize(this.recurrentStepSize);

        for (int stepIdx = 0; stepIdx < recurrentStepSize - 1; stepIdx++) {
            int expectedMatrixListSize = this.layerSizeList.size() - 1;
            if (stepIdx < this.recurrentStepSize - this.numOutCells) {
                expectedMatrixListSize--;
            }
            List<FloatMatrix> aWeightMatrixList = Lists.newArrayListWithExpectedSize(expectedMatrixListSize);
            List<FloatMatrix> aPrevWeightUpdatesList = Lists.newArrayListWithExpectedSize(expectedMatrixListSize);
            for (int matrixIdx = 0; matrixIdx < expectedMatrixListSize; matrixIdx++) {
                int rows = this.weightMatrixList.get(matrixIdx).getRowCount();
                int cols = this.weightMatrixList.get(matrixIdx).getColumnCount();
                if (stepIdx == 0)
                    cols = this.layerSizeList.get(matrixIdx);
                FloatMatrix weightMatrix = new DenseFloatMatrix(rows, cols);
                weightMatrix.applyToElements(new FloatFunction() {
                    @Override
                    public float apply(float value) {
                        return RandomUtils.nextFloat() - 0.5f;
                    }

                    @Override
                    public float applyDerivative(float value) {
                        throw new UnsupportedOperationException("");
                    }
                });
                aWeightMatrixList.add(weightMatrix);
                aPrevWeightUpdatesList
                        .add(new DenseFloatMatrix(this.prevWeightUpdatesList.get(matrixIdx).getRowCount(),
                                this.prevWeightUpdatesList.get(matrixIdx).getColumnCount()));
            }
            this.weightMatrixLists.add(aWeightMatrixList);
            this.prevWeightUpdatesLists.add(aPrevWeightUpdatesList);
        }
        // add matrix of last step
        this.weightMatrixLists.add(this.weightMatrixList);
        this.prevWeightUpdatesLists.add(this.prevWeightUpdatesList);
        this.weightMatrixList = Lists.newArrayList();
        this.prevWeightUpdatesList = Lists.newArrayList();
    }

    @Override
    /**
     * {@inheritDoc}
     */
    public int addLayer(int size, boolean isFinalLayer, FloatFunction squashingFunction,
            Class<? extends Neuron> neuronClass) {
        return addLayer(size, isFinalLayer, squashingFunction, neuronClass, null, true);
    }

    public int addLayer(int size, boolean isFinalLayer, FloatFunction squashingFunction,
            Class<? extends Neuron> neuronClass, int numOutCells) {
        if (isFinalLayer)
            this.numOutCells = (numOutCells == 0 ? this.recurrentStepSize : numOutCells);
        return addLayer(size, isFinalLayer, squashingFunction, neuronClass, null, false);
    }

    public int addLayer(int size, boolean isFinalLayer, FloatFunction squashingFunction,
            Class<? extends Neuron> neuronClass, Class<? extends IntermediateOutput> interlayer,
            boolean isRecurrent) {
        Preconditions.checkArgument(size > 0, "Size of layer must be larger than 0.");
        if (!isFinalLayer) {
            if (this.layerSizeList.size() == 0) {
                this.recurrentLayerList.add(false);
                LOG.info("add input layer: " + size + " neurons");
            } else {
                this.recurrentLayerList.add(isRecurrent);
                LOG.info("add hidden layer: " + size + " neurons");
            }
            size += 1;
        } else {
            this.recurrentLayerList.add(false);
        }

        this.layerSizeList.add(size);
        int layerIdx = this.layerSizeList.size() - 1;
        if (isFinalLayer) {
            this.finalLayerIdx = layerIdx;
            LOG.info("add output layer: " + size + " neurons");
        }

        // add weights between current layer and previous layer, and input layer has
        // no squashing function
        if (layerIdx > 0) {
            int sizePrevLayer = this.layerSizeList.get(layerIdx - 1);
            // row count equals to size of current size and column count equals to
            // size of previous layer
            int row = isFinalLayer ? size : size - 1;
            // expand matrix for recurrent layer
            int col = !(this.recurrentLayerList.get(layerIdx)) ? sizePrevLayer
                    : sizePrevLayer + this.layerSizeList.get(layerIdx) - 1;

            FloatMatrix weightMatrix = new DenseFloatMatrix(row, col);
            // initialize weights
            weightMatrix.applyToElements(new FloatFunction() {
                @Override
                public float apply(float value) {
                    return RandomUtils.nextFloat() - 0.5f;
                }

                @Override
                public float applyDerivative(float value) {
                    throw new UnsupportedOperationException("");
                }
            });
            this.weightMatrixList.add(weightMatrix);
            this.prevWeightUpdatesList.add(new DenseFloatMatrix(row, col));
            this.squashingFunctionList.add(squashingFunction);

            this.neuronClassList.add(neuronClass);
        }
        return layerIdx;
    }

    /**
     * Update the weight matrices with given matrices.
     * 
     * @param matrices
     */
    public void updateWeightMatrices(FloatMatrix[] matrices) {
        int matrixIdx = 0;
        for (List<FloatMatrix> aWeightMatrixList : this.weightMatrixLists) {
            for (int weightMatrixIdx = 0; weightMatrixIdx < aWeightMatrixList.size(); weightMatrixIdx++) {
                FloatMatrix matrix = aWeightMatrixList.get(weightMatrixIdx);
                aWeightMatrixList.set(weightMatrixIdx, matrix.add(matrices[matrixIdx++]));
            }
        }
    }

    /**
     * Set the previous weight matrices.
     * 
     * @param prevUpdates
     */
    void setPrevWeightMatrices(FloatMatrix[] prevUpdates) {
        int matrixIdx = 0;
        for (List<FloatMatrix> aWeightUpdateMatrixList : this.prevWeightUpdatesLists) {
            for (int weightMatrixIdx = 0; weightMatrixIdx < aWeightUpdateMatrixList.size(); weightMatrixIdx++) {
                aWeightUpdateMatrixList.set(weightMatrixIdx, prevUpdates[matrixIdx++]);
            }
        }
    }

    /**
     * Add a batch of matrices onto the given destination matrices.
     * 
     * @param destMatrices
     * @param sourceMatrices
     */
    static void matricesAdd(FloatMatrix[] destMatrices, FloatMatrix[] sourceMatrices) {
        for (int i = 0; i < destMatrices.length; ++i) {
            destMatrices[i] = destMatrices[i].add(sourceMatrices[i]);
        }
    }

    /**
     * Get all the weight matrices.
     * 
     * @return The matrices in form of matrix array.
     */
    FloatMatrix[] getWeightMatrices() {
        FloatMatrix[] matrices = new FloatMatrix[this.getSizeOfWeightmatrix()];
        int matrixIdx = 0;
        for (List<FloatMatrix> aWeightMatrixList : this.weightMatrixLists) {
            for (FloatMatrix aWeightMatrix : aWeightMatrixList) {
                matrices[matrixIdx++] = aWeightMatrix;
            }
        }
        return matrices;
    }

    /**
     * Set the weight matrices.
     * 
     * @param matrices
     */
    public void setWeightMatrices(FloatMatrix[] matrices) {
        int matrixIdx = 0;
        for (List<FloatMatrix> aWeightMatrixList : this.weightMatrixLists) {
            for (int weightMatrixIdx = 0; weightMatrixIdx < aWeightMatrixList.size(); weightMatrixIdx++) {
                aWeightMatrixList.set(weightMatrixIdx, matrices[matrixIdx++]);
            }
        }
    }

    /**
     * Get the previous matrices updates in form of array.
     * 
     * @return The matrices in form of matrix array.
     */
    public FloatMatrix[] getPrevMatricesUpdates() {
        FloatMatrix[] matrices = new FloatMatrix[this.getSizeOfWeightmatrix()];
        int matrixIdx = 0;
        for (List<FloatMatrix> aWeightMatrixList : this.prevWeightUpdatesLists) {
            for (FloatMatrix aWeightMatrix : aWeightMatrixList) {
                matrices[matrixIdx++] = aWeightMatrix;
            }
        }
        return matrices;
    }

    public void setWeightMatrix(int index, FloatMatrix matrix) {
        Preconditions.checkArgument(0 <= index && index < this.weightMatrixList.size(),
                String.format("index [%d] should be in range[%d, %d].", index, 0, this.weightMatrixList.size()));
        this.weightMatrixList.set(index, matrix);
    }

    @Override
    public void readFields(DataInput input) throws IOException {
        super.readFields(input);

        this.finalLayerIdx = input.readInt();
        this.dropRate = input.readFloat();

        // read neuron classes
        int neuronClasses = input.readInt();
        this.neuronClassList = Lists.newArrayList();
        for (int i = 0; i < neuronClasses; ++i) {
            try {
                Class<? extends Neuron> clazz = (Class<? extends Neuron>) Class.forName(input.readUTF());
                neuronClassList.add(clazz);
            } catch (ClassNotFoundException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
        }

        // read squash functions
        int squashingFunctionSize = input.readInt();
        this.squashingFunctionList = Lists.newArrayList();
        for (int i = 0; i < squashingFunctionSize; ++i) {
            this.squashingFunctionList.add(FunctionFactory.createFloatFunction(WritableUtils.readString(input)));
        }

        this.recurrentStepSize = input.readInt();
        this.numOutCells = input.readInt();
        int recurrentLayerListSize = input.readInt();
        this.recurrentLayerList = Lists.newArrayList();
        for (int i = 0; i < recurrentLayerListSize; i++) {
            this.recurrentLayerList.add(input.readBoolean());
        }

        // read weights and construct matrices of previous updates
        int numOfMatrices = input.readInt();
        this.weightMatrixLists = Lists.newArrayListWithExpectedSize(this.recurrentStepSize);
        this.prevWeightUpdatesLists = Lists.newArrayList();

        for (int step = 0; step < this.recurrentStepSize; step++) {
            this.weightMatrixList = Lists.newArrayList();
            this.prevWeightUpdatesList = Lists.newArrayList();

            for (int j = 0; j < this.layerSizeList.size() - 2; j++) {
                FloatMatrix matrix = FloatMatrixWritable.read(input);
                this.weightMatrixList.add(matrix);
                this.prevWeightUpdatesList.add(new DenseFloatMatrix(matrix.getRowCount(), matrix.getColumnCount()));
            }
            // if the cell has output layer, read from input
            if (step >= this.recurrentStepSize - this.numOutCells) {
                FloatMatrix matrix = FloatMatrixWritable.read(input);
                this.weightMatrixList.add(matrix);
                this.prevWeightUpdatesList.add(new DenseFloatMatrix(matrix.getRowCount(), matrix.getColumnCount()));
            }
            this.weightMatrixLists.add(this.weightMatrixList);
            this.prevWeightUpdatesLists.add(this.prevWeightUpdatesList);
        }
    }

    //  }
    protected int getSizeOfWeightmatrix() {
        return this.recurrentStepSize * (this.layerSizeList.size() - 2) + this.numOutCells;
    }

    @Override
    public void write(DataOutput output) throws IOException {
        super.write(output);
        output.writeInt(finalLayerIdx);
        output.writeFloat(dropRate);

        // write neuron classes
        output.writeInt(this.neuronClassList.size());
        for (Class<? extends Neuron> clazz : this.neuronClassList) {
            output.writeUTF(clazz.getName());
        }

        // write squashing functions
        output.writeInt(this.squashingFunctionList.size());
        for (FloatFunction aSquashingFunctionList : this.squashingFunctionList) {
            WritableUtils.writeString(output, aSquashingFunctionList.getFunctionName());
        }

        // write recurrent step size
        output.writeInt(this.recurrentStepSize);

        // write recurrent step size
        output.writeInt(this.numOutCells);

        // write recurrent layer list
        output.writeInt(this.recurrentLayerList.size());
        for (Boolean isReccurentLayer : recurrentLayerList) {
            output.writeBoolean(isReccurentLayer);
        }

        // write weight matrices
        output.writeInt(this.getSizeOfWeightmatrix());
        for (List<FloatMatrix> aWeightMatrixLists : this.weightMatrixLists) {
            for (FloatMatrix aWeightMatrixList : aWeightMatrixLists) {
                FloatMatrixWritable.write(aWeightMatrixList, output);
            }
        }

        // DO NOT WRITE WEIGHT UPDATE
    }

    @Override
    public FloatMatrix getWeightsByLayer(int layerIdx) {
        return this.weightMatrixList.get(layerIdx);
    }

    public FloatMatrix getWeightsByLayer(int stepIdx, int layerIdx) {
        return this.weightMatrixLists.get(stepIdx).get(layerIdx);
    }

    /**
     * Get the output of the model according to given feature instance.
     */
    @Override
    public FloatVector getOutput(FloatVector instance) {
        Preconditions.checkArgument(
                (this.layerSizeList.get(0) - 1) * this.recurrentStepSize == instance.getDimension(),
                String.format("The dimension of input instance should be %d.", this.layerSizeList.get(0) - 1));
        // transform the features to another space
        FloatVector transformedInstance = this.featureTransformer.transform(instance);
        // add bias feature
        FloatVector instanceWithBias = new DenseFloatVector(transformedInstance.getDimension() + 1);
        instanceWithBias.set(0, 0.99999f); // set bias to be a little bit less than
                                           // 1.0
        for (int i = 1; i < instanceWithBias.getDimension(); ++i) {
            instanceWithBias.set(i, transformedInstance.get(i - 1));
        }
        // return the output of the last layer
        return getOutputInternal(instanceWithBias);
    }

    public void setDropRateOfInputLayer(float dropRate) {
        this.dropRate = dropRate;
    }

    /**
     * Calculate output internally, the intermediate output of each layer will be
     * stored.
     * 
     * @param instanceWithBias The instance contains the features.
     * @return Cached output of each layer.
     */
    public FloatVector getOutputInternal(FloatVector instanceWithBias) {
        // sets the output of input layer
        Neuron[] inputLayer;
        for (int stepIdx = 0; stepIdx < this.weightMatrixLists.size(); stepIdx++) {
            inputLayer = neuronLists.get(stepIdx).get(0);
            for (int inputNeuronIdx = 0; inputNeuronIdx < this.layerSizeList.get(0); inputNeuronIdx++) {
                float m2 = MathUtils.getBinomial(1, dropRate);
                if (m2 == 0)
                    inputLayer[inputNeuronIdx].setDrop(true);
                else
                    inputLayer[inputNeuronIdx].setDrop(false);
                inputLayer[inputNeuronIdx]
                        .setOutput(instanceWithBias.get(stepIdx * this.layerSizeList.get(0) + inputNeuronIdx) * m2);
            }
            // loop forward as much as recurrent step size
            this.weightMatrixList = this.weightMatrixLists.get(stepIdx);
            for (int layerIdx = 0; layerIdx < weightMatrixList.size(); ++layerIdx) {
                forward(stepIdx, layerIdx);
            }
        }

        // output for each recurrent step
        int singleOutputLength = neuronLists.get(this.recurrentStepSize - 1).get(this.finalLayerIdx).length;
        FloatVector output = new DenseFloatVector(singleOutputLength * this.numOutCells);
        int outputNeuronIdx = 0;
        for (int step = this.recurrentStepSize - this.numOutCells; step < this.recurrentStepSize; step++) {
            neurons = neuronLists.get(step);
            for (int neuronIdx = 0; neuronIdx < singleOutputLength; neuronIdx++) {
                output.set(outputNeuronIdx, neurons.get(this.finalLayerIdx)[neuronIdx].getOutput());
                outputNeuronIdx++;
            }
        }

        return output;
    }

    /**
     * @param neuronClass
     * @return a new neuron instance
     */
    @SuppressWarnings({ "rawtypes" })
    public static Neuron newNeuronInstance(Class<? extends Neuron> neuronClass) {
        return (Neuron) ReflectionUtils.newInstance(neuronClass);
    }

    public class InputMessageIterable implements Iterable<Synapse<FloatWritable, FloatWritable>> {
        private int currNeuronID;
        private int prevNeuronID;
        private int end;
        private FloatMatrix weightMat;
        private Neuron[] layer;

        public InputMessageIterable(int fromLayer, int row) {
            this.currNeuronID = row;
            this.prevNeuronID = -1;
            this.end = weightMatrixList.get(fromLayer).getColumnCount() - 1;
            this.weightMat = weightMatrixList.get(fromLayer);
            this.layer = neurons.get(fromLayer);
        }

        @Override
        public Iterator<Synapse<FloatWritable, FloatWritable>> iterator() {
            return new MessageIterator();
        }

        private class MessageIterator implements Iterator<Synapse<FloatWritable, FloatWritable>> {

            @Override
            public boolean hasNext() {
                if (prevNeuronID < end) {
                    return true;
                } else {
                    return false;
                }
            }

            private FloatWritable i = new FloatWritable();
            private FloatWritable w = new FloatWritable();
            private Synapse<FloatWritable, FloatWritable> msg = new Synapse<FloatWritable, FloatWritable>();

            @Override
            public Synapse<FloatWritable, FloatWritable> next() {
                prevNeuronID++;
                i.set(layer[prevNeuronID].getOutput());
                w.set(weightMat.get(currNeuronID, prevNeuronID));
                msg.set(prevNeuronID, i, w);
                return new Synapse<FloatWritable, FloatWritable>(prevNeuronID, i, w);
            }

            @Override
            public void remove() {
            }
        }
    }

    /**
     * Forward the calculation for one layer.
     * 
     * @param fromLayerIdx The index of the previous layer.
     */
    protected void forward(int stepIdx, int fromLayerIdx) {
        neurons = this.neuronLists.get(stepIdx);
        int curLayerIdx = fromLayerIdx + 1;
        // weight matrix for current layer
        FloatMatrix weightMatrix = this.weightMatrixList.get(fromLayerIdx);
        FloatFunction squashingFunction = getSquashingFunction(fromLayerIdx);
        FloatVector vec = new DenseFloatVector(weightMatrix.getRowCount());

        FloatVector inputVector = new DenseFloatVector(neurons.get(fromLayerIdx).length);
        for (int i = 0; i < neurons.get(fromLayerIdx).length; i++) {
            inputVector.set(i, neurons.get(fromLayerIdx)[i].getOutput());
        }

        for (int row = 0; row < weightMatrix.getRowCount(); row++) {
            Neuron n;
            if (curLayerIdx == finalLayerIdx)
                n = neurons.get(curLayerIdx)[row];
            else
                n = neurons.get(curLayerIdx)[row + 1];

            try {
                FloatVector weightVector = weightMatrix.getRowVector(row);
                n.setWeightVector(weightVector);

                ((RecurrentDropoutNeuron) n).setRecurrentDelta(0);
                n.setIterationNumber(iterations);
                n.forward(inputVector);
            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
            vec.set(row, n.getOutput());
        }

        if (squashingFunction.getFunctionName().equalsIgnoreCase(SoftMax.class.getSimpleName())) {
            IntermediateOutput interlayer = (IntermediateOutput) ReflectionUtils
                    .newInstance(SoftMax.SoftMaxOutputComputer.class);
            try {
                vec = interlayer.interlayer(vec);

                for (int i = 0; i < vec.getDimension(); i++) {
                    neurons.get(curLayerIdx)[i].setOutput(vec.get(i));
                }
            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
        }

        // add bias
        if (curLayerIdx != finalLayerIdx)
            neurons.get(curLayerIdx)[0].setOutput(1);

        // copy output to next recurrent layer
        if (this.recurrentLayerList.get(curLayerIdx) && stepIdx < this.recurrentStepSize - 1) {
            for (int i = 0; i < vec.getDimension(); i++) {
                this.neuronLists.get(stepIdx + 1).get(fromLayerIdx)[this.layerSizeList.get(fromLayerIdx) + i]
                        .setOutput(vec.get(i));
            }
        }
    }

    /**
     * Train the model online.
     * 
     * @param trainingInstance
     */
    public void trainOnline(FloatVector trainingInstance) {
        FloatMatrix[] updateMatrices = this.trainByInstance(trainingInstance);
        this.updateWeightMatrices(updateMatrices);
    }

    @Override
    public FloatMatrix[] trainByInstance(FloatVector trainingInstance) {
        int inputDimension = (this.layerSizeList.get(0) - 1) * this.recurrentStepSize;
        FloatVector transformedVector = this.featureTransformer
                .transform(trainingInstance.sliceUnsafe(inputDimension));
        int outputDimension;
        FloatVector inputInstance = null;
        FloatVector labels = null;
        if (this.learningStyle == LearningStyle.SUPERVISED) {
            outputDimension = this.layerSizeList.get(this.layerSizeList.size() - 1);
            // validate training instance
            Preconditions.checkArgument((inputDimension + outputDimension == trainingInstance.getDimension()
                    || inputDimension + outputDimension * recurrentStepSize == trainingInstance.getDimension()),
                    String.format("The dimension of training instance is %d, but requires %d.",
                            trainingInstance.getDimension(), inputDimension + outputDimension));

            inputInstance = new DenseFloatVector(this.layerSizeList.get(0) * this.recurrentStepSize);
            // get the features from the transformed vector
            int vecIdx = 0;
            for (int i = 0; i < inputInstance.getLength(); ++i) {
                if (i % this.layerSizeList.get(0) == 0) {
                    inputInstance.set(i, 1); // add bias
                } else {
                    inputInstance.set(i, transformedVector.get(vecIdx));
                    vecIdx++;
                }
            }
            // get the labels from the original training instance
            labels = trainingInstance.sliceUnsafe(transformedVector.getDimension(),
                    trainingInstance.getDimension() - 1);
        } else if (this.learningStyle == LearningStyle.UNSUPERVISED) {
            // labels are identical to input features
            outputDimension = inputDimension;
            // validate training instance
            Preconditions.checkArgument(inputDimension == trainingInstance.getDimension(),
                    String.format("The dimension of training instance is %d, but requires %d.",
                            trainingInstance.getDimension(), inputDimension));

            inputInstance = new DenseFloatVector(this.layerSizeList.get(0) * this.recurrentStepSize);
            // get the features from the transformed vector
            int vecIdx = 0;
            for (int i = 0; i < inputInstance.getLength(); ++i) {
                if (i % this.layerSizeList.get(0) == 0) {
                    inputInstance.set(i, 1); // add bias
                } else {
                    inputInstance.set(i, transformedVector.get(vecIdx));
                    vecIdx++;
                }
            }
            // get the labels by copying the transformed vector
            labels = transformedVector.deepCopy();
        }
        FloatVector output = this.getOutputInternal(inputInstance);
        // get the training error
        calculateTrainingError(labels, output);
        if (this.trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
            FloatMatrix[] updates = this.trainByInstanceGradientDescent(labels);
            return updates;
        } else {
            throw new IllegalArgumentException(String.format("Training method is not supported."));
        }
    }

    /**
     * Train by gradient descent. Get the updated weights using one training
     * instance.
     * 
     * @param trainingInstance
     * @return The weight update matrices.
     */
    private FloatMatrix[] trainByInstanceGradientDescent(FloatVector labels) {

        // initialize weight update matrices
        DenseFloatMatrix[] weightUpdateMatrices = new DenseFloatMatrix[this.getSizeOfWeightmatrix()];
        int matrixIdx = 0;
        for (List<FloatMatrix> aWeightMatrixList : this.weightMatrixLists) {
            for (FloatMatrix aWeightMatrix : aWeightMatrixList) {
                weightUpdateMatrices[matrixIdx++] = new DenseFloatMatrix(aWeightMatrix.getRowCount(),
                        aWeightMatrix.getColumnCount());
            }
        }
        FloatVector deltaVec = new DenseFloatVector(
                this.layerSizeList.get(layerSizeList.size() - 1) * this.numOutCells);

        FloatFunction squashingFunction = this.squashingFunctionList.get(this.squashingFunctionList.size() - 1);

        int labelIdx = 0;
        // start from last recurrent step to first recurrent step
        for (int step = this.recurrentStepSize - this.numOutCells; step < this.recurrentStepSize; step++) {

            FloatMatrix lastWeightMatrix = this.weightMatrixLists.get(step)
                    .get(this.weightMatrixLists.get(step).size() - 1);
            int neuronIdx = 0;
            for (Neuron aNeurons : this.neuronLists.get(step).get(this.finalLayerIdx)) {
                float finalOut = aNeurons.getOutput();
                float costFuncDerivative = this.costFunction.applyDerivative(labels.get(labelIdx), finalOut);
                // add regularization
                costFuncDerivative += this.regularizationWeight * lastWeightMatrix.getRowVector(neuronIdx).sum();

                if (!squashingFunction.getFunctionName().equalsIgnoreCase(SoftMax.class.getSimpleName())) {
                    costFuncDerivative *= squashingFunction.applyDerivative(finalOut);
                }
                aNeurons.backpropagate(costFuncDerivative);
                deltaVec.set(labelIdx, costFuncDerivative);
                neuronIdx++;
                labelIdx++;
            }
        }

        // start from last recurrent step to first recurrent step
        boolean skipLastLayer = false;
        int weightMatrixIdx = weightUpdateMatrices.length - 1;
        for (int step = this.recurrentStepSize - 1; step >= 0; --step) {
            this.weightMatrixList = this.weightMatrixLists.get(step);
            this.prevWeightUpdatesList = this.prevWeightUpdatesLists.get(step);
            this.neurons = this.neuronLists.get(step);
            if (step < this.recurrentStepSize - this.numOutCells)
                skipLastLayer = true;
            // start from previous layer of output layer
            for (int layer = this.layerSizeList.size() - 2; layer >= 0; --layer) {
                if (skipLastLayer) {
                    skipLastLayer = false;
                    continue;
                }
                backpropagate(step, layer, weightUpdateMatrices[weightMatrixIdx--]);
            }
        }
        // TODO eliminate non-output cells from weightUpdateLists
        this.setPrevWeightMatrices(weightUpdateMatrices);
        return weightUpdateMatrices;
    }

    public class ErrorMessageIterable implements Iterable<Synapse<FloatWritable, FloatWritable>> {
        private int row;
        private int neuronID;
        private int end;
        private FloatMatrix weightMat;
        private FloatMatrix prevWeightMat;

        private float[] nextLayerDelta;

        public ErrorMessageIterable(int recurrentStepIdx, int curLayerIdx, int row) {
            this.row = row;
            this.neuronID = -1;
            this.weightMat = weightMatrixLists.get(recurrentStepIdx).get(curLayerIdx);
            this.end = weightMat.getRowCount() - 1;
            this.prevWeightMat = prevWeightUpdatesLists.get(recurrentStepIdx).get(curLayerIdx);

            Neuron[] nextLayer = neuronLists.get(recurrentStepIdx).get(curLayerIdx + 1);
            nextLayerDelta = new float[weightMat.getRowCount()];

            for (int i = 0; i <= end; ++i) {
                if (curLayerIdx + 1 == finalLayerIdx) {
                    nextLayerDelta[i] = nextLayer[i].getDelta();
                } else {
                    nextLayerDelta[i] = nextLayer[i + 1].getDelta();
                }
            }
        }

        @Override
        public Iterator<Synapse<FloatWritable, FloatWritable>> iterator() {
            return new MessageIterator();
        }

        private class MessageIterator implements Iterator<Synapse<FloatWritable, FloatWritable>> {

            @Override
            public boolean hasNext() {
                if (neuronID < end) {
                    return true;
                } else {
                    return false;
                }
            }

            private FloatWritable d = new FloatWritable();
            private FloatWritable w = new FloatWritable();
            private FloatWritable p = new FloatWritable();
            private Synapse<FloatWritable, FloatWritable> msg = new Synapse<FloatWritable, FloatWritable>();

            @Override
            public Synapse<FloatWritable, FloatWritable> next() {
                neuronID++;

                d.set(nextLayerDelta[neuronID]);
                w.set(weightMat.get(neuronID, row));
                p.set(prevWeightMat.get(neuronID, row));
                msg.set(neuronID, d, w, p);
                return msg;
            }

            @Override
            public void remove() {
            }

        }
    }

    /**
     * Back-propagate the errors to from next layer to current layer. The weight
     * updated information will be stored in the weightUpdateMatrices, and the
     * delta of the prevLayer would be returned.
     * 
     * @param layer Index of current layer.
     */
    private void backpropagate(int recurrentStepIdx, int curLayerIdx, DenseFloatMatrix weightUpdateMatrix) {

        FloatMatrix weightMat = weightMatrixLists.get(recurrentStepIdx).get(curLayerIdx);
        FloatMatrix prevWeightMat = prevWeightUpdatesLists.get(recurrentStepIdx).get(curLayerIdx);

        // get layer related information
        int x = this.weightMatrixList.get(curLayerIdx).getColumnCount();
        int y = this.weightMatrixList.get(curLayerIdx).getRowCount();

        Neuron[] ns = this.neuronLists.get(recurrentStepIdx).get(curLayerIdx);

        for (int row = 0; row < x; ++row) {
            Neuron n = ns[row];
            n.setWeightVector(y);

            try {
                FloatVector weightVector = weightMat.getColumnVector(row);
                n.setWeightVector(weightVector);

                Neuron[] nextLayer = neuronLists.get(recurrentStepIdx).get(curLayerIdx + 1);
                FloatVector deltaVector = new DenseFloatVector(weightVector.getDimension());

                for (int i = 0; i < weightVector.getDimension(); ++i) {
                    if (curLayerIdx + 1 == finalLayerIdx) {
                        deltaVector.set(i, nextLayer[i].getDelta());
                    } else {
                        deltaVector.set(i, nextLayer[i + 1].getDelta());
                    }
                }
                n.setDeltaVector(deltaVector);
                n.setPrevWeightVector(prevWeightMat.getColumnVector(row));

                n.backward(deltaVector);

                if (row >= layerSizeList.get(curLayerIdx) && recurrentStepIdx > 0
                        && recurrentLayerList.get(curLayerIdx + 1)) {
                    Neuron recurrentNeuron = neuronLists.get(recurrentStepIdx - 1).get(curLayerIdx + 1)[row
                            - layerSizeList.get(curLayerIdx) + 1];
                    recurrentNeuron.backpropagate(n.getDelta());
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            // update weights
            weightUpdateMatrix.setColumn(row, n.getUpdates());
        }
    }

    @Override
    protected BSPJob trainInternal(HamaConfiguration conf)
            throws IOException, InterruptedException, ClassNotFoundException {
        this.conf = conf;
        this.fs = FileSystem.get(conf);

        String modelPath = conf.get("model.path");
        if (modelPath != null) {
            this.modelPath = modelPath;
        }
        // modelPath must be set before training
        if (this.modelPath == null) {
            throw new IllegalArgumentException("Please specify the modelPath for model, "
                    + "either through setModelPath() or add 'modelPath' to the training parameters.");
        }
        this.setRecurrentStepSize(conf.getInt("training.recurrent.step.size", 1));
        this.initializeWeightMatrixLists();
        this.writeModelToFile();

        // create job
        BSPJob job = new BSPJob(conf, RecurrentLayeredNeuralNetworkTrainer.class);
        job.setJobName("Neural Network training");
        job.setJarByClass(RecurrentLayeredNeuralNetworkTrainer.class);
        job.setBspClass(RecurrentLayeredNeuralNetworkTrainer.class);

        job.getConfiguration().setInt(Constants.ADDITIONAL_BSP_TASKS, 1);

        job.setBoolean("training.mode", true);
        job.setInputPath(new Path(conf.get("training.input.path")));
        job.setInputFormat(org.apache.hama.bsp.SequenceFileInputFormat.class);
        job.setInputKeyClass(LongWritable.class);
        job.setInputValueClass(VectorWritable.class);
        job.setOutputKeyClass(NullWritable.class);
        job.setOutputValueClass(NullWritable.class);
        job.setOutputFormat(org.apache.hama.bsp.NullOutputFormat.class);

        return job;
    }

    @Override
    protected void calculateTrainingError(FloatVector labels, FloatVector output) {
        FloatVector errors = labels.deepCopy().applyToElements(output, this.costFunction);
        this.trainingError = errors.sum();
    }

    /**
     * Get the squashing function of a specified layer.
     * 
     * @param idx
     * @return a new vector with the result of the operation.
     */
    public FloatFunction getSquashingFunction(int idx) {
        return this.squashingFunctionList.get(idx);
    }

    public void setIterationNumber(long iterations) {
        this.iterations = iterations;
    }

    public void setRecurrentStepSize(int recurrentStepSize) {
        this.recurrentStepSize = recurrentStepSize;
    }

}