Usage of CUPTI appears to rarely cause Cuda graph conditional nodes to segfault upon instantiation

Originally reported to PyTorch, more detail here: Segfault in Torch profiler when CUDA Graph Conditional Nodes are used · Issue #134308 · pytorch/pytorch · GitHub

Steps to reproduce are:

  1. Profile something using CUPTI
  2. Outside the section being profiled, attempt to instantiate a Cuda graph that contains conditional nodes. Without conditional nodes I was unable to reproduce the issue.
  3. About 0.5% of the time, this will segfault. If CUPTI has not been used yet in this process, I have never seen it segfault despite running it ~millions of times.

Here is the GDB backtrace, implicating CUPTI. However, note that profiling was not occurring at the time of this segfault.

(gdb) bt
#0  0x00007fffb1c3188e in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#1  0x00007fff4e513ee9 in ?? () from /usr/local/cuda/targets/x86_64-linux/lib/libcupti.so.12
#2  0x00007fff4e50d3a8 in ?? () from /usr/local/cuda/targets/x86_64-linux/lib/libcupti.so.12
#3  0x00007fffb1a03ead in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007fffb19233b4 in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#5  0x00007fffb1abfbcf in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#6  0x00007fff50221825 in ?? () from /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12
#7  0x00007fff5026b48e in cudaGraphInstantiate () from /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12
#8  0x00007fffba9abc17 in cudaGraphInstantiate(CUgraphExec_st**, CUgraph_st*, CUgraphNode_st**, char*, unsigned long) ()
   from /home/ubuntu/cudagraphissue/repro_kernels.cpython-310-x86_64-linux-gnu.so

This is the code that segfaults rarely. When I run this in a loop 100 times before profiling, it never segfaults, when I run it in a loop 100 times after profiling (or during profiling) it segfaults about half the time, hence I roughly estimate this code’s chance of segfaulting is 1/200.

__global__ void loopKernel(cudaGraphConditionalHandle handle) {
	// simulate work
	__nanosleep(1000000);
	// end the loop
	cudaGraphSetConditional(handle, 0);
}

void cudaGraphDemo() {
	auto stream = at::cuda::getCurrentCUDAStream().stream();
	cudaGraph_t graph;
	AT_CUDA_CHECK(cudaGraphCreate(&graph, 0));
	cudaGraphConditionalHandle loopHandle;
	AT_CUDA_CHECK(cudaGraphConditionalHandleCreate(&loopHandle, graph, true, cudaGraphCondAssignDefault));
	cudaGraphNodeParams whileNodeParams = { cudaGraphNodeTypeConditional };
	whileNodeParams.conditional.handle = loopHandle;
	whileNodeParams.conditional.type = cudaGraphCondTypeWhile;
	whileNodeParams.conditional.size = 1;
	cudaGraphNode_t loopNode;
	AT_CUDA_CHECK(cudaGraphAddNode(&loopNode, graph, nullptr, 0, &whileNodeParams));
	cudaStream_t captureStream;
	AT_CUDA_CHECK(cudaStreamCreate(&captureStream));
	AT_CUDA_CHECK(cudaStreamBeginCaptureToGraph(captureStream, whileNodeParams.conditional.phGraph_out[0], nullptr, nullptr, 0, cudaStreamCaptureModeThreadLocal));
	loopKernel<<<1, 1, 0, captureStream>>>(loopHandle);
	AT_CUDA_CHECK(cudaPeekAtLastError());
	AT_CUDA_CHECK(cudaStreamEndCapture(captureStream, nullptr));
	AT_CUDA_CHECK(cudaStreamDestroy(captureStream));
	cudaGraphExec_t instance;
	AT_CUDA_CHECK(cudaGraphInstantiate(&instance, graph, NULL, NULL, 0));
	AT_CUDA_CHECK(cudaGraphLaunch(instance, stream));
	AT_CUDA_CHECK(cudaStreamSynchronize(stream));
	AT_CUDA_CHECK(cudaGraphExecDestroy(instance));
	AT_CUDA_CHECK(cudaGraphDestroy(graph));
}

Hi, @leif5

Sorry for the issue you met.
We tried to execute “python repro.py”, no segmentation fault found.

(torch-env) test@gqa-r21-6:~/ticket_tracking/conditional_forum_305571/seg_code$ python repro.py
Warming up
Profiling
Loop inside profiler 0
Loop inside profiler 1

Loop inside profiler 98
Loop inside profiler 99
(torch-env) test@gqa-r21-6:~/ticket_tracking/conditional_forum_305571/seg_code$ python --version
Python 3.10.0
(torch-env) test@gqa-r21-6:~/ticket_tracking/conditional_forum_305571/seg_code$ python -c “import torch; print(torch.version)”
2.4.1+cu124

Anything else need to repro ?
Also can you tell which CUDA/Driver/GPU you are using ?

Thank you for the reply and thank you for attempting to reproduce!

All version info can be found in the GitHub issue, under the versions section, I’ve pasted it here:

Collecting environment information...
PyTorch version: 2.3.1
Is debug build: False
CUDA used to build PyTorch: 12.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.8.0-40-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA H100 PCIe
Nvidia driver version: 550.90.07
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               26
On-line CPU(s) list:                  0-25
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Platinum 8480+
CPU family:                           6
Model:                                143
Thread(s) per core:                   1
Core(s) per socket:                   1
Socket(s):                            26
Stepping:                             8
BogoMIPS:                             4000.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization:                       VT-x
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            832 KiB (26 instances)
L1i cache:                            832 KiB (26 instances)
L2 cache:                             104 MiB (26 instances)
L3 cache:                             416 MiB (26 instances)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-25
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] flake8==4.0.1
[pip3] numpy==1.25.2
[pip3] torch==2.3.1
[pip3] torchvision==0.18.1
[pip3] triton==3.0.0
[conda] Could not collect

But I’ve just gone through and redone the repro for you. Here are full reproduction steps, I did not run any commands other than these.

  1. Launch an gpu_1x_a100_sxm4 instance on Lambda cloud. (just because they are currently out of capacity for H100)
  2. nvidia-smi says it only supports Cuda 12.2, so we need to upgrade the driver. sudo apt update && sudo apt dist-upgrade -y && sudo reboot (this step takes quite a while)
  3. Reconnect after reboot and expect that nvidia-smi now outputs driver version 550.90.07 and support for Cuda 12.4.
  4. pip install torch --upgrade --force-reinstall --index-url https://download.pytorch.org/whl/cu124 Replace the default torch with one installed from the Cuda 12.4 branch, again because conditional nodes otherwise cause a linker error. python --version should be Python 3.10.12 and python -c "import torch; print(torch.__version__)" should be 2.4.1+cu124
  5. pip install pybind11
  6. Write to files: nano setup.py nano repro.py nano repro.cu nano test.bash contents of each are below
  7. pip install -e .
  8. bash test.bash
  9. On my machine, the segfault rate is quite high.
$ bash test.bash
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2229 Segmentation fault      (core dumped) python repro.py
Run #1
Successes: 0
Failures: 1
Failure rate: 100.00%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2271 Segmentation fault      (core dumped) python repro.py
Run #2
Successes: 0
Failures: 2
Failure rate: 100.00%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2313 Segmentation fault      (core dumped) python repro.py
Run #3
Successes: 0
Failures: 3
Failure rate: 100.00%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
Run #4
Successes: 1
Failures: 3
Failure rate: 75.00%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2393 Segmentation fault      (core dumped) python repro.py
Run #5
Successes: 1
Failures: 4
Failure rate: 80.00%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2435 Segmentation fault      (core dumped) python repro.py
Run #6
Successes: 1
Failures: 5
Failure rate: 83.33%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2477 Segmentation fault      (core dumped) python repro.py
Run #7
Successes: 1
Failures: 6
Failure rate: 85.71%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2519 Segmentation fault      (core dumped) python repro.py
Run #8
Successes: 1
Failures: 7
Failure rate: 87.50%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2561 Segmentation fault      (core dumped) python repro.py
Run #9
Successes: 1
Failures: 8
Failure rate: 88.89%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
Run #10
Successes: 2
Failures: 8
Failure rate: 80.00%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2641 Segmentation fault      (core dumped) python repro.py
Run #11
Successes: 2
Failures: 9
Failure rate: 81.82%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2683 Segmentation fault      (core dumped) python repro.py
Run #12
Successes: 2
Failures: 10
Failure rate: 83.33%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2725 Segmentation fault      (core dumped) python repro.py
Run #13
Successes: 2
Failures: 11
Failure rate: 84.62%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2767 Segmentation fault      (core dumped) python repro.py
Run #14
Successes: 2
Failures: 12
Failure rate: 85.71%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2809 Segmentation fault      (core dumped) python repro.py
Run #15
Successes: 2
Failures: 13
Failure rate: 86.67%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2851 Segmentation fault      (core dumped) python repro.py
Run #16
Successes: 2
Failures: 14
Failure rate: 87.50%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2893 Segmentation fault      (core dumped) python repro.py
Run #17
Successes: 2
Failures: 15
Failure rate: 88.24%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2936 Segmentation fault      (core dumped) python repro.py
Run #18
Successes: 2
Failures: 16
Failure rate: 88.89%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  2978 Segmentation fault      (core dumped) python repro.py
Run #19
Successes: 2
Failures: 17
Failure rate: 89.47%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28:  3020 Segmentation fault      (core dumped) python repro.py
Run #20
Successes: 2
Failures: 18
Failure rate: 90.00%
------------------------

Contents should be:

from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from setuptools import setup
import pybind11

setup(
    name="repro",
    ext_modules=[CUDAExtension(
        "repro_kernels",
        ["repro.cu"],
        include_dirs=[pybind11.get_include()],
    )],
    cmdclass={"build_ext": BuildExtension},
)
from torch.profiler import profile, record_function, ProfilerActivity
import repro_kernels

print("Warming up")
for i in range(1000):
    repro_kernels.cudaGraphDemo()

print("Profiling")
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True) as prof:
    print("Profiled something")

print("Outside the profiler")
for i in range(1000):
    repro_kernels.cudaGraphDemo()
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

__global__ void loopKernel(cudaGraphConditionalHandle handle) {
	// simulate work
	__nanosleep(1000000);
	// end the loop
	cudaGraphSetConditional(handle, 0);
}

void cudaGraphDemo() {
	auto stream = at::cuda::getCurrentCUDAStream().stream();
	cudaGraph_t graph;
	AT_CUDA_CHECK(cudaGraphCreate(&graph, 0));
	cudaGraphConditionalHandle loopHandle;
	AT_CUDA_CHECK(cudaGraphConditionalHandleCreate(&loopHandle, graph, true, cudaGraphCondAssignDefault));
	cudaGraphNodeParams whileNodeParams = { cudaGraphNodeTypeConditional };
	whileNodeParams.conditional.handle = loopHandle;
	whileNodeParams.conditional.type = cudaGraphCondTypeWhile;
	whileNodeParams.conditional.size = 1;
	cudaGraphNode_t loopNode;
	AT_CUDA_CHECK(cudaGraphAddNode(&loopNode, graph, nullptr, 0, &whileNodeParams));
	cudaStream_t captureStream;
	AT_CUDA_CHECK(cudaStreamCreate(&captureStream));
	AT_CUDA_CHECK(cudaStreamBeginCaptureToGraph(captureStream, whileNodeParams.conditional.phGraph_out[0], nullptr, nullptr, 0, cudaStreamCaptureModeThreadLocal));
	loopKernel<<<1, 1, 0, captureStream>>>(loopHandle);
	AT_CUDA_CHECK(cudaPeekAtLastError());
	AT_CUDA_CHECK(cudaStreamEndCapture(captureStream, nullptr));
	AT_CUDA_CHECK(cudaStreamDestroy(captureStream));
	cudaGraphExec_t instance;
	AT_CUDA_CHECK(cudaGraphInstantiate(&instance, graph, NULL, NULL, 0));
	AT_CUDA_CHECK(cudaGraphLaunch(instance, stream));
	AT_CUDA_CHECK(cudaStreamSynchronize(stream));
	AT_CUDA_CHECK(cudaGraphExecDestroy(instance));
	AT_CUDA_CHECK(cudaGraphDestroy(graph));
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
	m.def("cudaGraphDemo", cudaGraphDemo);
}
#!/bin/bash

successes=0
failures=0
total=0

while true; do
    python repro.py
    exit_code=$?
    
    if [ $exit_code -eq 0 ]; then
        ((successes++))
    elif [ $exit_code -eq 139 ]; then
        ((failures++))
    else
        echo "Unexpected exit code: $exit_code"
        continue
    fi
    
    ((total++))
    failure_percentage=$(awk "BEGIN {printf \"%.2f\", ($failures / $total) * 100}")
    
    echo "Run #$total"
    echo "Successes: $successes"
    echo "Failures: $failures"
    echo "Failure rate: $failure_percentage%"
    echo "------------------------"
done

Here’s a full terminal log of reproducing the issue from a fresh Lambda H100 instance. I have only modified the log to redact my last login IP address. The only other command run was half of an apt update.

Log is too big to paste. GitHub Gist: reproduction of cuda graph issue · GitHub Pastebin: Last login: Fri Sep 6 11:28:02 on ttys001leijurv@Leifs-MBP ~ % ssh ubuntu@209 - Pastebin.com