CTC(Connectionist temporal classification) example code

Where is example code of the cudnnCTCLoss() API in cuDNN 7 ?

////////////////////////////////////////////////////////////////////////////////
// Example of using CUDNN implementation of CTC
// This example was written and tested against CUDNN v7.05
// If you have used other implementations of CTC loss and gradient calculations
// bear the following in mind:
// 1. Maximum number of labels in any mini-batch member = 256
// 2. The number of time steps (for a given minibatch sample) must be at least
// the number of labels plus the number of repeats)
// 3. You must provide probabilities which have been calculated by softmax
// 4. The gradient values are with respect to pre-softmax activations and not
// with respect to the probabilities
////////////////////////////////////////////////////////////////////////////////
#include
#include
#include “cudnn.h”

using namespace std;

#define CUDNN_CALL(func)
{
auto e = (func);
if (e != CUDNN_STATUS_SUCCESS)
{
cerr << "cuDNN error in " << FILE << “:” << LINE;
cerr << " : " << cudnnGetErrorString(e);
}
}

#define CUDA_CALL(func)
{
auto e = (func);
if (e != cudaSuccess)
{
cerr << "CUDA error in " << FILE << “:” << LINE;
cerr << " : " << cudaGetErrorString(e);
}
}

int main()
{
// Define some minibatch paramaters (the values are purely exemplary)
const int B = 5; // Batch size
const int T = 100; // Number of time steps (must exceed L + R, where R is the number of repeats)
const int A = 10; // Alphabet size (number of different labels plus blank)
const int L = 4; // Number of labels (*** must not exceed 256)
const int D = 3; // Dimensions of tensor data

// Test data
// Cudnn CTC uses label value of 0 for “blank”
int simpleLabels = { 1, 2, 3, 4 };
vector flatLabels;
vector labelLengths;
vector inputLengths;

// Create label data and input lengths
for (int b = 0; b < B; b++)
{
for (int l = 0; l < L; l++)
flatLabels.push_back(simpleLabels[l]);
labelLengths.push_back(L);
inputLengths.push_back(T);
}

// Create cudnn library handle
cudnnHandle_t handle;
CUDNN_CALL(cudnnCreate(&handle));

// Create probability tensor descriptor and data
cudnnTensorDescriptor_t gpuProbsDesc;
float* gpuProbsData;
{
CUDNN_CALL(cudnnCreateTensorDescriptor(&gpuProbsDesc));
const int dims = {T, B, A};
const int strides = {B*A, A, 1};
CUDNN_CALL(cudnnSetTensorNdDescriptor(gpuProbsDesc, CUDNN_DATA_FLOAT, D, dims, strides));
CUDA_CALL(cudaMallocManaged(&gpuProbsData, sizeof(float) * A * T));

// Fill the probs tensor with values. The probabilities are laid out TxBxA.
// Here we fill in uniform probabilities, but these should be
// the values output from softmax().
// *** When you write your own code you must use the outputs from softmax()
// *** because the CTC loss function calculates gradients with respect to
// *** activations which it assumes were used to generate softmax probabilities
for (size_t t = 0; t < T; t++)
{
  for (size_t b = 0; b < B; b++)
  {
    for (size_t a = 0; a < A; a++)
      gpuProbsData[t*B*A + b*A + a] = 1.0/A;
  }
}

}

// Create gradient tensor descriptor and data
cudnnTensorDescriptor_t gpuGradsDesc;
float* gpuGradsData;
{
CUDNN_CALL(cudnnCreateTensorDescriptor(&gpuGradsDesc));
const int dims = {T, B, A};
const int strides = {B*A, A, 1};
CUDNN_CALL(cudnnSetTensorNdDescriptor(gpuGradsDesc, CUDNN_DATA_FLOAT, 3, dims, strides));
CUDA_CALL(cudaMalloc(&gpuGradsData, sizeof(float) * T * B * A));
// Zero out gradients
CUDA_CALL(cudaMemset(gpuGradsData, 0, sizeof(float) * T * B * A));
}

// Create loss descriptor
cudnnCTCLossDescriptor_t ctcLossDesc;
CUDNN_CALL(cudnnCreateCTCLossDescriptor(&ctcLossDesc));
CUDNN_CALL(cudnnSetCTCLossDescriptor(ctcLossDesc, CUDNN_DATA_FLOAT));

// Get gpuWorkSpace size
size_t gpuWorkSpaceSize;
CUDNN_CALL(cudnnGetCTCLossWorkspaceSize(handle,
gpuProbsDesc,
gpuGradsDesc,
flatLabels.data(),
labelLengths.data(),
inputLengths.data(),
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
ctcLossDesc,
&gpuWorkSpaceSize));

// Allocate gpuWorkSpace
void *gpuWorkSpace;
CUDA_CALL(cudaMalloc(&gpuWorkSpace, gpuWorkSpaceSize));

// Allocate memory for losses
float *gpuLosses;
CUDA_CALL(cudaMalloc(&gpuLosses, sizeof(float) * B));

// Run CTC calculation
auto status = cudnnCTCLoss(handle,
gpuProbsDesc,
gpuProbsData,
flatLabels.data(),
labelLengths.data(),
inputLengths.data(),
gpuLosses,
gpuGradsDesc,
gpuGradsData,
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
ctcLossDesc,
gpuWorkSpace,
gpuWorkSpaceSize);

switch(status)
{
case CUDNN_STATUS_SUCCESS:
break;
case CUDNN_STATUS_BAD_PARAM:
if (L > 256)
cout << “Too many labels (limit is 256)” << endl;
else
cout << cudnnGetErrorString(status) << endl;
break;
case CUDNN_STATUS_NOT_SUPPORTED:
case CUDNN_STATUS_EXECUTION_FAILED:
cout << cudnnGetErrorString(status) << endl;
break;
}

// Optionally, copy losses back to host (you probably will want to do this for logging purposes)
float cpuLosses[B];
CUDA_CALL(cudaMemcpy(&cpuLosses, gpuLosses, B * sizeof(float), cudaMemcpyDeviceToHost));
for (int b = 0; b < B; b++)
cout << “Loss[” << b << "] = " << cpuLosses[b] << endl;

// Optionally, copy grads back to host (you probably won’t want to do this)
// *** These values are with respect to pre-softmax activations
float cpuGrads = (float)calloc(sizeof(float), T * B * A);
CUDA_CALL(cudaMemcpy(cpuGrads, gpuGradsData, sizeof(float) * B * A * T, cudaMemcpyDeviceToHost));

// Free resources
cudaFree(gpuLosses);
cudaFree(gpuWorkSpace);
cudaFree(gpuProbsData);
cudaFree(gpuGradsData);
CUDNN_CALL(cudnnDestroyCTCLossDescriptor(ctcLossDesc));
CUDNN_CALL(cudnnDestroyTensorDescriptor(gpuProbsDesc));
CUDNN_CALL(cudnnDestroyTensorDescriptor(gpuGradsDesc));
CUDNN_CALL(cudnnDestroy(handle));

exit(0);
}

谢谢

不用谢

I tested your code with CUDA 9.0 (using CUDNN 7) on Win10 64bit and get this output on batch size B=5

Loss[0] = inf
Loss[1] = inf
Loss[2] = inf
Loss[3] = inf
Loss[4] = inf

when using batch size B = 1:
Loss[0] = 203.984

then I changed line “CUDA_CALL(cudaMallocManaged(&gpuProbsData, sizeof(float) * A * T));”
into CUDA_CALL(cudaMallocManaged(&gpuProbsData, sizeof(float) * A * B * T));

and output is:
Loss[0] = 203.984
Loss[1] = 203.984
Loss[2] = 203.984
Loss[3] = 203.984
Loss[4] = 203.984

is this the expected output?

Your code correction (adding the minibatch size) looks correct to me and my output values agree with yours.
Apologies for introducing an error when I pasted code into the web page.

@ccorfield007

I have another question to your code: I’m a bit unsure, which input values would receive which data:

On [url]https://github.com/stardut/ctc-ocr-tensorflow/blob/master/train.py[/url] I found Python scripts, which use a RNN and a CTC loss.
The inputs are images 28x28 pixels. Test Pictures are 224x28 (so 8 handwritten MNIST digits is a word)
The output format of the RNN is [batchsize seqlen hiddensize].

input_size = 28 ok, its the input for RNN; known
num_class = 10 + 1 => A = Alphabet size (number of different labels plus blank); right?
num_layers = 2 ok, its for RNN; known where to use
seq_len = 28 => seqlen from RNN right?
batch_size = 64 => B = 64 ok, known
num_units = 256 => (num_units its the hiddensize from RNN, right?)
word_size = 8 => target labels = L = Number of labels right?

step = 100000 => T = number of time steps; right?

the model is defined here: [url]https://github.com/stardut/ctc-ocr-tensorflow/blob/master/model.py[/url]
RNN + CTC
instead of the transposes I added a fully-connected layer with 10 outputs and a softmax activation after RNN.
alphabet size A is then the number of the outputs (here =10) of that layer, right?
And L is 1 when running on word_size=1, right ? And T = 1 also then for one 28x28 image.

thank you!

Hi m1:

In my reading of the code on github, you are basically correct with a couple of corrections.

The number of classes in CTC is the number of different symbols/labels plus one for blank. Different implementations of CTC choose different values for blank, sometimes it is zero and sometimes it is #classes-1 (the highest symbol/label value). CTC handles messy problems which for which you would rather use NLL/CrossEntropy, except that you don’t know in advance how the labels line up against the input data. In the case here, there are 8 labels, or digits, but the precise alignment to image data is not known, even though each digit does live somewhere within its image of 28x28 pixels.

As it relates to the code on this web page, A = 11 (ten digits plus blank), B = batch_size= 64, L = number of digits = 8, T = 28*8 and not 100,000 which is the number of training iterations or steps. If you were decoding one digit at a time you would use L = 1, T = 28.

In terms of overall model design, it is better to ask the author of the model(s) on github, since this bulletin board is really for nVidia-related questions and some of your questions are about TensorFlow components. In terms of understanding CTC you might want to check out Baidu’s “Warp CTC” on github. It has a bit of a tutorial. It appears that the CTC algorithm is not an easy one to understand, indeed nVidia’s implementation suggests that whoever implemented it within nVidia is/was less than comfortable too.

Hi ccorfield007,

thank you very much for your answer. It is very helpful.

Since I don’t really want to use TensorFlow, I try to use cuDNN with RNN and CTC.
Do you know of C++ example code where a cuDNN CTC is directly connected to a RNN layer? Unfortunately I haven’t found any after many searches.
I will also go through the python example and also ask the author.

Thank you very much!

Hi m1:

Off the top of my head, I can’t point you at straight forward C-code examples using CuDNN, RNNs and CTC. I would suggest searching github, but you’ve obviously done that already. My advice for prototyping a network is to do it in a higher level frame work. Currently, TensorFlow has the most users, but (to me) it is a mess. PyTorch is newer and (to me) feels a lot cleaner. It conveniently runs inside Eclipse/PyDev SDK and has numerous prebuilt widgets including recurrent layers backed onto CuDNN. It does not yet have visualization tools, such as TensorBoard, but I have not found that to be a hindrance to progress. You are not losing much performance with the higher level python scripting, and if you want to migrate to C, NVIDIA’s libraries and interfaces are well laid out, and work well with debugging tools such as nsight or cuda-gdb (nsight is a GUI on top of cuda-gdb and other command line tools). The PyTorch 1.0 RC (“Release Candidate”) includes an interface to NVIDIA’s implementation of CTC, so the notes on this page may help in you with a PyTorch implementation.

Happy coding!

Thank you very much for the information.

I ended up in setting up Tensorflow 1.5.0 +openCV and can run the sample [url]https://github.com/stardut/ctc-ocr-tensorflow[/url].
Unfortunately it looses lot of performance. Tensorflow on my system only uses CPU, but my C++ code can use GPU. And so the Python version is much slower. But I think I figured out, which model is used there. So I applied it to my C++ code.
I did also some tests. See [url]https://github.com/stardut/ctc-ocr-tensorflow/issues/2[/url]. However, I’m still not sure, whether I did it right.

On learning rate 0.001 the CTCLoss fluctuates as it should I think (yet it increases in the end, but that seems to be a tweaking problem, right?)

But on a learning rate 0.01 the (cuDNN) CTCLoss turns to “inf” and MaxGradient is constantly 1.0. (on same model + input data)
So how can this happen?

Dear m1:

Due to time constraints I have not analyzed your comments on github or tried to run your code, so the following comments may not be quite on target.

Here are some general thoughts on how you can approach your network:

Your image – 8 handwritten digits – is a sequence which measures 8*28 x 28 (8 concatenated images of dimensions 28x28). You can think of this as a temporal sequence and use temporal, or 1D, convolutions.

Temporal convolutions are a little different from the usual 2D image processing convolutions. With temporal convolutions, pick a kernel width (a small integer, such as 3,5,7), define the # of incoming channels to be 28. If you have not used temporal convolutions, setting the number of channels to the height of your images may seem strange, but the logic is that you will get the benefit of a “dense” (in TF terms) or “linear” (in PT terms) layer across all the rows in your images.

At the beginning of your network, place several temporal convolution layers, each followed by a ReLU (or your favorite non-linearity) and a normalization layer (batch is the most commonly used, but there are others). The first convolutional layer will have something like channelsIn=28, channelsOut=256 (or whatever internal dimension you are using for your network). The other convolutional layers can all be channelsIn=256, channelsOut=256.

After the convolutions, add your RNN layers. Make sure that you have rearranged your data dimensions appropriately – a classic mistake is one type of layer wants BxTxF and another wants TxBxF and you forget to transpose dimensions.

After your recurrent layers, place a dense/linear layer to project from the network’s internal dimension (e.g., 256) down to your output dimension = 11 = #digits + blank. Rearrange the data so that it is TxBxF, which is what the CTC loss function (usually) expects. Make sure that you know what value your CTC loss function uses for blank, it will either be zero or #labels-1. When you train a CTC network, the first class it learns to predict is blank, so you should find the network’s output for the blank class rises rapidly while the other classes stay low. Once the blank class has peaked, the outputs for the other classes will start to rise.

After the dense/linear layer you may need a softmax layer. This depends on what toolkit you are using. If you are writing to NVIDIA’s CTC you definitely need a softmax layer (see below). If you are writing to Baidu’s WarpCTC you don’t. If you are using TF or PT, check the documentation.

In your C-code implementation, be warned that NVIDIA’s CTC function is a little strange. The usual behavior for functions which calculate losses and gradients is to return the gradients with respect to the activations you gave the loss function. NVIDIA’s CTC loss function is asymmetric, it takes softmax probabilities and returns gradients with respect to the pre-softmax activations, this means that your C-code needs to include a softmax function to generate the values for NVIDIA’s CTC function, but you back propagate the returned gradients through the layer just before the softmax. If you use Baidu’s Warp CTC it is symmetric, it takes the activations from your net and returns gradients with respect to those activation (it does the softmax calculation internally).

When training networks, especially ones including recurrent layers or CTC, you may encounter the classic problems of vanishing or exploding gradients. Vanishing gradients can be addressed by using “Residual Net” techniques – sum the input of a layer with its output (where layer = transformation + nonlinearity + normalization). Exploding gradients can be handled by using gradient clipping – intercept the gradients before calling your optimizer and try one of the (many) gradient clipping techniques out there.

Your bonus project will be to replace the recurrent layers with attention layers, which will yield equal or better results.

Hi ccorfield007,

“NVIDIA’s CTC loss function is asymmetric”
This was exactly what I did wrong. Thank you very much for your nice explanation!!!

So I correctly use the fully-connected layer (28*11=308outputs) using B=64 T=28, A=11 and then transposed to TxBxA for CTCloss. This was the major problem I think. And its solved now! Great.
And Baidu’s WarpCTC (added to my C++ project) never gave me “inf” results and finally the CTCloss fluctuates on low values:

LearningRate:=0.0001 Adam Decaying Learning Rate GammaPolicy=0.0001 PowerPolicy=0.75
iteration= 17950 CTCLoss=4.301215 MaxGradient=0.969473
iteration= 29930 CTCLoss=2.660810 MaxGradient=0.937475
iteration= 38310 CTCLoss=4.197363 MaxGradient=0.901569
iteration= 54870 CTCLoss=2.681414 MaxGradient=0.952714
iteration= 66470 CTCLoss=3.438653 MaxGradient=0.922492
iteration= 72950 CTCLoss=2.838113 MaxGradient=0.899031
iteration= 84940 CTCLoss=3.332845 MaxGradient=0.934275
iteration= 90730 CTCLoss=3.135019 MaxGradient=0.946567
iteration= 99080 CTCLoss=3.443978 MaxGradient=0.945578

however on inference then always CTC score is: 3.668164 => chosen=4 which is still wrong.
But I think the training system itself is working correctly; I will discard this image-based sample for now.

I will try out audio input (then of course also with conv layers) and variable sequences, updating always the labelLengths array according to the label count. Then I will tweak the settings using vanish/exploding gradient techniques.

Your information is very helpful for me. So thank you again for all your help!!!

One other thing to check: You wrote T=28, but it might need to be T=288 to reflect that the total number of “frames” or “time steps” across the width of the combined image: T = width = 828, F = #channels = #planes = height = 28. All the best with your project.