Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package com.anhth12.nn.multilayer; import com.anhth12.dataset.iterator.DataSetIterator; import com.anhth12.downpourSGD.util.Pair; import com.anhth12.eval.Evaluation; import com.anhth12.nn.api.Classifier; import com.anhth12.nn.api.Layer; import com.anhth12.nn.api.OptimizationAlgorithm; import com.anhth12.nn.conf.MultiLayerConfiguration; import com.anhth12.nn.conf.NeuralNetworkConfiguration; import com.anhth12.nn.conf.OutputPreprocessor; import com.anhth12.nn.gradient.DefaultGradient; import com.anhth12.nn.gradient.Gradient; import com.anhth12.nn.layers.OutputLayer; import com.anhth12.nn.params.DefaultParamInitializer; import com.anhth12.util.MultiLayerUtil; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.commons.lang3.ArrayUtils; import org.nd4j.linalg.api.activation.ActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.sampling.Sampling; import org.nd4j.linalg.transformation.MatrixTransform; import org.nd4j.linalg.util.FeatureUtil; import org.nd4j.linalg.util.LinAlgExceptions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * * @author anhth12 */ public class MultiLayerNetwork implements Serializable, Classifier { private static Logger log = LoggerFactory.getLogger(MultiLayerNetwork.class); private Layer[] layers; private INDArray input, labels; //layer - weight transform protected Map<Integer, MatrixTransform> weightTransforms = new HashMap<>(); //layer - hbias transform protected Map<Integer, MatrixTransform> hiddenBiasTransforms = new HashMap<>(); //layer - vbias transforms protected Map<Integer, MatrixTransform> visibleBiasTransform = new HashMap<>(); protected NeuralNetworkConfiguration defaultConfiguration; protected MultiLayerConfiguration layerWiseConfiguration; protected boolean initCalled = false; //binary drop connect mask; protected INDArray mask; public MultiLayerNetwork(MultiLayerConfiguration conf) { this.layerWiseConfiguration = conf; defaultConfiguration = conf.getConf(0); } protected void initializeConfigurations() { if (layerWiseConfiguration == null) { layerWiseConfiguration = new MultiLayerConfiguration.Builder().build(); } if (layers == null) { layers = new Layer[layerWiseConfiguration.getHiddenLayerSizes().length + 1]; } if (defaultConfiguration == null) { defaultConfiguration = new NeuralNetworkConfiguration(); } if (layerWiseConfiguration == null || layerWiseConfiguration.getConfs().isEmpty()) { for (int i = 0; i < layerWiseConfiguration.getHiddenLayerSizes().length + 1; i++) { layerWiseConfiguration.getConfs().add(defaultConfiguration.clone()); } } } public void pretrain(DataSetIterator iter) { if (!layerWiseConfiguration.isPretrain()) { return; } INDArray layerInput; for (int i = 0; i < getnLayers(); i++) { if (i == 0) { while (iter.hasNext()) { DataSet next = iter.next(); this.input = next.getFeatureMatrix(); if (this.getInput() == null || this.getLayers() == null) { setInput(input); initializeLayer(input); } else { setInput(input); } log.info("Training on layer " + (i + 1)); getLayers()[i].fit(next.getFeatureMatrix()); } iter.reset(); } else { while (iter.hasNext()) { DataSet next = iter.next(); layerInput = next.getFeatureMatrix(); for (int j = 0; j <= i; j++) { layerInput = activationFromPrevLayer(j - 1, layerInput); } log.info("Training on layer " + (i + 1)); getLayers()[i].fit(layerInput); } iter.reset(); } } } public void pretrain(INDArray input) { if (!layerWiseConfiguration.isPretrain()) { return; } /* During pretrain, feed forward expected activations of network, use activation cooccurrences during pretrain */ if (this.getInput() == null || this.getLayers() == null) { setInput(input); initializeLayer(input); } else { setInput(input); } INDArray layerInput = null; for (int i = 0; i < getnLayers() - 1; i++) { if (i == 0) { layerInput = getInput(); } else { layerInput = activationFromPrevLayer(i - 1, layerInput); } log.info("Training on layer " + (i + 1)); getLayers()[i].fit(layerInput); } } @Override public int batchSize() { return input.rows(); } @Override public INDArray input() { return input; } @Override public void validateInput() { } @Override public void setConf(NeuralNetworkConfiguration conf) { throw new UnsupportedOperationException("Not supported yet."); //To change body of generated methods, choose Tools | Templates. } @Override public NeuralNetworkConfiguration conf() { throw new UnsupportedOperationException("Not supported yet."); //To change body of generated methods, choose Tools | Templates. } @Override public INDArray transform(INDArray data) { return output(data); } public NeuralNetworkConfiguration getDefaultConfiguration() { return defaultConfiguration; } public void setDefaultConfiguration(NeuralNetworkConfiguration defaultConfiguration) { this.defaultConfiguration = defaultConfiguration; } public MultiLayerConfiguration getLayerWiseConfiguration() { return layerWiseConfiguration; } public void setLayerWiseConfiguration(MultiLayerConfiguration layerWiseConfiguration) { this.layerWiseConfiguration = layerWiseConfiguration; } public void initializeLayer(INDArray input) { if (input == null) { throw new IllegalArgumentException("Unable init neural network with empty input"); } int[] hiddenLayerSizes = layerWiseConfiguration.getHiddenLayerSizes(); if (input.shape().length == 2) { for (int i = 0; i < hiddenLayerSizes.length; i++) { if (hiddenLayerSizes[i] < 1) { throw new IllegalStateException("All hidden layer size must >= 1"); } } } this.input = input.dup(); if (!initCalled) { init(); } } public void init() { if (layerWiseConfiguration == null || layers == null) { initializeConfigurations(); } //unnecessary INDArray layerInput = input; int inputSize; if (getnLayers() < 1) { throw new IllegalStateException("number layer must >= 1"); } int[] hiddenLayerSizes = layerWiseConfiguration.getHiddenLayerSizes(); if (this.layers == null || this.layers[0] == null) { this.layers = new Layer[hiddenLayerSizes.length + 1]; for (int i = 0; i < this.layers.length; i++) { if (i == 0) { inputSize = layerWiseConfiguration.getConf(0).getnIn(); } else { inputSize = hiddenLayerSizes[i - 1]; } if (i == 0) { layerWiseConfiguration.getConf(i).setnIn(inputSize); layerWiseConfiguration.getConf(i).setnOut(hiddenLayerSizes[i]); layers[i] = layerWiseConfiguration.getConf(i).getLayerFactory() .create(layerWiseConfiguration.getConf(i)); } else if (i < getLayers().length - 1) { if (input != null) { //unneccessary layerInput = activationFromPrevLayer(i - 1, layerInput); } layerWiseConfiguration.getConf(i).setnIn(inputSize); layerWiseConfiguration.getConf(i).setnOut(hiddenLayerSizes[i]); layers[i] = layerWiseConfiguration.getConf(i).getLayerFactory() .create(layerWiseConfiguration.getConf(i)); } } NeuralNetworkConfiguration last = layerWiseConfiguration .getConf(layerWiseConfiguration.getConfs().size() - 1); NeuralNetworkConfiguration secondToLast = layerWiseConfiguration .getConf(layerWiseConfiguration.getConfs().size() - 2); last.setnIn(secondToLast.getnOut()); this.layers[layers.length - 1] = last.getLayerFactory().create(last); initCalled = true; initMask(); } } public INDArray activate() { return getLayers()[getLayers().length - 1].activate(); } public INDArray activate(int layer) { return getLayers()[layer].activate(); } public INDArray activate(int layer, INDArray input) { return getLayers()[layer].activate(input); } public void initialize(DataSet dataSet) { setInput(dataSet.getFeatureMatrix()); //feedforward feedForward(dataSet.getFeatureMatrix()); this.labels = dataSet.getLabels(); if (getOutputLayer() instanceof OutputLayer) { OutputLayer o = (OutputLayer) getOutputLayer(); o.setLabels(labels); } } public synchronized INDArray activationFromPrevLayer(int curr, INDArray input) { return layers[curr].activate(input); } public List<INDArray> feedForward() { INDArray currInput = this.input; if (this.input.isMatrix() && this.input.columns() != defaultConfiguration.getnIn()) { throw new IllegalStateException("Illegal input length"); } List<INDArray> activations = new ArrayList<>(); activations.add(currInput); for (int i = 0; i < layers.length; i++) { currInput = activationFromPrevLayer(i, currInput); //pre process the activation before passing to the next layer OutputPreprocessor preProcessor = getLayerWiseConfiguration().getPreProcessor(i); if (preProcessor != null) { currInput = preProcessor.preProcess(currInput); } //applies drop connect to the activation applyDropConnectIfNecessary(currInput); activations.add(currInput); } return activations; } public List<INDArray> feedForward(INDArray input) { if (input == null) { throw new IllegalStateException("Unable to perform feed forward; no input found"); } else { this.input = input; } return feedForward(); } @Override public Gradient getGradient() { Gradient ret = new DefaultGradient(); for (int i = 0; i < layers.length; i += 2) { ret.gradientLookupTable().put(String.valueOf(i), layers[i].getGradient().gradient()); } return ret; } @Override public Pair<Gradient, Double> gradientAndScore() { return new Pair<>(getGradient(), getOutputLayer().score()); } protected void applyDropConnectIfNecessary(INDArray input) { if (layerWiseConfiguration.isUseDropConnect()) { INDArray mask = Sampling.binomial(Nd4j.valueArrayOf(input.rows(), input.columns(), 0.5), 1, defaultConfiguration.getRng()); input.muli(mask); //apply l2 for drop connect if (defaultConfiguration.getL2() > 0) { input.muli(defaultConfiguration.getL2()); } } } /* delta computation for back prop with the R operator */ protected List<INDArray> computeDeltasR(INDArray v) { List<INDArray> deltaRet = new ArrayList<>(); INDArray[] deltas = new INDArray[getnLayers() + 1]; List<INDArray> activations = feedForward(); List<INDArray> rActivations = feedForwardR(activations, v); /* * Precompute activations and z's (pre activation network outputs) */ List<INDArray> weights = new ArrayList<>(); List<INDArray> biases = new ArrayList<>(); List<ActivationFunction> activationFunctions = new ArrayList<>(); for (int j = 0; j < getLayers().length; j++) { weights.add(getLayers()[j].getParam(DefaultParamInitializer.WEIGHT_KEY)); biases.add(getLayers()[j].getParam(DefaultParamInitializer.BIAS_KEY)); activationFunctions.add(getLayers()[j].conf().getActivationFunction()); } INDArray rix = rActivations.get(rActivations.size() - 1).divi((double) input.rows()); LinAlgExceptions.assertValidNum(rix); //errors for (int i = getnLayers() - 1; i >= 0; i--) { //W^t * error^l + 1 deltas[i] = activations.get(i).transpose().mmul(rix); applyDropConnectIfNecessary(deltas[i]); if (i > 0) { rix = rix.mmul(weights.get(i).addRowVector(biases.get(i)).transpose()) .muli(activationFunctions.get(i - 1).applyDerivative(activations.get(i))); } } for (int i = 0; i < deltas.length - 1; i++) { if (defaultConfiguration.isConstrainGradientToUnitNorm()) { double sum = deltas[i].sum(Integer.MAX_VALUE).getDouble(0); if (sum > 0) { deltaRet.add(deltas[i].div(deltas[i].norm2(Integer.MAX_VALUE))); } else { deltaRet.add(deltas[i]); } } else { deltaRet.add(deltas[i]); } LinAlgExceptions.assertValidNum(deltaRet.get(i)); } return deltaRet; } public void dampingUpdate(double rho, double boost, double decrease) { if (rho < 0.25 || Double.isNaN(rho)) { layerWiseConfiguration.setDampingFactor(getLayerWiseConfiguration().getDampingFactor() * boost); } else if (rho > 0.75) { layerWiseConfiguration.setDampingFactor(getLayerWiseConfiguration().getDampingFactor() * decrease); } } public double reductionRation(INDArray p, double currScore, double score, INDArray gradient) { double currentDamp = layerWiseConfiguration.getDampingFactor(); layerWiseConfiguration.setDampingFactor(0); ; INDArray denom = getBackPropRGradient(p); denom.muli(0.5).muli(p.mul(denom)).sum(0); denom.subi(gradient.mul(p).sum(0)); double rho = (currScore - score) / (double) denom.getScalar(0).element(); layerWiseConfiguration.setDampingFactor(currentDamp); if (score - currScore > 0) { return Float.NEGATIVE_INFINITY; } return rho; } /* delta computation for back prop with precon for SFH */ protected List<Pair<INDArray, INDArray>> computeDeltas2() { List<Pair<INDArray, INDArray>> deltaRet = new ArrayList<>(); List<INDArray> activations = feedForward(); INDArray[] deltas = new INDArray[activations.size() - 1]; INDArray[] preCons = new INDArray[activations.size() - 1]; //- y - h INDArray ix = activations.get(activations.size() - 1).sub(labels).div(labels.rows()); /* * Precompute activations and z's (pre activation network outputs) */ List<INDArray> weights = new ArrayList<>(); List<INDArray> biases = new ArrayList<>(); List<ActivationFunction> activationFunctions = new ArrayList<>(); for (int j = 0; j < getLayers().length; j++) { weights.add(getLayers()[j].getParam(DefaultParamInitializer.WEIGHT_KEY)); biases.add(getLayers()[j].getParam(DefaultParamInitializer.BIAS_KEY)); activationFunctions.add(getLayers()[j].conf().getActivationFunction()); } //errors for (int i = weights.size() - 1; i >= 0; i--) { deltas[i] = activations.get(i).transpose().mmul(ix); preCons[i] = Transforms.pow(activations.get(i).transpose(), 2).mmul(Transforms.pow(ix, 2)) .muli(labels.rows()); applyDropConnectIfNecessary(deltas[i]); if (i > 0) { //W[i] + b[i] * f'(z[i - 1]) ix = ix.mmul(weights.get(i).transpose()) .muli(activationFunctions.get(i - 1).applyDerivative(activations.get(i))); } } for (int i = 0; i < deltas.length; i++) { if (defaultConfiguration.isConstrainGradientToUnitNorm()) { deltaRet.add(new Pair<>(deltas[i].divi(deltas[i].norm2(Integer.MAX_VALUE)), preCons[i])); } else { deltaRet.add(new Pair<>(deltas[i], preCons[i])); } } return deltaRet; } /* delta computation for back prop */ protected List<INDArray> computeDeltas() { List<INDArray> deltaRet = new ArrayList<>(); INDArray[] deltas = new INDArray[getnLayers() + 2]; List<INDArray> activations = feedForward(); //- y - h INDArray ix = labels.sub(activations.get(activations.size() - 1)).subi(getOutputLayer().conf() .getActivationFunction().applyDerivative(activations.get(activations.size() - 1))); /* * Precompute activations and z's (pre activation network outputs) */ List<INDArray> weights = new ArrayList<>(); List<INDArray> biases = new ArrayList<>(); List<ActivationFunction> activationFunctions = new ArrayList<>(); for (int j = 0; j < getLayers().length; j++) { weights.add(getLayers()[j].getParam(DefaultParamInitializer.WEIGHT_KEY)); biases.add(getLayers()[j].getParam(DefaultParamInitializer.BIAS_KEY)); activationFunctions.add(getLayers()[j].conf().getActivationFunction()); } weights.add(getOutputLayer().getParam(DefaultParamInitializer.WEIGHT_KEY)); biases.add(getOutputLayer().getParam(DefaultParamInitializer.BIAS_KEY)); activationFunctions.add(getOutputLayer().conf().getActivationFunction()); //errors for (int i = getnLayers() + 1; i >= 0; i--) { //output layer if (i >= getnLayers() + 1) { //-( y - h) .* f'(z^l) where l is the output layer deltas[i] = ix; } else { INDArray delta = activations.get(i).transpose().mmul(ix); deltas[i] = delta; applyDropConnectIfNecessary(deltas[i]); INDArray weightsPlusBias = weights.get(i).transpose(); INDArray activation = activations.get(i); if (i > 0) { ix = ix.mmul(weightsPlusBias).muli(activationFunctions.get(i - 1).applyDerivative(activation)); } } } for (int i = 0; i < deltas.length; i++) { if (defaultConfiguration.isConstrainGradientToUnitNorm()) { deltaRet.add(deltas[i].divi(deltas[i].norm2(Integer.MAX_VALUE))); } else { deltaRet.add(deltas[i]); } } return deltaRet; } public void backPropStep() { List<Pair<INDArray, INDArray>> deltas = backPropGradient(); for (int i = 0; i < layers.length; i++) { layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).addi(deltas.get(i).getFirst()); layers[i].getParam(DefaultParamInitializer.BIAS_KEY).addi(deltas.get(i).getSecond()); } } public INDArray getBackPropRGradient(INDArray v) { return pack(backPropGradientR(v)); } /** * Gets the back prop gradient with the r operator (gauss vector) and the * associated precon matrix This is also called computeGV * * @return the back prop with r gradient */ public Pair<INDArray, INDArray> getBackPropGradient2() { List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> deltas = backPropGradient2(); List<Pair<INDArray, INDArray>> deltaNormal = new ArrayList<>(); List<Pair<INDArray, INDArray>> deltasPreCon = new ArrayList<>(); for (int i = 0; i < deltas.size(); i++) { deltaNormal.add(deltas.get(i).getFirst()); deltasPreCon.add(deltas.get(i).getSecond()); } return new Pair<>(pack(deltaNormal), pack(deltasPreCon)); } @Override public MultiLayerNetwork clone() { MultiLayerNetwork ret; try { ret = getClass().newInstance(); ret.update(this); } catch (Exception e) { throw new IllegalStateException("Unable to cloe network"); } return ret; } @Override public INDArray params() { List<INDArray> params = new ArrayList<>(); for (int i = 0; i < getnLayers(); i++) { params.add(layers[i].params()); } return Nd4j.toFlattened(params); } @Override public void setParams(INDArray param) { setParameters(param); } public INDArray pack() { return params(); } public INDArray pack(List<Pair<INDArray, INDArray>> layers) { List<INDArray> list = new ArrayList<>(); for (int i = 0; i < layers.size(); i++) { list.add(layers.get(i).getFirst()); list.add(layers.get(i).getSecond()); } return Nd4j.toFlattened(list); } public List<Pair<INDArray, INDArray>> backPropGradient() { //feedforward to compute activations //initial error //precompute deltas List<INDArray> deltas = computeDeltas(); List<Pair<INDArray, INDArray>> vWvB = new ArrayList<>(); for (int i = 0; i < layers.length; i++) vWvB.add(new Pair<>(layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY), layers[i].getParam(DefaultParamInitializer.BIAS_KEY))); List<Pair<INDArray, INDArray>> list = new ArrayList<>(); for (int l = 0; l < getnLayers() + 1; l++) { INDArray gradientChange = deltas.get(l); if (gradientChange.length() != getLayers()[l].getParam(DefaultParamInitializer.WEIGHT_KEY).length()) throw new IllegalStateException("Gradient change not equal to weight change"); //update hidden bias INDArray deltaColumnSums = deltas.get(l).isVector() ? deltas.get(l) : deltas.get(l).mean(0); list.add(new Pair<>(gradientChange, deltaColumnSums)); } if (mask == null) initMask(); return list; } public List<Pair<INDArray, INDArray>> unPack(INDArray param) { //more sanity checks! if (param.rows() != 1) { param = param.reshape(1, param.length()); } List<Pair<INDArray, INDArray>> ret = new ArrayList<>(); int curr = 0; for (int i = 0; i < layers.length; i++) { int layerLength = layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).length() + layers[i].getParam(DefaultParamInitializer.BIAS_KEY).length(); INDArray subMatrix = param.get(NDArrayIndex.interval(curr, curr + layerLength)); INDArray weightPortion = subMatrix .get(NDArrayIndex.interval(0, layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).length())); int beginHBias = layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).length(); int endHbias = subMatrix.length(); INDArray hBiasPortion = subMatrix.get(NDArrayIndex.interval(beginHBias, endHbias)); int layerLengthSum = weightPortion.length() + hBiasPortion.length(); if (layerLengthSum != layerLength) { if (hBiasPortion.length() != layers[i].getParam(DefaultParamInitializer.BIAS_KEY).length()) { throw new IllegalStateException("Hidden bias on layer " + i + " was off"); } if (weightPortion.length() != layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).length()) { throw new IllegalStateException("Weight portion on layer " + i + " was off"); } } ret.add(new Pair<>( weightPortion.reshape(layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).rows(), layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).columns()), hBiasPortion.reshape(layers[i].getParam(DefaultParamInitializer.BIAS_KEY).rows(), layers[i].getParam(DefaultParamInitializer.BIAS_KEY).columns()))); curr += layerLength; } return ret; } protected List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> backPropGradient2() { //feedforward to compute activations //initial error //precompute deltas List<Pair<INDArray, INDArray>> deltas = computeDeltas2(); List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> list = new ArrayList<>(); List<Pair<INDArray, INDArray>> grad = new ArrayList<>(); List<Pair<INDArray, INDArray>> preCon = new ArrayList<>(); for (int l = 0; l < deltas.size(); l++) { INDArray gradientChange = deltas.get(l).getFirst(); INDArray preConGradientChange = deltas.get(l).getSecond(); if (l < layers.length && gradientChange.length() != layers[l].getParam(DefaultParamInitializer.WEIGHT_KEY).length()) throw new IllegalStateException("Gradient change not equal to weight change"); //update hidden bias INDArray deltaColumnSums = deltas.get(l).getFirst().mean(0); INDArray preConColumnSums = deltas.get(l).getSecond().mean(0); grad.add(new Pair<>(gradientChange, deltaColumnSums)); preCon.add(new Pair<>(preConGradientChange, preConColumnSums)); if (l < layers.length && deltaColumnSums.length() != layers[l].getParam(DefaultParamInitializer.BIAS_KEY).length()) throw new IllegalStateException("Bias change not equal to weight change"); else if (l == getLayers().length && deltaColumnSums.length() != getOutputLayer() .getParam(DefaultParamInitializer.BIAS_KEY).length()) throw new IllegalStateException("Bias change not equal to weight change"); } INDArray g = pack(grad); INDArray con = pack(preCon); INDArray theta = params(); if (mask == null) initMask(); g.addi(theta.mul(defaultConfiguration.getL2()).muli(mask)); INDArray conAdd = Transforms.pow( mask.mul(defaultConfiguration.getL2()) .add(Nd4j.valueArrayOf(g.rows(), g.columns(), layerWiseConfiguration.getDampingFactor())), 3.0 / 4.0); con.addi(conAdd); List<Pair<INDArray, INDArray>> gUnpacked = unPack(g); List<Pair<INDArray, INDArray>> conUnpacked = unPack(con); for (int i = 0; i < gUnpacked.size(); i++) list.add(new Pair<>(gUnpacked.get(i), conUnpacked.get(i))); return list; } public List<INDArray> feedForwardR(List<INDArray> acts, INDArray v) { List<INDArray> R = new ArrayList<>(); R.add(Nd4j.zeros(input.rows(), input.columns())); List<Pair<INDArray, INDArray>> vWvB = unPack(v); List<INDArray> W = MultiLayerUtil.weightMatrices(this); for (int i = 0; i < layers.length; i++) { ActivationFunction derivative = getLayers()[i].conf().getActivationFunction(); //R[i] * W[i] + acts[i] * (vW[i] + vB[i]) .* f'([acts[i + 1]) R.add(R.get(i).mmul(W.get(i)) .addi(acts.get(i).mmul(vWvB.get(i).getFirst().addRowVector(vWvB.get(i).getSecond()))) .muli((derivative.applyDerivative(acts.get(i + 1))))); } return R; } /** * Do a back prop iteration. This involves computing the activations, * tracking the last neuralNets weights to revert to in case of convergence, * the learning rate being used to iterate and the current epoch * * @param v the v in gaussian newton vector g * v * @return whether the training should converge or not */ protected List<Pair<INDArray, INDArray>> backPropGradientR(INDArray v) { //feedforward to compute activations //initial error //log.info("Back prop step " + epoch); if (mask == null) { initMask(); } //precompute deltas List<INDArray> deltas = computeDeltasR(v); //compute derivatives and gradients given activations List<Pair<INDArray, INDArray>> list = new ArrayList<>(); for (int l = 0; l < getnLayers(); l++) { INDArray gradientChange = deltas.get(l); if (gradientChange.length() != getLayers()[l].getParam(DefaultParamInitializer.WEIGHT_KEY).length()) { throw new IllegalStateException("Gradient change not equal to weight change"); } //update hidden bias INDArray deltaColumnSums = deltas.get(l).mean(0); if (deltaColumnSums.length() != layers[l].getParam(DefaultParamInitializer.BIAS_KEY).length()) { throw new IllegalStateException("Bias change not equal to weight change"); } list.add(new Pair<>(gradientChange, deltaColumnSums)); } INDArray pack = pack(list).addi(mask.mul(defaultConfiguration.getL2()).muli(v)) .addi(v.mul(layerWiseConfiguration.getDampingFactor())); return unPack(pack); } /** * Unpacks a parameter matrix in to a transform of pairs(w,hbias) triples * with layer wise * * @param param the param vector * @return a segmented list of the param vector */ @Override public void fit(DataSetIterator iter) { pretrain(iter); iter.reset(); finetune(iter); } public void finetune(DataSetIterator iter) { iter.reset(); while (iter.hasNext()) { DataSet next = iter.next(); if (next.getFeatureMatrix() == null || next.getLabels() == null) { break; } setInput(next.getFeatureMatrix()); setLabels(next.getLabels()); if (getOutputLayer().conf().getOptimizationAlgo() != OptimizationAlgorithm.HESSIAN_FREE) { feedForward(); if (getOutputLayer() instanceof OutputLayer) { OutputLayer o = (OutputLayer) getOutputLayer(); o.fit(); } } else { //stochastic Hessian Free throw new UnsupportedOperationException("Stochastic Hessina Free is not supported yet"); } } } public void finetune(INDArray labels) { if (labels != null) { this.labels = labels; } if (!(getOutputLayer() instanceof OutputLayer)) { log.warn("Output layer not instance of output layer returning"); return; } OutputLayer o = (OutputLayer) getOutputLayer(); if (getOutputLayer().conf().getOptimizationAlgo() != OptimizationAlgorithm.HESSIAN_FREE) { feedForward(); o.fit(getOutputLayer().getInput(), labels); } else { feedForward(); o.setLabels(labels); throw new UnsupportedOperationException("Stochastic Hessian Free is not supported yet"); } } @Override public int[] predict(INDArray examples) { INDArray output = output(examples); int[] ret = new int[examples.rows()]; for (int i = 0; i < ret.length; i++) { ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i)); } return ret; } @Override public INDArray labelProbabilities(INDArray examples) { List<INDArray> feed = feedForward(examples); OutputLayer o = (OutputLayer) getOutputLayer(); return o.labelProbabilities(feed.get(feed.size() - 1)); } @Override public void fit(INDArray examples, INDArray labels) { if (layerWiseConfiguration.isPretrain()) { pretrain(examples); } else { this.input = examples; } finetune(labels); } @Override public void fit(INDArray input) { pretrain(input); } @Override public void iterate(INDArray input) { pretrain(input); } @Override public void fit(DataSet dataSet) { fit(dataSet.getFeatureMatrix(), dataSet.getLabels()); } @Override public void fit(INDArray examples, int[] lables) { fit(examples, FeatureUtil.toOutcomeMatrix(lables, getOutputLayer().conf().getnOut())); } public INDArray output(INDArray x) { List<INDArray> activations = feedForward(x); return activations.get(activations.size() - 1); } public INDArray reconstruct(INDArray x, int layerNum) { List<INDArray> forward = feedForward(x); return forward.get(layerNum - 1); } /** * Prints the configuration */ public void printConfiguration() { StringBuilder sb = new StringBuilder(); int count = 0; for (NeuralNetworkConfiguration conf : getLayerWiseConfiguration().getConfs()) { sb.append(" Layer " + count++ + " conf " + conf); } log.info(sb.toString()); } /** * Assigns the parameters of this model to the ones specified by this * network. This is used in loading from input streams, factory methods, etc * * @param network the network to getFromOrigin parameters from */ public void update(MultiLayerNetwork network) { this.defaultConfiguration = network.defaultConfiguration; this.input = network.input; this.labels = network.labels; this.weightTransforms = network.weightTransforms; this.visibleBiasTransform = network.visibleBiasTransform; this.hiddenBiasTransforms = network.hiddenBiasTransforms; this.layers = ArrayUtils.clone(network.layers); } public Layer getOutputLayer() { return getLayers()[getLayers().length - 1]; } private void initMask() { setMask(Nd4j.ones(1, pack().length())); } public int getnLayers() { return layerWiseConfiguration.getHiddenLayerSizes().length + 1; } public synchronized Layer[] getLayers() { return layers; } public void setLayers(Layer[] layers) { this.layers = layers; } public INDArray getMask() { return mask; } public void setMask(INDArray mask) { this.mask = mask; } @Override public double score(DataSet data) { feedForward(data.getFeatureMatrix()); setLabels(data.getLabels()); return score(); } @Override public void fit() { fit(input, labels); } @Override public void update(Gradient gradient) { } @Override public double score() { if (getOutputLayer().getInput() == null) { feedForward(); } return getOutputLayer().score(); } public double score(INDArray param) { INDArray params = params(); setParameters(param); double ret = score(); double regCost = 0.5f * defaultConfiguration.getL2() * (double) Transforms.pow(mask.mul(param), 2).sum(Integer.MAX_VALUE).element(); setParameters(params); return ret + regCost; } public void setParameters(INDArray params) { int idx = 0; for (int i = 0; i < getLayers().length; i++) { Layer layer = getLayers()[i]; int range = layer.numParams(); layer.setParams(params.get(NDArrayIndex.interval(idx, range + idx))); idx += range; } } public void merge(MultiLayerNetwork network, int batchSize) { if (network.layers.length != layers.length) { throw new IllegalArgumentException("Unable to merge networks that are not of equal length"); } for (int i = 0; i < getnLayers(); i++) { Layer n = layers[i]; Layer otherNetwork = network.layers[i]; n.merge(otherNetwork, batchSize); } getOutputLayer().merge(network.getOutputLayer(), batchSize); } @Override public double score(INDArray examples, INDArray labels) { feedForward(examples); setLabels(labels); Evaluation eval = new Evaluation(); eval.eval(labels, labelProbabilities(examples)); return eval.f1(); } @Override public int numLables() { return labels.columns(); } @Override public int numParams() { int length = 0; for (int i = 0; i < layers.length; i++) { length += layers[i].numParams(); } return length; } /** * GETTER and SETTER */ //<editor-fold defaultstate="collapsed" desc="getter and setter"> public void setLabels(INDArray labels) { this.labels = labels; } public synchronized INDArray getInput() { return input; } public void setInput(INDArray input) { if (input != null && this.layers == null) { initializeLayer(input); ; } this.input = input; } //</editor-fold> }