Thank you for the reply and thank you for attempting to reproduce!
All version info can be found in the GitHub issue, under the versions section, I’ve pasted it here:
Collecting environment information...
PyTorch version: 2.3.1
Is debug build: False
CUDA used to build PyTorch: 12.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.8.0-40-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA H100 PCIe
Nvidia driver version: 550.90.07
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 26
On-line CPU(s) list: 0-25
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8480+
CPU family: 6
Model: 143
Thread(s) per core: 1
Core(s) per socket: 1
Socket(s): 26
Stepping: 8
BogoMIPS: 4000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization: VT-x
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 832 KiB (26 instances)
L1i cache: 832 KiB (26 instances)
L2 cache: 104 MiB (26 instances)
L3 cache: 416 MiB (26 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-25
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Versions of relevant libraries:
[pip3] flake8==4.0.1
[pip3] numpy==1.25.2
[pip3] torch==2.3.1
[pip3] torchvision==0.18.1
[pip3] triton==3.0.0
[conda] Could not collect
But I’ve just gone through and redone the repro for you. Here are full reproduction steps, I did not run any commands other than these.
- Launch an
gpu_1x_a100_sxm4
instance on Lambda cloud. (just because they are currently out of capacity for H100)
nvidia-smi
says it only supports Cuda 12.2, so we need to upgrade the driver. sudo apt update && sudo apt dist-upgrade -y && sudo reboot
(this step takes quite a while)
- Reconnect after reboot and expect that
nvidia-smi
now outputs driver version 550.90.07
and support for Cuda 12.4.
pip install torch --upgrade --force-reinstall --index-url https://download.pytorch.org/whl/cu124
Replace the default torch with one installed from the Cuda 12.4 branch, again because conditional nodes otherwise cause a linker error. python --version
should be Python 3.10.12
and python -c "import torch; print(torch.__version__)"
should be 2.4.1+cu124
pip install pybind11
- Write to files:
nano setup.py
nano repro.py
nano repro.cu
nano test.bash
contents of each are below
pip install -e .
bash test.bash
- On my machine, the segfault rate is quite high.
$ bash test.bash
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2229 Segmentation fault (core dumped) python repro.py
Run #1
Successes: 0
Failures: 1
Failure rate: 100.00%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2271 Segmentation fault (core dumped) python repro.py
Run #2
Successes: 0
Failures: 2
Failure rate: 100.00%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2313 Segmentation fault (core dumped) python repro.py
Run #3
Successes: 0
Failures: 3
Failure rate: 100.00%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
Run #4
Successes: 1
Failures: 3
Failure rate: 75.00%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2393 Segmentation fault (core dumped) python repro.py
Run #5
Successes: 1
Failures: 4
Failure rate: 80.00%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2435 Segmentation fault (core dumped) python repro.py
Run #6
Successes: 1
Failures: 5
Failure rate: 83.33%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2477 Segmentation fault (core dumped) python repro.py
Run #7
Successes: 1
Failures: 6
Failure rate: 85.71%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2519 Segmentation fault (core dumped) python repro.py
Run #8
Successes: 1
Failures: 7
Failure rate: 87.50%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2561 Segmentation fault (core dumped) python repro.py
Run #9
Successes: 1
Failures: 8
Failure rate: 88.89%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
Run #10
Successes: 2
Failures: 8
Failure rate: 80.00%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2641 Segmentation fault (core dumped) python repro.py
Run #11
Successes: 2
Failures: 9
Failure rate: 81.82%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2683 Segmentation fault (core dumped) python repro.py
Run #12
Successes: 2
Failures: 10
Failure rate: 83.33%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2725 Segmentation fault (core dumped) python repro.py
Run #13
Successes: 2
Failures: 11
Failure rate: 84.62%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2767 Segmentation fault (core dumped) python repro.py
Run #14
Successes: 2
Failures: 12
Failure rate: 85.71%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2809 Segmentation fault (core dumped) python repro.py
Run #15
Successes: 2
Failures: 13
Failure rate: 86.67%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2851 Segmentation fault (core dumped) python repro.py
Run #16
Successes: 2
Failures: 14
Failure rate: 87.50%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2893 Segmentation fault (core dumped) python repro.py
Run #17
Successes: 2
Failures: 15
Failure rate: 88.24%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2936 Segmentation fault (core dumped) python repro.py
Run #18
Successes: 2
Failures: 16
Failure rate: 88.89%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 2978 Segmentation fault (core dumped) python repro.py
Run #19
Successes: 2
Failures: 17
Failure rate: 89.47%
------------------------
Warming up
Profiling
Profiled something
Outside the profiler
test.bash: line 28: 3020 Segmentation fault (core dumped) python repro.py
Run #20
Successes: 2
Failures: 18
Failure rate: 90.00%
------------------------
Contents should be:
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from setuptools import setup
import pybind11
setup(
name="repro",
ext_modules=[CUDAExtension(
"repro_kernels",
["repro.cu"],
include_dirs=[pybind11.get_include()],
)],
cmdclass={"build_ext": BuildExtension},
)
from torch.profiler import profile, record_function, ProfilerActivity
import repro_kernels
print("Warming up")
for i in range(1000):
repro_kernels.cudaGraphDemo()
print("Profiling")
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True) as prof:
print("Profiled something")
print("Outside the profiler")
for i in range(1000):
repro_kernels.cudaGraphDemo()
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
__global__ void loopKernel(cudaGraphConditionalHandle handle) {
// simulate work
__nanosleep(1000000);
// end the loop
cudaGraphSetConditional(handle, 0);
}
void cudaGraphDemo() {
auto stream = at::cuda::getCurrentCUDAStream().stream();
cudaGraph_t graph;
AT_CUDA_CHECK(cudaGraphCreate(&graph, 0));
cudaGraphConditionalHandle loopHandle;
AT_CUDA_CHECK(cudaGraphConditionalHandleCreate(&loopHandle, graph, true, cudaGraphCondAssignDefault));
cudaGraphNodeParams whileNodeParams = { cudaGraphNodeTypeConditional };
whileNodeParams.conditional.handle = loopHandle;
whileNodeParams.conditional.type = cudaGraphCondTypeWhile;
whileNodeParams.conditional.size = 1;
cudaGraphNode_t loopNode;
AT_CUDA_CHECK(cudaGraphAddNode(&loopNode, graph, nullptr, 0, &whileNodeParams));
cudaStream_t captureStream;
AT_CUDA_CHECK(cudaStreamCreate(&captureStream));
AT_CUDA_CHECK(cudaStreamBeginCaptureToGraph(captureStream, whileNodeParams.conditional.phGraph_out[0], nullptr, nullptr, 0, cudaStreamCaptureModeThreadLocal));
loopKernel<<<1, 1, 0, captureStream>>>(loopHandle);
AT_CUDA_CHECK(cudaPeekAtLastError());
AT_CUDA_CHECK(cudaStreamEndCapture(captureStream, nullptr));
AT_CUDA_CHECK(cudaStreamDestroy(captureStream));
cudaGraphExec_t instance;
AT_CUDA_CHECK(cudaGraphInstantiate(&instance, graph, NULL, NULL, 0));
AT_CUDA_CHECK(cudaGraphLaunch(instance, stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
AT_CUDA_CHECK(cudaGraphExecDestroy(instance));
AT_CUDA_CHECK(cudaGraphDestroy(graph));
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cudaGraphDemo", cudaGraphDemo);
}
#!/bin/bash
successes=0
failures=0
total=0
while true; do
python repro.py
exit_code=$?
if [ $exit_code -eq 0 ]; then
((successes++))
elif [ $exit_code -eq 139 ]; then
((failures++))
else
echo "Unexpected exit code: $exit_code"
continue
fi
((total++))
failure_percentage=$(awk "BEGIN {printf \"%.2f\", ($failures / $total) * 100}")
echo "Run #$total"
echo "Successes: $successes"
echo "Failures: $failures"
echo "Failure rate: $failure_percentage%"
echo "------------------------"
done