Processing image with a CUDA kernel gives me different result than a seemingly equivalent CPU function

Hello guys, I feel like I’m missing something quite basic here. I have an image data on which I’m doing a depacking operation (the details are not that important). I have a CPU version of processing and a (seemingly) equivalent GPU processing. The two functions output different results - CPU has a correct result and GPU an incorrect one.

//GPU
__global__
void depacking_kernel(uint8_t *src, uint8_t *dst)
{
  unsigned int index = blockIdx.x * blockDim.x + threadIdx.x;
  unsigned int src_it = (index * 3);
  unsigned int dst_it = (index * 4);
  

  dst[dst_it]=src[src_it];
  dst[dst_it+1]=src[src_it+1] & 0b00001111;
  dst[dst_it+2]=((src[src_it+1] & 0b11110000) >> 4) | ((src[src_it+2] & 0b00001111) << 4);
  dst[dst_it+3]=((src[src_it+2] & 0b11110000) >> 4);
}

void depack(uint8_t *srcGpu, uint8_t *dstGpu,unsigned int src_length)
{
    int num_threads = 64;
    //from every three bytes, four bytes are made
    int num_blocks = (src_length/3)/num_threads;
    

    if (num_blocks>65535)
   {
      cerr << "too many blocks";
      exit(-1);
    }

    depacking_kernel<<<num_blocks,num_threads>>>(srcGpu,dstGpu);
    
    cudaDeviceSynchronize(); 

}

//CPU
void depack_cpu (uint8_t* src, unsigned int src_datalen, uint8_t* dst,unsigned int dst_datalen)
{

  
  for (auto it=0; it<src_datalen/3;it++) 
  {
    depack_cpu_kernel(src, dst,it);
  }

  return;
}

void depack_cpu_kernel(uint8_t* src, uint8_t* dst,int index)
{
  unsigned int src_it = (index * 3);
  unsigned int dst_it = (index * 4);
  dst[dst_it]=src[src_it];
  dst[dst_it+1]=src[src_it+1] & 0b00001111;
  dst[dst_it+2]=((src[src_it+1] & 0b11110000) >> 4) | ((src[src_it+2] & 0b00001111) << 4);
  dst[dst_it+3]=((src[src_it+2] & 0b11110000) >> 4);

}

I’m at a loss at why is this happening.

full_code.zip (2.6 MB)

I don’t have OpenCV, and it should not be needed to test your claim, i.e. focusing specifically on the code you have excerpted. I modified your kernel_test.cpp as follows, to do a byte-by-byte comparison of the buffers, flagging any differences. No difference is reported:

# cat kernel_test.cpp
#include <iostream>
#include <thread>
#include <chrono>
#include <time.h>
#include <stdexcept>
#include <cstdio>
#include <cstring>
#include <cuda_runtime.h>
#include "kernel.h"

int main()
{
  int ret;

  FILE *f_image;
  void *srcGpu;
  void *dstGpu;
  void *dstCpu;

  const int width=2056;
  const int height=1504;

  const int src_len=width*height*1.5;
  const int dst_len=width*height*2;

  f_image = fopen("image.dat","rb");

  ret = cudaMallocManaged (&srcGpu, src_len, cudaMemAttachGlobal);
  if (ret != cudaSuccess)
  {
   std::cerr << "Failed to allocate cuda unified memory 1" << std::endl;
  }

  ret = cudaMallocManaged (&dstGpu, dst_len, cudaMemAttachGlobal);
  if (ret != cudaSuccess)
  {
   std::cerr << "Failed to allocate cuda unified memory 1" << std::endl;
  }

  dstCpu=malloc(dst_len);

  fread(srcGpu,src_len,1,f_image);

  depack_cpu((uint8_t*)srcGpu,src_len,(uint8_t*)dstCpu,dst_len);//CPU
  depack((uint8_t*)srcGpu,(uint8_t*)dstGpu,src_len);//GPU
  for (int i = 0; i < dst_len;i++) if (((unsigned char *)dstGpu)[i] != ((unsigned char *)dstCpu)[i]) {std::cout << "Diff at: " << i << " was: " << ((unsigned char *)dstGpu)[i] << " should be: " << ((unsigned char *)dstCpu)[i] << std::endl;  return 0;}
  fclose(f_image);

  cudaFree(srcGpu);
  cudaFree(dstGpu);
  free(dstCpu);
}

# nvcc -o test  kernel_test.cpp kernel.cu
# compute-sanitizer ./test
========= COMPUTE-SANITIZER
========= ERROR SUMMARY: 0 errors
#

My conclusion would be that the problem lies elsewhere.

Allright, thanks for a sanity check.

When I tried displaying the image I’ve had the same issue happen with openCV and SDL2. Do you happen to think of any reasons why using CUDA kernels and unified memory in the way I’m using them could somehow interfere with displaying data on the screen?

In your code (kernel_test.cpp) you had this:

  //difference
  unsigned long diff = 0;
  for (auto it = 0;it<dst_len;it++)
  {
    diff+=abs( ((uint8_t*)dstGpu)[it] - ((uint8_t*)dstCpu)[it] );
  }

  std::cout << "Difference between buffers: " << diff << std::endl;

I noticed in your posting you did not include that printout. Was the printout something like:

Difference between buffers:  0

in your case? If so that would have saved me some time had you included that. I just assumed it was some non-zero value.

It actually wasn’t, I just assumed the check was incorrect when I received your answer.

Running the code again and adding your check (with modification to display integer values) gives:

Difference between buffers: 48184800
Diff at: 522359 was: 251 should be: 11

Well, there are some differences between what I am seeing and what you are seeing. I cannot explain what you are seeing, at the moment. But if there is a difference, then obviously that would be the area to focus on. Start the debug there.

The original system was:
Jetson Orin NX, Jetpack 5.1.2 - L4t 35.4.1, cuda 11.4
With the result:

Difference between buffers: 48184800

When I tried this on another system:
Jetson AGX Xavier, Jetpack 4.6.3 - L4t 32.7.3, cuda 10.2
the result was

Difference between buffers: 166339

which was good enough to be usable.

I guess you tested this on a conventional x64 system - which resulted with difference of zero. Do you or anyone else at nvidia have any idea what’s causing this issue?

Given that the unpacking consists of pure re-shuffling of bits, there is no reason for the CPU version and the GPU version to differ. None. So anything other than zero mismatches is indicative that the code is broken somewhere. The fact that differences between different Jetson platforms are observed likewise tells use that something is very wrong.

What I would do here to debug is try to find the smallest image size that reliably exhibits mismatches. Also, make the launch configuration as small as possible; ideally down to one thread. Then start instrumenting the code with simple printf() calls. Inadvertent overlap of source and destination data? Accesses out of bounds (if they are of the off-by-one kind, automated checkers may not find them)? Eventually some data will jump out as obviously wrong, providing a starting point for deeper digging.

I stared at the code shown above for some time, and nothing in terms of undefined C++ behavior or race conditions jumped out at me. It think there is a distinct possibility that the the image gets corrupted outside the GPU processing code shown here, so keep an open mind when deciding where to instrument.

So, the conclusion:

The GPU code, when running above certain amount of blocks*threads (or maybe when the kernel gets launched amount of times), does in fact output different values. The right shift operator behaved differently to the CPU code.

Loosely inspired by forum posts:

I modified the kernel as follows:

__global__
void depacking_kernel(uint8_t *src, uint8_t *dst)
{
  unsigned int index = blockIdx.x * blockDim.x + threadIdx.x;
  unsigned int src_it = (index * 3);
  unsigned int dst_it = (index * 4);

  dst[dst_it]=src[src_it];
  dst[dst_it+1]=src[src_it+1] & 0b00001111;
  dst[dst_it+2]=(uint8_t)((((uint16_t)(src[src_it+1] & 0b11110000)) >> 4) | (((uint16_t)(src[src_it+2] & 0b00001111)) << 4));
  dst[dst_it+3]=(uint8_t)(((uint16_t)(src[src_it+2] & 0b11110000)) >> 4);
}

Which solved the issue.

Thank you @njuffa for suggesting “try to find the smallest image size that reliably exhibits mismatches” on which I managed to debug the issue.

This should not be the case. There is a caveat with right shifts in C++ (prior to C++20, from what I understand) when a signed integer of negative value is shifted to the right. When I stared at the code earlier, I noticed that src[] has an unsigned type that is narrower than int, that the standard promotions defined in the language therefore widen this data to int when evaluating the expression, and that the data shifted therefore is never negative. Therefore neither undefined nor implementation-defined behavior should occur. From the ISO C++11 standard (section 5.8 Shift operators):

The value of E1 >> E2 is E1 right-shifted E2 bit positions. If E1 has an unsigned type or if E1 has a signed type and a non-negative value, the value of the result is the integral part of the quotient of E1/2E2. If E1 has a signed type and a negative value, the resulting value is implementation-defined.

What CUDA version do you use in your work, and what is the host platform?

But have you really found the reason for the error? The minimal example is necessary to easier reason and debug the actual cause of the error. E.g. you still have 3 shifts and several logic operators in your kernel. Can you further reduce the kernel code (it can calculate something simpler than your actual kernel) and still exhibit the mismatch to the CPU code? What happens, if you increase the number of threads, but have each thread calculate the same operation? Can you somehow check for the error (e.g. by doing the shift with division) and print debug output, when the error happens?

Jetson Orin NX (16GB ram), L4T 35.4.1, CUDA 11.4.315, compiled as c++17

With the kernel modified:

__global__
void depacking_kernel(uint8_t *src, uint8_t *dst)
{
  unsigned int index = blockIdx.x * blockDim.x + threadIdx.x;
  unsigned int src_it = (index * 3);
  unsigned int dst_it = (index * 4);
  
  dst[dst_it]=src[src_it];
  dst[dst_it+1]=src[src_it+1] & 0b00001111;
  dst[dst_it+2]= (uint8_t)((((uint16_t)(src[src_it+1] & 0b11110000)) >> 4) | (((uint16_t)(src[src_it+2] & 0b00001111)) << 4));
  
  //THIS WILL NOT WORK
  dst[dst_it+3]=(uint8_t) (((uint8_t)(src[src_it+2] & 0b11110000)) >> 4);
  
  //THIS WORKS
  //dst[dst_it+3]=(uint8_t) (((uint16_t)(src[src_it+2] & 0b11110000)) >> 4);
  if (dst_it+3==7379)
  {
    printf("From %d, masked into %d did %d should have done %d\n", src[src_it+2],
                                                                 src[src_it+2] & 0b11110000,
                                                                 dst[dst_it+3],
                                                                 (uint8_t) (((uint8_t)(src[src_it+2] & 0b11110000)) >> 4)
                                                                 /*for reasons unknown,this gives correct output when inside printf*/ );
  }
}

The output is :

From 170, masked into 160 did 250 should have done 10
Diff at: 7379 of 240
Difference between buffers: 240

Basically it’s attempting to shift 160 - (msb) 1010000 (lsb) right by four places.
The correct result is 10 - (msb) 00001010 (lsb),
but under certain circumstances it arrives at 240 - (msb) 11111010 (lsb)

full_code_feb_11.zip (4.9 KB)

I am looking at the code generated by nvcc from CUDA 11.4 for an sm_87 target. The PTX code looks different from what I expected since the uint8_t data appears to be widened only to uint16_t rather than int as I expected, as the literal constant 0b11110000 should default to type int best I know. However, the PTX code does not look wrong at first glance. I will look some more at this and then compare to the SASS generated from the PTX.

Even if there is some compiler issue, the next question is does this work correctly in the current CUDA version 12.3?

ld.global.u8    %rs1, [%rd6];    // dst[dst_it]
st.global.u8    [%rd8], %rs1;

ld.global.u8    %rs2, [%rd6+1];  // dst[dst_it+]
and.b16         %rs3, %rs2, 15;
st.global.u8    [%rd8+1], %rs3;

ld.global.u8    %rs4, [%rd6+1];  // dst[dst_it+2]
shr.u16         %rs5, %rs4, 4;
ld.global.u8    %rs6, [%rd6+2];
shl.b16         %rs7, %rs6, 4;
or.b16          %rs8, %rs7, %rs5;
st.global.u8    [%rd8+2], %rs8;

ld.global.u8    %rs9, [%rd6+2];  // dst[dst_it+3]
shr.u16         %rs10, %rs9, 4;
st.global.u8    [%rd8+3], %rs10;
1 Like

Cuda 12 is planned for Jetpack 6 , which has a stable release planned for March 2024. I’m using Jetpack 5.1.2, as Jetpack 6 only has a developer preview which I had issues flashing onto the Jetson. (Jetpack is the OS package for the Jetsons)

In other words I personally cannot test it right now, but I will be able in the future.

Given that the masking in the code is actually redundant with the subsequent shifts, as the result goes back into a uint8_t, the PTX generated by CUDA 11.4.3 seems correct. The generated machine code (SASS) looks correct, too. I am a little bit irritated by the use of the .HI variant of SHF (funnel shift), but pxtas seems to be using that pattern consistently across architectures and CUDA versions.

LDG.E.U8 R7, [R2.64]                 // src[src_it]
STG.E.U8 [R4.64], R7                 // dst[dst_it]

LDG.E.U8 R0, [R2.64+0x1]             // src[src_it+1]
LOP3.LUT R9, R0, 0xf, RZ, 0xc0, !PT  // src[src_it+1] & 15
STG.E.U8 [R4.64+0x1], R9             // dst[dst_it+1]

LDG.E.U8 R6, [R2.64+0x2]             // src[src_it+2]
LDG.E.U8 R0, [R2.64+0x1]             // src[src_it+1]
IMAD.SHL.U32 R11, R6, 0x10, RZ       // src[src_it+2] * 0x10
SHF.R.U32.HI R0, RZ, 0x4, R0         // src[src_it+1] >> 4
LOP3.LUT R11, R11, R0, RZ, 0xfc, !PT // (src[src_it+2] * 0x10) | (src[src_it+1] >> 4) 
STG.E.U8 [R4.64+0x2], R11            // dst[dst_it+2]

LDG.E.U8 R0, [R2.64+0x2]             // src[src_it+2]  
SHF.R.U32.HI R7, RZ, 0x4, R0         // src[src_it+2] >> 4
STG.E.U8 [R4.64+0x3], R7             // dst[dst_it+3]

For what it is worth, the original kernel code and the modified kernel code (with the additional uint16_t casts) compile to exactly the same PTX code, as expected. In conclusion, I am unable to reproduce the issue seen with the original post.

Now, I ran with a toolchain hosted on an x86-64 system, so maybe there is an issue with the ARM hosted toolchain? I have not used NVIDIA’s embedded platforms. When developing for the Jetson Orin NX, does that involve cross-compilation, with the toolchain running on an x86-64 system, or does that use a toolchain natively hosted on the Jetson? I assume it is the latter.

It’s the latter.

Given that CUDA 12 is not available to you (yet), you may want to consider filing a bug with NVIDIA. There is nothing wrong with your original kernel code (all operations are well defined) and those additional uint16_t casts you used in the workaround should not make a difference. But since you observe that these casts do make a difference, that would appear to point at a compiler bug of some sort, which may or may not have been reported by someone else.

Could you resolve the inner paranthesis (= before the shift)? E.g. directly put 0xA0 there with different types, e.g. int8_t, uint8_t, uint16_t, int32_t?
Which types give 10, which types give 240 after the shift?
Could you assign the result of the AND operation (without or with cast to uint8_t) before the shift to a int32_t variable and print it? Is it 160 or is it -96? The printf with ’ %d’ could mask this. I would not fully trust the %d of printf as much as I would trust an assignment to int32_t (and afterwards the printf with %d).

Can you share the (relevant) .ptx and SASS output of your toolchain for uint8_t and uint16_t?