org.apache.hama.ml.ann.SmallLayeredNeuralNetwork.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hama.ml.ann.SmallLayeredNeuralNetwork.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hama.ml.ann;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang.math.RandomUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hama.HamaConfiguration;
import org.apache.hama.bsp.BSPJob;
import org.apache.hama.commons.io.MatrixWritable;
import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleMatrix;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleFunction;
import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.commons.math.FunctionFactory;
import org.mortbay.log.Log;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;

/**
 * SmallLayeredNeuralNetwork defines the general operations for derivative
 * layered models, include Linear Regression, Logistic Regression, Multilayer
 * Perceptron, Autoencoder, and Restricted Boltzmann Machine, etc. For
 * SmallLayeredNeuralNetwork, the training can be conducted in parallel, but the
 * parameters of the models are assumes to be stored in a single machine.
 * 
 * In general, these models consist of neurons which are aligned in layers.
 * Between layers, for any two adjacent layers, the neurons are connected to
 * form a bipartite weighted graph.
 * 
 */
public class SmallLayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {

    /* Weights between neurons at adjacent layers */
    protected List<DoubleMatrix> weightMatrixList;

    /* Previous weight updates between neurons at adjacent layers */
    protected List<DoubleMatrix> prevWeightUpdatesList;

    /* Different layers can have different squashing function */
    protected List<DoubleFunction> squashingFunctionList;

    protected int finalLayerIdx;

    public SmallLayeredNeuralNetwork() {
        this.layerSizeList = Lists.newArrayList();
        this.weightMatrixList = Lists.newArrayList();
        this.prevWeightUpdatesList = Lists.newArrayList();
        this.squashingFunctionList = Lists.newArrayList();
    }

    public SmallLayeredNeuralNetwork(String modelPath) {
        super(modelPath);
    }

    @Override
    /**
     * {@inheritDoc}
     */
    public int addLayer(int size, boolean isFinalLayer, DoubleFunction squashingFunction) {
        Preconditions.checkArgument(size > 0, "Size of layer must be larger than 0.");
        if (!isFinalLayer) {
            size += 1;
        }

        this.layerSizeList.add(size);
        int layerIdx = this.layerSizeList.size() - 1;
        if (isFinalLayer) {
            this.finalLayerIdx = layerIdx;
        }

        // add weights between current layer and previous layer, and input layer has
        // no squashing function
        if (layerIdx > 0) {
            int sizePrevLayer = this.layerSizeList.get(layerIdx - 1);
            // row count equals to size of current size and column count equals to
            // size of previous layer
            int row = isFinalLayer ? size : size - 1;
            int col = sizePrevLayer;
            DoubleMatrix weightMatrix = new DenseDoubleMatrix(row, col);
            // initialize weights
            weightMatrix.applyToElements(new DoubleFunction() {
                @Override
                public double apply(double value) {
                    return RandomUtils.nextDouble() - 0.5;
                }

                @Override
                public double applyDerivative(double value) {
                    throw new UnsupportedOperationException("");
                }
            });
            this.weightMatrixList.add(weightMatrix);
            this.prevWeightUpdatesList.add(new DenseDoubleMatrix(row, col));
            this.squashingFunctionList.add(squashingFunction);
        }
        return layerIdx;
    }

    /**
     * Update the weight matrices with given matrices.
     * 
     * @param matrices
     */
    public void updateWeightMatrices(DoubleMatrix[] matrices) {
        for (int i = 0; i < matrices.length; ++i) {
            DoubleMatrix matrix = this.weightMatrixList.get(i);
            this.weightMatrixList.set(i, matrix.add(matrices[i]));
        }
    }

    /**
     * Set the previous weight matrices.
     * @param prevUpdates
     */
    void setPrevWeightMatrices(DoubleMatrix[] prevUpdates) {
        this.prevWeightUpdatesList.clear();
        Collections.addAll(this.prevWeightUpdatesList, prevUpdates);
    }

    /**
     * Add a batch of matrices onto the given destination matrices.
     * 
     * @param destMatrices
     * @param sourceMatrices
     */
    static void matricesAdd(DoubleMatrix[] destMatrices, DoubleMatrix[] sourceMatrices) {
        for (int i = 0; i < destMatrices.length; ++i) {
            destMatrices[i] = destMatrices[i].add(sourceMatrices[i]);
        }
    }

    /**
     * Get all the weight matrices.
     * 
     * @return The matrices in form of matrix array.
     */
    DoubleMatrix[] getWeightMatrices() {
        DoubleMatrix[] matrices = new DoubleMatrix[this.weightMatrixList.size()];
        this.weightMatrixList.toArray(matrices);
        return matrices;
    }

    /**
     * Set the weight matrices.
     * 
     * @param matrices
     */
    public void setWeightMatrices(DoubleMatrix[] matrices) {
        this.weightMatrixList = new ArrayList<DoubleMatrix>();
        Collections.addAll(this.weightMatrixList, matrices);
    }

    /**
     * Get the previous matrices updates in form of array.
     * 
     * @return The matrices in form of matrix array.
     */
    public DoubleMatrix[] getPrevMatricesUpdates() {
        DoubleMatrix[] prevMatricesUpdates = new DoubleMatrix[this.prevWeightUpdatesList.size()];
        for (int i = 0; i < this.prevWeightUpdatesList.size(); ++i) {
            prevMatricesUpdates[i] = this.prevWeightUpdatesList.get(i);
        }
        return prevMatricesUpdates;
    }

    public void setWeightMatrix(int index, DoubleMatrix matrix) {
        Preconditions.checkArgument(0 <= index && index < this.weightMatrixList.size(),
                String.format("index [%d] should be in range[%d, %d].", index, 0, this.weightMatrixList.size()));
        this.weightMatrixList.set(index, matrix);
    }

    @Override
    public void readFields(DataInput input) throws IOException {
        super.readFields(input);

        // read squash functions
        int squashingFunctionSize = input.readInt();
        this.squashingFunctionList = Lists.newArrayList();
        for (int i = 0; i < squashingFunctionSize; ++i) {
            this.squashingFunctionList.add(FunctionFactory.createDoubleFunction(WritableUtils.readString(input)));
        }

        // read weights and construct matrices of previous updates
        int numOfMatrices = input.readInt();
        this.weightMatrixList = Lists.newArrayList();
        this.prevWeightUpdatesList = Lists.newArrayList();
        for (int i = 0; i < numOfMatrices; ++i) {
            DoubleMatrix matrix = MatrixWritable.read(input);
            this.weightMatrixList.add(matrix);
            this.prevWeightUpdatesList.add(new DenseDoubleMatrix(matrix.getRowCount(), matrix.getColumnCount()));
        }

    }

    @Override
    public void write(DataOutput output) throws IOException {
        super.write(output);

        // write squashing functions
        output.writeInt(this.squashingFunctionList.size());
        for (DoubleFunction aSquashingFunctionList : this.squashingFunctionList) {
            WritableUtils.writeString(output, aSquashingFunctionList.getFunctionName());
        }

        // write weight matrices
        output.writeInt(this.weightMatrixList.size());
        for (DoubleMatrix aWeightMatrixList : this.weightMatrixList) {
            MatrixWritable.write(aWeightMatrixList, output);
        }

        // DO NOT WRITE WEIGHT UPDATE
    }

    @Override
    public DoubleMatrix getWeightsByLayer(int layerIdx) {
        return this.weightMatrixList.get(layerIdx);
    }

    /**
     * Get the output of the model according to given feature instance.
     */
    @Override
    public DoubleVector getOutput(DoubleVector instance) {
        Preconditions.checkArgument(this.layerSizeList.get(0) - 1 == instance.getDimension(),
                String.format("The dimension of input instance should be %d.", this.layerSizeList.get(0) - 1));
        // transform the features to another space
        DoubleVector transformedInstance = this.featureTransformer.transform(instance);
        // add bias feature
        DoubleVector instanceWithBias = new DenseDoubleVector(transformedInstance.getDimension() + 1);
        instanceWithBias.set(0, 0.99999); // set bias to be a little bit less than
                                          // 1.0
        for (int i = 1; i < instanceWithBias.getDimension(); ++i) {
            instanceWithBias.set(i, transformedInstance.get(i - 1));
        }

        List<DoubleVector> outputCache = getOutputInternal(instanceWithBias);
        // return the output of the last layer
        DoubleVector result = outputCache.get(outputCache.size() - 1);
        // remove bias
        return result.sliceUnsafe(1, result.getDimension() - 1);
    }

    /**
     * Calculate output internally, the intermediate output of each layer will be
     * stored.
     * 
     * @param instanceWithBias The instance contains the features.
     * @return Cached output of each layer.
     */
    public List<DoubleVector> getOutputInternal(DoubleVector instanceWithBias) {
        List<DoubleVector> outputCache = new ArrayList<DoubleVector>();
        // fill with instance
        DoubleVector intermediateOutput = instanceWithBias;
        outputCache.add(intermediateOutput);

        for (int i = 0; i < this.layerSizeList.size() - 1; ++i) {
            intermediateOutput = forward(i, intermediateOutput);
            outputCache.add(intermediateOutput);
        }
        return outputCache;
    }

    /**
     * Forward the calculation for one layer.
     * 
     * @param fromLayer The index of the previous layer.
     * @param intermediateOutput The intermediateOutput of previous layer.
     * @return a new vector with the result of the operation.
     */
    protected DoubleVector forward(int fromLayer, DoubleVector intermediateOutput) {
        DoubleMatrix weightMatrix = this.weightMatrixList.get(fromLayer);

        DoubleVector vec = weightMatrix.multiplyVectorUnsafe(intermediateOutput);
        vec = vec.applyToElements(this.squashingFunctionList.get(fromLayer));

        // add bias
        DoubleVector vecWithBias = new DenseDoubleVector(vec.getDimension() + 1);
        vecWithBias.set(0, 1);
        for (int i = 0; i < vec.getDimension(); ++i) {
            vecWithBias.set(i + 1, vec.get(i));
        }
        return vecWithBias;
    }

    /**
     * Train the model online.
     * 
     * @param trainingInstance
     */
    public void trainOnline(DoubleVector trainingInstance) {
        DoubleMatrix[] updateMatrices = this.trainByInstance(trainingInstance);
        this.updateWeightMatrices(updateMatrices);
    }

    @Override
    public DoubleMatrix[] trainByInstance(DoubleVector trainingInstance) {
        DoubleVector transformedVector = this.featureTransformer
                .transform(trainingInstance.sliceUnsafe(this.layerSizeList.get(0) - 1));

        int inputDimension = this.layerSizeList.get(0) - 1;
        int outputDimension;
        DoubleVector inputInstance = null;
        DoubleVector labels = null;
        if (this.learningStyle == LearningStyle.SUPERVISED) {
            outputDimension = this.layerSizeList.get(this.layerSizeList.size() - 1);
            // validate training instance
            Preconditions.checkArgument(inputDimension + outputDimension == trainingInstance.getDimension(),
                    String.format("The dimension of training instance is %d, but requires %d.",
                            trainingInstance.getDimension(), inputDimension + outputDimension));

            inputInstance = new DenseDoubleVector(this.layerSizeList.get(0));
            inputInstance.set(0, 1); // add bias
            // get the features from the transformed vector
            for (int i = 0; i < inputDimension; ++i) {
                inputInstance.set(i + 1, transformedVector.get(i));
            }
            // get the labels from the original training instance
            labels = trainingInstance.sliceUnsafe(inputInstance.getDimension() - 1,
                    trainingInstance.getDimension() - 1);
        } else if (this.learningStyle == LearningStyle.UNSUPERVISED) {
            // labels are identical to input features
            outputDimension = inputDimension;
            // validate training instance
            Preconditions.checkArgument(inputDimension == trainingInstance.getDimension(),
                    String.format("The dimension of training instance is %d, but requires %d.",
                            trainingInstance.getDimension(), inputDimension));

            inputInstance = new DenseDoubleVector(this.layerSizeList.get(0));
            inputInstance.set(0, 1); // add bias
            // get the features from the transformed vector
            for (int i = 0; i < inputDimension; ++i) {
                inputInstance.set(i + 1, transformedVector.get(i));
            }
            // get the labels by copying the transformed vector
            labels = transformedVector.deepCopy();
        }

        List<DoubleVector> internalResults = this.getOutputInternal(inputInstance);
        DoubleVector output = internalResults.get(internalResults.size() - 1);

        // get the training error
        calculateTrainingError(labels, output.deepCopy().sliceUnsafe(1, output.getDimension() - 1));

        if (this.trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
            return this.trainByInstanceGradientDescent(labels, internalResults);
        } else {
            throw new IllegalArgumentException(String.format("Training method is not supported."));
        }
    }

    /**
     * Train by gradient descent. Get the updated weights using one training
     * instance.
     * 
     * @param trainingInstance
     * @return The weight update matrices.
     */
    private DoubleMatrix[] trainByInstanceGradientDescent(DoubleVector labels, List<DoubleVector> internalResults) {

        DoubleVector output = internalResults.get(internalResults.size() - 1);
        // initialize weight update matrices
        DenseDoubleMatrix[] weightUpdateMatrices = new DenseDoubleMatrix[this.weightMatrixList.size()];
        for (int m = 0; m < weightUpdateMatrices.length; ++m) {
            weightUpdateMatrices[m] = new DenseDoubleMatrix(this.weightMatrixList.get(m).getRowCount(),
                    this.weightMatrixList.get(m).getColumnCount());
        }
        DoubleVector deltaVec = new DenseDoubleVector(this.layerSizeList.get(this.layerSizeList.size() - 1));

        DoubleFunction squashingFunction = this.squashingFunctionList.get(this.squashingFunctionList.size() - 1);

        DoubleMatrix lastWeightMatrix = this.weightMatrixList.get(this.weightMatrixList.size() - 1);
        for (int i = 0; i < deltaVec.getDimension(); ++i) {
            double costFuncDerivative = this.costFunction.applyDerivative(labels.get(i), output.get(i + 1));
            // add regularization
            costFuncDerivative += this.regularizationWeight * lastWeightMatrix.getRowVector(i).sum();
            deltaVec.set(i, costFuncDerivative * squashingFunction.applyDerivative(output.get(i + 1)));
        }

        // start from previous layer of output layer
        for (int layer = this.layerSizeList.size() - 2; layer >= 0; --layer) {
            output = internalResults.get(layer);
            deltaVec = backpropagate(layer, deltaVec, internalResults, weightUpdateMatrices[layer]);
        }

        this.setPrevWeightMatrices(weightUpdateMatrices);

        return weightUpdateMatrices;
    }

    /**
     * Back-propagate the errors to from next layer to current layer. The weight
     * updated information will be stored in the weightUpdateMatrices, and the
     * delta of the prevLayer would be returned.
     * 
     * @param layer Index of current layer.
     * @param internalOutput Internal output of current layer.
     * @param deltaVec Delta of next layer.
     * @return the squashing function of the specified position.
     */
    private DoubleVector backpropagate(int curLayerIdx, DoubleVector nextLayerDelta, List<DoubleVector> outputCache,
            DenseDoubleMatrix weightUpdateMatrix) {

        // get layer related information
        DoubleFunction squashingFunction = this.squashingFunctionList.get(curLayerIdx);
        DoubleVector curLayerOutput = outputCache.get(curLayerIdx);
        DoubleMatrix weightMatrix = this.weightMatrixList.get(curLayerIdx);
        DoubleMatrix prevWeightMatrix = this.prevWeightUpdatesList.get(curLayerIdx);

        // next layer is not output layer, remove the delta of bias neuron
        if (curLayerIdx != this.layerSizeList.size() - 2) {
            nextLayerDelta = nextLayerDelta.slice(1, nextLayerDelta.getDimension() - 1);
        }

        DoubleVector delta = weightMatrix.transpose().multiplyVector(nextLayerDelta);
        for (int i = 0; i < delta.getDimension(); ++i) {
            delta.set(i, delta.get(i) * squashingFunction.applyDerivative(curLayerOutput.get(i)));
        }

        // update weights
        for (int i = 0; i < weightUpdateMatrix.getRowCount(); ++i) {
            for (int j = 0; j < weightUpdateMatrix.getColumnCount(); ++j) {
                weightUpdateMatrix.set(i, j, -learningRate * nextLayerDelta.get(i) * curLayerOutput.get(j)
                        + this.momentumWeight * prevWeightMatrix.get(i, j));
            }
        }

        return delta;
    }

    @Override
    protected void trainInternal(Path dataInputPath, Map<String, String> trainingParams)
            throws IOException, InterruptedException, ClassNotFoundException {
        // add all training parameters to configuration
        Configuration conf = new Configuration();
        for (Map.Entry<String, String> entry : trainingParams.entrySet()) {
            conf.set(entry.getKey(), entry.getValue());
        }

        // if training parameters contains the model path, update the model path
        String modelPath = trainingParams.get("modelPath");
        if (modelPath != null) {
            this.modelPath = modelPath;
        }
        // modelPath must be set before training
        if (this.modelPath == null) {
            throw new IllegalArgumentException("Please specify the modelPath for model, "
                    + "either through setModelPath() or add 'modelPath' to the training parameters.");
        }

        conf.set("modelPath", this.modelPath);
        this.writeModelToFile();

        HamaConfiguration hamaConf = new HamaConfiguration(conf);

        // create job
        BSPJob job = new BSPJob(hamaConf, SmallLayeredNeuralNetworkTrainer.class);
        job.setJobName("Small scale Neural Network training");
        job.setJarByClass(SmallLayeredNeuralNetworkTrainer.class);
        job.setBspClass(SmallLayeredNeuralNetworkTrainer.class);
        job.setInputPath(dataInputPath);
        job.setInputFormat(org.apache.hama.bsp.SequenceFileInputFormat.class);
        job.setInputKeyClass(LongWritable.class);
        job.setInputValueClass(VectorWritable.class);
        job.setOutputKeyClass(NullWritable.class);
        job.setOutputValueClass(NullWritable.class);
        job.setOutputFormat(org.apache.hama.bsp.NullOutputFormat.class);

        int numTasks = conf.getInt("tasks", 1);
        Log.info(String.format("Number of tasks: %d\n", numTasks));
        job.setNumBspTask(numTasks);
        job.waitForCompletion(true);

        // reload learned model
        Log.info(String.format("Reload model from %s.", this.modelPath));
        this.readFromModel();

    }

    @Override
    protected void calculateTrainingError(DoubleVector labels, DoubleVector output) {
        DoubleVector errors = labels.deepCopy().applyToElements(output, this.costFunction);
        this.trainingError = errors.sum();
    }

    /**
     * Get the squashing function of a specified layer.
     * 
     * @param idx
     * @return a new vector with the result of the operation.
     */
    public DoubleFunction getSquashingFunction(int idx) {
        return this.squashingFunctionList.get(idx);
    }

}