Deadlock on cuda kernel launch

We have observed an interesting behavior from the cuda runtime. We are trying to run some computation on two concurrent streams. What we observe is that when some kernel X is executing on stream A, we can’t concurrently run a new kernel (i.e. a kernel which hasn’t been run previously within the cuda context) on stream B. In our particular application, this leads to a deadlock. A minimal script to reproduce this issue (tested on H100-NVL) is attached along. If the kernel warmup step line 223-233 is skipped it leads to a deadlock. Is this issue something fundamental? Is there some way to work around this? Thanks!

#include <cuda_runtime.h>
#include <iostream>
#include <thread>
#include <mutex>
#include <chrono> // For sleep
#include <string> // For std::string

// Error checking macro
#define CHECK_CUDA_ERROR(val) check_cuda((val), #val, __FILE__, __LINE__)
inline void check_cuda(cudaError_t result, const char* func, const char* file, int line) {
    if (result != cudaSuccess) {
        std::cerr << "CUDA error at " << file << ":" << line << " code=" << static_cast<unsigned int>(result)
                  << " (" << cudaGetErrorName(result) << ") \"" << func << "\" \n"
                  << cudaGetErrorString(result) << std::endl;
        // cudaDeviceReset(); // Optional: Reset device state on fatal error
        exit(EXIT_FAILURE);
    }
}

// --- Synchronization Kernels ---

// Kernel to wait for the flag to become 1 (acquire lock)
__global__ void lockKernel(int* gpu_lock_flag) {
    // Only one thread needs to poll
    if (threadIdx.x == 0 && blockIdx.x == 0) {
         printf("Lock kernel waiting for flag at time %lld\n", clock64());
        // Use volatile to prevent compiler caching the flag value in registers
        // and ensure we re-read from global memory.
        volatile int* v_gpu_lock_flag = gpu_lock_flag;

        // Spin-wait until the flag is set to 1
        while (*v_gpu_lock_flag == 0) {
            // Yielding or adding delays here is complex on GPU and often
            // less efficient than simple spinning for short waits.
            // For very long waits, this GPU-side spin is inefficient.
        }
        *v_gpu_lock_flag = 0;

        printf("Lock kernel saw flag set (flag=%d) at time %lld\n", *v_gpu_lock_flag, clock64());

        // Memory fence: Ensures that the read of the flag happens, and is visible
        // globally, before any subsequent memory operations *in later kernels on the same stream*
        // are allowed to start. This guarantees that we see the sender's data *after*
        // seeing the flag change.
        __threadfence();
    }
}

// Kernel to set the flag to 1 (release lock)
__global__ void unlockKernel(int* gpu_lock_flag) {
    // Only one thread needs to set the flag
    if (threadIdx.x == 0 && blockIdx.x == 0) {
        // Memory fence: Ensures that all previous global memory writes
        // *from preceding kernels on the same stream* (i.e., the sender's work)
        // are visible globally before the flag is set.
        __threadfence();

        printf("Unlock kernel setting flag to 1 at time %lld\n", clock64());
        // Set the flag to signal completion/unlock
        *gpu_lock_flag = 1;

        // Optional: Another fence here ensures the write to the flag itself is globally visible.
        // __threadfence(); // Often implicitly handled by kernel launch semantics, but can be added for extra safety.
    }
}

// --- Computation Kernels (Simplified) ---

// Simple computation kernel - Sender
__global__ void computeKernelSender(int* data, int size, int value) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx == 0 && blockIdx.x == 0 && threadIdx.x == 0) { // Print only once
        printf("Sender compute kernel starting work (value: %d) at time %lld\n", value, clock64());
    }

    // --- Main Kernel Work ---
    if (idx < size) {
        // Simulate work
        for (int i = 0; i < 100000; i++) { // Reduced loop iterations for faster testing
            data[idx] += value; // Use atomicAdd if multiple threads could write same location
        }
    }
    // --- Synchronization logic removed ---

    __syncthreads(); // Wait for all threads in the block to finish *this* kernel's work
     if (idx == 0 && blockIdx.x == 0 && threadIdx.x == 0) { // Print only once
         printf("Sender compute kernel finished work (value: %d) at time %lld\n", value, clock64());
     }
}

// Simple computation kernel - Receiver
__global__ void computeKernelReceiver(int* data, int size, int value) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    // --- Synchronization logic removed ---

    if (idx == 0 && blockIdx.x == 0 && threadIdx.x == 0) { // Print only once
        printf("Receiver compute kernel starting work (value: %d) at time %lld\n", value, clock64());
    }

    // --- Main Kernel Work ---
    if (idx < size) {
        // Simulate work
        for (int i = 0; i < 100000; i++) { // Reduced loop iterations
            data[idx] += value; // Use atomicAdd if multiple threads could write same location
        }
    }

     __syncthreads(); // Wait for all threads in the block to finish *this* kernel's work
     if (idx == 0 && blockIdx.x == 0 && threadIdx.x == 0) { // Print only once
         printf("Receiver compute kernel finished execution (value: %d) at time %lld\n", value, clock64());
     }
}


// Mutex for thread-safe printing
std::mutex printMutex;

// Thread-safe printing function
void safePrint(const std::string& message) {
    std::lock_guard<std::mutex> lock(printMutex);
    std::cout << message << std::endl;
}

// Thread function for the sender
void senderThread(cudaStream_t stream, int* d_lock_flag, int* d_shared_data) {
    safePrint("Sender thread starting.");
    CHECK_CUDA_ERROR(cudaSetDevice(0)); // Good practice to set device in each thread

    const int dataSize = 256; // Example data size
    const int blockSize = 128;
    int numBlocks = (dataSize + blockSize - 1) / blockSize;
    const int senderValue = 10;

    // Optional: Allocate specific data for this thread if needed, or use shared data
    // int* d_senderData;
    // CHECK_CUDA_ERROR(cudaMalloc(&d_senderData, dataSize * sizeof(int)));
    // CHECK_CUDA_ERROR(cudaMemsetAsync(d_senderData, 0, dataSize * sizeof(int), stream));

    safePrint("Sender thread: Sleeping for 10 seconds..."); // Shorter sleep
    std::this_thread::sleep_for(std::chrono::seconds(10));
    safePrint("Sender thread: Woke up.");

    safePrint("Sender thread: Launching sender compute kernel...");
    computeKernelSender<<<numBlocks, blockSize, 0, stream>>>(
        d_shared_data, dataSize, senderValue);
    CHECK_CUDA_ERROR(cudaGetLastError());

    safePrint("Sender thread: Launching unlock kernel...");
    unlockKernel<<<1, 1, 0, stream>>>(d_lock_flag); // Launch unlock *after* compute
    CHECK_CUDA_ERROR(cudaGetLastError());

    safePrint("Sender thread: Kernels launched.");

    // Wait for this stream's completion before exiting thread
    CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
    safePrint("Sender thread: Stream synchronized, exiting.");
    // CHECK_CUDA_ERROR(cudaFree(d_senderData)); // Free thread-specific data if allocated
}

// Thread function for the receiver
void receiverThread(cudaStream_t stream, int* d_lock_flag, int* d_shared_data) {
    safePrint("Receiver thread starting.");
    CHECK_CUDA_ERROR(cudaSetDevice(0)); // Good practice to set device in each thread

    const int dataSize = 256; // Example data size
    const int blockSize = 128;
    int numBlocks = (dataSize + blockSize - 1) / blockSize;
    const int receiverValue = 20;

    // Optional: Allocate specific data for this thread if needed, or use shared data
    // int* d_receiverData;
    // CHECK_CUDA_ERROR(cudaMalloc(&d_receiverData, dataSize * sizeof(int)));
    // CHECK_CUDA_ERROR(cudaMemsetAsync(d_receiverData, 0, dataSize * sizeof(int), stream));

    safePrint("Receiver thread: Launching lock kernel...");
    lockKernel<<<1, 1, 0, stream>>>(d_lock_flag); // Launch lock *before* compute
    CHECK_CUDA_ERROR(cudaGetLastError());

    safePrint("Receiver thread: Launching receiver compute kernel...");
    computeKernelReceiver<<<numBlocks, blockSize, 0, stream>>>(
        d_shared_data, dataSize, receiverValue);
    CHECK_CUDA_ERROR(cudaGetLastError());

    safePrint("Receiver thread: Kernels launched.");

    // Wait for this stream's completion before exiting thread
    CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
    safePrint("Receiver thread: Stream synchronized, exiting.");
    // CHECK_CUDA_ERROR(cudaFree(d_receiverData)); // Free thread-specific data if allocated
}

int main() {
    CHECK_CUDA_ERROR(cudaSetDevice(0)); // Set device for main thread

    cudaStream_t streamSender, streamReceiver;
    // Using non-blocking streams allows potential concurrency between streams if hardware supports it
    // and work is available on both. Here, dependency forces sequential execution for the lock/unlock part.
    CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&streamSender, cudaStreamNonBlocking));
    CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&streamReceiver, cudaStreamNonBlocking));

    // --- Allocate the shared GPU lock flag ---
    int* d_lock_flag = nullptr;
    CHECK_CUDA_ERROR(cudaMalloc(&d_lock_flag, sizeof(int)));
    // Initialize the flag to 0 (locked state) synchronously or on one stream before use.
    // Doing it synchronously here is simplest.
    int initial_flag_value = 0;
    CHECK_CUDA_ERROR(cudaMemcpy(d_lock_flag, &initial_flag_value, sizeof(int), cudaMemcpyHostToDevice));
    // Alternatively, use cudaMemsetAsync on one of the streams before launching kernels that depend on it.
    // CHECK_CUDA_ERROR(cudaMemsetAsync(d_lock_flag, 0, sizeof(int), streamReceiver)); // Init on receiver stream


    // --- Allocate shared data buffer (optional, for kernels to interact) ---
    const int dataSize = 256;
    int* d_shared_data = nullptr;
    CHECK_CUDA_ERROR(cudaMalloc(&d_shared_data, dataSize * sizeof(int)));
    CHECK_CUDA_ERROR(cudaMemset(d_shared_data, 0, dataSize * sizeof(int))); // Sync memset ok here


    std::cout << "Starting sender and receiver threads..." << std::endl;

    // warmup the lock and unlock
    unlockKernel<<<1, 1, 0, streamSender>>>(d_lock_flag);
    CHECK_CUDA_ERROR(cudaGetLastError());
    lockKernel<<<1, 1, 0, streamSender>>>(d_lock_flag);
    CHECK_CUDA_ERROR(cudaGetLastError());

    // warmup the sender and receiver
    computeKernelSender<<<1, 1, 0, streamSender>>>(d_shared_data, 1, 20);
    CHECK_CUDA_ERROR(cudaGetLastError());
    computeKernelReceiver<<<1, 1, 0, streamReceiver>>>(d_shared_data, 1, 20);
    CHECK_CUDA_ERROR(cudaGetLastError());

    // sync
    CHECK_CUDA_ERROR(cudaStreamSynchronize(streamSender));
    CHECK_CUDA_ERROR(cudaStreamSynchronize(streamReceiver));

    // Launch threads concurrently
    // Pass the streams, the lock flag, and any shared data to the threads
    std::thread t1(senderThread, streamSender, d_lock_flag, d_shared_data);
    std::thread t2(receiverThread, streamReceiver, d_lock_flag, d_shared_data);

    // Wait for threads to complete
    t1.join();
    t2.join();

    std::cout << "All threads completed." << std::endl;

    // (Optional) Verify results by copying d_shared_data back to host
    int* h_data = new int[dataSize];
    CHECK_CUDA_ERROR(cudaMemcpy(h_data, d_shared_data, dataSize * sizeof(int), cudaMemcpyDeviceToHost));
    std::cout << "Result data[0] = " << h_data[0] << " (Expected something like 10+20 = 30 if dataSize=1 and loops run once)" << std::endl;
    delete[] h_data;


    // Clean up
    CHECK_CUDA_ERROR(cudaFree(d_shared_data));
    CHECK_CUDA_ERROR(cudaFree(d_lock_flag)); // Free the lock flag
    CHECK_CUDA_ERROR(cudaStreamDestroy(streamSender));
    CHECK_CUDA_ERROR(cudaStreamDestroy(streamReceiver));

    std::cout << "Resources cleaned up." << std::endl;

    // Optional: Reset device state at the very end
    // CHECK_CUDA_ERROR(cudaDeviceReset());

    return 0;
}

It’s an example of lazy loading. It’s a documented mechanism, and you have some options for workarounds. A likely workaround would be to follow the suggestion for emulating eager loading with cudaFuncGetAttributes(). (Also, kernel-to-kernel communication, requiring concurrency, is a frowned-on design practice, as CUDA does not guarantee concurrency.)

Hi @Robert_Crovella,

Thanks a lot for the response. This makes sense, I agree that this approach is fragile. Here is some more context around what I am trying to do: I want to work around kernel dispatch latency. I have two streams A & B. Say stream A is executing batch 1, I want to start dispatch of batch 2 on stream B before stream A completes execution/dispatch. In typical situations, we would just use cuda events for this. We could register an event at the end of dispatch of stream A, and then wait on that event on stream B. But since we want to overlap the dispatch on stream A/B, we essentially need to wait on the event before it can be registered. If you have any other suggestions to achieve similar outcome, it would be really helpful. Thanks!

Do the kernels have an actual dependency, e.g. reading data from each other, or do you just want to optimize, when the kernels start for optimizing performance and resource usage?

@Curefab, yes execution on the two streams have a logical dependency on each other. Stream B should not execute before Stream A

What about the following alternatives?

  • Fuse kernel A and B so that after each batch automatically the further processing is done
  • Invoke the kernels separately for each batch, so you can get events and stream-level synchronization after each batch

What are the reasons (e.g. number of batches / latency overhead, not fitting number of threads between A and B) for not doing so?

I don’t want to change the kernel definitions for this use case. the mechanism is supposed to operate at in a transparent manner.

@Robert_Crovella,

I tried an alternate approach using cuStreamWriteValue32/cuStreamWaitValue32 APIs, i suppose these are supposed to match the stream event semantics. But still running into the same deadlock without warmup. I ran with CUDA_MODULE_LOADING=EAGER. Thank you!

#include <cuda_runtime.h>
#include <cuda.h>
#include <iostream>
#include <thread>
#include <mutex>
#include <chrono>
#include <string>
#include <vector>
#include <iomanip>

// Error checking macros
#define CHECK_CUDA_ERROR(val) check_cuda((val), #val, __FILE__, __LINE__)
#define CHECK_CU_ERROR(val) check_cu((val), #val, __FILE__, __LINE__)

inline void check_cuda(cudaError_t result, const char* func, const char* file, int line) {
    if (result != cudaSuccess) {
        std::cerr << "CUDA error at " << file << ":" << line << " code=" << static_cast<unsigned int>(result)
                  << " (" << cudaGetErrorName(result) << ") \"" << func << "\" \n"
                  << cudaGetErrorString(result) << std::endl;
        exit(EXIT_FAILURE);
    }
}

inline void check_cu(CUresult result, const char* func, const char* file, int line) {
    if (result != CUDA_SUCCESS) {
        const char* errorStr;
        cuGetErrorString(result, &errorStr);
        std::cerr << "CUDA Driver error at " << file << ":" << line << " code=" << static_cast<unsigned int>(result)
                  << " (" << errorStr << ") \"" << func << "\" " << std::endl;
        exit(EXIT_FAILURE);
    }
}

// Simple computation kernel - Sender
__global__ void computeKernelSender(int* data, int size, int value, volatile int32_t* sync_flag) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx == 0 && blockIdx.x == 0 && threadIdx.x == 0) {
        printf("Sender: flag value before work = %u at time %lld\n", *sync_flag, clock64());
        printf("Sender: starting work (value: %d) at time %lld\n", value, clock64());
    }

    // Main kernel work
    if (idx < size) {
        for (int i = 0; i < 100000; i++) {
            data[idx] += value;
        }
    }

    __syncthreads();
    
    if (idx == 0 && blockIdx.x == 0 && threadIdx.x == 0) {
        printf("Sender: finished work at time %lld\n", clock64());
        // Note: We don't write the flag here - cuStreamWriteValue32 will do it
    }
}

// Simple computation kernel - Receiver
__global__ void computeKernelReceiver(int* data, int size, int value, volatile int32_t* sync_flag) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx == 0 && blockIdx.x == 0 && threadIdx.x == 0) {
        printf("Receiver: flag value = %u, starting work (value: %d) at time %lld\n", 
               *sync_flag, value, clock64());
    }

    // Main kernel work
    if (idx < size) {
        for (int i = 0; i < 100000; i++) {
            data[idx] += value;
        }
    }

    __syncthreads();
    
    if (idx == 0 && blockIdx.x == 0 && threadIdx.x == 0) {
        printf("Receiver: finished work at time %lld\n", clock64());
    }
}

// Mutex for thread-safe printing
std::mutex printMutex;

void safePrint(const std::string& message) {
    std::lock_guard<std::mutex> lock(printMutex);
    std::cout << message << std::endl;
}

// Thread function for the sender
void senderThread(cudaStream_t stream, int32_t* d_sync_flag, int* d_shared_data) {
    safePrint("Sender thread starting.");
    CHECK_CUDA_ERROR(cudaSetDevice(0));

    const int dataSize = 256;
    const int blockSize = 128;
    int numBlocks = (dataSize + blockSize - 1) / blockSize;
    const int senderValue = 10;

    safePrint("Sender thread: Sleeping for 10 seconds...");
    std::this_thread::sleep_for(std::chrono::seconds(10));
    safePrint("Sender thread: Woke up.");

    safePrint("Sender thread: Launching compute kernel...");
    computeKernelSender<<<numBlocks, blockSize, 0, stream>>>(
        d_shared_data, dataSize, senderValue, d_sync_flag);
    CHECK_CUDA_ERROR(cudaGetLastError());

    // Use cuStreamWriteValue32 to signal completion
    // First, get the driver API stream handle
    CUstream cuStream = reinterpret_cast<CUstream>(stream);

    safePrint("Sender thread: Writing value 1 to sync flag after kernel completion...");
    CHECK_CU_ERROR(cuStreamWriteValue32(cuStream, (CUdeviceptr)d_sync_flag, (cuuint32_t)1, CU_STREAM_WRITE_VALUE_DEFAULT));

    safePrint("Sender thread: Operations enqueued.");

    // Wait for stream completion
    CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
    safePrint("Sender thread: Stream synchronized, exiting.");
}

// Thread function for the receiver
void receiverThread(cudaStream_t stream, int32_t* d_sync_flag, int* d_shared_data) {
    safePrint("Receiver thread starting.");
    CHECK_CUDA_ERROR(cudaSetDevice(0));

    const int dataSize = 256;
    const int blockSize = 128;
    int numBlocks = (dataSize + blockSize - 1) / blockSize;
    const int receiverValue = 20;

    // Get driver API stream handle
    CUstream cuStream = reinterpret_cast<CUstream>(stream);

    // Wait for flag to become 1
    safePrint("Receiver thread: Setting up wait for sync flag to become 1...");
    CHECK_CU_ERROR(cuStreamWaitValue32(cuStream, (CUdeviceptr)d_sync_flag, (cuuint32_t)1, CU_STREAM_WAIT_VALUE_EQ));

    safePrint("Receiver thread: Wait enqueued, launching compute kernel after wait...");
    computeKernelReceiver<<<numBlocks, blockSize, 0, stream>>>(
        d_shared_data, dataSize, receiverValue, d_sync_flag);
    CHECK_CUDA_ERROR(cudaGetLastError());

    safePrint("Receiver thread: Operations enqueued.");

    // Wait for stream completion
    CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
    safePrint("Receiver thread: Stream synchronized, exiting.");
}


int main() {
    // Initialize CUDA
    CHECK_CU_ERROR(cuInit(0));
    CHECK_CUDA_ERROR(cudaSetDevice(0));

    // Query device properties
    cudaDeviceProp prop;
    CHECK_CUDA_ERROR(cudaGetDeviceProperties(&prop, 0));
    std::cout << "Device: " << prop.name << std::endl;
    std::cout << "Compute Capability: " << prop.major << "." << prop.minor << std::endl;

    // Create streams
    cudaStream_t streamSender, streamReceiver;
    // Using non-blocking streams allows potential concurrency between streams if hardware supports it
    // and work is available on both. Here, dependency forces sequential execution for the lock/unlock part.
    CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&streamSender, cudaStreamNonBlocking));
    CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&streamReceiver, cudaStreamNonBlocking));

    // Allocate synchronization flag
    int32_t* d_sync_flag;
    CHECK_CUDA_ERROR(cudaMalloc(&d_sync_flag, sizeof(int32_t)));
    CHECK_CUDA_ERROR(cudaMemset(d_sync_flag, (cuuint32_t)0, sizeof(int32_t)));

    // Allocate shared data buffer
    const int dataSize = 256;
    int* d_shared_data = nullptr;
    CHECK_CUDA_ERROR(cudaMalloc(&d_shared_data, dataSize * sizeof(int)));
    CHECK_CUDA_ERROR(cudaMemset(d_shared_data, 0, dataSize * sizeof(int)));

    // warmup the kernels
    computeKernelSender<<<1, 1, 0, streamSender>>>(d_shared_data, 0, 10, d_sync_flag);
    computeKernelReceiver<<<1, 1, 0, streamReceiver>>>(d_shared_data, 0, 20, d_sync_flag);
    CHECK_CUDA_ERROR(cudaStreamSynchronize(streamSender));
    CHECK_CUDA_ERROR(cudaStreamSynchronize(streamReceiver));
    CHECK_CUDA_ERROR(cudaGetLastError());

    // Launch threads
    std::thread t1(senderThread, streamSender, d_sync_flag, d_shared_data);
    std::thread t2(receiverThread, streamReceiver, d_sync_flag, d_shared_data);

    // Wait for threads to complete
    t1.join();
    t2.join();

    std::cout << "All threads completed." << std::endl;

    // Verify results
    int* h_data = new int[dataSize];
    CHECK_CUDA_ERROR(cudaMemcpy(h_data, d_shared_data, dataSize * sizeof(int), cudaMemcpyDeviceToHost));
    std::cout << "Result data[0] = " << h_data[0] << " (Expected: 10*100000 + 20*100000 = 3000000)" << std::endl;
    delete[] h_data;

    // Clean up
    CHECK_CUDA_ERROR(cudaFree(d_shared_data));
    CHECK_CUDA_ERROR(cudaFree(d_sync_flag));
    CHECK_CUDA_ERROR(cudaStreamDestroy(streamSender));
    CHECK_CUDA_ERROR(cudaStreamDestroy(streamReceiver));

    std::cout << "\nAll tests completed successfully!" << std::endl;
    return 0;
}

@Robert_Crovella, by the way, I do notice slight difference in the behavior with eager and lazy loading, but both are ending up in a deadlock. In case of lazy dispatch, sender thread is getting stuck at “Sender thread: Launching compute kernel…” (so the cpu dispatch is stuck). With eager flag, CPU dispatch proceeds fine, and we see “Sender thread: Operations enqueued.”, but neither of the stream sync ops complete.

As you do not get any guarantee that the kernels are actually running at the same time (even if it works now, it can fail at any invocation in production), you should at least code in a defensive manner with timeouts.

There are ways to get both kernels running at the same time by starting a cooperative kernel and then let some blocks run function A, some blocks run function B.
With it, you also get official grid synchronization.

I haven’t studied your case carefully, but programmatic dependent launch may possibly be of interest.

Regarding your most recent code posting, when I run that on my L4 GPU on CUDA 12.2, I don’t witness any hang. The output looks like this:

# nvcc -o t396 t396.cu -lcuda
# ./t396
Device: NVIDIA L4
Compute Capability: 8.9
Sender: flag value before work = 0 at time 649837039
Sender: starting work (value: 10) at time 649922296
Sender: finished work at time 649991584
Receiver: flag value = 0, starting work (value: 20) at time 650060657
Receiver: finished work at time 650149442
Sender thread starting.
Receiver thread starting.
Receiver thread: Setting up wait for sync flag to become 1...
Sender thread: Sleeping for 10 seconds...
Receiver thread: Wait enqueued, launching compute kernel after wait...
Receiver thread: Operations enqueued.
Sender thread: Woke up.
Sender thread: Launching compute kernel...
Sender thread: Writing value 1 to sync flag after kernel completion...
Sender thread: Operations enqueued.
Sender: flag value before work = 0 at time 1863473608
Sender: starting work (value: 10) at time 1863547016
Sender: finished work at time 1863616607
Sender thread: Stream synchronized, exiting.
Receiver: flag value = 1, starting work (value: 20) at time 1863685243
Receiver: finished work at time 1863772214
Receiver thread: Stream synchronized, exiting.
All threads completed.
Result data[0] = 3000000 (Expected: 10*100000 + 20*100000 = 3000000)

All tests completed successfully!
#
1 Like

The code I have shared uses warmup by default, you comment out the warmup line to reproduce the deadlock. I am looking for solutions where I don’t need to modify existing kernels. Thanks!

You can achieve an effect similar to warm-up using the method I already suggested, without doing a warm-up run and without modification to your kernel code.