com.simiacryptus.mindseye.layers.cudnn.ImgLinearSubnetLayer.java Source code

Java tutorial

Introduction

Here is the source code for com.simiacryptus.mindseye.layers.cudnn.ImgLinearSubnetLayer.java

Source

/*
 * Copyright (c) 2018 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.layers.cudnn;

import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.cudnn.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
 * This key works as a scaling function, similar to a father wavelet. Allows convolutional and pooling layers to work
 * across larger png regions.
 */
@SuppressWarnings("serial")
public class ImgLinearSubnetLayer extends LayerBase implements MultiPrecision<ImgLinearSubnetLayer> {

    private static final Logger logger = LoggerFactory.getLogger(ImgLinearSubnetLayer.class);
    private final List<SubnetLeg> legs = new ArrayList<>();
    private Precision precision = Precision.Double;
    private boolean parallel = true;

    /**
     * Instantiates a new Rescaled subnet key.
     */
    public ImgLinearSubnetLayer() {
        super();
    }

    /**
     * Instantiates a new Rescaled subnet key.
     *
     * @param json the json
     * @param rs   the rs
     */
    protected ImgLinearSubnetLayer(@Nonnull final JsonObject json, Map<CharSequence, byte[]> rs) {
        super(json);
        this.precision = Precision.valueOf(json.getAsJsonPrimitive("precision").getAsString());
        setParallel(json.get("parallel").getAsBoolean());
        JsonArray jsonArray = json.get("legs").getAsJsonArray();
        for (int i = 0; i < jsonArray.size(); i++) {
            legs.add(new SubnetLeg(jsonArray.get(i).getAsJsonObject(), rs));
        }
    }

    /**
     * From json rescaled subnet key.
     *
     * @param json the json
     * @param rs   the rs
     * @return the rescaled subnet key
     */
    public static ImgLinearSubnetLayer fromJson(@Nonnull final JsonObject json, Map<CharSequence, byte[]> rs) {
        return new ImgLinearSubnetLayer(json, rs);
    }

    /**
     * Gets legs.
     *
     * @return the legs
     */
    public List<SubnetLeg> getLegs() {
        return legs;
    }

    /**
     * Add img linear subnet key.
     *
     * @param from  the from
     * @param to    the to
     * @param layer the key
     * @return the img linear subnet key
     */
    public ImgLinearSubnetLayer add(final int from, final int to, final Layer layer) {
        getLegs().add(new SubnetLeg(layer, from, to));
        return this;
    }

    @Override
    protected void _free() {
        super._free();
        legs.stream().forEach(ReferenceCounting::freeRef);
    }

    @Nullable
    @Override
    public Result evalAndFree(@Nonnull final Result... inObj) {
        assert 1 == inObj.length;
        Result input = inObj[0];
        TensorList inputData = input.getData();
        @Nonnull
        final int[] inputDims = inputData.getDimensions();
        assert 3 == inputDims.length;
        int length = inputData.length();
        int maxBand = legs.stream().mapToInt(x -> x.toBand).max().getAsInt();
        assert maxBand == inputDims[2] : maxBand + " != " + inputDims[2];
        assert IntStream.range(0, maxBand)
                .allMatch(i -> 1 == legs.stream().filter(x -> x.fromBand <= i && x.toBand > i).count());
        CudaTensor passback = CudaSystem.run(gpu -> {
            return CudaTensor.wrap(gpu.allocate(inputData.getElements() * precision.size, MemoryType.Device, true),
                    gpu.newTensorDescriptor(precision, length, inputDims[2], inputDims[1], inputDims[0]),
                    precision);
        });
        try {
            AtomicInteger counter = new AtomicInteger(0);
            SumInputsLayer sumInputsLayer = new SumInputsLayer();
            try {
                Result[] legResults = legs.stream().map(leg -> {
                    passback.addRef();
                    ImgBandSelectLayer imgBandSelectLayer = new ImgBandSelectLayer(leg.fromBand, leg.toBand);
                    input.addRef();
                    TensorList legData = imgBandSelectLayer.eval(input).getDataAndFree();
                    imgBandSelectLayer.freeRef();
                    return leg.inner.evalAndFree(new Result(legData, (DeltaSet<UUID> ctx, TensorList delta) -> {
                        int[] outputDimensions = delta.getDimensions();
                        int[] inputDimensions = inputDims;
                        synchronized (passback) {
                            CudaSystem.run(gpu -> {
                                @Nonnull
                                final CudaDevice.CudaTensorDescriptor viewDescriptor = gpu.newTensorDescriptor(
                                        precision, length, outputDimensions[2], outputDimensions[1],
                                        outputDimensions[0], //
                                        inputDimensions[2] * inputDimensions[1] * inputDimensions[0], //
                                        inputDimensions[1] * inputDimensions[0], //
                                        inputDimensions[0], //
                                        1);
                                final int byteOffset = viewDescriptor.cStride * leg.fromBand * precision.size;
                                assert delta.length() == inputData.length();
                                assert passback.getDeviceId() == gpu.getDeviceId();
                                //assert error.stream().flatMapToDouble(x-> Arrays.stream(x.getData())).allMatch(Double::isFinite);
                                @Nullable
                                final CudaTensor deltaTensor = gpu.getTensor(delta, precision, MemoryType.Device,
                                        true);
                                @Nonnull
                                final CudaMemory passbackBuffer = passback.getMemory(gpu);
                                CudaMemory errorPtrMemory = deltaTensor.getMemory(gpu);
                                passbackBuffer.synchronize();
                                gpu.cudnnTransformTensor(precision.getPointer(1.0), deltaTensor.descriptor.getPtr(),
                                        errorPtrMemory.getPtr(), precision.getPointer(0.0), viewDescriptor.getPtr(),
                                        passbackBuffer.getPtr().withByteOffset(byteOffset));
                                errorPtrMemory.dirty();
                                passbackBuffer.dirty();
                                Stream.<ReferenceCounting>of(deltaTensor, viewDescriptor, passbackBuffer,
                                        errorPtrMemory).forEach(ReferenceCounting::freeRef);
                            }, passback);
                        }
                        if (counter.incrementAndGet() >= legs.size()) {
                            counter.set(0);
                            input.accumulate(ctx, CudaTensorList.create(passback, length, inputDims, precision));
                        }
                    }) {
                        @Override
                        protected void _free() {
                            super._free();
                            input.freeRef();
                            passback.freeRef();
                        }
                    });
                }).toArray(i -> new Result[i]);
                return sumInputsLayer.setParallel(parallel).setPrecision(precision).evalAndFree(legResults);
            } finally {
                sumInputsLayer.freeRef();
                input.freeRef();
            }
        } finally {
            passback.freeRef();
        }
    }

    @Nonnull
    @Override
    public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
        @Nonnull
        final JsonObject json = super.getJsonStub();
        json.addProperty("precision", precision.name());
        json.addProperty("parallel", isParallel());
        JsonArray jsonArray = new JsonArray();
        legs.stream().map(x -> x.getJson(resources, dataSerializer)).forEach(jsonArray::add);
        json.add("legs", jsonArray);
        return json;
    }

    @Nonnull
    @Override
    public List<double[]> state() {
        return new ArrayList<>();
    }

    @Override
    public Precision getPrecision() {
        return precision;
    }

    @Nonnull
    @Override
    public ImgLinearSubnetLayer setPrecision(Precision precision) {
        this.precision = precision;
        return this;
    }

    @Nonnull
    @Override
    public Layer setFrozen(final boolean frozen) {
        legs.stream().map(x -> x.inner).forEach(x -> x.setFrozen(frozen));
        return super.setFrozen(frozen);
    }

    /**
     * Is parallel boolean.
     *
     * @return the boolean
     */
    public boolean isParallel() {
        return parallel;
    }

    /**
     * Sets parallel.
     *
     * @param parallel the parallel
     * @return the parallel
     */
    public ImgLinearSubnetLayer setParallel(boolean parallel) {
        this.parallel = parallel;
        return this;
    }

    /**
     * The type Subnet leg.
     */
    public static class SubnetLeg extends ReferenceCountingBase {

        private final Layer inner;
        private final int fromBand;
        private final int toBand;

        /**
         * Instantiates a new Subnet leg.
         *
         * @param inner    the inner
         * @param fromBand the from band
         * @param toBand   the to band
         */
        public SubnetLeg(final Layer inner, final int fromBand, final int toBand) {
            this.inner = inner;
            this.fromBand = fromBand;
            this.toBand = toBand;
            this.inner.addRef();
        }

        /**
         * Instantiates a new Rescaled subnet key.
         *
         * @param json the json
         * @param rs   the rs
         */
        protected SubnetLeg(@Nonnull final JsonObject json, Map<CharSequence, byte[]> rs) {
            fromBand = json.getAsJsonPrimitive("fromBand").getAsInt();
            toBand = json.getAsJsonPrimitive("toBand").getAsInt();
            inner = Layer.fromJson(json.getAsJsonObject("network"), rs);
        }

        @Override
        protected void _free() {
            super._free();
            inner.freeRef();
        }

        /**
         * Gets json.
         *
         * @param resources      the resources
         * @param dataSerializer the data serializer
         * @return the json
         */
        @Nonnull
        public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
            @Nonnull
            final JsonObject json = new JsonObject();
            json.addProperty("fromBand", fromBand);
            json.addProperty("toBand", toBand);
            json.add("network", inner.getJson(resources, dataSerializer));
            return json;
        }

    }
}