Cross-thread pageable D2H copy appears to delay `cudaLaunchKernel` in another thread

Hi, I am trying to understand a concurrency behavior I observed with CUDA from Python/PyTorch.

I have a minimal setup with two Python threads in the same process:

  • Thread A repeatedly performs a GPU-to-CPU copy into pageable host memory
  • Thread B repeatedly launches a small CUDA kernel

In Nsight Systems, I observe that:

  • Thread A is inside a pageable D2H path (copy_ to CPU)
  • At the same time, Thread B appears stuck in cudaLaunchKernel
  • The long duration appears on the host API call cudaLaunchKernel, not on the GPU kernel execution itself

So it looks like a D2H copy issued from one host thread is delaying another host thread’s kernel launch at the CUDA runtime/API level.

My questions are:

  1. Is this expected behavior for pageable D2H copies?
  2. Can pageable D2H introduce process-wide or context-wide serialization that delays cudaLaunchKernel from another host thread?
  3. Is this mainly due to:
    • internal CUDA runtime locks,
    • pageable-memory staging/synchronization,
    • or something else?

By the way, I figured out that the simplest solution is to use pinned memory so that the copy is performed asynchronously. Below is a minimal PyTorch repro that also exports a trace:

import time
import threading

import torch
from torch.profiler import profile, ProfilerActivity


start_event = threading.Event()
stop_event = threading.Event()


def d2h_worker(src: torch.Tensor):
    """
    Thread A:
    Repeatedly perform GPU -> CPU copies into pageable host memory.
    """
    start_event.wait()
    iters = 0
    while not stop_event.is_set():
        dst = torch.empty_like(src, device="cpu")   # pageable host memory
        dst.copy_(src, non_blocking=False)          # blocking D2H copy
        iters += 1
    print(f"[d2h_worker] iterations = {iters}")


def launch_worker(x: torch.Tensor):
    """
    Thread B:
    Repeatedly launch a small CUDA kernel.
    """
    start_event.wait()
    iters = 0
    while not stop_event.is_set():
        y = torch.sin(x)                            # small CUDA kernel
        _ = y
        iters += 1
    print(f"[launch_worker] iterations = {iters}")


def main():
    assert torch.cuda.is_available(), "CUDA is required."

    print("PyTorch version:", torch.__version__)
    print("GPU:", torch.cuda.get_device_name(0))

    # Large tensor for D2H so the copy takes noticeable time.
    # 64M float32 elements ~= 256 MB.
    d2h_src = torch.empty(64 * 1024 * 1024, dtype=torch.float32, device="cuda")
    d2h_src.uniform_(0, 1)

    # Small tensor for kernel launch so the kernel itself is short.
    launch_x = torch.randn(1024, device="cuda")

    # Warmup.
    for _ in range(5):
        tmp = torch.empty_like(d2h_src, device="cpu")
        tmp.copy_(d2h_src, non_blocking=False)
        _ = torch.sin(launch_x)
    torch.cuda.synchronize()

    t1 = threading.Thread(target=d2h_worker, args=(d2h_src,), daemon=True)
    t2 = threading.Thread(target=launch_worker, args=(launch_x,), daemon=True)

    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=False,
        profile_memory=False,
        with_stack=True,
    ) as prof:
        t1.start()
        t2.start()

        start_event.set()
        time.sleep(3.0)
        stop_event.set()

        t1.join()
        t2.join()
        torch.cuda.synchronize()

    output_path = "repro_d2h_pageable_vs_launch.json"
    prof.export_chrome_trace(output_path)
    print(f"Trace saved to: {output_path}")
    print("Open it with chrome://tracing or Perfetto.")


if __name__ == "__main__":
    main()

Thanks.

pageable copies generally should be avoided for best throughput in various concurrency scenarios.

The basic notion that a CUDA API call may take a variable amount of latency in a concurrency scenario is covered here.

That is almost certainly due to API locking behavior, as mentioned in the link above. The exact cause/dependency of the lock is not discoverable using NVIDIA documentation, AFAIK. You may also wish to scan the programming guide for certain mentions of “implicit synchronization” such as that mentioned in section 3.3

I don’t really know what CUDA API that pytorch may be using in this case, but we could assume it might be something like cudaMemcpyAsync. In that case, the above link I provided includes the following:

Asynchronous

  1. For transfers between device memory and pageable host memory, the function might be synchronous with respect to host.

Since your question seems to revolve around the latency of the cudaLaunchKernel call as viewed from the host, this seems to be connected.

1 Like