com.simiacryptus.mindseye.layers.java.FullyConnectedLayer.java Source code

Java tutorial

Introduction

Here is the source code for com.simiacryptus.mindseye.layers.java.FullyConnectedLayer.java

Source

/*
 * Copyright (c) 2018 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.layers.java;

import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.util.FastRandom;
import com.simiacryptus.util.JsonUtil;
import com.simiacryptus.util.Util;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.DoubleSupplier;
import java.util.function.IntToDoubleFunction;
import java.util.function.ToDoubleBiFunction;
import java.util.function.ToDoubleFunction;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
 * A dense matrix operator using vector-matrix multiplication. Represents a fully connected key of synapses, where all
 * inputs are connected to all outputs via seperate coefficients.
 */
@SuppressWarnings("serial")
public class FullyConnectedLayer extends LayerBase {

    @SuppressWarnings("unused")
    private static final Logger log = LoggerFactory.getLogger(FullyConnectedLayer.class);
    /**
     * The Input dims.
     */
    @Nullable
    public final int[] inputDims;
    /**
     * The Output dims.
     */
    @Nullable
    public final int[] outputDims;
    @Nullable
    private final Tensor weights;

    /**
     * Instantiates a new Fully connected key.
     */
    protected FullyConnectedLayer() {
        super();
        outputDims = null;
        weights = null;
        inputDims = null;
    }

    /**
     * Instantiates a new Fully connected key.
     *
     * @param inputDims  the input dims
     * @param outputDims the output dims
     */
    public FullyConnectedLayer(@Nonnull final int[] inputDims, @Nonnull final int[] outputDims) {
        final int inputs = Tensor.length(inputDims);
        this.inputDims = Arrays.copyOf(inputDims, inputDims.length);
        this.outputDims = Arrays.copyOf(outputDims, outputDims.length);
        final int outs = Tensor.length(outputDims);
        weights = new Tensor(inputs, outs);
        set(() -> {
            final double ratio = Math.sqrt(6. / (inputs + outs + 1));
            final double fate = Util.R.get().nextDouble();
            final double v = (1 - 2 * fate) * ratio;
            return v;
        });
    }

    /**
     * Instantiates a new Fully connected key.
     *
     * @param json      the json
     * @param resources the resources
     */
    protected FullyConnectedLayer(@Nonnull final JsonObject json, Map<CharSequence, byte[]> resources) {
        super(json);
        outputDims = JsonUtil.getIntArray(json.getAsJsonArray("outputDims"));
        inputDims = JsonUtil.getIntArray(json.getAsJsonArray("inputDims"));
        weights = Tensor.fromJson(json.get("weights"), resources);
    }

    /**
     * Cross multiply.
     *
     * @param rows   the rows
     * @param cols   the cols
     * @param matrix the matrix
     */
    public static void crossMultiply(@Nonnull final double[] rows, @Nonnull final double[] cols,
            final double[] matrix) {
        int i = 0;
        for (final double c : cols) {
            for (final double r : rows) {
                matrix[i++] = r * c;
            }
        }
    }

    /**
     * Cross multiply t.
     *
     * @param rows   the rows
     * @param cols   the cols
     * @param matrix the matrix
     */
    public static void crossMultiplyT(@Nonnull final double[] rows, @Nonnull final double[] cols,
            final double[] matrix) {
        int i = 0;
        for (final double r : rows) {
            for (final double c : cols) {
                matrix[i++] = r * c;
            }
        }
    }

    /**
     * From json fully connected key.
     *
     * @param json the json
     * @param rs   the rs
     * @return the fully connected key
     */
    public static FullyConnectedLayer fromJson(@Nonnull final JsonObject json, Map<CharSequence, byte[]> rs) {
        return new FullyConnectedLayer(json, rs);
    }

    /**
     * Multiply.
     *
     * @param matrix the matrix
     * @param in     the in
     * @param out    the out
     */
    public static void multiply(final double[] matrix, @Nonnull final double[] in, @Nonnull final double[] out) {
        @Nonnull
        final DoubleMatrix matrixObj = new DoubleMatrix(out.length, in.length, matrix);
        matrixObj.mmuli(new DoubleMatrix(in.length, 1, in), new DoubleMatrix(out.length, 1, out));
    }

    /**
     * Multiply t.
     *
     * @param matrix the matrix
     * @param in     the in
     * @param out    the out
     */
    public static void multiplyT(final double[] matrix, @Nonnull final double[] in, @Nonnull final double[] out) {
        @Nonnull
        DoubleMatrix doubleMatrix = new DoubleMatrix(in.length, out.length, matrix);
        @Nonnull
        final DoubleMatrix matrixObj = FullyConnectedLayer.transpose(doubleMatrix);
        matrixObj.mmuli(new DoubleMatrix(in.length, 1, in), new DoubleMatrix(out.length, 1, out));
        RecycleBin.DOUBLES.recycle(matrixObj.data, matrixObj.data.length);
    }

    /**
     * Transpose double matrix.
     *
     * @param doubleMatrix the double matrix
     * @return the double matrix
     */
    @Nonnull
    public static DoubleMatrix transpose(@Nonnull final DoubleMatrix doubleMatrix) {
        @Nonnull
        final DoubleMatrix result = new DoubleMatrix(doubleMatrix.columns, doubleMatrix.rows,
                RecycleBin.DOUBLES.obtain(doubleMatrix.length));
        for (int i = 0; i < doubleMatrix.rows; ++i) {
            for (int j = 0; j < doubleMatrix.columns; ++j) {
                result.put(j, i, doubleMatrix.get(i, j));
            }
        }
        return result;
    }

    @Override
    protected void _free() {
        weights.freeRef();
        super._free();
    }

    @Nonnull
    @Override
    public Result eval(@Nonnull final Result... inObj) {
        final TensorList indata = inObj[0].getData();
        indata.addRef();
        for (@Nonnull
        Result result : inObj) {
            result.addRef();
        }
        FullyConnectedLayer.this.addRef();
        assert Tensor.length(indata.getDimensions()) == Tensor.length(this.inputDims) : Arrays
                .toString(indata.getDimensions()) + " == " + Arrays.toString(this.inputDims);
        @Nonnull
        DoubleMatrix doubleMatrix = new DoubleMatrix(Tensor.length(indata.getDimensions()),
                Tensor.length(outputDims), this.weights.getData());
        @Nonnull
        final DoubleMatrix matrixObj = FullyConnectedLayer.transpose(doubleMatrix);
        @Nonnull
        TensorArray tensorArray = TensorArray
                .wrap(IntStream.range(0, indata.length()).parallel().mapToObj(dataIndex -> {
                    @Nullable
                    final Tensor input = indata.get(dataIndex);
                    @Nullable
                    final Tensor output = new Tensor(outputDims);
                    matrixObj.mmuli(new DoubleMatrix(input.length(), 1, input.getData()),
                            new DoubleMatrix(output.length(), 1, output.getData()));
                    input.freeRef();
                    return output;
                }).toArray(i -> new Tensor[i]));
        RecycleBin.DOUBLES.recycle(matrixObj.data, matrixObj.data.length);
        this.weights.addRef();
        return new Result(tensorArray, (@Nonnull final DeltaSet<UUID> buffer, @Nonnull final TensorList delta) -> {
            if (!isFrozen()) {
                final Delta<UUID> deltaBuffer = buffer.get(FullyConnectedLayer.this.getId(),
                        this.weights.getData());
                final int threads = 4;
                IntStream.range(0, threads).parallel().mapToObj(x -> x).flatMap(thread -> {
                    @Nullable
                    Stream<Tensor> stream = IntStream.range(0, indata.length()).filter(i -> thread == i % threads)
                            .mapToObj(dataIndex -> {
                                @Nonnull
                                final Tensor weightDelta = new Tensor(Tensor.length(inputDims),
                                        Tensor.length(outputDims));
                                Tensor deltaTensor = delta.get(dataIndex);
                                Tensor inputTensor = indata.get(dataIndex);
                                FullyConnectedLayer.crossMultiplyT(deltaTensor.getData(), inputTensor.getData(),
                                        weightDelta.getData());
                                inputTensor.freeRef();
                                deltaTensor.freeRef();
                                return weightDelta;
                            });
                    return stream;
                }).reduce((a, b) -> {
                    @Nullable
                    Tensor c = a.addAndFree(b);
                    b.freeRef();
                    return c;
                }).map(data -> {
                    @Nonnull
                    Delta<UUID> layerDelta = deltaBuffer.addInPlace(data.getData());
                    data.freeRef();
                    return layerDelta;
                });
                deltaBuffer.freeRef();
            }
            if (inObj[0].isAlive()) {
                @Nonnull
                final TensorList tensorList = TensorArray
                        .wrap(IntStream.range(0, indata.length()).parallel().mapToObj(dataIndex -> {
                            Tensor deltaTensor = delta.get(dataIndex);
                            @Nonnull
                            final Tensor passback = new Tensor(indata.getDimensions());
                            FullyConnectedLayer.multiply(this.weights.getData(), deltaTensor.getData(),
                                    passback.getData());
                            deltaTensor.freeRef();
                            return passback;
                        }).toArray(i -> new Tensor[i]));
                inObj[0].accumulate(buffer, tensorList);
            }
        }) {

            @Override
            protected void _free() {
                indata.freeRef();
                FullyConnectedLayer.this.freeRef();
                for (@Nonnull
                Result result : inObj) {
                    result.freeRef();
                }
                FullyConnectedLayer.this.weights.freeRef();
            }

            @Override
            public boolean isAlive() {
                return !isFrozen() || Arrays.stream(inObj).anyMatch(x -> x.isAlive());
            }

        };
    }

    @Nonnull
    @Override
    public JsonObject getJson(Map<CharSequence, byte[]> resources, @Nonnull DataSerializer dataSerializer) {
        @Nonnull
        final JsonObject json = super.getJsonStub();
        json.add("outputDims", JsonUtil.getJson(outputDims));
        json.add("inputDims", JsonUtil.getJson(inputDims));
        json.add("weights", getWeights().toJson(resources, dataSerializer));
        return json;
    }

    /**
     * Gets transpose.
     *
     * @return the transpose
     */
    @Nonnull
    public Layer getTranspose() {
        throw new RuntimeException("Not Implemented");
    }

    /**
     * The Weights.
     */
    /**
     * Gets weights.
     *
     * @return the weights
     */
    @Nullable
    public Tensor getWeights() {
        return weights;
    }

    /**
     * Sets weights.
     *
     * @param f the f
     * @return the weights
     */
    @Nonnull
    public FullyConnectedLayer set(@Nonnull final DoubleSupplier f) {
        Arrays.parallelSetAll(getWeights().getData(), i -> f.getAsDouble());
        return this;
    }

    /**
     * Sets weights.
     *
     * @param f the f
     * @return the weights
     */
    @Nonnull
    public FullyConnectedLayer set(@Nonnull final IntToDoubleFunction f) {
        getWeights().set(f);
        return this;
    }

    /**
     * Sets weights.
     *
     * @param f the f
     * @return the weights
     */
    @Nonnull
    public FullyConnectedLayer setByCoord(@Nonnull final ToDoubleFunction<Coordinate> f) {
        getWeights().coordStream(true).forEach(c -> {
            getWeights().set(c, f.applyAsDouble(c));
        });
        return this;
    }

    /**
     * Init spacial.
     *
     * @param radius    the radius
     * @param stiffness the stiffness
     * @param peak      the peak
     */
    public void initSpacial(final double radius, final double stiffness, final double peak) {
        setByCoord((@Nonnull final Coordinate in, @Nonnull final Coordinate out) -> {
            final double[] doubleCoords = IntStream.range(0, in.getCoords().length).mapToDouble(d -> {
                final double from = in.getCoords()[d] * 1.0 / FullyConnectedLayer.this.inputDims[d];
                final double to = out.getCoords()[d] * 1.0 / FullyConnectedLayer.this.outputDims[d];
                return from - to;
            }).toArray();
            final double dist = Math.sqrt(Arrays.stream(doubleCoords).map(x -> x * x).sum());
            final double factor = (1 + Math.tanh(stiffness * (radius - dist))) / 2;
            return peak * factor;
        });
    }

    /**
     * Sets weights.
     *
     * @param data the data
     * @return the weights
     */
    @Nonnull
    public FullyConnectedLayer set(final double[] data) {
        getWeights().set(data);
        return this;
    }

    /**
     * Set fully connected key.
     *
     * @param data the data
     * @return the fully connected key
     */
    @Nonnull
    public FullyConnectedLayer set(@Nonnull final Tensor data) {
        getWeights().set(data);
        return this;
    }

    /**
     * Sets weights.
     *
     * @param f the f
     * @return the weights
     */
    @Nonnull
    public FullyConnectedLayer setByCoord(@Nonnull final ToDoubleBiFunction<Coordinate, Coordinate> f) {
        new Tensor(inputDims).coordStream(true).forEach(in -> {
            new Tensor(outputDims).coordStream(true).forEach(out -> {
                getWeights().set(new int[] { in.getIndex(), out.getIndex() }, f.applyAsDouble(in, out));
            });
        });
        return this;
    }

    /**
     * Sets weights log.
     *
     * @param value the value
     * @return the weights log
     */
    @Nonnull
    public FullyConnectedLayer setWeightsLog(final double value) {
        getWeights().coordStream(false).forEach(c -> {
            getWeights().set(c, (FastRandom.INSTANCE.random() - 0.5) * Math.pow(10, value));
        });
        return this;
    }

    @Nonnull
    @Override
    public List<double[]> state() {
        return Arrays.asList(getWeights().getData());
    }

}