Help running Nemotron 3 Nano 30B-A3B-FP8 on DGX Spark (GB10)

Hi,

I’m trying to run nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 on my DGX Spark (GB10 GPU). Running into several issues:

Setup:

  • DGX Spark founder edition with GB10 GPU (compute capability 12.1)

  • SGLang docker images (tried :spark and :latest)

Issues Encountered:

1. :spark image - Config attribute error:

AttributeError: ‘NemotronHConfig’ object has no attribute ‘rms*_norm_*eps’

The :spark image (5 weeks old) doesn’t support Nemotron 3’s new config format.

2. :latest image with FP8 - CUDA kernel error:

torch.AcceleratorError: CUDA error: no kernel image is available for execution on the device

FP8 quantization kernels don’t support compute capability 12.1 (GB10).

3. :latest image with BF16 - Currently trying, downloading model (~65GB).

Questions:

1. Is there an updated SGLang :spark image that supports Nemotron 3 Nano 30B?

2. Are there specific flags needed for GB10 + FP8?

3. Any official NIM container for Nemotron 3 Nano 30B on DGX Spark?

Thanks!

I’ve been trying to deploy nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B on my DGX Spark and hit several compatibility issues with the GB10 GPU (compute capability 12.1 / sm_121a).

What I’ve tried:

1. FP8 version (NVIDIA-Nemotron-3-Nano-30B-A3B-FP8)

- Error: CUDA error: no kernel image is available for execution on the device

- FP8 quantization kernels don't support sm_121a yet

2. SGLang :spark image (5 weeks old)

- Error: 'NemotronHConfig' object has no attribute 'rms_norm_eps'

- The image predates Nemotron 3's config format

3. SGLang :latest with BF16 - Model loads successfully (57GB)!

- Server starts, KV cache allocated

- **But fails on first inference** with Triton PTXAS error:

ptxas fatal: Value ‘sm*_121a’ is not defined for option ‘gpu-name’*

Root cause: Triton’s bundled ptxas compiler doesn’t support the GB10’s sm_121a architecture yet.

Currently testing: Setting TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas to use the system CUDA 12.9 toolkit instead.


Update 2: Setting TRITON_PTXAS_PATH doesn’t help - Triton ignores it in newer versions and uses its bundled ptxas which doesn’t support sm_121a.

Conclusion: Nemotron 3 Nano 30B currently cannot run on DGX Spark via SGLang due to Triton’s lack of sm_121a support.

Working alternative: Nemotron Nano 9B v2 via the official NIM container (nvcr.io/nim/nvidia/nvidia-nemotron-nano-9b-v2-dgx-spark:1.0.0-variant) works perfectly.

I ran nemotron nano vl v2 before using vLLM. You could try to adapt that one:

https://forums.developer.nvidia.com/t/running-nvidia-nemotron-nano-vl-12b-v2-nvfp4-qad-on-your-spark/350349/5

Update after system updates (Dec 15, 2025 23:40 CET)

Ran full system updates and rebooted:

sudo apt update && sudo apt dist-upgrade

sudo fwupdmgr refresh && sudo fwupdmgr upgrade

sudo reboot

New kernel: 6.14.0-1015-nvidia (was 6.14.0-1013-nvidia)

CUDA: 13.0.88 (unchanged)

Driver: 580.95.05


What we tried after update:

1. BF16 + :latest with patched Triton ptxas

sudo docker run --gpus all \

-v /usr/local/cuda/bin/ptxas:/usr/local/cuda-host/bin/ptxas:ro \\

lmsysorg/sglang:latest \\

bash -c '

  TRITON_PTXAS=$(find /usr/local/lib -name ptxas -path "\*/triton/\*" | head -1)

  cp /usr/local/cuda-host/bin/ptxas $TRITON_PTXAS

  python3 -m sglang.launch_server --model-path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 ...

'

Result: Model loads successfully! (57.62 GB, ~5 min load time)

[2025-12-15 22:23:51] Load weight end. type=NemotronHForCausalLM, dtype=torch.bfloat16, mem usage=57.62 GB

[2025-12-15 22:24:00] Uvicorn running on http://0.0.0.0:30000

But crashes on first inference:

RuntimeError: RMSNorm failed with error code no kernel image is available for execution on the device

Full traceback:

File “sglang/srt/layers/layernorm.py”, line 125, in forward*_cuda*

  *out = rmsnorm(x, self.weight.data, self.variance\_*epsilon)

File “sgl*_kernel/elementwise.py”, line 45, in rmsnorm*

  *torch.ops.sgl\_*kernel.rmsnorm.default(out, input, weight, eps, enable*\_pdl)*

RuntimeError: RMSNorm failed with error code no kernel image is available for execution on the device


2. FP8 + :latest with triton backend

python3 -m sglang.launch_server \

--model-path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 \\

--attention-backend triton \\

--sampling-backend pytorch \\

...

Result: Crashes during weight loading

File “sglang/srt/layers/quantization/utils.py”, line 157, in requantize*_with_max_scale*

  *weight\[start:end, :\], \_* = scaled*\_fp8\_*quant(weight*\_dq, max\_*w*\_scale)*

torch.AcceleratorError: CUDA error: no kernel image is available for execution on the device


3. FP8 + :spark image

Result: Config attribute error (old image)

AttributeError: ‘NemotronHConfig’ object has no attribute ‘rms*_norm_*eps’


Root Cause Analysis

The issue is not Triton’s ptxas (we patched that successfully). The real blocker is:

sgl_kernel precompiled CUDA operators don’t have sm_121a support

The SGLang docker images include precompiled CUDA kernels (sgl_kernel) for operations like:

  • rmsnorm (RMSNorm layer normalization)

  • scaled_fp8_quant (FP8 quantization)

  • Other custom ops

These are compiled for architectures up to sm_120 (CUDA 12.0), but GB10 requires sm_121a (compute capability 12.1).

Evidence:

  • System ptxas (CUDA 13.0) supports sm_121a: ✅

  • Triton ptxas can be patched to use system version: ✅

  • sgl_kernel precompiled ops: ❌ No sm_121a kernels


Next that i am going to try out is

1. llama.cpp + GGUF - unsloth/Nemotron-3-Nano-30B-A3B-GGUF · Hugging Face (llama.cpp PR merged, no CUDA kernel dependency)

You’ll probably load in less than 30-40s with my container. Please evaluate it.

Thanks @raphael.amorim! I’ll definitely try your container.

But I have to be honest - I just read NVIDIA’s blog post about Nemotron 3 and got really excited. They literally say:

“This 30B total 3B active parameter model is specifically designed for DGX Spark, H100, and B200 GPUs”

So I thought: “Perfect! This is exactly what my €4,500 Spark is for!”

Several hours of debugging later… here I am. The official FP8 model doesn’t work. The BF16 loads but crashes on inference. The :spark container is outdated. The root cause? Precompiled CUDA kernels without sm_121a support.

A bit disappointing to say the least.

Before trying your container, I want to test llama.cpp + GGUF first since:

  1. No dependency on precompiled CUDA kernels (our actual root cause)

  2. PR #18058 for Nemotron 3 just merged

  3. Unsloth GGUF quantizations are already available

If that works, it confirms the issue is purely sgl_kernel architecture support, not GB10 itself.

Then I’ll try your container and compare. Will report back with results!

1 Like

yeah, it’s part of playing with a new ecosystem. We’ve come a long way since october 15. I would suggest you take a look at some of these errors in other threads, because it has happened a lot. For single node I would definitely try llama.cpp, best perfomance for single node non-concurrent on the spark..

Those scripts are for multi-node deployment of trt-llm, sglang and vLLM. Might get some insights there for the runtime adjustments.

1 Like

Thanks @raphael.amorim! Good to know llama.cpp is the recommended path for single-node. That’s exactly my use case.

And fair point about the ecosystem being new, I get it, early adopter pains. Just wish the marketing (“specifically designed for Spark”) matched the current reality a bit better. ;)

Going to try llama.cpp now. Will check out those mark-ramsey-ri scripts too for future reference when I want to explore multi-node setups.

Will report back once I have llama.cpp running!

1 Like

Update: Nemotron 3 Nano 30B running on DGX Spark via llama.cpp! ✅

After the system updates (kernel 6.14.0-1015-nvidia), I tested all SGLang approaches again. Here’s the full breakdown:


SGLang Status (Still Blocked)

BF16 + :latest with patched Triton ptxas:

  • ✅ Model loads successfully (57.62 GB)

  • ✅ Server starts, KV cache allocated

  • ❌ Crashes on first inference:

RuntimeError: RMSNorm failed with error code no kernel image is available for execution on the device

Root cause: Even after patching Triton’s ptxas to use system CUDA 13.0, the precompiled sgl_kernel CUDA operators (rmsnorm, scaled_fp8_quant, etc.) don’t have sm_121a kernels. This is separate from the Triton issue and cannot be fixed without rebuilding sgl_kernel from source.

FP8 + :latest: Same kernel error during weight quantization.

FP8 + :spark: Config error (‘NemotronHConfig’ object has no attribute ‘rms_norm_eps’) - image too old.


Working Solution: llama.cpp + GGUF

llama.cpp compiles CUDA kernels at build time for your specific architecture, avoiding the precompiled kernel issue entirely.

Steps:

# 1. Clone llama.cpp

cd ~

git clone GitHub - ggml-org/llama.cpp: LLM inference in C/C++

cd llama.cpp

# 2. Checkout the Nemotron H support PR (not merged to master yet)

git fetch origin pull/18058/head:nemotron-h

git checkout nemotron-h

# 3. Build with CUDA for GB10 (sm_121)

mkdir build && cd build

cmake .. -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=“121” -DLLAMA_CURL=OFF

make -j8

# 4. Download the Q8 GGUF model (~38GB)

huggingface-cli download unsloth/Nemotron-3-Nano-30B-A3B-GGUF \

Nemotron-3-Nano-30B-A3B-UD-Q8_K_XL.gguf \\

--local-dir \~/models/nemotron3-gguf

# 5. Run the server

./bin/llama-server \

--model \~/models/nemotron3-gguf/Nemotron-3-Nano-30B-A3B-UD-Q8_K_XL.gguf \\

--host 0.0.0.0 \\

--port 30000 \\

--n-gpu-layers 99 \\

--ctx-size 8192 \\

--threads 8

Test it:

curl http://localhost:30000/v1/chat/completions \

-H "Content-Type: application/json" \\

-d '{"model":"nemotron","messages":\[{"role":"user","content":"What is 2+2?"}\],"max_tokens":100}'

Performance on DGX Spark (GB10)

| Metric | Value |

|------------------|------------------------------|

| Quantization | Q8_K_XL |

| Model Size | 37.66 GiB |

| GPU VRAM Used | ~38 GiB |

| Prompt Speed | 86 tokens/sec |

| Generation Speed | 43.7 tokens/sec |

| Context Size | 8192 (configurable up to 1M) |

The model includes built-in reasoning (thinking mode) and tool calling support via the chat template.


Summary

| Method | Status | Issue |

|-----------------------|--------|-----------------------------------|

| SGLang :latest (BF16) | ❌ | sgl_kernel missing sm_121a |

| SGLang :latest (FP8) | ❌ | FP8 quant kernels missing sm_121a |

| SGLang :spark | ❌ | Old image, no Nemotron 3 support |

| llama.cpp + GGUF | ✅ | Works perfectly! |

Request to NVIDIA/SGLang team: Please rebuild sgl_kernel with sm_121a support for the :spark image, or provide a NIM container for Nemotron 3 30B optimized for DGX Spark.


Hope this helps others with DGX Spark!

1 Like

Tried to run on my docker setup with freshly built vllm from main branch - no bueno. It loads (with flashinfer errors during loading), but produces complete garbage as an output. Errors look like this:

(EngineCore_DP0 pid=294) 2025-12-16 07:15:12,481 - WARNING - autotuner.py:490 - flashinfer.jit: [Autotuner]: Skipping tactic <flashinfer.fused_moe.core.get_cutlass_fused_moe_module.<locals>.MoERunner object at 0xfb9e3269aea0> 41, due to failure while profiling: [TensorRT-LLM][ERROR] Assertion failed: GPU lacks the shared memory resources to run GroupedGEMM kernel (/workspace/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h:175)
(EngineCore_DP0 pid=294) 1       0xfb9e194a5eb4 tensorrt_llm::common::throwRuntimeError(char const*, int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 84
(EngineCore_DP0 pid=294) 2       0xfb9e195d04f0 void tensorrt_llm::kernels::cutlass_kernels_oss::dispatchGemmConfig<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, cutlass::arch::Sm89, tensorrt_llm::cutlass_extensions::EpilogueOpDefault, cutlass::gemm::GemmShape<16, 256, 128>, cutlass::gemm::GemmShape<16, 64, 128>, (void*)0>(tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>, int) + 2400
(EngineCore_DP0 pid=294) 3       0xfb9e195d8394 void tensorrt_llm::kernels::cutlass_kernels_oss::dispatchMoeGemmToCutlass<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, cutlass::arch::Sm89, tensorrt_llm::cutlass_extensions::EpilogueOpDefault, (void*)0>(tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>, int) + 356
(EngineCore_DP0 pid=294) 4       0xfb9e195d86f0 void tensorrt_llm::kernels::cutlass_kernels::MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>::dispatchToArch<tensorrt_llm::cutlass_extensions::EpilogueOpDefault>(tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>, tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput) + 320
(EngineCore_DP0 pid=294) 5       0xfb9e195d8ffc tensorrt_llm::kernels::cutlass_kernels::MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>::moeGemmBiasAct(tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>, tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput) + 508
(EngineCore_DP0 pid=294) 6       0xfb9e19954dcc tensorrt_llm::kernels::cutlass_kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16, void>::gemm2(tensorrt_llm::kernels::cutlass_kernels::MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>&, tensorrt_llm::kernels::fp8_blockscale_gemm::CutlassFp8BlockScaleGemmRunnerInterface*, __nv_fp8_e4m3 const*, void*, __nv_bfloat16*, long const*, tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput, __nv_fp8_e4m3 const*, __nv_bfloat16 const*, __nv_bfloat16 const*, float const*, unsigned char const*, tensorrt_llm::kernels::cutlass_kernels::QuantParams, float const*, float const*, int const*, int const*, int const*, long const*, long, long, long, long, long, int, long, float const**, bool, void*, CUstream_st*, tensorrt_llm::kernels::cutlass_kernels::MOEParallelismConfig, bool, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, bool, int*, int*, bool) + 716
(EngineCore_DP0 pid=294) 7       0xfb9e19955240 tensorrt_llm::kernels::cutlass_kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16, void>::gemm2(void const*, void*, void*, long const*, tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput, void const*, void const*, void const*, float const*, unsigned char const*, tensorrt_llm::kernels::cutlass_kernels::QuantParams, float const*, float const*, int const*, int const*, int const*, long const*, long, long, long, long, long, int, long, float const**, bool, void*, bool, CUstream_st*, tensorrt_llm::kernels::cutlass_kernels::MOEParallelismConfig, bool, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, bool, int*, int*, bool) + 400
(EngineCore_DP0 pid=294) 8       0xfb9e198ec9f8 tensorrt_llm::kernels::cutlass_kernels::GemmProfilerBackend::runProfiler(int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig const&, char*, void const*, bool, CUstream_st* const&) + 2552
(EngineCore_DP0 pid=294) 9       0xfb9e19897318 /usr/local/lib/python3.12/dist-packages/flashinfer_jit_cache/jit_cache/fused_moe_120/fused_moe_120.so(+0x567318) [0xfb9e19897318]
(EngineCore_DP0 pid=294) 10      0xfb9e198bf074 /usr/local/lib/python3.12/dist-packages/flashinfer_jit_cache/jit_cache/fused_moe_120/fused_moe_120.so(+0x58f074) [0xfb9e198bf074]
(EngineCore_DP0 pid=294) 11      0xfb9e19894608 tvm::ffi::details::FunctionObjImpl<tvm::ffi::Function::FromTyped<FusedMoeRunner::GetFunction(tvm::ffi::String const&)::{lambda(tvm::ffi::TensorView, tvm::ffi::TensorView, tvm::ffi::Optional<tvm::ffi::TensorView, void>, tvm::ffi::TensorView, tvm::ffi::Optional<tvm::ffi::TensorView, void>, long, long, long, long, long, long, long, bool, bool, long, long, bool, bool, long)#1}>(FusedMoeRunner::GetFunction(tvm::ffi::String const&)::{lambda(tvm::ffi::TensorView, tvm::ffi::TensorView, tvm::ffi::Optional<tvm::ffi::TensorView, void>, tvm::ffi::TensorView, tvm::ffi::Optional<tvm::ffi::TensorView, void>, long, long, long, long, long, long, long, bool, bool, long, long, bool, bool, long)#1}&&)::{lambda(tvm::ffi::AnyView const*, int, tvm::ffi::Any*)#1}>::SafeCall(void*, TVMFFIAny const*, int, TVMFFIAny*) + 696
(EngineCore_DP0 pid=294) 12      0xfb9e9c43231c /usr/local/lib/python3.12/dist-packages/tvm_ffi/core.abi3.so(+0x5231c) [0xfb9e9c43231c]
(EngineCore_DP0 pid=294) 13            0x4c2e78 _PyObject_MakeTpCall + 120
(EngineCore_DP0 pid=294) 14            0x564f28 _PyEval_EvalFrameDefault + 2292
(EngineCore_DP0 pid=294) 15            0x4c7088 VLLM::EngineCore() [0x4c7088]
(EngineCore_DP0 pid=294) 16            0x4c5408 PyObject_Call + 280
(EngineCore_DP0 pid=294) 17            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 18            0x4c4a44 _PyObject_Call_Prepend + 436
(EngineCore_DP0 pid=294) 19            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 20            0x4c535c PyObject_Call + 108
(EngineCore_DP0 pid=294) 21            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 22            0x4c7088 VLLM::EngineCore() [0x4c7088]
(EngineCore_DP0 pid=294) 23            0x4c5408 PyObject_Call + 280
(EngineCore_DP0 pid=294) 24            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 25            0x4c7088 VLLM::EngineCore() [0x4c7088]
(EngineCore_DP0 pid=294) 26            0x4c5408 PyObject_Call + 280
(EngineCore_DP0 pid=294) 27            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 28            0x4c7088 VLLM::EngineCore() [0x4c7088]
(EngineCore_DP0 pid=294) 29            0x4c5408 PyObject_Call + 280
(EngineCore_DP0 pid=294) 30            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 31            0x4c4a44 _PyObject_Call_Prepend + 436
(EngineCore_DP0 pid=294) 32            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 33            0x4c2f30 _PyObject_MakeTpCall + 304
(EngineCore_DP0 pid=294) 34            0x564f28 _PyEval_EvalFrameDefault + 2292
(EngineCore_DP0 pid=294) 35      0xfba179d6d020 /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so(+0xccd020) [0xfba179d6d020]
(EngineCore_DP0 pid=294) 36      0xfba17a117ed4 /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so(+0x1077ed4) [0xfba17a117ed4]
(EngineCore_DP0 pid=294) 37      0xfba17497b75c /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so(+0x636b75c) [0xfba17497b75c]
(EngineCore_DP0 pid=294) 38      0xfba179e82b60 /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so(+0xde2b60) [0xfba179e82b60]
(EngineCore_DP0 pid=294) 39      0xfba179e830f0 /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so(+0xde30f0) [0xfba179e830f0]
(EngineCore_DP0 pid=294) 40      0xfba179d61860 /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so(+0xcc1860) [0xfba179d61860]
(EngineCore_DP0 pid=294) 41      0xfba179677cac /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so(+0x5d7cac) [0xfba179677cac]
(EngineCore_DP0 pid=294) 42            0x503884 VLLM::EngineCore() [0x503884]
(EngineCore_DP0 pid=294) 43            0x4c535c PyObject_Call + 108
(EngineCore_DP0 pid=294) 44            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 45            0x4c4954 _PyObject_Call_Prepend + 196
(EngineCore_DP0 pid=294) 46            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 47            0x4c2e78 _PyObject_MakeTpCall + 120
(EngineCore_DP0 pid=294) 48            0x564f28 _PyEval_EvalFrameDefault + 2292
(EngineCore_DP0 pid=294) 49            0x4c4954 _PyObject_Call_Prepend + 196
(EngineCore_DP0 pid=294) 50            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 51            0x4c2e78 _PyObject_MakeTpCall + 120
(EngineCore_DP0 pid=294) 52            0x564f28 _PyEval_EvalFrameDefault + 2292
(EngineCore_DP0 pid=294) 53            0x4c4954 _PyObject_Call_Prepend + 196
(EngineCore_DP0 pid=294) 54            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 55            0x4c2e78 _PyObject_MakeTpCall + 120
(EngineCore_DP0 pid=294) 56            0x564f28 _PyEval_EvalFrameDefault + 2292
(EngineCore_DP0 pid=294) 57            0x4c4954 _PyObject_Call_Prepend + 196
(EngineCore_DP0 pid=294) 58            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 59            0x4c535c PyObject_Call + 108
(EngineCore_DP0 pid=294) 60            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 61            0x4c4954 _PyObject_Call_Prepend + 196
(EngineCore_DP0 pid=294) 62            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 63            0x4c535c PyObject_Call + 108
(EngineCore_DP0 pid=294) 64            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 65            0x4c4954 _PyObject_Call_Prepend + 196
(EngineCore_DP0 pid=294) 66            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 67            0x4c2e78 _PyObject_MakeTpCall + 120
(EngineCore_DP0 pid=294) 68            0x564f28 _PyEval_EvalFrameDefault + 2292
(EngineCore_DP0 pid=294) 69            0x4c6fac VLLM::EngineCore() [0x4c6fac]
(EngineCore_DP0 pid=294) 70            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 71            0x4c6fac VLLM::EngineCore() [0x4c6fac]
(EngineCore_DP0 pid=294) 72            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 73            0x4c6fac VLLM::EngineCore() [0x4c6fac]
(EngineCore_DP0 pid=294) 74            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 75            0x4c4954 _PyObject_Call_Prepend + 196
(EngineCore_DP0 pid=294) 76            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 77            0x4c535c PyObject_Call + 108
(EngineCore_DP0 pid=294) 78            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 79            0x4c4954 _PyObject_Call_Prepend + 196
(EngineCore_DP0 pid=294) 80            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 81            0x4c535c PyObject_Call + 108
(EngineCore_DP0 pid=294) 82            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 83            0x4c4954 _PyObject_Call_Prepend + 196
(EngineCore_DP0 pid=294) 84            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 85            0x4c2e78 _PyObject_MakeTpCall + 120
(EngineCore_DP0 pid=294) 86            0x564f28 _PyEval_EvalFrameDefault + 2292
(EngineCore_DP0 pid=294) 87            0x4c4954 _PyObject_Call_Prepend + 196
(EngineCore_DP0 pid=294) 88            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 89            0x4c535c PyObject_Call + 108
(EngineCore_DP0 pid=294) 90            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 91            0x4c4954 _PyObject_Call_Prepend + 196
(EngineCore_DP0 pid=294) 92            0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 93            0x4c2e78 _PyObject_MakeTpCall + 120
(EngineCore_DP0 pid=294) 94            0x564f28 _PyEval_EvalFrameDefault + 2292
(EngineCore_DP0 pid=294) 95            0x4c7088 VLLM::EngineCore() [0x4c7088]
(EngineCore_DP0 pid=294) 96            0x4c5408 PyObject_Call + 280
(EngineCore_DP0 pid=294) 97            0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 98            0x4c7088 VLLM::EngineCore() [0x4c7088]
(EngineCore_DP0 pid=294) 99            0x4c5408 PyObject_Call + 280
(EngineCore_DP0 pid=294) 100           0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 101           0x4c4a44 _PyObject_Call_Prepend + 436
(EngineCore_DP0 pid=294) 102           0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 103           0x4c535c PyObject_Call + 108
(EngineCore_DP0 pid=294) 104           0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 105           0x4c4a44 _PyObject_Call_Prepend + 436
(EngineCore_DP0 pid=294) 106           0x528f50 VLLM::EngineCore() [0x528f50]
(EngineCore_DP0 pid=294) 107           0x4c535c PyObject_Call + 108
(EngineCore_DP0 pid=294) 108           0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 109           0x4c7024 VLLM::EngineCore() [0x4c7024]
(EngineCore_DP0 pid=294) 110           0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 111           0x4c4a44 _PyObject_Call_Prepend + 436
(EngineCore_DP0 pid=294) 112           0x5228b0 VLLM::EngineCore() [0x5228b0]
(EngineCore_DP0 pid=294) 113           0x51df74 VLLM::EngineCore() [0x51df74]
(EngineCore_DP0 pid=294) 114           0x4c535c PyObject_Call + 108
(EngineCore_DP0 pid=294) 115           0x56832c _PyEval_EvalFrameDefault + 15608
(EngineCore_DP0 pid=294) 116           0x563224 PyEval_EvalCode + 304
(EngineCore_DP0 pid=294) 117           0x59bfb0 PyRun_StringFlags + 224
(EngineCore_DP0 pid=294) 118           0x67f0d4 PyRun_SimpleStringFlags + 68
(EngineCore_DP0 pid=294) 119           0x68b890 Py_RunMain + 912
(EngineCore_DP0 pid=294) 120           0x68b398 Py_BytesMain + 40
(EngineCore_DP0 pid=294) 121     0xfba2081a84c4 /usr/lib/aarch64-linux-gnu/libc.so.6(+0x284c4) [0xfba2081a84c4]
(EngineCore_DP0 pid=294) 122     0xfba2081a8598 __libc_start_main + 152
(EngineCore_DP0 pid=294) 123           0x5f6bb0 _start + 48

Even tried nightly cu130 pytorch + flashinfer - same result.

1 Like

LM Studio is a convenient wrapper for llama.cpp as well

1 Like

Can confirm, I have a similar setup, tried my image, and your image on GH, but it seems all environments running nemotron-3 at fp8 quant that lead to “successful” inferencing return nonsense output. I’ll toy around again later today to see if I can get it working.

Note that running BF16 works just fine though. I’m getting 25-30 tps.

1 Like

Thanks, running into the same issue, would be nice to see this work.

1 Like

LM Studio NVIDIA/ Q4_K_M is working fine on the spark at 60 TPS

2 Likes

Interesting. In theory, it should give better performance than that, given that gpt-oss-120b has the same performance in llama.cpp with 5.1B active parameters. I get 83 t/s from Qwen3-30B-A3B in AWQ 4-bit quant in VLLM, and this is the model Nemotron is based at.

Are you using the latest llama.cpp build, or the one that comes with LM Studio?

1 Like

LM Studio. Will try llama.cpp latest

3 Likes

around the same @eugr. For the exact same gguf file.

./build/bin/llama-server -m ~/.cache/llama.cpp/nvidia_Nemotron-3-Nano-30B-A3B-Q4_K_M.gguf
-fa on -ngl 999
–jinja
–ctx-size 0
-b 2048 -ub 2048
–no-mmap
–temp 1.0
–top-p 1.0
–top-k 0
–reasoning-format auto
–chat-template-kwargs “{"reasoning_effort": "low"}”

2 Likes

I see, probably not well optimized yet.

2 Likes

@ooze.orb Thanks for fighting through this. I too was excited to see it’s designed for the DGX.

1 Like

@john337 Anytime! 😄 Have to say, it’s a lot of fun being part of the NVIDIA community, especially when we’re all debugging together!

1 Like