com.simiacryptus.mindseye.layers.cudnn.ImgTileSubnetLayer.java Source code

Java tutorial

Introduction

Here is the source code for com.simiacryptus.mindseye.layers.cudnn.ImgTileSubnetLayer.java

Source

/*
 * Copyright (c) 2018 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.layers.cudnn;

import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.cudnn.*;
import com.simiacryptus.mindseye.layers.java.WrapperLayer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * This key works as a scaling function, similar to a father wavelet. Allows convolutional and pooling layers to work
 * across larger png regions.
 */
@SuppressWarnings("serial")
public class ImgTileSubnetLayer extends WrapperLayer implements MultiPrecision<ImgTileSubnetLayer> {

    private static final Logger logger = LoggerFactory.getLogger(ImgTileSubnetLayer.class);
    private final int height;
    private final int width;
    private final int strideX;
    private final int strideY;
    private Precision precision = Precision.Double;
    private boolean parallel = true;

    /**
     * Instantiates a new Rescaled subnet key.
     *
     * @param subnetwork the subnetwork
     * @param width      the width
     * @param height     the scale
     * @param strideX    the stride x
     * @param strideY    the stride y
     */
    public ImgTileSubnetLayer(final Layer subnetwork, final int width, final int height, final int strideX,
            final int strideY) {
        super(subnetwork);
        this.height = height;
        this.width = width;
        this.strideX = strideX;
        this.strideY = strideY;
    }

    /**
     * Instantiates a new Img tile subnet key.
     *
     * @param subnetwork the subnetwork
     * @param width      the width
     * @param height     the height
     */
    public ImgTileSubnetLayer(final Layer subnetwork, final int width, final int height) {
        this(subnetwork, width, height, width, height);
    }

    /**
     * Instantiates a new Rescaled subnet key.
     *
     * @param json the json
     * @param rs   the rs
     */
    protected ImgTileSubnetLayer(@Nonnull final JsonObject json, Map<CharSequence, byte[]> rs) {
        super(json, rs);
        this.precision = Precision.valueOf(json.getAsJsonPrimitive("precision").getAsString());
        height = json.getAsJsonPrimitive("height").getAsInt();
        width = json.getAsJsonPrimitive("width").getAsInt();
        strideX = json.getAsJsonPrimitive("strideX").getAsInt();
        strideY = json.getAsJsonPrimitive("strideY").getAsInt();
        this.parallel = json.get("parallel").getAsBoolean();
    }

    /**
     * From json rescaled subnet key.
     *
     * @param json the json
     * @param rs   the rs
     * @return the rescaled subnet key
     */
    public static ImgTileSubnetLayer fromJson(@Nonnull final JsonObject json, Map<CharSequence, byte[]> rs) {
        return new ImgTileSubnetLayer(json, rs);
    }

    @Override
    protected void _free() {
        super._free();
    }

    @Nullable
    @Override
    public Result evalAndFree(@Nonnull final Result... inObj) {
        assert 1 == inObj.length;
        Result input = inObj[0];
        TensorList inputData = input.getData();
        @Nonnull
        final int[] inputDims = inputData.getDimensions();
        assert 3 == inputDims.length;
        int bands = inputDims[2];
        int length = inputData.length();
        CudaTensor passback = CudaSystem.run(gpu -> {
            return CudaTensor.wrap(gpu.allocate(inputData.getElements() * precision.size, MemoryType.Managed, true),
                    gpu.newTensorDescriptor(precision, length, inputDims[2], inputDims[1], inputDims[0]),
                    precision);
        });
        try {
            AtomicInteger counter = new AtomicInteger(0);
            int cols = (int) (Math.ceil((inputDims[0] - width) * 1.0 / strideX) + 1);
            int rows = (int) (Math.ceil((inputDims[1] - height) * 1.0 / strideY) + 1);
            if (cols == 1 && rows == 1)
                return getInner().evalAndFree(inObj);
            int[] tileDimensions = { width, height, bands };
            Result[][] tileResults = new Result[rows][];
            for (int row = 0; row < rows; row++) {
                tileResults[row] = new Result[cols];
                for (int col = 0; col < cols; col++) {
                    int positionX = col * strideX;
                    int positionY = row * strideY;
                    assert positionX >= 0;
                    assert positionY >= 0;
                    assert positionX < inputDims[0];
                    assert positionY < inputDims[1];

                    CudaTensor tile = CudaSystem.run(gpu -> {
                        return ImgTileSelectLayer.copy(gpu, inputData, inputData.getDimensions(), tileDimensions,
                                precision, positionX, positionY, true);
                    });

                    passback.addRef();
                    tileResults[row][col] = getInner()
                            .evalAndFree(new Result(CudaTensorList.wrap(tile, length, tileDimensions, precision),
                                    (DeltaSet<UUID> ctx, TensorList delta) -> {
                                        CudaSystem.run(gpu -> {
                                            ImgTileSelectLayer.copy(gpu, delta, tileDimensions, -positionX,
                                                    -positionY, precision, passback).freeRef();
                                        });
                                        if (counter.incrementAndGet() >= rows * cols) {
                                            counter.set(0);
                                            input.accumulate(ctx,
                                                    CudaTensorList.create(passback, length, inputDims, precision));
                                        }
                                    }) {
                                @Override
                                protected void _free() {
                                    super._free();
                                    passback.freeRef();
                                }
                            });
                }
            }
            inputData.freeRef();
            logger.debug(
                    String.format("Broke input %s into %s rows, %s cols", Arrays.toString(inputDims), rows, cols));
            Result result = new ImgTileAssemblyLayer(cols, rows).setParallel(parallel).setPrecision(precision)
                    .evalAndFree(Arrays.stream(tileResults).flatMap(Arrays::stream).toArray(i -> new Result[i]));
            return new Result(result.getData(), (ctx, delta) -> {
                result.accumulate(ctx, delta);
            }) {

                @Override
                public void accumulate(final DeltaSet<UUID> buffer, final TensorList delta) {
                    getAccumulator().accept(buffer, delta);
                }

                @Override
                protected void _free() {
                    super._free();
                    result.freeRef();
                    input.freeRef();
                }
            };
        } finally {
            passback.freeRef();
        }
    }

    @Nonnull
    @Override
    public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
        @Nonnull
        final JsonObject json = super.getJson(resources, dataSerializer);
        json.addProperty("height", height);
        json.addProperty("width", width);
        json.addProperty("strideX", strideX);
        json.addProperty("strideY", strideY);
        json.addProperty("precision", precision.name());
        json.addProperty("parallel", isParallel());
        return json;
    }

    @Nonnull
    @Override
    public List<double[]> state() {
        return new ArrayList<>();
    }

    @Override
    public Precision getPrecision() {
        return precision;
    }

    @Nonnull
    @Override
    public ImgTileSubnetLayer setPrecision(Precision precision) {
        this.precision = precision;
        return this;
    }

    @Nonnull
    @Override
    public Layer setFrozen(final boolean frozen) {
        getInner().setFrozen(frozen);
        return super.setFrozen(frozen);
    }

    /**
     * Is parallel boolean.
     *
     * @return the boolean
     */
    public boolean isParallel() {
        return parallel;
    }

    /**
     * Sets parallel.
     *
     * @param parallel the parallel
     * @return the parallel
     */
    public ImgTileSubnetLayer setParallel(boolean parallel) {
        this.parallel = parallel;
        return this;
    }
}