org.apache.sysml.runtime.matrix.data.LibMatrixCuDNNRnnAlgorithm.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.sysml.runtime.matrix.data.LibMatrixCuDNNRnnAlgorithm.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.matrix.data;

import static jcuda.jcudnn.JCudnn.cudnnCreateFilterDescriptor;
import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
import static jcuda.jcudnn.JCudnn.cudnnDestroyFilterDescriptor;
import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
import static jcuda.jcudnn.JCudnn.cudnnSetTensorNdDescriptor;
import static jcuda.jcudnn.JCudnn.cudnnDestroyDropoutDescriptor;
import static jcuda.jcudnn.JCudnn.cudnnDestroyRNNDescriptor;
import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
import static jcuda.jcudnn.JCudnn.cudnnCreateRNNDescriptor;
import static jcuda.jcudnn.cudnnRNNInputMode.CUDNN_LINEAR_INPUT;
import static jcuda.jcudnn.cudnnDirectionMode.CUDNN_UNIDIRECTIONAL;
import static jcuda.jcudnn.cudnnRNNAlgo.CUDNN_RNN_ALGO_STANDARD;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;

import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnDropoutDescriptor;
import jcuda.jcudnn.cudnnFilterDescriptor;
import jcuda.jcudnn.cudnnRNNDescriptor;
import jcuda.jcudnn.cudnnTensorDescriptor;

public class LibMatrixCuDNNRnnAlgorithm implements java.lang.AutoCloseable {
    private static final Log LOG = LogFactory.getLog(LibMatrixCuDNNRnnAlgorithm.class.getName());
    GPUContext gCtx;
    String instName;
    cudnnDropoutDescriptor dropoutDesc;
    cudnnRNNDescriptor rnnDesc;
    cudnnTensorDescriptor[] xDesc, dxDesc, yDesc, dyDesc; // of length T
    cudnnTensorDescriptor hxDesc, cxDesc, hyDesc, cyDesc, dhxDesc, dcxDesc, dhyDesc, dcyDesc;
    cudnnFilterDescriptor wDesc;
    cudnnFilterDescriptor dwDesc;
    long sizeInBytes;
    Pointer workSpace;
    long reserveSpaceSizeInBytes;
    Pointer reserveSpace;
    long dropOutSizeInBytes;
    Pointer dropOutStateSpace;

    public LibMatrixCuDNNRnnAlgorithm(ExecutionContext ec, GPUContext gCtx, String instName, String rnnMode, int N,
            int T, int M, int D, boolean isTraining, Pointer w) throws DMLRuntimeException {
        this.gCtx = gCtx;
        this.instName = instName;

        // Allocate input/output descriptors
        xDesc = new cudnnTensorDescriptor[T];
        dxDesc = new cudnnTensorDescriptor[T];
        yDesc = new cudnnTensorDescriptor[T];
        dyDesc = new cudnnTensorDescriptor[T];
        for (int t = 0; t < T; t++) {
            xDesc[t] = allocateTensorDescriptorWithStride(N, D, 1);
            dxDesc[t] = allocateTensorDescriptorWithStride(N, D, 1);
            yDesc[t] = allocateTensorDescriptorWithStride(N, M, 1);
            dyDesc[t] = allocateTensorDescriptorWithStride(N, M, 1);
        }
        hxDesc = allocateTensorDescriptorWithStride(1, N, M);
        dhxDesc = allocateTensorDescriptorWithStride(1, N, M);
        cxDesc = allocateTensorDescriptorWithStride(1, N, M);
        dcxDesc = allocateTensorDescriptorWithStride(1, N, M);
        hyDesc = allocateTensorDescriptorWithStride(1, N, M);
        dhyDesc = allocateTensorDescriptorWithStride(1, N, M);
        cyDesc = allocateTensorDescriptorWithStride(1, N, M);
        dcyDesc = allocateTensorDescriptorWithStride(1, N, M);

        // Initial dropout descriptor
        dropoutDesc = new cudnnDropoutDescriptor();
        JCudnn.cudnnCreateDropoutDescriptor(dropoutDesc);
        long[] _dropOutSizeInBytes = { -1 };
        JCudnn.cudnnDropoutGetStatesSize(gCtx.getCudnnHandle(), _dropOutSizeInBytes);
        dropOutSizeInBytes = _dropOutSizeInBytes[0];
        dropOutStateSpace = new Pointer();
        if (dropOutSizeInBytes != 0) {
            if (LOG.isDebugEnabled())
                LOG.debug("Allocating " + dropOutSizeInBytes + " bytes for lstm dropout space.");
            dropOutStateSpace = gCtx.allocate(instName, dropOutSizeInBytes);
        }
        JCudnn.cudnnSetDropoutDescriptor(dropoutDesc, gCtx.getCudnnHandle(), 0, dropOutStateSpace,
                dropOutSizeInBytes, 12345);

        // Initialize RNN descriptor
        rnnDesc = new cudnnRNNDescriptor();
        cudnnCreateRNNDescriptor(rnnDesc);
        JCudnn.cudnnSetRNNDescriptor_v6(gCtx.getCudnnHandle(), rnnDesc, M, 1, dropoutDesc, CUDNN_LINEAR_INPUT,
                CUDNN_UNIDIRECTIONAL, getCuDNNRnnMode(rnnMode), CUDNN_RNN_ALGO_STANDARD,
                LibMatrixCUDA.CUDNN_DATA_TYPE);

        // Allocate filter descriptor
        int expectedNumWeights = getExpectedNumWeights();
        if (rnnMode.equalsIgnoreCase("lstm") && (D + M + 2) * 4 * M != expectedNumWeights) {
            throw new DMLRuntimeException("Incorrect number of RNN parameters " + (D + M + 2) * 4 * M + " != "
                    + expectedNumWeights + ", where numFeatures=" + D + ", hiddenSize=" + M);
        }
        wDesc = allocateFilterDescriptor(expectedNumWeights);
        dwDesc = allocateFilterDescriptor(expectedNumWeights);

        // Setup workspace
        workSpace = new Pointer();
        reserveSpace = new Pointer();
        sizeInBytes = getWorkspaceSize(T);
        if (sizeInBytes != 0) {
            if (LOG.isDebugEnabled())
                LOG.debug("Allocating " + sizeInBytes + " bytes for lstm workspace.");
            workSpace = gCtx.allocate(instName, sizeInBytes);
        }
        reserveSpaceSizeInBytes = 0;
        if (isTraining) {
            reserveSpaceSizeInBytes = getReservespaceSize(T);
            if (reserveSpaceSizeInBytes != 0) {
                if (LOG.isDebugEnabled())
                    LOG.debug("Allocating " + reserveSpaceSizeInBytes + " bytes for lstm reserve space.");
                reserveSpace = gCtx.allocate(instName, reserveSpaceSizeInBytes);
            }
        }
    }

    @SuppressWarnings("unused")
    private int getNumLinearLayers(String rnnMode) throws DMLRuntimeException {
        int ret = 0;
        if (rnnMode.equalsIgnoreCase("rnn_relu") || rnnMode.equalsIgnoreCase("rnn_tanh")) {
            ret = 2;
        } else if (rnnMode.equalsIgnoreCase("lstm")) {
            ret = 8;
        } else if (rnnMode.equalsIgnoreCase("gru")) {
            ret = 6;
        } else {
            throw new DMLRuntimeException("Unsupported rnn mode:" + rnnMode);
        }
        return ret;
    }

    private long getWorkspaceSize(int seqLength) {
        long[] sizeInBytesArray = new long[1];
        JCudnn.cudnnGetRNNWorkspaceSize(gCtx.getCudnnHandle(), rnnDesc, seqLength, xDesc, sizeInBytesArray);
        return sizeInBytesArray[0];
    }

    private long getReservespaceSize(int seqLength) {
        long[] sizeInBytesArray = new long[1];
        JCudnn.cudnnGetRNNTrainingReserveSize(gCtx.getCudnnHandle(), rnnDesc, seqLength, xDesc, sizeInBytesArray);
        return sizeInBytesArray[0];
    }

    private int getCuDNNRnnMode(String rnnMode) throws DMLRuntimeException {
        int rnnModeVal = -1;
        if (rnnMode.equalsIgnoreCase("rnn_relu")) {
            rnnModeVal = jcuda.jcudnn.cudnnRNNMode.CUDNN_RNN_RELU;
        } else if (rnnMode.equalsIgnoreCase("rnn_tanh")) {
            rnnModeVal = jcuda.jcudnn.cudnnRNNMode.CUDNN_RNN_TANH;
        } else if (rnnMode.equalsIgnoreCase("lstm")) {
            rnnModeVal = jcuda.jcudnn.cudnnRNNMode.CUDNN_LSTM;
        } else if (rnnMode.equalsIgnoreCase("gru")) {
            rnnModeVal = jcuda.jcudnn.cudnnRNNMode.CUDNN_GRU;
        } else {
            throw new DMLRuntimeException("Unsupported rnn mode:" + rnnMode);
        }
        return rnnModeVal;
    }

    private int getExpectedNumWeights() throws DMLRuntimeException {
        long[] weightSizeInBytesArray = { -1 }; // (D+M+2)*4*M
        JCudnn.cudnnGetRNNParamsSize(gCtx.getCudnnHandle(), rnnDesc, xDesc[0], weightSizeInBytesArray,
                LibMatrixCUDA.CUDNN_DATA_TYPE);
        // check if (D+M+2)*4M == weightsSize / sizeof(dataType) where weightsSize is given by 'cudnnGetRNNParamsSize'.
        return LibMatrixCUDA.toInt(weightSizeInBytesArray[0] / LibMatrixCUDA.sizeOfDataType);
    }

    private cudnnFilterDescriptor allocateFilterDescriptor(int numWeights) {
        cudnnFilterDescriptor filterDesc = new cudnnFilterDescriptor();
        cudnnCreateFilterDescriptor(filterDesc);
        JCudnn.cudnnSetFilterNdDescriptor(filterDesc, LibMatrixCUDA.CUDNN_DATA_TYPE, CUDNN_TENSOR_NCHW, 3,
                new int[] { numWeights, 1, 1 });
        return filterDesc;
    }

    private static cudnnTensorDescriptor allocateTensorDescriptorWithStride(int firstDim, int secondDim,
            int thirdDim) throws DMLRuntimeException {
        cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
        cudnnCreateTensorDescriptor(tensorDescriptor);
        int[] dimA = new int[] { firstDim, secondDim, thirdDim };
        int[] strideA = new int[] { dimA[2] * dimA[1], dimA[2], 1 };
        cudnnSetTensorNdDescriptor(tensorDescriptor, LibMatrixCUDA.CUDNN_DATA_TYPE, 3, dimA, strideA);
        return tensorDescriptor;
    }

    @Override
    public void close() {
        if (dropoutDesc != null)
            cudnnDestroyDropoutDescriptor(dropoutDesc);
        dropoutDesc = null;
        if (rnnDesc != null)
            cudnnDestroyRNNDescriptor(rnnDesc);
        rnnDesc = null;
        if (hxDesc != null)
            cudnnDestroyTensorDescriptor(hxDesc);
        hxDesc = null;
        if (dhxDesc != null)
            cudnnDestroyTensorDescriptor(dhxDesc);
        dhxDesc = null;
        if (hyDesc != null)
            cudnnDestroyTensorDescriptor(hyDesc);
        hyDesc = null;
        if (dhyDesc != null)
            cudnnDestroyTensorDescriptor(dhyDesc);
        dhyDesc = null;
        if (cxDesc != null)
            cudnnDestroyTensorDescriptor(cxDesc);
        cxDesc = null;
        if (dcxDesc != null)
            cudnnDestroyTensorDescriptor(dcxDesc);
        dcxDesc = null;
        if (cyDesc != null)
            cudnnDestroyTensorDescriptor(cyDesc);
        cyDesc = null;
        if (dcyDesc != null)
            cudnnDestroyTensorDescriptor(dcyDesc);
        dcyDesc = null;
        if (wDesc != null)
            cudnnDestroyFilterDescriptor(wDesc);
        wDesc = null;
        if (dwDesc != null)
            cudnnDestroyFilterDescriptor(dwDesc);
        dwDesc = null;
        if (xDesc != null) {
            for (cudnnTensorDescriptor dsc : xDesc) {
                cudnnDestroyTensorDescriptor(dsc);
            }
            xDesc = null;
        }
        if (dxDesc != null) {
            for (cudnnTensorDescriptor dsc : dxDesc) {
                cudnnDestroyTensorDescriptor(dsc);
            }
            dxDesc = null;
        }
        if (yDesc != null) {
            for (cudnnTensorDescriptor dsc : yDesc) {
                cudnnDestroyTensorDescriptor(dsc);
            }
            yDesc = null;
        }
        if (dyDesc != null) {
            for (cudnnTensorDescriptor dsc : dyDesc) {
                cudnnDestroyTensorDescriptor(dsc);
            }
            dyDesc = null;
        }
        if (sizeInBytes != 0) {
            try {
                gCtx.cudaFreeHelper(instName, workSpace, gCtx.EAGER_CUDA_FREE);
            } catch (DMLRuntimeException e) {
                throw new RuntimeException(e);
            }
        }
        workSpace = null;
        if (reserveSpaceSizeInBytes != 0) {
            try {
                gCtx.cudaFreeHelper(instName, reserveSpace, gCtx.EAGER_CUDA_FREE);
            } catch (DMLRuntimeException e) {
                throw new RuntimeException(e);
            }
        }
        reserveSpace = null;
        if (dropOutSizeInBytes != 0) {
            try {
                gCtx.cudaFreeHelper(instName, dropOutStateSpace, gCtx.EAGER_CUDA_FREE);
            } catch (DMLRuntimeException e) {
                throw new RuntimeException(e);
            }
        }
        dropOutStateSpace = null;
    }
}