com.anhth12.nn.multilayer.MultiLayerNetwork.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.nn.multilayer.MultiLayerNetwork.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.nn.multilayer;

import com.anhth12.dataset.iterator.DataSetIterator;
import com.anhth12.downpourSGD.util.Pair;
import com.anhth12.eval.Evaluation;
import com.anhth12.nn.api.Classifier;
import com.anhth12.nn.api.Layer;
import com.anhth12.nn.api.OptimizationAlgorithm;
import com.anhth12.nn.conf.MultiLayerConfiguration;
import com.anhth12.nn.conf.NeuralNetworkConfiguration;
import com.anhth12.nn.conf.OutputPreprocessor;
import com.anhth12.nn.gradient.DefaultGradient;
import com.anhth12.nn.gradient.Gradient;
import com.anhth12.nn.layers.OutputLayer;
import com.anhth12.nn.params.DefaultParamInitializer;
import com.anhth12.util.MultiLayerUtil;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.sampling.Sampling;
import org.nd4j.linalg.transformation.MatrixTransform;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 *
 * @author anhth12
 */
public class MultiLayerNetwork implements Serializable, Classifier {

    private static Logger log = LoggerFactory.getLogger(MultiLayerNetwork.class);

    private Layer[] layers;
    private INDArray input, labels;

    //layer - weight transform
    protected Map<Integer, MatrixTransform> weightTransforms = new HashMap<>();
    //layer - hbias transform
    protected Map<Integer, MatrixTransform> hiddenBiasTransforms = new HashMap<>();
    //layer - vbias transforms
    protected Map<Integer, MatrixTransform> visibleBiasTransform = new HashMap<>();

    protected NeuralNetworkConfiguration defaultConfiguration;
    protected MultiLayerConfiguration layerWiseConfiguration;
    protected boolean initCalled = false;

    //binary drop connect mask;
    protected INDArray mask;

    public MultiLayerNetwork(MultiLayerConfiguration conf) {
        this.layerWiseConfiguration = conf;
        defaultConfiguration = conf.getConf(0);
    }

    protected void initializeConfigurations() {
        if (layerWiseConfiguration == null) {
            layerWiseConfiguration = new MultiLayerConfiguration.Builder().build();
        }

        if (layers == null) {
            layers = new Layer[layerWiseConfiguration.getHiddenLayerSizes().length + 1];
        }

        if (defaultConfiguration == null) {
            defaultConfiguration = new NeuralNetworkConfiguration();
        }

        if (layerWiseConfiguration == null || layerWiseConfiguration.getConfs().isEmpty()) {
            for (int i = 0; i < layerWiseConfiguration.getHiddenLayerSizes().length + 1; i++) {
                layerWiseConfiguration.getConfs().add(defaultConfiguration.clone());
            }
        }
    }

    public void pretrain(DataSetIterator iter) {
        if (!layerWiseConfiguration.isPretrain()) {
            return;
        }

        INDArray layerInput;
        for (int i = 0; i < getnLayers(); i++) {
            if (i == 0) {
                while (iter.hasNext()) {
                    DataSet next = iter.next();
                    this.input = next.getFeatureMatrix();
                    if (this.getInput() == null || this.getLayers() == null) {
                        setInput(input);
                        initializeLayer(input);
                    } else {
                        setInput(input);
                    }
                    log.info("Training on layer " + (i + 1));
                    getLayers()[i].fit(next.getFeatureMatrix());
                }
                iter.reset();
            } else {
                while (iter.hasNext()) {
                    DataSet next = iter.next();
                    layerInput = next.getFeatureMatrix();
                    for (int j = 0; j <= i; j++) {
                        layerInput = activationFromPrevLayer(j - 1, layerInput);
                    }
                    log.info("Training on layer " + (i + 1));
                    getLayers()[i].fit(layerInput);
                }
                iter.reset();
            }
        }
    }

    public void pretrain(INDArray input) {

        if (!layerWiseConfiguration.isPretrain()) {
            return;
        }
        /* During pretrain, feed forward expected activations of network, use activation cooccurrences during pretrain  */
        if (this.getInput() == null || this.getLayers() == null) {
            setInput(input);
            initializeLayer(input);
        } else {
            setInput(input);
        }

        INDArray layerInput = null;

        for (int i = 0; i < getnLayers() - 1; i++) {
            if (i == 0) {
                layerInput = getInput();
            } else {
                layerInput = activationFromPrevLayer(i - 1, layerInput);
            }
            log.info("Training on layer " + (i + 1));
            getLayers()[i].fit(layerInput);

        }
    }

    @Override
    public int batchSize() {
        return input.rows();
    }

    @Override
    public INDArray input() {
        return input;
    }

    @Override
    public void validateInput() {

    }

    @Override
    public void setConf(NeuralNetworkConfiguration conf) {
        throw new UnsupportedOperationException("Not supported yet."); //To change body of generated methods, choose Tools | Templates.
    }

    @Override
    public NeuralNetworkConfiguration conf() {
        throw new UnsupportedOperationException("Not supported yet."); //To change body of generated methods, choose Tools | Templates.
    }

    @Override
    public INDArray transform(INDArray data) {
        return output(data);
    }

    public NeuralNetworkConfiguration getDefaultConfiguration() {
        return defaultConfiguration;
    }

    public void setDefaultConfiguration(NeuralNetworkConfiguration defaultConfiguration) {
        this.defaultConfiguration = defaultConfiguration;
    }

    public MultiLayerConfiguration getLayerWiseConfiguration() {
        return layerWiseConfiguration;
    }

    public void setLayerWiseConfiguration(MultiLayerConfiguration layerWiseConfiguration) {
        this.layerWiseConfiguration = layerWiseConfiguration;
    }

    public void initializeLayer(INDArray input) {
        if (input == null) {
            throw new IllegalArgumentException("Unable init neural network with empty input");
        }

        int[] hiddenLayerSizes = layerWiseConfiguration.getHiddenLayerSizes();

        if (input.shape().length == 2) {
            for (int i = 0; i < hiddenLayerSizes.length; i++) {
                if (hiddenLayerSizes[i] < 1) {
                    throw new IllegalStateException("All hidden layer size must >= 1");
                }
            }
        }
        this.input = input.dup();
        if (!initCalled) {
            init();
        }
    }

    public void init() {
        if (layerWiseConfiguration == null || layers == null) {
            initializeConfigurations();
        }

        //unnecessary
        INDArray layerInput = input;

        int inputSize;
        if (getnLayers() < 1) {
            throw new IllegalStateException("number layer must >= 1");
        }
        int[] hiddenLayerSizes = layerWiseConfiguration.getHiddenLayerSizes();

        if (this.layers == null || this.layers[0] == null) {
            this.layers = new Layer[hiddenLayerSizes.length + 1];
            for (int i = 0; i < this.layers.length; i++) {
                if (i == 0) {
                    inputSize = layerWiseConfiguration.getConf(0).getnIn();
                } else {
                    inputSize = hiddenLayerSizes[i - 1];
                }

                if (i == 0) {
                    layerWiseConfiguration.getConf(i).setnIn(inputSize);
                    layerWiseConfiguration.getConf(i).setnOut(hiddenLayerSizes[i]);
                    layers[i] = layerWiseConfiguration.getConf(i).getLayerFactory()
                            .create(layerWiseConfiguration.getConf(i));
                } else if (i < getLayers().length - 1) {
                    if (input != null) {
                        //unneccessary
                        layerInput = activationFromPrevLayer(i - 1, layerInput);
                    }
                    layerWiseConfiguration.getConf(i).setnIn(inputSize);
                    layerWiseConfiguration.getConf(i).setnOut(hiddenLayerSizes[i]);
                    layers[i] = layerWiseConfiguration.getConf(i).getLayerFactory()
                            .create(layerWiseConfiguration.getConf(i));
                }
            }

            NeuralNetworkConfiguration last = layerWiseConfiguration
                    .getConf(layerWiseConfiguration.getConfs().size() - 1);
            NeuralNetworkConfiguration secondToLast = layerWiseConfiguration
                    .getConf(layerWiseConfiguration.getConfs().size() - 2);
            last.setnIn(secondToLast.getnOut());

            this.layers[layers.length - 1] = last.getLayerFactory().create(last);

            initCalled = true;
            initMask();
        }
    }

    public INDArray activate() {
        return getLayers()[getLayers().length - 1].activate();
    }

    public INDArray activate(int layer) {
        return getLayers()[layer].activate();
    }

    public INDArray activate(int layer, INDArray input) {
        return getLayers()[layer].activate(input);
    }

    public void initialize(DataSet dataSet) {
        setInput(dataSet.getFeatureMatrix());
        //feedforward
        feedForward(dataSet.getFeatureMatrix());
        this.labels = dataSet.getLabels();
        if (getOutputLayer() instanceof OutputLayer) {
            OutputLayer o = (OutputLayer) getOutputLayer();
            o.setLabels(labels);
        }
    }

    public synchronized INDArray activationFromPrevLayer(int curr, INDArray input) {
        return layers[curr].activate(input);
    }

    public List<INDArray> feedForward() {
        INDArray currInput = this.input;
        if (this.input.isMatrix() && this.input.columns() != defaultConfiguration.getnIn()) {
            throw new IllegalStateException("Illegal input length");
        }

        List<INDArray> activations = new ArrayList<>();
        activations.add(currInput);

        for (int i = 0; i < layers.length; i++) {
            currInput = activationFromPrevLayer(i, currInput);
            //pre process the activation before passing to the next layer
            OutputPreprocessor preProcessor = getLayerWiseConfiguration().getPreProcessor(i);
            if (preProcessor != null) {
                currInput = preProcessor.preProcess(currInput);
            }
            //applies drop connect to the activation
            applyDropConnectIfNecessary(currInput);
            activations.add(currInput);
        }

        return activations;
    }

    public List<INDArray> feedForward(INDArray input) {
        if (input == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        } else {
            this.input = input;
        }
        return feedForward();
    }

    @Override
    public Gradient getGradient() {
        Gradient ret = new DefaultGradient();
        for (int i = 0; i < layers.length; i += 2) {
            ret.gradientLookupTable().put(String.valueOf(i), layers[i].getGradient().gradient());
        }

        return ret;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(getGradient(), getOutputLayer().score());
    }

    protected void applyDropConnectIfNecessary(INDArray input) {
        if (layerWiseConfiguration.isUseDropConnect()) {
            INDArray mask = Sampling.binomial(Nd4j.valueArrayOf(input.rows(), input.columns(), 0.5), 1,
                    defaultConfiguration.getRng());
            input.muli(mask);
            //apply l2 for drop connect
            if (defaultConfiguration.getL2() > 0) {
                input.muli(defaultConfiguration.getL2());
            }
        }
    }

    /* delta computation for back prop with the R operator */
    protected List<INDArray> computeDeltasR(INDArray v) {
        List<INDArray> deltaRet = new ArrayList<>();

        INDArray[] deltas = new INDArray[getnLayers() + 1];
        List<INDArray> activations = feedForward();
        List<INDArray> rActivations = feedForwardR(activations, v);
        /*
         * Precompute activations and z's (pre activation network outputs)
         */
        List<INDArray> weights = new ArrayList<>();
        List<INDArray> biases = new ArrayList<>();
        List<ActivationFunction> activationFunctions = new ArrayList<>();

        for (int j = 0; j < getLayers().length; j++) {
            weights.add(getLayers()[j].getParam(DefaultParamInitializer.WEIGHT_KEY));
            biases.add(getLayers()[j].getParam(DefaultParamInitializer.BIAS_KEY));
            activationFunctions.add(getLayers()[j].conf().getActivationFunction());
        }

        INDArray rix = rActivations.get(rActivations.size() - 1).divi((double) input.rows());
        LinAlgExceptions.assertValidNum(rix);

        //errors
        for (int i = getnLayers() - 1; i >= 0; i--) {
            //W^t * error^l + 1
            deltas[i] = activations.get(i).transpose().mmul(rix);
            applyDropConnectIfNecessary(deltas[i]);

            if (i > 0) {
                rix = rix.mmul(weights.get(i).addRowVector(biases.get(i)).transpose())
                        .muli(activationFunctions.get(i - 1).applyDerivative(activations.get(i)));
            }

        }

        for (int i = 0; i < deltas.length - 1; i++) {
            if (defaultConfiguration.isConstrainGradientToUnitNorm()) {
                double sum = deltas[i].sum(Integer.MAX_VALUE).getDouble(0);
                if (sum > 0) {
                    deltaRet.add(deltas[i].div(deltas[i].norm2(Integer.MAX_VALUE)));
                } else {
                    deltaRet.add(deltas[i]);
                }
            } else {
                deltaRet.add(deltas[i]);
            }
            LinAlgExceptions.assertValidNum(deltaRet.get(i));
        }

        return deltaRet;
    }

    public void dampingUpdate(double rho, double boost, double decrease) {
        if (rho < 0.25 || Double.isNaN(rho)) {
            layerWiseConfiguration.setDampingFactor(getLayerWiseConfiguration().getDampingFactor() * boost);
        } else if (rho > 0.75) {
            layerWiseConfiguration.setDampingFactor(getLayerWiseConfiguration().getDampingFactor() * decrease);
        }
    }

    public double reductionRation(INDArray p, double currScore, double score, INDArray gradient) {
        double currentDamp = layerWiseConfiguration.getDampingFactor();
        layerWiseConfiguration.setDampingFactor(0);
        ;
        INDArray denom = getBackPropRGradient(p);
        denom.muli(0.5).muli(p.mul(denom)).sum(0);
        denom.subi(gradient.mul(p).sum(0));
        double rho = (currScore - score) / (double) denom.getScalar(0).element();
        layerWiseConfiguration.setDampingFactor(currentDamp);
        if (score - currScore > 0) {
            return Float.NEGATIVE_INFINITY;
        }
        return rho;
    }

    /* delta computation for back prop with precon for SFH */
    protected List<Pair<INDArray, INDArray>> computeDeltas2() {
        List<Pair<INDArray, INDArray>> deltaRet = new ArrayList<>();
        List<INDArray> activations = feedForward();
        INDArray[] deltas = new INDArray[activations.size() - 1];
        INDArray[] preCons = new INDArray[activations.size() - 1];

        //- y - h
        INDArray ix = activations.get(activations.size() - 1).sub(labels).div(labels.rows());

        /*
         * Precompute activations and z's (pre activation network outputs)
         */
        List<INDArray> weights = new ArrayList<>();
        List<INDArray> biases = new ArrayList<>();

        List<ActivationFunction> activationFunctions = new ArrayList<>();
        for (int j = 0; j < getLayers().length; j++) {
            weights.add(getLayers()[j].getParam(DefaultParamInitializer.WEIGHT_KEY));
            biases.add(getLayers()[j].getParam(DefaultParamInitializer.BIAS_KEY));
            activationFunctions.add(getLayers()[j].conf().getActivationFunction());
        }

        //errors
        for (int i = weights.size() - 1; i >= 0; i--) {
            deltas[i] = activations.get(i).transpose().mmul(ix);
            preCons[i] = Transforms.pow(activations.get(i).transpose(), 2).mmul(Transforms.pow(ix, 2))
                    .muli(labels.rows());
            applyDropConnectIfNecessary(deltas[i]);

            if (i > 0) {
                //W[i] + b[i] * f'(z[i - 1])
                ix = ix.mmul(weights.get(i).transpose())
                        .muli(activationFunctions.get(i - 1).applyDerivative(activations.get(i)));
            }
        }

        for (int i = 0; i < deltas.length; i++) {
            if (defaultConfiguration.isConstrainGradientToUnitNorm()) {
                deltaRet.add(new Pair<>(deltas[i].divi(deltas[i].norm2(Integer.MAX_VALUE)), preCons[i]));
            } else {
                deltaRet.add(new Pair<>(deltas[i], preCons[i]));
            }

        }

        return deltaRet;
    }

    /* delta computation for back prop */
    protected List<INDArray> computeDeltas() {
        List<INDArray> deltaRet = new ArrayList<>();
        INDArray[] deltas = new INDArray[getnLayers() + 2];

        List<INDArray> activations = feedForward();

        //- y - h
        INDArray ix = labels.sub(activations.get(activations.size() - 1)).subi(getOutputLayer().conf()
                .getActivationFunction().applyDerivative(activations.get(activations.size() - 1)));

        /*
         * Precompute activations and z's (pre activation network outputs)
         */
        List<INDArray> weights = new ArrayList<>();
        List<INDArray> biases = new ArrayList<>();

        List<ActivationFunction> activationFunctions = new ArrayList<>();
        for (int j = 0; j < getLayers().length; j++) {
            weights.add(getLayers()[j].getParam(DefaultParamInitializer.WEIGHT_KEY));
            biases.add(getLayers()[j].getParam(DefaultParamInitializer.BIAS_KEY));
            activationFunctions.add(getLayers()[j].conf().getActivationFunction());
        }

        weights.add(getOutputLayer().getParam(DefaultParamInitializer.WEIGHT_KEY));
        biases.add(getOutputLayer().getParam(DefaultParamInitializer.BIAS_KEY));
        activationFunctions.add(getOutputLayer().conf().getActivationFunction());

        //errors
        for (int i = getnLayers() + 1; i >= 0; i--) {
            //output layer
            if (i >= getnLayers() + 1) {
                //-( y - h) .* f'(z^l) where l is the output layer
                deltas[i] = ix;

            } else {
                INDArray delta = activations.get(i).transpose().mmul(ix);
                deltas[i] = delta;
                applyDropConnectIfNecessary(deltas[i]);
                INDArray weightsPlusBias = weights.get(i).transpose();
                INDArray activation = activations.get(i);
                if (i > 0) {
                    ix = ix.mmul(weightsPlusBias).muli(activationFunctions.get(i - 1).applyDerivative(activation));
                }

            }

        }

        for (int i = 0; i < deltas.length; i++) {
            if (defaultConfiguration.isConstrainGradientToUnitNorm()) {
                deltaRet.add(deltas[i].divi(deltas[i].norm2(Integer.MAX_VALUE)));
            } else {
                deltaRet.add(deltas[i]);
            }
        }

        return deltaRet;
    }

    public void backPropStep() {
        List<Pair<INDArray, INDArray>> deltas = backPropGradient();
        for (int i = 0; i < layers.length; i++) {
            layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).addi(deltas.get(i).getFirst());
            layers[i].getParam(DefaultParamInitializer.BIAS_KEY).addi(deltas.get(i).getSecond());

        }

    }

    public INDArray getBackPropRGradient(INDArray v) {
        return pack(backPropGradientR(v));
    }

    /**
     * Gets the back prop gradient with the r operator (gauss vector) and the
     * associated precon matrix This is also called computeGV
     *
     * @return the back prop with r gradient
     */
    public Pair<INDArray, INDArray> getBackPropGradient2() {
        List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> deltas = backPropGradient2();
        List<Pair<INDArray, INDArray>> deltaNormal = new ArrayList<>();
        List<Pair<INDArray, INDArray>> deltasPreCon = new ArrayList<>();
        for (int i = 0; i < deltas.size(); i++) {
            deltaNormal.add(deltas.get(i).getFirst());
            deltasPreCon.add(deltas.get(i).getSecond());
        }

        return new Pair<>(pack(deltaNormal), pack(deltasPreCon));
    }

    @Override
    public MultiLayerNetwork clone() {
        MultiLayerNetwork ret;
        try {
            ret = getClass().newInstance();
            ret.update(this);

        } catch (Exception e) {
            throw new IllegalStateException("Unable to cloe network");
        }
        return ret;
    }

    @Override
    public INDArray params() {
        List<INDArray> params = new ArrayList<>();
        for (int i = 0; i < getnLayers(); i++) {
            params.add(layers[i].params());
        }

        return Nd4j.toFlattened(params);
    }

    @Override
    public void setParams(INDArray param) {
        setParameters(param);
    }

    public INDArray pack() {
        return params();

    }

    public INDArray pack(List<Pair<INDArray, INDArray>> layers) {
        List<INDArray> list = new ArrayList<>();

        for (int i = 0; i < layers.size(); i++) {
            list.add(layers.get(i).getFirst());
            list.add(layers.get(i).getSecond());
        }
        return Nd4j.toFlattened(list);
    }

    public List<Pair<INDArray, INDArray>> backPropGradient() {
        //feedforward to compute activations
        //initial error

        //precompute deltas
        List<INDArray> deltas = computeDeltas();

        List<Pair<INDArray, INDArray>> vWvB = new ArrayList<>();
        for (int i = 0; i < layers.length; i++)
            vWvB.add(new Pair<>(layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY),
                    layers[i].getParam(DefaultParamInitializer.BIAS_KEY)));

        List<Pair<INDArray, INDArray>> list = new ArrayList<>();

        for (int l = 0; l < getnLayers() + 1; l++) {
            INDArray gradientChange = deltas.get(l);
            if (gradientChange.length() != getLayers()[l].getParam(DefaultParamInitializer.WEIGHT_KEY).length())
                throw new IllegalStateException("Gradient change not equal to weight change");

            //update hidden bias
            INDArray deltaColumnSums = deltas.get(l).isVector() ? deltas.get(l) : deltas.get(l).mean(0);

            list.add(new Pair<>(gradientChange, deltaColumnSums));

        }

        if (mask == null)
            initMask();

        return list;

    }

    public List<Pair<INDArray, INDArray>> unPack(INDArray param) {
        //more sanity checks!
        if (param.rows() != 1) {
            param = param.reshape(1, param.length());
        }
        List<Pair<INDArray, INDArray>> ret = new ArrayList<>();
        int curr = 0;
        for (int i = 0; i < layers.length; i++) {
            int layerLength = layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).length()
                    + layers[i].getParam(DefaultParamInitializer.BIAS_KEY).length();
            INDArray subMatrix = param.get(NDArrayIndex.interval(curr, curr + layerLength));
            INDArray weightPortion = subMatrix
                    .get(NDArrayIndex.interval(0, layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).length()));

            int beginHBias = layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).length();
            int endHbias = subMatrix.length();
            INDArray hBiasPortion = subMatrix.get(NDArrayIndex.interval(beginHBias, endHbias));
            int layerLengthSum = weightPortion.length() + hBiasPortion.length();
            if (layerLengthSum != layerLength) {
                if (hBiasPortion.length() != layers[i].getParam(DefaultParamInitializer.BIAS_KEY).length()) {
                    throw new IllegalStateException("Hidden bias on layer " + i + " was off");
                }
                if (weightPortion.length() != layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).length()) {
                    throw new IllegalStateException("Weight portion on layer " + i + " was off");
                }

            }

            ret.add(new Pair<>(
                    weightPortion.reshape(layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).rows(),
                            layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).columns()),
                    hBiasPortion.reshape(layers[i].getParam(DefaultParamInitializer.BIAS_KEY).rows(),
                            layers[i].getParam(DefaultParamInitializer.BIAS_KEY).columns())));
            curr += layerLength;
        }

        return ret;
    }

    protected List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> backPropGradient2() {
        //feedforward to compute activations
        //initial error

        //precompute deltas
        List<Pair<INDArray, INDArray>> deltas = computeDeltas2();

        List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> list = new ArrayList<>();
        List<Pair<INDArray, INDArray>> grad = new ArrayList<>();
        List<Pair<INDArray, INDArray>> preCon = new ArrayList<>();

        for (int l = 0; l < deltas.size(); l++) {
            INDArray gradientChange = deltas.get(l).getFirst();
            INDArray preConGradientChange = deltas.get(l).getSecond();

            if (l < layers.length
                    && gradientChange.length() != layers[l].getParam(DefaultParamInitializer.WEIGHT_KEY).length())
                throw new IllegalStateException("Gradient change not equal to weight change");

            //update hidden bias
            INDArray deltaColumnSums = deltas.get(l).getFirst().mean(0);
            INDArray preConColumnSums = deltas.get(l).getSecond().mean(0);

            grad.add(new Pair<>(gradientChange, deltaColumnSums));
            preCon.add(new Pair<>(preConGradientChange, preConColumnSums));
            if (l < layers.length
                    && deltaColumnSums.length() != layers[l].getParam(DefaultParamInitializer.BIAS_KEY).length())
                throw new IllegalStateException("Bias change not equal to weight change");
            else if (l == getLayers().length && deltaColumnSums.length() != getOutputLayer()
                    .getParam(DefaultParamInitializer.BIAS_KEY).length())
                throw new IllegalStateException("Bias change not equal to weight change");

        }

        INDArray g = pack(grad);
        INDArray con = pack(preCon);
        INDArray theta = params();

        if (mask == null)
            initMask();

        g.addi(theta.mul(defaultConfiguration.getL2()).muli(mask));

        INDArray conAdd = Transforms.pow(
                mask.mul(defaultConfiguration.getL2())
                        .add(Nd4j.valueArrayOf(g.rows(), g.columns(), layerWiseConfiguration.getDampingFactor())),
                3.0 / 4.0);

        con.addi(conAdd);

        List<Pair<INDArray, INDArray>> gUnpacked = unPack(g);

        List<Pair<INDArray, INDArray>> conUnpacked = unPack(con);

        for (int i = 0; i < gUnpacked.size(); i++)
            list.add(new Pair<>(gUnpacked.get(i), conUnpacked.get(i)));

        return list;

    }

    public List<INDArray> feedForwardR(List<INDArray> acts, INDArray v) {
        List<INDArray> R = new ArrayList<>();
        R.add(Nd4j.zeros(input.rows(), input.columns()));
        List<Pair<INDArray, INDArray>> vWvB = unPack(v);
        List<INDArray> W = MultiLayerUtil.weightMatrices(this);

        for (int i = 0; i < layers.length; i++) {
            ActivationFunction derivative = getLayers()[i].conf().getActivationFunction();
            //R[i] * W[i] + acts[i] * (vW[i] + vB[i]) .* f'([acts[i + 1])
            R.add(R.get(i).mmul(W.get(i))
                    .addi(acts.get(i).mmul(vWvB.get(i).getFirst().addRowVector(vWvB.get(i).getSecond())))
                    .muli((derivative.applyDerivative(acts.get(i + 1)))));

        }
        return R;
    }

    /**
     * Do a back prop iteration. This involves computing the activations,
     * tracking the last neuralNets weights to revert to in case of convergence,
     * the learning rate being used to iterate and the current epoch
     *
     * @param v the v in gaussian newton vector g * v
     * @return whether the training should converge or not
     */
    protected List<Pair<INDArray, INDArray>> backPropGradientR(INDArray v) {
        //feedforward to compute activations
        //initial error
        //log.info("Back prop step " + epoch);
        if (mask == null) {
            initMask();
        }
        //precompute deltas
        List<INDArray> deltas = computeDeltasR(v);
        //compute derivatives and gradients given activations

        List<Pair<INDArray, INDArray>> list = new ArrayList<>();

        for (int l = 0; l < getnLayers(); l++) {
            INDArray gradientChange = deltas.get(l);

            if (gradientChange.length() != getLayers()[l].getParam(DefaultParamInitializer.WEIGHT_KEY).length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }

            //update hidden bias
            INDArray deltaColumnSums = deltas.get(l).mean(0);
            if (deltaColumnSums.length() != layers[l].getParam(DefaultParamInitializer.BIAS_KEY).length()) {
                throw new IllegalStateException("Bias change not equal to weight change");
            }

            list.add(new Pair<>(gradientChange, deltaColumnSums));

        }

        INDArray pack = pack(list).addi(mask.mul(defaultConfiguration.getL2()).muli(v))
                .addi(v.mul(layerWiseConfiguration.getDampingFactor()));
        return unPack(pack);

    }

    /**
     * Unpacks a parameter matrix in to a transform of pairs(w,hbias) triples
     * with layer wise
     *
     * @param param the param vector
     * @return a segmented list of the param vector
     */

    @Override
    public void fit(DataSetIterator iter) {
        pretrain(iter);

        iter.reset();
        finetune(iter);
    }

    public void finetune(DataSetIterator iter) {
        iter.reset();

        while (iter.hasNext()) {
            DataSet next = iter.next();

            if (next.getFeatureMatrix() == null || next.getLabels() == null) {
                break;
            }

            setInput(next.getFeatureMatrix());
            setLabels(next.getLabels());

            if (getOutputLayer().conf().getOptimizationAlgo() != OptimizationAlgorithm.HESSIAN_FREE) {
                feedForward();
                if (getOutputLayer() instanceof OutputLayer) {
                    OutputLayer o = (OutputLayer) getOutputLayer();
                    o.fit();
                }
            } else {
                //stochastic Hessian Free
                throw new UnsupportedOperationException("Stochastic Hessina Free is not supported yet");
            }
        }
    }

    public void finetune(INDArray labels) {
        if (labels != null) {
            this.labels = labels;
        }

        if (!(getOutputLayer() instanceof OutputLayer)) {
            log.warn("Output layer not instance of output layer returning");
            return;
        }

        OutputLayer o = (OutputLayer) getOutputLayer();
        if (getOutputLayer().conf().getOptimizationAlgo() != OptimizationAlgorithm.HESSIAN_FREE) {
            feedForward();
            o.fit(getOutputLayer().getInput(), labels);
        } else {
            feedForward();
            o.setLabels(labels);
            throw new UnsupportedOperationException("Stochastic Hessian Free is not supported yet");
        }
    }

    @Override
    public int[] predict(INDArray examples) {
        INDArray output = output(examples);

        int[] ret = new int[examples.rows()];

        for (int i = 0; i < ret.length; i++) {
            ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
        }
        return ret;
    }

    @Override
    public INDArray labelProbabilities(INDArray examples) {
        List<INDArray> feed = feedForward(examples);

        OutputLayer o = (OutputLayer) getOutputLayer();
        return o.labelProbabilities(feed.get(feed.size() - 1));
    }

    @Override
    public void fit(INDArray examples, INDArray labels) {
        if (layerWiseConfiguration.isPretrain()) {
            pretrain(examples);
        } else {
            this.input = examples;
        }
        finetune(labels);
    }

    @Override
    public void fit(INDArray input) {
        pretrain(input);
    }

    @Override
    public void iterate(INDArray input) {
        pretrain(input);
    }

    @Override
    public void fit(DataSet dataSet) {
        fit(dataSet.getFeatureMatrix(), dataSet.getLabels());
    }

    @Override
    public void fit(INDArray examples, int[] lables) {
        fit(examples, FeatureUtil.toOutcomeMatrix(lables, getOutputLayer().conf().getnOut()));
    }

    public INDArray output(INDArray x) {
        List<INDArray> activations = feedForward(x);
        return activations.get(activations.size() - 1);
    }

    public INDArray reconstruct(INDArray x, int layerNum) {
        List<INDArray> forward = feedForward(x);
        return forward.get(layerNum - 1);
    }

    /**
     * Prints the configuration
     */
    public void printConfiguration() {
        StringBuilder sb = new StringBuilder();
        int count = 0;
        for (NeuralNetworkConfiguration conf : getLayerWiseConfiguration().getConfs()) {
            sb.append(" Layer " + count++ + " conf " + conf);
        }

        log.info(sb.toString());
    }

    /**
     * Assigns the parameters of this model to the ones specified by this
     * network. This is used in loading from input streams, factory methods, etc
     *
     * @param network the network to getFromOrigin parameters from
     */
    public void update(MultiLayerNetwork network) {
        this.defaultConfiguration = network.defaultConfiguration;
        this.input = network.input;
        this.labels = network.labels;
        this.weightTransforms = network.weightTransforms;
        this.visibleBiasTransform = network.visibleBiasTransform;
        this.hiddenBiasTransforms = network.hiddenBiasTransforms;
        this.layers = ArrayUtils.clone(network.layers);

    }

    public Layer getOutputLayer() {
        return getLayers()[getLayers().length - 1];
    }

    private void initMask() {
        setMask(Nd4j.ones(1, pack().length()));
    }

    public int getnLayers() {
        return layerWiseConfiguration.getHiddenLayerSizes().length + 1;
    }

    public synchronized Layer[] getLayers() {
        return layers;
    }

    public void setLayers(Layer[] layers) {
        this.layers = layers;
    }

    public INDArray getMask() {
        return mask;
    }

    public void setMask(INDArray mask) {
        this.mask = mask;
    }

    @Override
    public double score(DataSet data) {
        feedForward(data.getFeatureMatrix());
        setLabels(data.getLabels());

        return score();
    }

    @Override
    public void fit() {
        fit(input, labels);
    }

    @Override
    public void update(Gradient gradient) {

    }

    @Override
    public double score() {
        if (getOutputLayer().getInput() == null) {
            feedForward();
        }
        return getOutputLayer().score();
    }

    public double score(INDArray param) {
        INDArray params = params();
        setParameters(param);
        double ret = score();
        double regCost = 0.5f * defaultConfiguration.getL2()
                * (double) Transforms.pow(mask.mul(param), 2).sum(Integer.MAX_VALUE).element();
        setParameters(params);
        return ret + regCost;

    }

    public void setParameters(INDArray params) {
        int idx = 0;
        for (int i = 0; i < getLayers().length; i++) {
            Layer layer = getLayers()[i];
            int range = layer.numParams();
            layer.setParams(params.get(NDArrayIndex.interval(idx, range + idx)));
            idx += range;
        }
    }

    public void merge(MultiLayerNetwork network, int batchSize) {
        if (network.layers.length != layers.length) {
            throw new IllegalArgumentException("Unable to merge networks that are not of equal length");
        }

        for (int i = 0; i < getnLayers(); i++) {
            Layer n = layers[i];
            Layer otherNetwork = network.layers[i];
            n.merge(otherNetwork, batchSize);
        }

        getOutputLayer().merge(network.getOutputLayer(), batchSize);
    }

    @Override
    public double score(INDArray examples, INDArray labels) {
        feedForward(examples);
        setLabels(labels);
        Evaluation eval = new Evaluation();
        eval.eval(labels, labelProbabilities(examples));
        return eval.f1();
    }

    @Override
    public int numLables() {
        return labels.columns();
    }

    @Override
    public int numParams() {
        int length = 0;
        for (int i = 0; i < layers.length; i++) {
            length += layers[i].numParams();
        }

        return length;
    }

    /**
     * GETTER and SETTER
     */
    //<editor-fold defaultstate="collapsed" desc="getter and setter">
    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    public synchronized INDArray getInput() {
        return input;
    }

    public void setInput(INDArray input) {
        if (input != null && this.layers == null) {
            initializeLayer(input);
            ;
        }
        this.input = input;
    }

    //</editor-fold>
}