com.simiacryptus.mindseye.layers.java.MaxPoolingLayer.java Source code

Java tutorial

Introduction

Here is the source code for com.simiacryptus.mindseye.layers.java.MaxPoolingLayer.java

Source

/*
 * Copyright (c) 2018 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.layers.java;

import com.google.gson.JsonObject;
import com.simiacryptus.lang.Tuple2;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.util.JsonUtil;
import com.simiacryptus.util.Util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.Function;
import java.util.function.IntToDoubleFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * The type Max subsample key.
 */
@SuppressWarnings("serial")
public class MaxPoolingLayer extends LayerBase {

    private static final Function<MaxPoolingLayer.CalcRegionsParameter, List<Tuple2<Integer, int[]>>> calcRegionsCache = Util
            .cache(MaxPoolingLayer::calcRegions);
    @SuppressWarnings("unused")
    private static final Logger log = LoggerFactory.getLogger(MaxPoolingLayer.class);
    private int[] kernelDims;

    /**
     * Instantiates a new Max subsample key.
     */
    protected MaxPoolingLayer() {
        super();
    }

    /**
     * Instantiates a new Max subsample key.
     *
     * @param kernelDims the kernel dims
     */
    public MaxPoolingLayer(@Nonnull final int... kernelDims) {

        this.kernelDims = Arrays.copyOf(kernelDims, kernelDims.length);
    }

    /**
     * Instantiates a new Max subsample key.
     *
     * @param id         the id
     * @param kernelDims the kernel dims
     */
    protected MaxPoolingLayer(@Nonnull final JsonObject id, @Nonnull final int... kernelDims) {
        super(id);
        this.kernelDims = Arrays.copyOf(kernelDims, kernelDims.length);
    }

    private static List<Tuple2<Integer, int[]>> calcRegions(@Nonnull final MaxPoolingLayer.CalcRegionsParameter p) {
        @Nonnull
        final Tensor input = new Tensor(p.inputDims);
        final int[] newDims = IntStream.range(0, p.inputDims.length).map(i -> {
            //assert 0 == p.inputDims[i] % p.kernelDims[i];
            return (int) Math.ceil(p.inputDims[i] * 1.0 / p.kernelDims[i]);
        }).toArray();
        @Nonnull
        final Tensor output = new Tensor(newDims);

        List<Tuple2<Integer, int[]>> tuple2s = output.coordStream(true).map(o -> {
            Tensor tensor = new Tensor(p.kernelDims);
            final int[] inCoords = tensor.coordStream(true).mapToInt(kernelCoord -> {
                @Nonnull
                final int[] result = new int[o.getCoords().length];
                for (int index = 0; index < o.getCoords().length; index++) {
                    final int outputCoordinate = o.getCoords()[index];
                    final int kernelSize = p.kernelDims[index];
                    final int baseCoordinate = Math.min(outputCoordinate * kernelSize,
                            p.inputDims[index] - kernelSize);
                    final int kernelCoordinate = kernelCoord.getCoords()[index];
                    result[index] = baseCoordinate + kernelCoordinate;
                }
                return input.index(result);
            }).toArray();
            tensor.freeRef();
            return new Tuple2<>(o.getIndex(), inCoords);
        }).collect(Collectors.toList());
        input.freeRef();
        output.freeRef();
        return tuple2s;
    }

    /**
     * From json max subsample key.
     *
     * @param json the json
     * @param rs   the rs
     * @return the max subsample key
     */
    public static MaxPoolingLayer fromJson(@Nonnull final JsonObject json, Map<CharSequence, byte[]> rs) {
        return new MaxPoolingLayer(json, JsonUtil.getIntArray(json.getAsJsonArray("heapCopy")));
    }

    @Nonnull
    @Override
    public Result eval(@Nonnull final Result... inObj) {

        Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());

        final Result in = inObj[0];
        in.getData().length();

        @Nonnull
        final int[] inputDims = in.getData().getDimensions();
        final List<Tuple2<Integer, int[]>> regions = MaxPoolingLayer.calcRegionsCache
                .apply(new MaxPoolingLayer.CalcRegionsParameter(inputDims, kernelDims));
        final Tensor[] outputA = IntStream.range(0, in.getData().length()).mapToObj(dataIndex -> {
            final int[] newDims = IntStream.range(0, inputDims.length).map(i -> {
                return (int) Math.ceil(inputDims[i] * 1.0 / kernelDims[i]);
            }).toArray();
            @Nonnull
            final Tensor output = new Tensor(newDims);
            return output;
        }).toArray(i -> new Tensor[i]);
        Arrays.stream(outputA).mapToInt(x -> x.length()).sum();
        @Nonnull
        final int[][] gradientMapA = new int[in.getData().length()][];
        IntStream.range(0, in.getData().length()).forEach(dataIndex -> {
            @Nullable
            final Tensor input = in.getData().get(dataIndex);
            final Tensor output = outputA[dataIndex];
            @Nonnull
            final IntToDoubleFunction keyExtractor = inputCoords -> input.get(inputCoords);
            @Nonnull
            final int[] gradientMap = new int[input.length()];
            regions.parallelStream().forEach(tuple -> {
                final Integer from = tuple.getFirst();
                final int[] toList = tuple.getSecond();
                int toMax = -1;
                double bestValue = Double.NEGATIVE_INFINITY;
                for (final int c : toList) {
                    final double value = keyExtractor.applyAsDouble(c);
                    if (-1 == toMax || bestValue < value) {
                        bestValue = value;
                        toMax = c;
                    }
                }
                gradientMap[from] = toMax;
                output.set(from, input.get(toMax));
            });
            input.freeRef();
            gradientMapA[dataIndex] = gradientMap;
        });
        return new Result(TensorArray.wrap(outputA),
                (@Nonnull final DeltaSet<UUID> buffer, @Nonnull final TensorList data) -> {
                    if (in.isAlive()) {
                        @Nonnull
                        TensorArray tensorArray = TensorArray
                                .wrap(IntStream.range(0, in.getData().length()).parallel().mapToObj(dataIndex -> {
                                    @Nonnull
                                    final Tensor backSignal = new Tensor(inputDims);
                                    final int[] ints = gradientMapA[dataIndex];
                                    @Nullable
                                    final Tensor datum = data.get(dataIndex);
                                    for (int i = 0; i < datum.length(); i++) {
                                        backSignal.add(ints[i], datum.get(i));
                                    }
                                    datum.freeRef();
                                    return backSignal;
                                }).toArray(i -> new Tensor[i]));
                        in.accumulate(buffer, tensorArray);
                    }
                }) {

            @Override
            protected void _free() {
                Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
            }

            @Override
            public boolean isAlive() {
                return in.isAlive();
            }
        };
    }

    @Nonnull
    @Override
    public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
        @Nonnull
        final JsonObject json = super.getJsonStub();
        json.add("heapCopy", JsonUtil.getJson(kernelDims));
        return json;
    }

    @Nonnull
    @Override
    public List<double[]> state() {
        return Arrays.asList();
    }

    /**
     * The type Calc regions parameter.
     */
    public static class CalcRegionsParameter {
        /**
         * The Input dims.
         */
        public int[] inputDims;
        /**
         * The Kernel dims.
         */
        public int[] kernelDims;

        /**
         * Instantiates a new Calc regions parameter.
         *
         * @param inputDims  the input dims
         * @param kernelDims the kernel dims
         */
        public CalcRegionsParameter(final int[] inputDims, final int[] kernelDims) {
            this.inputDims = inputDims;
            this.kernelDims = kernelDims;
        }

        @Override
        public boolean equals(@Nullable final Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (getClass() != obj.getClass()) {
                return false;
            }
            @Nonnull
            final MaxPoolingLayer.CalcRegionsParameter other = (MaxPoolingLayer.CalcRegionsParameter) obj;
            if (!Arrays.equals(inputDims, other.inputDims)) {
                return false;
            }
            return Arrays.equals(kernelDims, other.kernelDims);
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + Arrays.hashCode(inputDims);
            result = prime * result + Arrays.hashCode(kernelDims);
            return result;
        }

    }
}