Use of cudnn rnn forwardtraining and backwardtraining

Hello,

I have some questions when I use RNN-functions of cudnnAPI.
I want to train a recurrent neural network with outputs that don’t occur at every time step. It is a many-to-one RNN. I have only one label for each sequences.

After the execution of cudnnRNNForwardTraining,

cudnnStatus_t cudnnRNNForwardTraining(
    cudnnHandle_t                   handle,

    const cudnnRNNDescriptor_t      rnnDesc,
    const int                       seqLength,
    const cudnnTensorDescriptor_t  *xDesc,
    const void                     *x,
    const cudnnTensorDescriptor_t   hxDesc,
    const void                     *hx,
    const cudnnTensorDescriptor_t   cxDesc,
    const void                     *cx,
    const cudnnFilterDescriptor_t   wDesc,
    const void                     *w,
    const cudnnTensorDescriptor_t  *yDesc,
    void                           *y,// output
    const cudnnTensorDescriptor_t   hyDesc,
    void                           *hy,
    const cudnnTensorDescriptor_t   cyDesc,
    void                           *cy,
    void                           *workspace,
    size_t                          workSpaceSizeInBytes,
    void                           *reserveSpace,
    size_t                          reserveSpaceSizeInBytes);

I get the data pointer to GPU memory associated with the output tensor descriptor yDesc. For each Iterations I get a output. But only one Label can be used.

y=[seqLength, batchsize,hiddenSize];//y dimensions
Label=[batchSize, LabelVectorDimensions];

If I use only the last Iteration of y to continue forward trainning, everything goes well. But by backward propagation,

cudnnStatus_t cudnnRNNBackwardData(
    cudnnHandle_t                   handle,
    const cudnnRNNDescriptor_t      rnnDesc,
    const int                       seqLength,
    const cudnnTensorDescriptor_t  *yDesc,
    const void                     *y,// input
    const cudnnTensorDescriptor_t  *dyDesc,
    const void                     *dy,// input
    const cudnnTensorDescriptor_t   dhyDesc,
    const void                     *dhy,
    const cudnnTensorDescriptor_t   dcyDesc,
    const void                     *dcy,
    const cudnnFilterDescriptor_t   wDesc,
    const void                     *w,
    const cudnnTensorDescriptor_t   hxDesc,
    const void                     *hx,
    const cudnnTensorDescriptor_t   cxDesc,
    const void                     *cx,
    const cudnnTensorDescriptor_t  *dxDesc,
    void                           *dx,
    const cudnnTensorDescriptor_t   dhxDesc,
    void                           *dhx,
    const cudnnTensorDescriptor_t   dcxDesc,
    void                           *dcx,
    void                           *workspace,
    size_t                          workSpaceSizeInBytes,
    const void                     *reserveSpace,
    size_t                          reserveSpaceSizeInBytes);

y and dy as input for the cudnnbackwardtraining should have the same dimensions. In NVIDIA-Docs y is the data pointer caculated by cudnnfrowardtraing. But in my Program the dimensions of y and dy are:

y=[seqLength, batchsize,hiddenSize];
dy=[batchsize,hiddenSize,1]

In order to make y and dy have same dimensions, what should I do?

Thanks in advance!

when you look at “cudnnRNNForwardTraining” in the docs (cuDNN 7.2.1 pdf )
[…]The second dimension of the tensor depends on the direction argument passed to the cudnnSetRNNDescriptor call
used to initialize rnnDesc:
‣ If direction is CUDNN_UNIDIRECTIONAL the second dimension should match
the hiddenSize argument passed to cudnnSetRNNDescriptor.
‣ If direction is CUDNN_BIDIRECTIONAL the second dimension should match
double the hiddenSize argument passed to cudnnSetRNNDescriptor[…]

In the RNN_example.cu of CUDNN 7.2 samples […]In this example dimA[1] is constant across the whole sequence. This isn’t required, all that is required is that it does not increase.[…]

So you should also define yDesc as y=[batchsize,hiddenSize,1] already on passing to cudnnRNNForwardTraining.

look at the RNN_example.cu sample; there yDesc is defined as dyDesc

Thanks for your Reply m1

Based on your RNN of MNIST dataset, RNNseqLength = 28 inputSize = 28.
Form NVIDIA doc, The Input *xDesc is an array of “sequence length” tensor descriptors. For each tensor descriptors, the dimensions are:

xDesc=[batchsize, inputSize, 1]// here inputSize = 28

But what about the input data x?I think the pointer * x should point to the array, which has sequencelengthbatchsizeinputSize numbers of pixels. So feeding x saved in this form to the cudnnRNNforwardtraining function

x=[seqLength, batchSize, inputSize]// here seqLength = 28, inputsize = 28

ANd *yDesc is also an array of tensor descriptors. Each has dimensions

yDesc=[batchSize, hiddenSize,1]

SO the output y after RNNforwardtraining points to an array, which has totally sequenceLengthbatchsize* HiddenSize numbers

y=[seqLength, batchSize, hiddenSIze]// here seqLength = 28

.

*dyDesc is an array of tensor descriptors and the dimensions are the same as *yDesc. Each descriptor:

yDesc=[batchSize, hiddenSize,1]

So here is my question, what is this dy actually?I think this dy should point to an array, which has sequenceLengthbatchsize HiddenSize elements. Maybe also saved in this form:

dy=[seqLength, batchSize, hiddenSIze]

for one sequence we only have one label. How to get dy for every sequence as input for RNNbackwardData functions?

Thanks again!

after I again read through all of this, your input seems to be right. So for test I simply now use batch size 1. (But still same bad classification)

For dy I apply “input_data”. When you look at cudnnActivationBackward, cudnnSoftmaxBackward, cudnnConvolutionBackwardBias, … there dy (format dyDesc) always receives the input data from a higher layer; On last layer its the loss data (from result from Fwd Propagation, calculated in my case on the backward softmax activation in the fullyconnected layer ).

“dy” is “An array of fully packed tensor descriptors describing the gradient at the output from each recurrent iteration”
It has the same format as “y” with these amount of elements: RNNseqLength * RNNhiddenSize * batchSize * (bidirectional ? 2 : 1),
cause its the gradient resuilt from the next layer (during back propagation).

// you can find out how they do it in CNTK:

// in RNNNodes.cpp  :
 m_transposedDOutput->DoGatherColumnsOf(0.0, *(this->m_packingIndex), this->Gradient(), 1.0);
// ...
 m_transposedOutput->RNNBackwardData(*m_transposedDOutput, paramW, *m_transposedDInput, m_rnnAttributes, *m_reserve, *m_workspace);
// https://github.com/Microsoft/CNTK/blob/987b22a8350211cb4c44278951857af1289c3666/Source/ComputationNetworkLib/RNNNodes.cpp

//  *m_transposedDOutput is "outputDY"

// in GPUMatrix.cu : 
...RNNBackwardData(const GPUMatrix<ElemType>& outputDY, ...)
//...
    m_rnnExecutor->BackwardDataCore(*this, outputDY, 

// in CuDnnRNN.cpp : 
// <a target='_blank' rel='noopener noreferrer' href='https://github.com/Microsoft/CNTK/blob/987b22a8350211cb4c44278951857af1289c3666/Source/Math/CuDnnRNN.cpp'>https://github.com/Microsoft/CNTK/blob/987b22a8350211cb4c44278951857af1289c3666/Source/Math/CuDnnRNN.cpp</a>
BackwardDataCore(const GPUMatrix<ElemType>& outputY, const GPUMatrix<ElemType>& outputDY,  ....
// ...
    cudnnRNNBackwardData(             // cuDNN API call
            *m_cudnn, *m_rnnT,
            (int)m_seqLength,
            yDesc.data(), outputY.Data(),
            yDesc.data(), outputDY.Data(),   // this is dyDesc, dy  (CNTK uses yDesc for both)
            ...

The Python version (RNN on MNIST dataset) says: “Since there are many outputs from the RNN, we only care about the last one.”
So I now assumed, that may be the final hidden state. That indeed has different output dimensions.
so when enabling “OUTPUT_TO_FINAL_HIDDEN_STATE_RNN” in main.cu in my code I tested that.
It runs without API errors, although the dimensions are wrong, but also same 90% error.

"dy" has these elements:  RNNseqLength * RNNhiddenSize * batchSize * (bidirectional ? 2 : 1)
"hy" has these elements:   RNNnumLayers * RNNhiddenSize * batchSize * (bidirectional ? 2 : 1)

But the problem here is: there is indeed no data for “dy” , cause when using “hy” output as input for next dense layer (fully connected) then the format of “dhy” (dhyDesc = hyDesc) only can be used for that input at dhyDesc. But for “dy” then there is no data.
but the function does not allow to be “dy” zero.
I asked it now in this thread: https://devtalk.nvidia.com/default/topic/1005271/gpu-accelerated-libraries/use-of-cudnn-rnn/post/5284778/#5284778
So maybe my softmax/loss calc is wrong?

Hallo m1,

I think the Problem is dy. I have constructed my own netz and tried with MNIST. It can get good results.

What I thought before is wrong. For cudnnRNNBackwardData and cudnnRNNBackwardweights I use dy in this form as Input:
For example

seqLength=28
batchSize = 200
hiddenSize=256
dy = [28 200 256]

From next Layer using backpropagation I get an Array of totle length 200*256(dy_). That means in dy:

dy[ 0:26, 200, 256]// set to 0
dy[27, 200, 256]= dy_// copy the array dy_ from backPropagation to the last time step of dy.

that means, the first 27200256 numbers of dy are Zero because for this RNN we dont have labels for every Iteration. and the last 1200256 are numbers I get from next layers(maybe fully connected layer). Then I use dw from cudnnRNNBackwardweights to adjust the weights of RNN.
It can esaily get more than 90% accuracy.
You can try further in your netzs.

Thank you very much! I obviously used a wrong output dimension.

I corrected the output count of the RNN by multiplying it with the seq length. So the next layer has that size. And now it works (14% classification error after 20000 training iterations; 37% after 1000) for batch size = 1.

For higher batch sizes I’m still unsure about how to transpose the batch data. I’ll try your zero-filling approach.

hi @Peter666,

Finally I now also did it for higher batch sizes:

I now transpose the x input (batchsize 16: images 28x28 width * height; where height=seqlen) to:
28 seqlen blocks: each 28x16 for input in cudnnRNNForwardInference/cudnnRNNForwardTraining.

After that I transpose the y output back to a buffer with 16 batch blocks 256*28 (hiddensize=256, seqlen=28),
so that the following fully-connected layer has again batch blocks as input.
Also I keep the original y output for bwd run of the RNN (so I don’t need to “un”-transpose it again at bwd pass).

On cudnnRNNBackwardData I used the original y output of the RNN (28 seqlen blocks).
But I transpose dy in the same way the y output would be “un”-transposed.
I tried out the zero-filling you proposed, but discarded it, cause of the results.
Its right, that you have no label for each seqlen, but you should have a label for each of the batch blocks.

Finally I also transpose the dx output of cudnnRNNBackwardData from seqlen blocks to batch blocks.

result with zero-filling seq 0…26:
Batch size: 16, iterations: 60000 Classification result: 24.76% error (used 10000 images) => 75.24% accuracy
Batch size: 200, iterations: 1000 Classification result: 13.78% error (used 10000 images) => 86.22% accuracy
Batch size: 200, iterations: 10000 Classification result: 13.20% error (used 10000 images) => 86.8% accuracy
Batch size: 200, iterations: 100000 Classification result: 4.57% error (used 10000 images) => 95.43% accuracy

result (without zero-filling):
Batch size: 16, iterations: 60000 Classification result: 1.77% error (used 10000 images) => 98.23% accuracy
Batch size: 200, iterations: 1000 Classification result: 7.02% error (used 10000 images) => 92.98% accuracy
Batch size: 200, iterations: 10000 Classification result: 1.58% error (used 10000 images) => 98.42% accuracy
Batch size: 200, iterations: 50000 Classification result: 1.13% error (used 10000 images) => 98.87% accuracy
Batch size: 200, iterations: 100000 Classification result: 1.17% error (used 10000 images) => 98.83% accuracy

Compared to Test Acc: 0.972 [url]https://medium.com/machine-learning-algorithms/mnist-using-recurrent-neural-network-2d070a5915a2[/url] some of my results are even better than that.

I’m not quite sure, whether I’ve correctly done it.
@Peter666 How many iterations did you run to get to > 90% accuracy?

ADDED:
All the tests above were using TANH mode.

other results:
LSTM cy output used as cx input on next iteration
Batch size: 200, iterations: 1000 Classification result: 2.83% error (used 10000 images) => 97.17% accuracy
Batch size: 200, iterations: 10000 Classification result: 1.06% error (used 10000 images) => 98.94% accuracy
Batch size: 200, iterations: 20000 Classification result: 0.90% error (used 10000 images) => 99.1% accuracy
Batch size: 200, iterations: 50000 Classification result: 0.96% error (used 10000 images) => 99.04% accuracy

GRU:
Batch size: 200, iterations: 1000 Classification result: 3.34% error (used 10000 images) => 96.66% accuracy

Hallo m1,

thanks so much! But I have some questions.

1.how did you adjust the dy without zerofilling? Did you feed the identical labels for every time step to cudnnRNNbackwarddata?

  1. May I ask the structure of your net? Mines is not complex. Only one RNN-layer and a fullyconnected layer. (rnn layer with hiddensize 256, number of layers 1, unidirectional)

With my own net I cannot get results as good as yours.
LSTM:
Batch size: 200, iterations: 1000 Classification => 91.07% accuracy
Batch size: 200, iterations: 10000 Classification => 94.46% accuracy
Batch size: 200, iterations: 20000 Classification => 94.95% accuracy
Batch size: 200, iterations: 50000 Classification => 95.68% accuracy

GRU:
Batch size: 200, iterations: 1000 Classification => 88.46% accuracy

Hi Peter666,

  1. I do a transpose of dy so that it has seqlen blocks instead of batch blocks:
    in my net this differential gradient data (from next layer during backpropagation) in dy is organized in batch blocks (as the input images are) and then transposed:

example with batch size: 2 seqlen: 3 width: 4 image size: width * seqlen (4x3)

block 1:
aaaa this is one horizontal line in the image
bbbb
cccc

block 2:
1234
5678
9012

this is transposed to 3 seqlen blocks:
aaaa // first horizontal line from first batch block
1234 // first horizontal line from second batch block

bbbb // seocnd …
5678

cccc // third …
9012

Example of this transpose with text “The, brown, fox, is,quick” here: [url]https://www.tensorflow.org/tutorials/sequences/recurrent[/url]

So each sequence contains all related data from all batch blocks. After running the cudnnRNNbackwarddata that result then again is transposed back as it was organized before, cause all further layers (if any) would also process data in batch blocks and not in seqlen blocks.

So basically any data entering RNN is ensured to be aligned in seqlen blocks. And any data outgoing the RNN then is ensured to be aligned in batch blocks (as any other layers process the data).

I’m not sure, whether I understand right what you mean with “feed the identical labels for every time step to cudnnRNNbackwarddata”; I do not feed any labels into the RNN. The labels are processed through a softmax layer in my net. But the RNN processes in each seqlen block one horizontal line for each batch block; and so I can simply fill up all the seqlen blocks with the data from the next layer.
On LSTM I also additionally fill cx with the output of cy from previous layer. I’m not sure about that in case of TANH, GRU, RELU, so I posted this question [url]https://devtalk.nvidia.com/default/topic/1042710/cudnn/questions-about-cxdesc-cx-cydesc-cy/[/url].

  1. the structurre of my net is also one RNN layer, a fully connected layer (10 outputs) and a softmax layer (hidden size 256, layers 1, unidrectional) No other activation.

Basically I use fully-connected and softmax as its defined here: [url]https://github.com/tbennun/cudnn-training[/url]
But with different learning rate and different weight update optimizer.

learning rate: 0.001 (decaying learning rate using Learning Rate Policy Gamma: 0.0001 Learning Rate Policy Power: 0.85)

Weight Update: Adam Optimizer applied also to the RNN weights (in same way as on the fully connected layer) Beta1= 0.9 Beta2= 0.999 no momentum

Hallo m1,

thanks so so much. You are so nice and explained so clear and I understand what you did by Transposition.

But I’m still not clear about what you did with dy. Let me start with batchSize=1.
From cudnnRNNforwardtraining I get output y in this form:

y=[28 1 256]

.
Questions:
1.Did you put the whole matrix y into the fully-connected layer and then get the output of fully-connected layer in dis form

y_fullyConnected = [28 1 10]

. Then through the softmax function I get the Output(y_softmax) in the same form of y_fullyConnected.
2.Because y_label for batchSize 1 is an 1*10 array. Did you copy numbers of the label and expand y_label to this form [28 1 10] ? Then in BackPropagation calculating the subtraction of y_softmax and y_label to get difference. For each difference of each time steps you use the same y_label. Is my understanding right?

thanks in advance!

for batch size 1 I do not any transpose at all, cause it would simply create the same data.

  1. yes the full transposed [seqlen batchsize hiddensize] is my input into the fully connected layer, so I transpose it to
    [batchsize seqlen hiddensize]
    The fully connected layer output y has only 10 outputs for each batch block, so its [1 10] on batchsize 1. It also has only 1 channel and only 1 dimension.
    See “FC2 layer” on [url]https://github.com/tbennun/cudnn-training/blob/master/lenet.cu[/url]

on softmax I also use same format as on fully-connected [1 10] on batch size 1. Its simply a pure CUDA kernel; see SoftmaxLossBackprop
on [url]https://github.com/tbennun/cudnn-training/blob/master/lenet.cu[/url]

  1. no, not with the seqlen: see kernel SoftmaxLossBackprop
    yes, for each difference of each time step I use the same y_label (cause that belongs to the same image)

Hallo m1,
finally I understand it and thanks so much for your help!