com.simiacryptus.mindseye.layers.java.AvgPoolingLayer.java Source code

Java tutorial

Introduction

Here is the source code for com.simiacryptus.mindseye.layers.java.AvgPoolingLayer.java

Source

/*
 * Copyright (c) 2018 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.layers.java;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.util.JsonUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * A local-pooling key which sets all elements to the average value.
 */
@SuppressWarnings("serial")
public class AvgPoolingLayer extends LayerBase {

    /**
     * The constant indexMapCache.
     */
    public static final LoadingCache<AvgPoolingLayer.IndexMapKey, Map<Coordinate, List<int[]>>> indexMapCache = CacheBuilder
            .newBuilder().build(new LayerCacheLoader());
    @SuppressWarnings("unused")
    private static final Logger log = LoggerFactory.getLogger(AvgPoolingLayer.class);
    private int[] kernelDims;

    /**
     * Instantiates a new Avg subsample key.
     */
    protected AvgPoolingLayer() {
        super();
    }

    /**
     * Instantiates a new Avg subsample key.
     *
     * @param kernelDims the kernel dims
     */
    public AvgPoolingLayer(@Nonnull final int... kernelDims) {

        this.kernelDims = Arrays.copyOf(kernelDims, kernelDims.length);
    }

    /**
     * Instantiates a new Avg subsample key.
     *
     * @param id         the id
     * @param kernelDims the kernel dims
     */
    protected AvgPoolingLayer(@Nonnull final JsonObject id, @Nonnull final int... kernelDims) {
        super(id);
        this.kernelDims = Arrays.copyOf(kernelDims, kernelDims.length);
    }

    /**
     * From json avg subsample key.
     *
     * @param json the json
     * @param rs   the rs
     * @return the avg subsample key
     */
    public static AvgPoolingLayer fromJson(@Nonnull final JsonObject json, Map<CharSequence, byte[]> rs) {
        return new AvgPoolingLayer(json, JsonUtil.getIntArray(json.getAsJsonArray("heapCopy")));
    }

    private static synchronized Map<Coordinate, List<int[]>> getCoordMap(final int[] kernelDims,
            final int[] outDims) {
        try {
            return AvgPoolingLayer.indexMapCache.get(new AvgPoolingLayer.IndexMapKey(kernelDims, outDims));
        } catch (@Nonnull final ExecutionException e) {
            throw new RuntimeException(e);
        }
    }

    @Nonnull
    @SuppressWarnings("unchecked")
    @Override
    public Result eval(@Nonnull final Result... inObj) {
        final int kernelSize = Tensor.length(kernelDims);
        final TensorList data = inObj[0].getData();
        @Nonnull
        final int[] inputDims = data.getDimensions();
        final int[] newDims = IntStream.range(0, inputDims.length).map(i -> {
            assert 0 == inputDims[i] % kernelDims[i] : inputDims[i] + ":" + kernelDims[i];
            return inputDims[i] / kernelDims[i];
        }).toArray();
        final Map<Coordinate, List<int[]>> coordMap = AvgPoolingLayer.getCoordMap(kernelDims, newDims);
        final Tensor[] outputValues = IntStream.range(0, data.length()).mapToObj(dataIndex -> {
            @Nullable
            final Tensor input = data.get(dataIndex);
            @Nonnull
            final Tensor output = new Tensor(newDims);
            for (@Nonnull
            final Entry<Coordinate, List<int[]>> entry : coordMap.entrySet()) {
                double sum = entry.getValue().stream().mapToDouble(inputCoord -> input.get(inputCoord)).sum();
                if (Double.isFinite(sum)) {
                    output.add(entry.getKey(), sum / kernelSize);
                }
            }
            input.freeRef();
            return output;
        }).toArray(i -> new Tensor[i]);
        Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
        return new Result(TensorArray.wrap(outputValues),
                (@Nonnull final DeltaSet<UUID> buffer, @Nonnull final TensorList delta) -> {
                    if (inObj[0].isAlive()) {
                        final Tensor[] passback = IntStream.range(0, delta.length()).mapToObj(dataIndex -> {
                            @Nullable
                            Tensor tensor = delta.get(dataIndex);
                            @Nonnull
                            final Tensor backSignal = new Tensor(inputDims);
                            for (@Nonnull
                            final Entry<Coordinate, List<int[]>> outputMapping : coordMap.entrySet()) {
                                final double outputValue = tensor.get(outputMapping.getKey());
                                for (@Nonnull
                                final int[] inputCoord : outputMapping.getValue()) {
                                    backSignal.add(inputCoord, outputValue / kernelSize);
                                }
                            }
                            tensor.freeRef();
                            return backSignal;
                        }).toArray(i -> new Tensor[i]);
                        @Nonnull
                        TensorArray tensorArray = TensorArray.wrap(passback);
                        inObj[0].accumulate(buffer, tensorArray);
                    }
                }) {

            @Override
            protected void _free() {
                Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
            }

            @Override
            public boolean isAlive() {
                return inObj[0].isAlive();
            }
        };
    }

    @Nonnull
    @Override
    public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
        @Nonnull
        final JsonObject json = super.getJsonStub();
        json.add("heapCopy", JsonUtil.getJson(kernelDims));
        return json;
    }

    @Nonnull
    @Override
    public List<double[]> state() {
        return Arrays.asList();
    }

    /**
     * The type Index buildMap key.
     */
    public static final class IndexMapKey {
        /**
         * The Kernel.
         */
        int[] kernel;
        /**
         * The Output.
         */
        int[] output;

        /**
         * Instantiates a new Index buildMap key.
         *
         * @param kernel the kernel
         * @param output the output
         */
        public IndexMapKey(final int[] kernel, final int[] output) {
            super();
            this.kernel = kernel;
            this.output = output;
        }

        /**
         * Instantiates a new Index buildMap key.
         *
         * @param kernel the kernel
         * @param input  the input
         * @param output the output
         */
        public IndexMapKey(@Nonnull final Tensor kernel, final Tensor input, @Nonnull final Tensor output) {
            super();
            this.kernel = kernel.getDimensions();
            this.output = output.getDimensions();
        }

        @Override
        public boolean equals(@Nullable final Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (getClass() != obj.getClass()) {
                return false;
            }
            @Nullable
            final AvgPoolingLayer.IndexMapKey other = (AvgPoolingLayer.IndexMapKey) obj;
            if (!Arrays.equals(kernel, other.kernel)) {
                return false;
            }
            return Arrays.equals(output, other.output);
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + Arrays.hashCode(kernel);
            result = prime * result + Arrays.hashCode(output);
            return result;
        }
    }

    private static class LayerCacheLoader extends CacheLoader<IndexMapKey, Map<Coordinate, List<int[]>>> {
        @Override
        public Map<Coordinate, List<int[]>> load(final IndexMapKey key) {
            final int[] ksize = key.kernel;
            Tensor tensor = new Tensor(key.output);
            final Map<Coordinate, List<int[]>> coordMap = tensor.coordStream(true)
                    .collect(Collectors.toMap(o -> o, o -> {
                        @Nonnull
                        Tensor blank = new Tensor(ksize);
                        List<int[]> collect = blank.coordStream(true).map(kernelCoord -> {
                            int[] coords = o.getCoords();
                            @Nonnull
                            final int[] r = new int[coords.length];
                            for (int i = 0; i < coords.length; i++) {
                                r[i] = coords[i] * ksize[i] + kernelCoord.getCoords()[i];
                            }
                            return r;
                        }).collect(Collectors.toList());
                        blank.freeRef();
                        return collect;
                    }));
            tensor.freeRef();
            return coordMap;
        }
    }
}