We have observed an interesting behavior from the cuda runtime. We are trying to run some computation on two concurrent streams. What we observe is that when some kernel X is executing on stream A, we can’t concurrently run a new kernel (i.e. a kernel which hasn’t been run previously within the cuda context) on stream B. In our particular application, this leads to a deadlock. A minimal script to reproduce this issue (tested on H100-NVL) is attached along. If the kernel warmup step line 223-233 is skipped it leads to a deadlock. Is this issue something fundamental? Is there some way to work around this? Thanks!
#include <cuda_runtime.h>
#include <iostream>
#include <thread>
#include <mutex>
#include <chrono> // For sleep
#include <string> // For std::string
// Error checking macro
#define CHECK_CUDA_ERROR(val) check_cuda((val), #val, __FILE__, __LINE__)
inline void check_cuda(cudaError_t result, const char* func, const char* file, int line) {
if (result != cudaSuccess) {
std::cerr << "CUDA error at " << file << ":" << line << " code=" << static_cast<unsigned int>(result)
<< " (" << cudaGetErrorName(result) << ") \"" << func << "\" \n"
<< cudaGetErrorString(result) << std::endl;
// cudaDeviceReset(); // Optional: Reset device state on fatal error
exit(EXIT_FAILURE);
}
}
// --- Synchronization Kernels ---
// Kernel to wait for the flag to become 1 (acquire lock)
__global__ void lockKernel(int* gpu_lock_flag) {
// Only one thread needs to poll
if (threadIdx.x == 0 && blockIdx.x == 0) {
printf("Lock kernel waiting for flag at time %lld\n", clock64());
// Use volatile to prevent compiler caching the flag value in registers
// and ensure we re-read from global memory.
volatile int* v_gpu_lock_flag = gpu_lock_flag;
// Spin-wait until the flag is set to 1
while (*v_gpu_lock_flag == 0) {
// Yielding or adding delays here is complex on GPU and often
// less efficient than simple spinning for short waits.
// For very long waits, this GPU-side spin is inefficient.
}
*v_gpu_lock_flag = 0;
printf("Lock kernel saw flag set (flag=%d) at time %lld\n", *v_gpu_lock_flag, clock64());
// Memory fence: Ensures that the read of the flag happens, and is visible
// globally, before any subsequent memory operations *in later kernels on the same stream*
// are allowed to start. This guarantees that we see the sender's data *after*
// seeing the flag change.
__threadfence();
}
}
// Kernel to set the flag to 1 (release lock)
__global__ void unlockKernel(int* gpu_lock_flag) {
// Only one thread needs to set the flag
if (threadIdx.x == 0 && blockIdx.x == 0) {
// Memory fence: Ensures that all previous global memory writes
// *from preceding kernels on the same stream* (i.e., the sender's work)
// are visible globally before the flag is set.
__threadfence();
printf("Unlock kernel setting flag to 1 at time %lld\n", clock64());
// Set the flag to signal completion/unlock
*gpu_lock_flag = 1;
// Optional: Another fence here ensures the write to the flag itself is globally visible.
// __threadfence(); // Often implicitly handled by kernel launch semantics, but can be added for extra safety.
}
}
// --- Computation Kernels (Simplified) ---
// Simple computation kernel - Sender
__global__ void computeKernelSender(int* data, int size, int value) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx == 0 && blockIdx.x == 0 && threadIdx.x == 0) { // Print only once
printf("Sender compute kernel starting work (value: %d) at time %lld\n", value, clock64());
}
// --- Main Kernel Work ---
if (idx < size) {
// Simulate work
for (int i = 0; i < 100000; i++) { // Reduced loop iterations for faster testing
data[idx] += value; // Use atomicAdd if multiple threads could write same location
}
}
// --- Synchronization logic removed ---
__syncthreads(); // Wait for all threads in the block to finish *this* kernel's work
if (idx == 0 && blockIdx.x == 0 && threadIdx.x == 0) { // Print only once
printf("Sender compute kernel finished work (value: %d) at time %lld\n", value, clock64());
}
}
// Simple computation kernel - Receiver
__global__ void computeKernelReceiver(int* data, int size, int value) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// --- Synchronization logic removed ---
if (idx == 0 && blockIdx.x == 0 && threadIdx.x == 0) { // Print only once
printf("Receiver compute kernel starting work (value: %d) at time %lld\n", value, clock64());
}
// --- Main Kernel Work ---
if (idx < size) {
// Simulate work
for (int i = 0; i < 100000; i++) { // Reduced loop iterations
data[idx] += value; // Use atomicAdd if multiple threads could write same location
}
}
__syncthreads(); // Wait for all threads in the block to finish *this* kernel's work
if (idx == 0 && blockIdx.x == 0 && threadIdx.x == 0) { // Print only once
printf("Receiver compute kernel finished execution (value: %d) at time %lld\n", value, clock64());
}
}
// Mutex for thread-safe printing
std::mutex printMutex;
// Thread-safe printing function
void safePrint(const std::string& message) {
std::lock_guard<std::mutex> lock(printMutex);
std::cout << message << std::endl;
}
// Thread function for the sender
void senderThread(cudaStream_t stream, int* d_lock_flag, int* d_shared_data) {
safePrint("Sender thread starting.");
CHECK_CUDA_ERROR(cudaSetDevice(0)); // Good practice to set device in each thread
const int dataSize = 256; // Example data size
const int blockSize = 128;
int numBlocks = (dataSize + blockSize - 1) / blockSize;
const int senderValue = 10;
// Optional: Allocate specific data for this thread if needed, or use shared data
// int* d_senderData;
// CHECK_CUDA_ERROR(cudaMalloc(&d_senderData, dataSize * sizeof(int)));
// CHECK_CUDA_ERROR(cudaMemsetAsync(d_senderData, 0, dataSize * sizeof(int), stream));
safePrint("Sender thread: Sleeping for 10 seconds..."); // Shorter sleep
std::this_thread::sleep_for(std::chrono::seconds(10));
safePrint("Sender thread: Woke up.");
safePrint("Sender thread: Launching sender compute kernel...");
computeKernelSender<<<numBlocks, blockSize, 0, stream>>>(
d_shared_data, dataSize, senderValue);
CHECK_CUDA_ERROR(cudaGetLastError());
safePrint("Sender thread: Launching unlock kernel...");
unlockKernel<<<1, 1, 0, stream>>>(d_lock_flag); // Launch unlock *after* compute
CHECK_CUDA_ERROR(cudaGetLastError());
safePrint("Sender thread: Kernels launched.");
// Wait for this stream's completion before exiting thread
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
safePrint("Sender thread: Stream synchronized, exiting.");
// CHECK_CUDA_ERROR(cudaFree(d_senderData)); // Free thread-specific data if allocated
}
// Thread function for the receiver
void receiverThread(cudaStream_t stream, int* d_lock_flag, int* d_shared_data) {
safePrint("Receiver thread starting.");
CHECK_CUDA_ERROR(cudaSetDevice(0)); // Good practice to set device in each thread
const int dataSize = 256; // Example data size
const int blockSize = 128;
int numBlocks = (dataSize + blockSize - 1) / blockSize;
const int receiverValue = 20;
// Optional: Allocate specific data for this thread if needed, or use shared data
// int* d_receiverData;
// CHECK_CUDA_ERROR(cudaMalloc(&d_receiverData, dataSize * sizeof(int)));
// CHECK_CUDA_ERROR(cudaMemsetAsync(d_receiverData, 0, dataSize * sizeof(int), stream));
safePrint("Receiver thread: Launching lock kernel...");
lockKernel<<<1, 1, 0, stream>>>(d_lock_flag); // Launch lock *before* compute
CHECK_CUDA_ERROR(cudaGetLastError());
safePrint("Receiver thread: Launching receiver compute kernel...");
computeKernelReceiver<<<numBlocks, blockSize, 0, stream>>>(
d_shared_data, dataSize, receiverValue);
CHECK_CUDA_ERROR(cudaGetLastError());
safePrint("Receiver thread: Kernels launched.");
// Wait for this stream's completion before exiting thread
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
safePrint("Receiver thread: Stream synchronized, exiting.");
// CHECK_CUDA_ERROR(cudaFree(d_receiverData)); // Free thread-specific data if allocated
}
int main() {
CHECK_CUDA_ERROR(cudaSetDevice(0)); // Set device for main thread
cudaStream_t streamSender, streamReceiver;
// Using non-blocking streams allows potential concurrency between streams if hardware supports it
// and work is available on both. Here, dependency forces sequential execution for the lock/unlock part.
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&streamSender, cudaStreamNonBlocking));
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&streamReceiver, cudaStreamNonBlocking));
// --- Allocate the shared GPU lock flag ---
int* d_lock_flag = nullptr;
CHECK_CUDA_ERROR(cudaMalloc(&d_lock_flag, sizeof(int)));
// Initialize the flag to 0 (locked state) synchronously or on one stream before use.
// Doing it synchronously here is simplest.
int initial_flag_value = 0;
CHECK_CUDA_ERROR(cudaMemcpy(d_lock_flag, &initial_flag_value, sizeof(int), cudaMemcpyHostToDevice));
// Alternatively, use cudaMemsetAsync on one of the streams before launching kernels that depend on it.
// CHECK_CUDA_ERROR(cudaMemsetAsync(d_lock_flag, 0, sizeof(int), streamReceiver)); // Init on receiver stream
// --- Allocate shared data buffer (optional, for kernels to interact) ---
const int dataSize = 256;
int* d_shared_data = nullptr;
CHECK_CUDA_ERROR(cudaMalloc(&d_shared_data, dataSize * sizeof(int)));
CHECK_CUDA_ERROR(cudaMemset(d_shared_data, 0, dataSize * sizeof(int))); // Sync memset ok here
std::cout << "Starting sender and receiver threads..." << std::endl;
// warmup the lock and unlock
unlockKernel<<<1, 1, 0, streamSender>>>(d_lock_flag);
CHECK_CUDA_ERROR(cudaGetLastError());
lockKernel<<<1, 1, 0, streamSender>>>(d_lock_flag);
CHECK_CUDA_ERROR(cudaGetLastError());
// warmup the sender and receiver
computeKernelSender<<<1, 1, 0, streamSender>>>(d_shared_data, 1, 20);
CHECK_CUDA_ERROR(cudaGetLastError());
computeKernelReceiver<<<1, 1, 0, streamReceiver>>>(d_shared_data, 1, 20);
CHECK_CUDA_ERROR(cudaGetLastError());
// sync
CHECK_CUDA_ERROR(cudaStreamSynchronize(streamSender));
CHECK_CUDA_ERROR(cudaStreamSynchronize(streamReceiver));
// Launch threads concurrently
// Pass the streams, the lock flag, and any shared data to the threads
std::thread t1(senderThread, streamSender, d_lock_flag, d_shared_data);
std::thread t2(receiverThread, streamReceiver, d_lock_flag, d_shared_data);
// Wait for threads to complete
t1.join();
t2.join();
std::cout << "All threads completed." << std::endl;
// (Optional) Verify results by copying d_shared_data back to host
int* h_data = new int[dataSize];
CHECK_CUDA_ERROR(cudaMemcpy(h_data, d_shared_data, dataSize * sizeof(int), cudaMemcpyDeviceToHost));
std::cout << "Result data[0] = " << h_data[0] << " (Expected something like 10+20 = 30 if dataSize=1 and loops run once)" << std::endl;
delete[] h_data;
// Clean up
CHECK_CUDA_ERROR(cudaFree(d_shared_data));
CHECK_CUDA_ERROR(cudaFree(d_lock_flag)); // Free the lock flag
CHECK_CUDA_ERROR(cudaStreamDestroy(streamSender));
CHECK_CUDA_ERROR(cudaStreamDestroy(streamReceiver));
std::cout << "Resources cleaned up." << std::endl;
// Optional: Reset device state at the very end
// CHECK_CUDA_ERROR(cudaDeviceReset());
return 0;
}