The following code
import torch
def test(a, b):
return a * b
a = torch.randn(3,3,3).to("cuda")
b = torch.tensor([[[1,-1,-1]]], device=a.device, dtype=a.dtype)
test_opt = torch.compile(test)
test_opt(a, b)
throws an error.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 454, in _fn
return fn(*args, **kwargs)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 904, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 769, in _convert_frame
result = inner_convert(
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 398, in _convert_frame_assert
return _compile(
File "/home/arl/miniconda3/envs/torch/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 669, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 249, in time_wrapper
r = func(*args, **kwargs)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 542, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
transformations(instructions, code_options)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 163, in _fn
return fn(*args, **kwargs)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 507, in transform
tracer.run()
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2122, in run
super().run()
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 785, in run
and self.step()
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 748, in step
getattr(self, inst.opname)(inst)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2241, in RETURN_VALUE
self.output.compile_subgraph(
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 931, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1102, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 249, in time_wrapper
r = func(*args, **kwargs)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1175, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1156, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/__init__.py", line 1730, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1321, in compile_fx
return aot_autograd(
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 57, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 891, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 249, in time_wrapper
r = func(*args, **kwargs)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 604, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 434, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 639, in aot_wrapper_synthetic_base
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 97, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 249, in time_wrapper
r = func(*args, **kwargs)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1249, in fw_compiler_base
return inner_compile(
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/debug.py", line 304, in inner
return fn(*args, **kwargs)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 423, in compile_fx_inner
compiled_graph = fx_codegen_and_compile(
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 689, in fx_codegen_and_compile
compiled_fn = graph.compile_to_fn()
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1227, in compile_to_fn
return self.compile_to_module().call
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 249, in time_wrapper
r = func(*args, **kwargs)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1175, in compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1150, in codegen
self.scheduler = Scheduler(self.buffers)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 249, in time_wrapper
r = func(*args, **kwargs)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 1217, in __init__
self.nodes = [self.create_scheduler_node(n) for n in nodes]
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 1217, in <listcomp>
self.nodes = [self.create_scheduler_node(n) for n in nodes]
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 1308, in create_scheduler_node
group_fn = self.get_backend(node.get_device()).group_fn
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 2225, in get_backend
self.backends[device] = self.create_backend(device)
File "/home/arl/miniconda3/envs/torch/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 2217, in create_backend
raise RuntimeError(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Package version:
(torch) arl@ubuntu:~$ conda list
# packages in environment at /home/arl/miniconda3/envs/torch:
#
# Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 51_gnu
_sysroot_linux-aarch64_curr_repodata_hack 4 h57d6b7b_14 conda-forge
binutils_impl_linux-aarch64 2.38 h0c9fd12_1
bzip2 1.0.8 h998d150_6
ca-certificates 2024.6.2 hcefe29a_0 conda-forge
filelock 3.15.4 pypi_0 pypi
fsspec 2024.6.0 pypi_0 pypi
gcc 12.1.0 h75b0f6e_10 conda-forge
gcc_impl_linux-aarch64 12.1.0 h9c21524_17 conda-forge
jinja2 3.1.4 pypi_0 pypi
kernel-headers_linux-aarch64 4.18.0 h5b4a56d_14 conda-forge
ld_impl_linux-aarch64 2.38 h8131f2d_1
libffi 3.4.4 h419075a_1
libgcc-devel_linux-aarch64 12.1.0 hf2ffb8d_17 conda-forge
libgcc-ng 13.2.0 he277a41_13 conda-forge
libgomp 13.2.0 he277a41_13 conda-forge
libsanitizer 12.1.0 hd01590b_17 conda-forge
libstdcxx-ng 13.2.0 h3f4de04_13 conda-forge
libuuid 1.41.5 h998d150_0
markupsafe 2.1.5 pypi_0 pypi
mpmath 1.3.0 pypi_0 pypi
ncurses 6.4 h419075a_0
networkx 3.3 pypi_0 pypi
numpy 1.26.4 pypi_0 pypi
openssl 3.3.1 h68df207_0 conda-forge
pip 24.0 py310hd43f75c_0
python 3.10.14 h4bb2201_1
readline 8.2 h998d150_0
setuptools 69.5.1 py310hd43f75c_0
sqlite 3.45.3 h998d150_0
sympy 1.12.1 pypi_0 pypi
sysroot_linux-aarch64 2.17 h5b4a56d_14 conda-forge
tk 8.6.14 h987d8db_0
torch 2.3.0a0+40ec155e58.nv24.3 pypi_0 pypi
typing-extensions 4.12.2 pypi_0 pypi
tzdata 2024a h04d1e81_0
wheel 0.43.0 py310hd43f75c_0
xz 5.4.6 h998d150_1
zlib 1.2.13 h998d150_1
JetPack version:
Package: nvidia-jetpack
Source: nvidia-jetpack (6.0)
Version: 6.0+b106
Architecture: arm64
Maintainer: NVIDIA Corporation
This works fine on x86-64 machines with the same PyTorch version.