Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.horn.core; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import org.apache.commons.lang.math.RandomUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.WritableUtils; import org.apache.hama.Constants; import org.apache.hama.HamaConfiguration; import org.apache.hama.bsp.BSPJob; import org.apache.hama.commons.io.FloatMatrixWritable; import org.apache.hama.commons.io.VectorWritable; import org.apache.hama.commons.math.DenseFloatMatrix; import org.apache.hama.commons.math.DenseFloatVector; import org.apache.hama.commons.math.FloatFunction; import org.apache.hama.commons.math.FloatMatrix; import org.apache.hama.commons.math.FloatVector; import org.apache.hama.util.ReflectionUtils; import org.apache.horn.core.Constants.LearningStyle; import org.apache.horn.core.Constants.TrainingMethod; import org.apache.horn.examples.MultiLayerPerceptron.StandardNeuron; import org.apache.horn.examples.RecurrentDropoutNeuron; import org.apache.horn.funcs.FunctionFactory; import org.apache.horn.funcs.IdentityFunction; import org.apache.horn.funcs.SoftMax; import org.apache.horn.utils.MathUtils; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; /** * SmallLayeredNeuralNetwork defines the general operations for derivative * layered models, include Linear Regression, Logistic Regression, Multilayer * Perceptron, Autoencoder, and Restricted Boltzmann Machine, etc. For * SmallLayeredNeuralNetwork, the training can be conducted in parallel, but the * parameters of the models are assumes to be stored in a single machine. * * In general, these models consist of neurons which are aligned in layers. * Between layers, for any two adjacent layers, the neurons are connected to * form a bipartite weighted graph. * */ public class RecurrentLayeredNeuralNetwork extends AbstractLayeredNeuralNetwork { private static final Log LOG = LogFactory.getLog(RecurrentLayeredNeuralNetwork.class); /* Weights between neurons at adjacent layers */ protected List<FloatMatrix> weightMatrixList; /* Weights between neurons at adjacent layers */ protected List<List<FloatMatrix>> weightMatrixLists; /* Previous weight updates between neurons at adjacent layers */ protected List<FloatMatrix> prevWeightUpdatesList; protected List<List<FloatMatrix>> prevWeightUpdatesLists; /* Different layers can have different squashing function */ protected List<FloatFunction> squashingFunctionList; protected List<Class<? extends Neuron>> neuronClassList; /* Record the recurrent layer */ protected List<Boolean> recurrentLayerList; /* Recurrent step size */ protected int recurrentStepSize; protected int finalLayerIdx; private List<Neuron[]> neurons; private List<List<Neuron[]>> neuronLists; private float dropRate; private long iterations; private int numOutCells; public RecurrentLayeredNeuralNetwork() { this.layerSizeList = Lists.newArrayList(); this.weightMatrixList = Lists.newArrayList(); this.prevWeightUpdatesList = Lists.newArrayList(); this.squashingFunctionList = Lists.newArrayList(); this.neuronClassList = Lists.newArrayList(); this.weightMatrixLists = Lists.newArrayList(); this.prevWeightUpdatesLists = Lists.newArrayList(); this.neuronLists = Lists.newArrayList(); this.recurrentLayerList = Lists.newArrayList(); } public RecurrentLayeredNeuralNetwork(HamaConfiguration conf, String modelPath) { super(conf, modelPath); initializeNeurons(false); initializeWeightMatrixLists(); } public RecurrentLayeredNeuralNetwork(HamaConfiguration conf, String modelPath, boolean isTraining) { super(conf, modelPath); initializeNeurons(isTraining); initializeWeightMatrixLists(); } /** * initialize neuron objects * @param isTraining */ private void initializeNeurons(boolean isTraining) { this.neuronLists = Lists.newArrayListWithExpectedSize(recurrentStepSize); for (int stepIdx = 0; stepIdx < this.recurrentStepSize; stepIdx++) { neurons = new ArrayList<Neuron[]>(); int expectedNeuronsSize = this.layerSizeList.size(); if (stepIdx < this.recurrentStepSize - this.numOutCells) { expectedNeuronsSize--; } for (int neuronLayerIdx = 0; neuronLayerIdx < expectedNeuronsSize; neuronLayerIdx++) { int numOfNeurons = layerSizeList.get(neuronLayerIdx); // if not final layer and next layer is recurrent if (stepIdx > 0 && neuronLayerIdx < layerSizeList.size() - 1 && this.recurrentLayerList.get(neuronLayerIdx + 1)) { numOfNeurons = numOfNeurons + layerSizeList.get(neuronLayerIdx + 1) - 1; } Class<? extends Neuron> neuronClass; if (neuronLayerIdx == 0) neuronClass = StandardNeuron.class; // actually doesn't needed else neuronClass = neuronClassList.get(neuronLayerIdx - 1); Neuron[] tmp = new Neuron[numOfNeurons]; for (int neuronIdx = 0; neuronIdx < numOfNeurons; neuronIdx++) { Neuron n = newNeuronInstance(neuronClass); if (n instanceof RecurrentDropoutNeuron) ((RecurrentDropoutNeuron) n).setDropRate(dropRate); if (neuronLayerIdx > 0 && neuronIdx < layerSizeList.get(neuronLayerIdx)) n.setSquashingFunction(squashingFunctionList.get(neuronLayerIdx - 1)); else n.setSquashingFunction(new IdentityFunction()); n.setLayerIndex(neuronLayerIdx); n.setNeuronID(neuronIdx); n.setLearningRate(this.learningRate); n.setMomentumWeight(this.momentumWeight); n.setTraining(isTraining); tmp[neuronIdx] = n; } neurons.add(tmp); } this.neuronLists.add(neurons); } } /** * Initialize WeightMatrixLists */ public void initializeWeightMatrixLists() { this.numOutCells = (numOutCells == 0 ? this.recurrentStepSize : numOutCells); this.weightMatrixLists.clear(); this.weightMatrixLists = Lists.newArrayListWithExpectedSize(this.recurrentStepSize); this.prevWeightUpdatesLists.clear(); this.prevWeightUpdatesLists = Lists.newArrayListWithExpectedSize(this.recurrentStepSize); for (int stepIdx = 0; stepIdx < recurrentStepSize - 1; stepIdx++) { int expectedMatrixListSize = this.layerSizeList.size() - 1; if (stepIdx < this.recurrentStepSize - this.numOutCells) { expectedMatrixListSize--; } List<FloatMatrix> aWeightMatrixList = Lists.newArrayListWithExpectedSize(expectedMatrixListSize); List<FloatMatrix> aPrevWeightUpdatesList = Lists.newArrayListWithExpectedSize(expectedMatrixListSize); for (int matrixIdx = 0; matrixIdx < expectedMatrixListSize; matrixIdx++) { int rows = this.weightMatrixList.get(matrixIdx).getRowCount(); int cols = this.weightMatrixList.get(matrixIdx).getColumnCount(); if (stepIdx == 0) cols = this.layerSizeList.get(matrixIdx); FloatMatrix weightMatrix = new DenseFloatMatrix(rows, cols); weightMatrix.applyToElements(new FloatFunction() { @Override public float apply(float value) { return RandomUtils.nextFloat() - 0.5f; } @Override public float applyDerivative(float value) { throw new UnsupportedOperationException(""); } }); aWeightMatrixList.add(weightMatrix); aPrevWeightUpdatesList .add(new DenseFloatMatrix(this.prevWeightUpdatesList.get(matrixIdx).getRowCount(), this.prevWeightUpdatesList.get(matrixIdx).getColumnCount())); } this.weightMatrixLists.add(aWeightMatrixList); this.prevWeightUpdatesLists.add(aPrevWeightUpdatesList); } // add matrix of last step this.weightMatrixLists.add(this.weightMatrixList); this.prevWeightUpdatesLists.add(this.prevWeightUpdatesList); this.weightMatrixList = Lists.newArrayList(); this.prevWeightUpdatesList = Lists.newArrayList(); } @Override /** * {@inheritDoc} */ public int addLayer(int size, boolean isFinalLayer, FloatFunction squashingFunction, Class<? extends Neuron> neuronClass) { return addLayer(size, isFinalLayer, squashingFunction, neuronClass, null, true); } public int addLayer(int size, boolean isFinalLayer, FloatFunction squashingFunction, Class<? extends Neuron> neuronClass, int numOutCells) { if (isFinalLayer) this.numOutCells = (numOutCells == 0 ? this.recurrentStepSize : numOutCells); return addLayer(size, isFinalLayer, squashingFunction, neuronClass, null, false); } public int addLayer(int size, boolean isFinalLayer, FloatFunction squashingFunction, Class<? extends Neuron> neuronClass, Class<? extends IntermediateOutput> interlayer, boolean isRecurrent) { Preconditions.checkArgument(size > 0, "Size of layer must be larger than 0."); if (!isFinalLayer) { if (this.layerSizeList.size() == 0) { this.recurrentLayerList.add(false); LOG.info("add input layer: " + size + " neurons"); } else { this.recurrentLayerList.add(isRecurrent); LOG.info("add hidden layer: " + size + " neurons"); } size += 1; } else { this.recurrentLayerList.add(false); } this.layerSizeList.add(size); int layerIdx = this.layerSizeList.size() - 1; if (isFinalLayer) { this.finalLayerIdx = layerIdx; LOG.info("add output layer: " + size + " neurons"); } // add weights between current layer and previous layer, and input layer has // no squashing function if (layerIdx > 0) { int sizePrevLayer = this.layerSizeList.get(layerIdx - 1); // row count equals to size of current size and column count equals to // size of previous layer int row = isFinalLayer ? size : size - 1; // expand matrix for recurrent layer int col = !(this.recurrentLayerList.get(layerIdx)) ? sizePrevLayer : sizePrevLayer + this.layerSizeList.get(layerIdx) - 1; FloatMatrix weightMatrix = new DenseFloatMatrix(row, col); // initialize weights weightMatrix.applyToElements(new FloatFunction() { @Override public float apply(float value) { return RandomUtils.nextFloat() - 0.5f; } @Override public float applyDerivative(float value) { throw new UnsupportedOperationException(""); } }); this.weightMatrixList.add(weightMatrix); this.prevWeightUpdatesList.add(new DenseFloatMatrix(row, col)); this.squashingFunctionList.add(squashingFunction); this.neuronClassList.add(neuronClass); } return layerIdx; } /** * Update the weight matrices with given matrices. * * @param matrices */ public void updateWeightMatrices(FloatMatrix[] matrices) { int matrixIdx = 0; for (List<FloatMatrix> aWeightMatrixList : this.weightMatrixLists) { for (int weightMatrixIdx = 0; weightMatrixIdx < aWeightMatrixList.size(); weightMatrixIdx++) { FloatMatrix matrix = aWeightMatrixList.get(weightMatrixIdx); aWeightMatrixList.set(weightMatrixIdx, matrix.add(matrices[matrixIdx++])); } } } /** * Set the previous weight matrices. * * @param prevUpdates */ void setPrevWeightMatrices(FloatMatrix[] prevUpdates) { int matrixIdx = 0; for (List<FloatMatrix> aWeightUpdateMatrixList : this.prevWeightUpdatesLists) { for (int weightMatrixIdx = 0; weightMatrixIdx < aWeightUpdateMatrixList.size(); weightMatrixIdx++) { aWeightUpdateMatrixList.set(weightMatrixIdx, prevUpdates[matrixIdx++]); } } } /** * Add a batch of matrices onto the given destination matrices. * * @param destMatrices * @param sourceMatrices */ static void matricesAdd(FloatMatrix[] destMatrices, FloatMatrix[] sourceMatrices) { for (int i = 0; i < destMatrices.length; ++i) { destMatrices[i] = destMatrices[i].add(sourceMatrices[i]); } } /** * Get all the weight matrices. * * @return The matrices in form of matrix array. */ FloatMatrix[] getWeightMatrices() { FloatMatrix[] matrices = new FloatMatrix[this.getSizeOfWeightmatrix()]; int matrixIdx = 0; for (List<FloatMatrix> aWeightMatrixList : this.weightMatrixLists) { for (FloatMatrix aWeightMatrix : aWeightMatrixList) { matrices[matrixIdx++] = aWeightMatrix; } } return matrices; } /** * Set the weight matrices. * * @param matrices */ public void setWeightMatrices(FloatMatrix[] matrices) { int matrixIdx = 0; for (List<FloatMatrix> aWeightMatrixList : this.weightMatrixLists) { for (int weightMatrixIdx = 0; weightMatrixIdx < aWeightMatrixList.size(); weightMatrixIdx++) { aWeightMatrixList.set(weightMatrixIdx, matrices[matrixIdx++]); } } } /** * Get the previous matrices updates in form of array. * * @return The matrices in form of matrix array. */ public FloatMatrix[] getPrevMatricesUpdates() { FloatMatrix[] matrices = new FloatMatrix[this.getSizeOfWeightmatrix()]; int matrixIdx = 0; for (List<FloatMatrix> aWeightMatrixList : this.prevWeightUpdatesLists) { for (FloatMatrix aWeightMatrix : aWeightMatrixList) { matrices[matrixIdx++] = aWeightMatrix; } } return matrices; } public void setWeightMatrix(int index, FloatMatrix matrix) { Preconditions.checkArgument(0 <= index && index < this.weightMatrixList.size(), String.format("index [%d] should be in range[%d, %d].", index, 0, this.weightMatrixList.size())); this.weightMatrixList.set(index, matrix); } @Override public void readFields(DataInput input) throws IOException { super.readFields(input); this.finalLayerIdx = input.readInt(); this.dropRate = input.readFloat(); // read neuron classes int neuronClasses = input.readInt(); this.neuronClassList = Lists.newArrayList(); for (int i = 0; i < neuronClasses; ++i) { try { Class<? extends Neuron> clazz = (Class<? extends Neuron>) Class.forName(input.readUTF()); neuronClassList.add(clazz); } catch (ClassNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } } // read squash functions int squashingFunctionSize = input.readInt(); this.squashingFunctionList = Lists.newArrayList(); for (int i = 0; i < squashingFunctionSize; ++i) { this.squashingFunctionList.add(FunctionFactory.createFloatFunction(WritableUtils.readString(input))); } this.recurrentStepSize = input.readInt(); this.numOutCells = input.readInt(); int recurrentLayerListSize = input.readInt(); this.recurrentLayerList = Lists.newArrayList(); for (int i = 0; i < recurrentLayerListSize; i++) { this.recurrentLayerList.add(input.readBoolean()); } // read weights and construct matrices of previous updates int numOfMatrices = input.readInt(); this.weightMatrixLists = Lists.newArrayListWithExpectedSize(this.recurrentStepSize); this.prevWeightUpdatesLists = Lists.newArrayList(); for (int step = 0; step < this.recurrentStepSize; step++) { this.weightMatrixList = Lists.newArrayList(); this.prevWeightUpdatesList = Lists.newArrayList(); for (int j = 0; j < this.layerSizeList.size() - 2; j++) { FloatMatrix matrix = FloatMatrixWritable.read(input); this.weightMatrixList.add(matrix); this.prevWeightUpdatesList.add(new DenseFloatMatrix(matrix.getRowCount(), matrix.getColumnCount())); } // if the cell has output layer, read from input if (step >= this.recurrentStepSize - this.numOutCells) { FloatMatrix matrix = FloatMatrixWritable.read(input); this.weightMatrixList.add(matrix); this.prevWeightUpdatesList.add(new DenseFloatMatrix(matrix.getRowCount(), matrix.getColumnCount())); } this.weightMatrixLists.add(this.weightMatrixList); this.prevWeightUpdatesLists.add(this.prevWeightUpdatesList); } } // } protected int getSizeOfWeightmatrix() { return this.recurrentStepSize * (this.layerSizeList.size() - 2) + this.numOutCells; } @Override public void write(DataOutput output) throws IOException { super.write(output); output.writeInt(finalLayerIdx); output.writeFloat(dropRate); // write neuron classes output.writeInt(this.neuronClassList.size()); for (Class<? extends Neuron> clazz : this.neuronClassList) { output.writeUTF(clazz.getName()); } // write squashing functions output.writeInt(this.squashingFunctionList.size()); for (FloatFunction aSquashingFunctionList : this.squashingFunctionList) { WritableUtils.writeString(output, aSquashingFunctionList.getFunctionName()); } // write recurrent step size output.writeInt(this.recurrentStepSize); // write recurrent step size output.writeInt(this.numOutCells); // write recurrent layer list output.writeInt(this.recurrentLayerList.size()); for (Boolean isReccurentLayer : recurrentLayerList) { output.writeBoolean(isReccurentLayer); } // write weight matrices output.writeInt(this.getSizeOfWeightmatrix()); for (List<FloatMatrix> aWeightMatrixLists : this.weightMatrixLists) { for (FloatMatrix aWeightMatrixList : aWeightMatrixLists) { FloatMatrixWritable.write(aWeightMatrixList, output); } } // DO NOT WRITE WEIGHT UPDATE } @Override public FloatMatrix getWeightsByLayer(int layerIdx) { return this.weightMatrixList.get(layerIdx); } public FloatMatrix getWeightsByLayer(int stepIdx, int layerIdx) { return this.weightMatrixLists.get(stepIdx).get(layerIdx); } /** * Get the output of the model according to given feature instance. */ @Override public FloatVector getOutput(FloatVector instance) { Preconditions.checkArgument( (this.layerSizeList.get(0) - 1) * this.recurrentStepSize == instance.getDimension(), String.format("The dimension of input instance should be %d.", this.layerSizeList.get(0) - 1)); // transform the features to another space FloatVector transformedInstance = this.featureTransformer.transform(instance); // add bias feature FloatVector instanceWithBias = new DenseFloatVector(transformedInstance.getDimension() + 1); instanceWithBias.set(0, 0.99999f); // set bias to be a little bit less than // 1.0 for (int i = 1; i < instanceWithBias.getDimension(); ++i) { instanceWithBias.set(i, transformedInstance.get(i - 1)); } // return the output of the last layer return getOutputInternal(instanceWithBias); } public void setDropRateOfInputLayer(float dropRate) { this.dropRate = dropRate; } /** * Calculate output internally, the intermediate output of each layer will be * stored. * * @param instanceWithBias The instance contains the features. * @return Cached output of each layer. */ public FloatVector getOutputInternal(FloatVector instanceWithBias) { // sets the output of input layer Neuron[] inputLayer; for (int stepIdx = 0; stepIdx < this.weightMatrixLists.size(); stepIdx++) { inputLayer = neuronLists.get(stepIdx).get(0); for (int inputNeuronIdx = 0; inputNeuronIdx < this.layerSizeList.get(0); inputNeuronIdx++) { float m2 = MathUtils.getBinomial(1, dropRate); if (m2 == 0) inputLayer[inputNeuronIdx].setDrop(true); else inputLayer[inputNeuronIdx].setDrop(false); inputLayer[inputNeuronIdx] .setOutput(instanceWithBias.get(stepIdx * this.layerSizeList.get(0) + inputNeuronIdx) * m2); } // loop forward as much as recurrent step size this.weightMatrixList = this.weightMatrixLists.get(stepIdx); for (int layerIdx = 0; layerIdx < weightMatrixList.size(); ++layerIdx) { forward(stepIdx, layerIdx); } } // output for each recurrent step int singleOutputLength = neuronLists.get(this.recurrentStepSize - 1).get(this.finalLayerIdx).length; FloatVector output = new DenseFloatVector(singleOutputLength * this.numOutCells); int outputNeuronIdx = 0; for (int step = this.recurrentStepSize - this.numOutCells; step < this.recurrentStepSize; step++) { neurons = neuronLists.get(step); for (int neuronIdx = 0; neuronIdx < singleOutputLength; neuronIdx++) { output.set(outputNeuronIdx, neurons.get(this.finalLayerIdx)[neuronIdx].getOutput()); outputNeuronIdx++; } } return output; } /** * @param neuronClass * @return a new neuron instance */ @SuppressWarnings({ "rawtypes" }) public static Neuron newNeuronInstance(Class<? extends Neuron> neuronClass) { return (Neuron) ReflectionUtils.newInstance(neuronClass); } public class InputMessageIterable implements Iterable<Synapse<FloatWritable, FloatWritable>> { private int currNeuronID; private int prevNeuronID; private int end; private FloatMatrix weightMat; private Neuron[] layer; public InputMessageIterable(int fromLayer, int row) { this.currNeuronID = row; this.prevNeuronID = -1; this.end = weightMatrixList.get(fromLayer).getColumnCount() - 1; this.weightMat = weightMatrixList.get(fromLayer); this.layer = neurons.get(fromLayer); } @Override public Iterator<Synapse<FloatWritable, FloatWritable>> iterator() { return new MessageIterator(); } private class MessageIterator implements Iterator<Synapse<FloatWritable, FloatWritable>> { @Override public boolean hasNext() { if (prevNeuronID < end) { return true; } else { return false; } } private FloatWritable i = new FloatWritable(); private FloatWritable w = new FloatWritable(); private Synapse<FloatWritable, FloatWritable> msg = new Synapse<FloatWritable, FloatWritable>(); @Override public Synapse<FloatWritable, FloatWritable> next() { prevNeuronID++; i.set(layer[prevNeuronID].getOutput()); w.set(weightMat.get(currNeuronID, prevNeuronID)); msg.set(prevNeuronID, i, w); return new Synapse<FloatWritable, FloatWritable>(prevNeuronID, i, w); } @Override public void remove() { } } } /** * Forward the calculation for one layer. * * @param fromLayerIdx The index of the previous layer. */ protected void forward(int stepIdx, int fromLayerIdx) { neurons = this.neuronLists.get(stepIdx); int curLayerIdx = fromLayerIdx + 1; // weight matrix for current layer FloatMatrix weightMatrix = this.weightMatrixList.get(fromLayerIdx); FloatFunction squashingFunction = getSquashingFunction(fromLayerIdx); FloatVector vec = new DenseFloatVector(weightMatrix.getRowCount()); FloatVector inputVector = new DenseFloatVector(neurons.get(fromLayerIdx).length); for (int i = 0; i < neurons.get(fromLayerIdx).length; i++) { inputVector.set(i, neurons.get(fromLayerIdx)[i].getOutput()); } for (int row = 0; row < weightMatrix.getRowCount(); row++) { Neuron n; if (curLayerIdx == finalLayerIdx) n = neurons.get(curLayerIdx)[row]; else n = neurons.get(curLayerIdx)[row + 1]; try { FloatVector weightVector = weightMatrix.getRowVector(row); n.setWeightVector(weightVector); ((RecurrentDropoutNeuron) n).setRecurrentDelta(0); n.setIterationNumber(iterations); n.forward(inputVector); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } vec.set(row, n.getOutput()); } if (squashingFunction.getFunctionName().equalsIgnoreCase(SoftMax.class.getSimpleName())) { IntermediateOutput interlayer = (IntermediateOutput) ReflectionUtils .newInstance(SoftMax.SoftMaxOutputComputer.class); try { vec = interlayer.interlayer(vec); for (int i = 0; i < vec.getDimension(); i++) { neurons.get(curLayerIdx)[i].setOutput(vec.get(i)); } } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } // add bias if (curLayerIdx != finalLayerIdx) neurons.get(curLayerIdx)[0].setOutput(1); // copy output to next recurrent layer if (this.recurrentLayerList.get(curLayerIdx) && stepIdx < this.recurrentStepSize - 1) { for (int i = 0; i < vec.getDimension(); i++) { this.neuronLists.get(stepIdx + 1).get(fromLayerIdx)[this.layerSizeList.get(fromLayerIdx) + i] .setOutput(vec.get(i)); } } } /** * Train the model online. * * @param trainingInstance */ public void trainOnline(FloatVector trainingInstance) { FloatMatrix[] updateMatrices = this.trainByInstance(trainingInstance); this.updateWeightMatrices(updateMatrices); } @Override public FloatMatrix[] trainByInstance(FloatVector trainingInstance) { int inputDimension = (this.layerSizeList.get(0) - 1) * this.recurrentStepSize; FloatVector transformedVector = this.featureTransformer .transform(trainingInstance.sliceUnsafe(inputDimension)); int outputDimension; FloatVector inputInstance = null; FloatVector labels = null; if (this.learningStyle == LearningStyle.SUPERVISED) { outputDimension = this.layerSizeList.get(this.layerSizeList.size() - 1); // validate training instance Preconditions.checkArgument((inputDimension + outputDimension == trainingInstance.getDimension() || inputDimension + outputDimension * recurrentStepSize == trainingInstance.getDimension()), String.format("The dimension of training instance is %d, but requires %d.", trainingInstance.getDimension(), inputDimension + outputDimension)); inputInstance = new DenseFloatVector(this.layerSizeList.get(0) * this.recurrentStepSize); // get the features from the transformed vector int vecIdx = 0; for (int i = 0; i < inputInstance.getLength(); ++i) { if (i % this.layerSizeList.get(0) == 0) { inputInstance.set(i, 1); // add bias } else { inputInstance.set(i, transformedVector.get(vecIdx)); vecIdx++; } } // get the labels from the original training instance labels = trainingInstance.sliceUnsafe(transformedVector.getDimension(), trainingInstance.getDimension() - 1); } else if (this.learningStyle == LearningStyle.UNSUPERVISED) { // labels are identical to input features outputDimension = inputDimension; // validate training instance Preconditions.checkArgument(inputDimension == trainingInstance.getDimension(), String.format("The dimension of training instance is %d, but requires %d.", trainingInstance.getDimension(), inputDimension)); inputInstance = new DenseFloatVector(this.layerSizeList.get(0) * this.recurrentStepSize); // get the features from the transformed vector int vecIdx = 0; for (int i = 0; i < inputInstance.getLength(); ++i) { if (i % this.layerSizeList.get(0) == 0) { inputInstance.set(i, 1); // add bias } else { inputInstance.set(i, transformedVector.get(vecIdx)); vecIdx++; } } // get the labels by copying the transformed vector labels = transformedVector.deepCopy(); } FloatVector output = this.getOutputInternal(inputInstance); // get the training error calculateTrainingError(labels, output); if (this.trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) { FloatMatrix[] updates = this.trainByInstanceGradientDescent(labels); return updates; } else { throw new IllegalArgumentException(String.format("Training method is not supported.")); } } /** * Train by gradient descent. Get the updated weights using one training * instance. * * @param trainingInstance * @return The weight update matrices. */ private FloatMatrix[] trainByInstanceGradientDescent(FloatVector labels) { // initialize weight update matrices DenseFloatMatrix[] weightUpdateMatrices = new DenseFloatMatrix[this.getSizeOfWeightmatrix()]; int matrixIdx = 0; for (List<FloatMatrix> aWeightMatrixList : this.weightMatrixLists) { for (FloatMatrix aWeightMatrix : aWeightMatrixList) { weightUpdateMatrices[matrixIdx++] = new DenseFloatMatrix(aWeightMatrix.getRowCount(), aWeightMatrix.getColumnCount()); } } FloatVector deltaVec = new DenseFloatVector( this.layerSizeList.get(layerSizeList.size() - 1) * this.numOutCells); FloatFunction squashingFunction = this.squashingFunctionList.get(this.squashingFunctionList.size() - 1); int labelIdx = 0; // start from last recurrent step to first recurrent step for (int step = this.recurrentStepSize - this.numOutCells; step < this.recurrentStepSize; step++) { FloatMatrix lastWeightMatrix = this.weightMatrixLists.get(step) .get(this.weightMatrixLists.get(step).size() - 1); int neuronIdx = 0; for (Neuron aNeurons : this.neuronLists.get(step).get(this.finalLayerIdx)) { float finalOut = aNeurons.getOutput(); float costFuncDerivative = this.costFunction.applyDerivative(labels.get(labelIdx), finalOut); // add regularization costFuncDerivative += this.regularizationWeight * lastWeightMatrix.getRowVector(neuronIdx).sum(); if (!squashingFunction.getFunctionName().equalsIgnoreCase(SoftMax.class.getSimpleName())) { costFuncDerivative *= squashingFunction.applyDerivative(finalOut); } aNeurons.backpropagate(costFuncDerivative); deltaVec.set(labelIdx, costFuncDerivative); neuronIdx++; labelIdx++; } } // start from last recurrent step to first recurrent step boolean skipLastLayer = false; int weightMatrixIdx = weightUpdateMatrices.length - 1; for (int step = this.recurrentStepSize - 1; step >= 0; --step) { this.weightMatrixList = this.weightMatrixLists.get(step); this.prevWeightUpdatesList = this.prevWeightUpdatesLists.get(step); this.neurons = this.neuronLists.get(step); if (step < this.recurrentStepSize - this.numOutCells) skipLastLayer = true; // start from previous layer of output layer for (int layer = this.layerSizeList.size() - 2; layer >= 0; --layer) { if (skipLastLayer) { skipLastLayer = false; continue; } backpropagate(step, layer, weightUpdateMatrices[weightMatrixIdx--]); } } // TODO eliminate non-output cells from weightUpdateLists this.setPrevWeightMatrices(weightUpdateMatrices); return weightUpdateMatrices; } public class ErrorMessageIterable implements Iterable<Synapse<FloatWritable, FloatWritable>> { private int row; private int neuronID; private int end; private FloatMatrix weightMat; private FloatMatrix prevWeightMat; private float[] nextLayerDelta; public ErrorMessageIterable(int recurrentStepIdx, int curLayerIdx, int row) { this.row = row; this.neuronID = -1; this.weightMat = weightMatrixLists.get(recurrentStepIdx).get(curLayerIdx); this.end = weightMat.getRowCount() - 1; this.prevWeightMat = prevWeightUpdatesLists.get(recurrentStepIdx).get(curLayerIdx); Neuron[] nextLayer = neuronLists.get(recurrentStepIdx).get(curLayerIdx + 1); nextLayerDelta = new float[weightMat.getRowCount()]; for (int i = 0; i <= end; ++i) { if (curLayerIdx + 1 == finalLayerIdx) { nextLayerDelta[i] = nextLayer[i].getDelta(); } else { nextLayerDelta[i] = nextLayer[i + 1].getDelta(); } } } @Override public Iterator<Synapse<FloatWritable, FloatWritable>> iterator() { return new MessageIterator(); } private class MessageIterator implements Iterator<Synapse<FloatWritable, FloatWritable>> { @Override public boolean hasNext() { if (neuronID < end) { return true; } else { return false; } } private FloatWritable d = new FloatWritable(); private FloatWritable w = new FloatWritable(); private FloatWritable p = new FloatWritable(); private Synapse<FloatWritable, FloatWritable> msg = new Synapse<FloatWritable, FloatWritable>(); @Override public Synapse<FloatWritable, FloatWritable> next() { neuronID++; d.set(nextLayerDelta[neuronID]); w.set(weightMat.get(neuronID, row)); p.set(prevWeightMat.get(neuronID, row)); msg.set(neuronID, d, w, p); return msg; } @Override public void remove() { } } } /** * Back-propagate the errors to from next layer to current layer. The weight * updated information will be stored in the weightUpdateMatrices, and the * delta of the prevLayer would be returned. * * @param layer Index of current layer. */ private void backpropagate(int recurrentStepIdx, int curLayerIdx, DenseFloatMatrix weightUpdateMatrix) { FloatMatrix weightMat = weightMatrixLists.get(recurrentStepIdx).get(curLayerIdx); FloatMatrix prevWeightMat = prevWeightUpdatesLists.get(recurrentStepIdx).get(curLayerIdx); // get layer related information int x = this.weightMatrixList.get(curLayerIdx).getColumnCount(); int y = this.weightMatrixList.get(curLayerIdx).getRowCount(); Neuron[] ns = this.neuronLists.get(recurrentStepIdx).get(curLayerIdx); for (int row = 0; row < x; ++row) { Neuron n = ns[row]; n.setWeightVector(y); try { FloatVector weightVector = weightMat.getColumnVector(row); n.setWeightVector(weightVector); Neuron[] nextLayer = neuronLists.get(recurrentStepIdx).get(curLayerIdx + 1); FloatVector deltaVector = new DenseFloatVector(weightVector.getDimension()); for (int i = 0; i < weightVector.getDimension(); ++i) { if (curLayerIdx + 1 == finalLayerIdx) { deltaVector.set(i, nextLayer[i].getDelta()); } else { deltaVector.set(i, nextLayer[i + 1].getDelta()); } } n.setDeltaVector(deltaVector); n.setPrevWeightVector(prevWeightMat.getColumnVector(row)); n.backward(deltaVector); if (row >= layerSizeList.get(curLayerIdx) && recurrentStepIdx > 0 && recurrentLayerList.get(curLayerIdx + 1)) { Neuron recurrentNeuron = neuronLists.get(recurrentStepIdx - 1).get(curLayerIdx + 1)[row - layerSizeList.get(curLayerIdx) + 1]; recurrentNeuron.backpropagate(n.getDelta()); } } catch (IOException e) { e.printStackTrace(); } // update weights weightUpdateMatrix.setColumn(row, n.getUpdates()); } } @Override protected BSPJob trainInternal(HamaConfiguration conf) throws IOException, InterruptedException, ClassNotFoundException { this.conf = conf; this.fs = FileSystem.get(conf); String modelPath = conf.get("model.path"); if (modelPath != null) { this.modelPath = modelPath; } // modelPath must be set before training if (this.modelPath == null) { throw new IllegalArgumentException("Please specify the modelPath for model, " + "either through setModelPath() or add 'modelPath' to the training parameters."); } this.setRecurrentStepSize(conf.getInt("training.recurrent.step.size", 1)); this.initializeWeightMatrixLists(); this.writeModelToFile(); // create job BSPJob job = new BSPJob(conf, RecurrentLayeredNeuralNetworkTrainer.class); job.setJobName("Neural Network training"); job.setJarByClass(RecurrentLayeredNeuralNetworkTrainer.class); job.setBspClass(RecurrentLayeredNeuralNetworkTrainer.class); job.getConfiguration().setInt(Constants.ADDITIONAL_BSP_TASKS, 1); job.setBoolean("training.mode", true); job.setInputPath(new Path(conf.get("training.input.path"))); job.setInputFormat(org.apache.hama.bsp.SequenceFileInputFormat.class); job.setInputKeyClass(LongWritable.class); job.setInputValueClass(VectorWritable.class); job.setOutputKeyClass(NullWritable.class); job.setOutputValueClass(NullWritable.class); job.setOutputFormat(org.apache.hama.bsp.NullOutputFormat.class); return job; } @Override protected void calculateTrainingError(FloatVector labels, FloatVector output) { FloatVector errors = labels.deepCopy().applyToElements(output, this.costFunction); this.trainingError = errors.sum(); } /** * Get the squashing function of a specified layer. * * @param idx * @return a new vector with the result of the operation. */ public FloatFunction getSquashingFunction(int idx) { return this.squashingFunctionList.get(idx); } public void setIterationNumber(long iterations) { this.iterations = iterations; } public void setRecurrentStepSize(int recurrentStepSize) { this.recurrentStepSize = recurrentStepSize; } }