Fail to build nn.Conv2d in TensorRT-10 with large max_shape for DynamicShape

Description

I am trying to upgrade TensorRT from 9 to 10 for my customized model. I failed to build it after the upgrade. After my investigation, I think the problem can be reduced to failing building a nn.Conv2d layer with large max_shape of its DynamicShape input.

  • Conv2d: 3x3, 256 channels to 256 channels, padding = 1
  • input shape:
    • min_shape: (0, 256, 14, 14)
    • opt_shape: (1, 256, 14, 14)
    • max_shape: (42800, 256, 14, 14)

Notice that 42800 x 256 x 14 x 14 > 2^31 - 1. It will not fail if max_shape is (42799, 256, 14, 14). Not sure whether it is related to the UNet issue (Unet results wrong of TensorRT 10.x when running on GPU L40s · Issue #4351 · NVIDIA/TensorRT · GitHub).

Environment

TensorRT Version: 10.13.3.9
GPU Type: RTX 6000 Ada
Nvidia Driver Version: 570.158.01
CUDA Version: 12.9
CUDNN Version: 8.9.6.50
Operating System + Version: ubuntu 22.04.5 LTS
Python Version (if applicable): 3.10.12
TensorFlow Version (if applicable): N/A
PyTorch Version (if applicable): 2.9.1+cu129
Baremetal or Container (if container which image + tag): container. nvcr.io/nvidia/cuda:12.9.1-cudnn-devel-ubuntu22.04

Steps To Reproduce

You can directly run the following python code to reproduce this issue:

import numpy as np
import torch
import torch.nn as nn
import tensorrt as trt


class ConvOnly(nn.Module):
    def __init__(self, in_ch=256, out_ch=256, k=3, pad=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=pad)

    def forward(self, x):
        return self.conv(x)


def build_trt_engine_from_torch_conv(
    torch_conv: nn.Conv2d,
    min_shape=(1, 256, 14, 14),
    opt_shape=(1, 256, 14, 14),
    max_shape=(4, 256, 14, 14),
    fp16=True,
    logger_severity=trt.Logger.WARNING,
):
    """
    Build a TensorRT engine with one Conv2D layer using weights from torch_conv.
    Uses explicit batch and a dynamic optimization profile on the input.
    """
    logger = trt.Logger(logger_severity)
    builder = trt.Builder(logger)

    # Explicit batch flag is required for dynamic shapes in modern TRT
    network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    network = builder.create_network(network_flags)
    config = builder.create_builder_config()
    config.builder_optimization_level = 0

    if fp16:
        config.set_flag(trt.BuilderFlag.FP16)
        config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)

    # You can raise this if you want more tactic choices
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 35)  # 32 GiB

    # ---- Input (dynamic N, H, W; C fixed) ----
    in_ch = torch_conv.in_channels
    x = network.add_input(
        name="input",
        dtype=trt.float16 if fp16 else trt.float32,
        shape=(-1, in_ch, 14, 14),  # (N, C, H, W)
    )

    # ---- Weights/Bias from PyTorch ----
    # TRT expects weights as CPU numpy arrays
    w = torch_conv.weight.detach().cpu().numpy().astype(np.float16 if fp16 else np.float32)
    b = None
    if torch_conv.bias is not None:
        b = torch_conv.bias.detach().cpu().numpy().astype(np.float16 if fp16 else np.float32)

    # ---- Convolution layer ----
    kH, kW = torch_conv.kernel_size
    conv = network.add_convolution_nd(
        input=x,
        num_output_maps=torch_conv.out_channels,
        kernel_shape=(kH, kW),
        kernel=w,
        bias=b
    )
    conv.stride_nd = tuple(torch_conv.stride)
    conv.padding_nd = tuple(torch_conv.padding)
    conv.dilation_nd = tuple(torch_conv.dilation)
    conv.num_groups = int(torch_conv.groups)

    # Mark output
    conv.get_output(0).name = "output"
    network.mark_output(conv.get_output(0))

    # ---- Optimization profile for dynamic shapes ----
    profile = builder.create_optimization_profile()
    profile.set_shape("input", min=min_shape, opt=opt_shape, max=max_shape)
    config.add_optimization_profile(profile)

    # ---- Build ----
    serialized = builder.build_serialized_network(network, config)
    if serialized is None:
        raise RuntimeError("Failed to build TensorRT engine (serialized network is None).")
    
    runtime = trt.Runtime(logger)
    engine = runtime.deserialize_cuda_engine(serialized)
    if engine is None:
        raise RuntimeError("Failed to deserialize TensorRT engine.")

    return engine, serialized


if __name__ == "__main__":
    in_ch = 256
    out_ch = 256
  
  
    # 1) Make a torch model and initialize weights
    model = ConvOnly(in_ch=in_ch, out_ch=out_ch, k=3, pad=1).eval()
    
    max_batch = 42800
    opt_batch = 1
    fp16 = True

    # 2) Build TRT engine from the torch conv layer (no ONNX, no torch2trt)
    engine, plan = build_trt_engine_from_torch_conv(
        model.conv,
        min_shape=(0, in_ch, 14, 14),
        opt_shape=(opt_batch, in_ch, 14, 14),
        max_shape=(max_batch, in_ch, 14, 14),
        fp16=fp16,
        logger_severity=trt.Logger.VERBOSE,
    )

    # 3) Save the plan
    with open("conv_dynamic.plan", "wb") as f:
        f.write(plan)

    print("Built engine:", engine.name if hasattr(engine, "name") else "<engine>")

If I use TensorRT-10.13.3.9, the log shows:

[12/16/2025-19:43:49] [TRT] [I] [MemUsageChange] Init CUDA: CPU +66, GPU +0, now: CPU 186, GPU 433 (MiB)
[12/16/2025-19:43:49] [TRT] [V] Trying to load shared library libnvinfer_builder_resource.so.10.13.3
[12/16/2025-19:43:49] [TRT] [V] Loaded shared library libnvinfer_builder_resource.so.10.13.3
[12/16/2025-19:43:50] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +1607, GPU +8, now: CPU 1995, GPU 441 (MiB)
[12/16/2025-19:43:50] [TRT] [V] CUDA lazy loading is enabled.
[12/16/2025-19:43:50] [TRT] [V] could not open /sys/fs/cgroup/memory/memory.limit_in_bytes or /sys/fs/cgroup/memory.max
[12/16/2025-19:43:50] [TRT] [V] Original: 1 layers
[12/16/2025-19:43:50] [TRT] [V] After dead-layer removal: 1 layers
[12/16/2025-19:43:50] [TRT] [V] SYMBOLIC CHECKS
[12/16/2025-19:43:50] [TRT] [V] GRAPH NODES
[12/16/2025-19:43:50] [TRT] [V] CONVOLUTION (Unnamed Layer* 0) [Convolution]
[12/16/2025-19:43:50] [TRT] [V]     Input 0
[12/16/2025-19:43:50] [TRT] [V]         [0,42800]       (# 0 (SHAPE input))
[12/16/2025-19:43:50] [TRT] [V]         256     256
[12/16/2025-19:43:50] [TRT] [V]         14      14
[12/16/2025-19:43:50] [TRT] [V]         14      14
[12/16/2025-19:43:50] [TRT] [V]     Output 0
[12/16/2025-19:43:50] [TRT] [V]         [0,42800]       (# 0 (SHAPE input))
[12/16/2025-19:43:50] [TRT] [V]         256     256
[12/16/2025-19:43:50] [TRT] [V]         14      14
[12/16/2025-19:43:50] [TRT] [V]         14      14
[12/16/2025-19:43:50] [TRT] [V] Graph construction completed in 0.000300123 seconds.
[12/16/2025-19:43:50] [TRT] [V] After adding DebugOutput nodes: 1 layers
[12/16/2025-19:43:50] [TRT] [V] After Myelin optimization: 1 layers
[12/16/2025-19:43:50] [TRT] [V] Applying ScaleNodes fusions.
[12/16/2025-19:43:50] [TRT] [V] After scale fusion: 1 layers
[12/16/2025-19:43:50] [TRT] [V] After dupe layer removal: 1 layers
[12/16/2025-19:43:50] [TRT] [V] After final dead-layer removal: 1 layers
[12/16/2025-19:43:50] [TRT] [V] After tensor merging: 1 layers
[12/16/2025-19:43:50] [TRT] [V] After vertical fusions: 1 layers
[12/16/2025-19:43:50] [TRT] [V] After dupe layer removal: 1 layers
[12/16/2025-19:43:50] [TRT] [V] After final dead-layer removal: 1 layers
[12/16/2025-19:43:50] [TRT] [V] After tensor merging: 1 layers
[12/16/2025-19:43:50] [TRT] [V] After slice removal: 1 layers
[12/16/2025-19:43:50] [TRT] [V] After concat removal: 1 layers
[12/16/2025-19:43:50] [TRT] [V] Trying to split Reshape and strided tensor
[12/16/2025-19:43:50] [TRT] [V] Graph optimization time: 0.000243988 seconds.
[12/16/2025-19:43:50] [TRT] [V] Building graph using backend strategy 0
[12/16/2025-19:43:50] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[12/16/2025-19:43:50] [TRT] [V] Constructing optimization profile number 0 [1/1].
[12/16/2025-19:43:50] [TRT] [V] Applying generic optimizations to the graph for inference.
[12/16/2025-19:43:50] [TRT] [V] Removed 2 format requirement combinations from consideration due to precision constraints.
[12/16/2025-19:43:50] [TRT] [V] Reserving memory for host IO tensors. Host: 0 bytes
[12/16/2025-19:43:50] [TRT] [V] =============== Computing costs for {ForeignNode[(Unnamed Layer* 0) [Convolution]]}
[12/16/2025-19:43:50] [TRT] [V] ForeignNode {ForeignNode[(Unnamed Layer* 0) [Convolution]]} metadata: 
[12/16/2025-19:43:50] [TRT] [V] *************** Autotuning format combination: Half(50176,196,14,1) -> Half(50176,196,14,1) ***************
[12/16/2025-19:43:50] [TRT] [V] --------------- Timing Runner: {ForeignNode[(Unnamed Layer* 0) [Convolution]]} (Myelin[0x80000023])
[12/16/2025-19:43:50] [TRT] [I] Compiler backend is used during engine build.
[12/16/2025-19:43:50] [TRT] [V] Disabling gemm+pointwise/reduce fusions at optimization level 0
[12/16/2025-19:43:50] [TRT] [V] Set number of tactics to 1 at optimization level 0
[12/16/2025-19:43:52] [TRT] [E] Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [tactic.cpp:conv_tactics:1191] No conv tactic found In compileGraph at optimizer/myelin/codeGenerator.cpp:1425
[12/16/2025-19:43:52] [TRT] [V] {ForeignNode[(Unnamed Layer* 0) [Convolution]]} (Myelin[0x80000023]) profiling completed in 1.85701 seconds. Fastest Tactic: 0xd15ea5edd15ea5ed Time: inf
[12/16/2025-19:43:52] [TRT] [V] *************** Autotuning format combination: Half(25088,1:2,1792,128) -> Half(25088,1:2,1792,128) ***************
[12/16/2025-19:43:52] [TRT] [V] --------------- Timing Runner: {ForeignNode[(Unnamed Layer* 0) [Convolution]]} (Myelin[0x80000023])
[12/16/2025-19:43:52] [TRT] [V] Disabling gemm+pointwise/reduce fusions at optimization level 0
[12/16/2025-19:43:52] [TRT] [V] Set number of tactics to 1 at optimization level 0
[12/16/2025-19:43:53] [TRT] [E] Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [tactic.cpp:conv_tactics:1191] No conv tactic found In compileGraph at optimizer/myelin/codeGenerator.cpp:1425
[12/16/2025-19:43:53] [TRT] [V] {ForeignNode[(Unnamed Layer* 0) [Convolution]]} (Myelin[0x80000023]) profiling completed in 0.274851 seconds. Fastest Tactic: 0xd15ea5edd15ea5ed Time: inf
[12/16/2025-19:43:53] [TRT] [V] *************** Autotuning format combination: Half(12544,1:4,896,64) -> Half(12544,1:4,896,64) ***************
[12/16/2025-19:43:53] [TRT] [V] --------------- Timing Runner: {ForeignNode[(Unnamed Layer* 0) [Convolution]]} (Myelin[0x80000023])
[12/16/2025-19:43:53] [TRT] [V] Disabling gemm+pointwise/reduce fusions at optimization level 0
[12/16/2025-19:43:53] [TRT] [V] Set number of tactics to 1 at optimization level 0
[12/16/2025-19:43:53] [TRT] [E] Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [tactic.cpp:conv_tactics:1191] No conv tactic found In compileGraph at optimizer/myelin/codeGenerator.cpp:1425
[12/16/2025-19:43:53] [TRT] [V] {ForeignNode[(Unnamed Layer* 0) [Convolution]]} (Myelin[0x80000023]) profiling completed in 0.289409 seconds. Fastest Tactic: 0xd15ea5edd15ea5ed Time: inf
[12/16/2025-19:43:53] [TRT] [V] *************** Autotuning format combination: Half(6272,1:8,448,32) -> Half(6272,1:8,448,32) ***************
[12/16/2025-19:43:53] [TRT] [V] --------------- Timing Runner: {ForeignNode[(Unnamed Layer* 0) [Convolution]]} (Myelin[0x80000023])
[12/16/2025-19:43:53] [TRT] [V] Disabling gemm+pointwise/reduce fusions at optimization level 0
[12/16/2025-19:43:53] [TRT] [V] Set number of tactics to 1 at optimization level 0
[12/16/2025-19:43:53] [TRT] [E] Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [tactic.cpp:conv_tactics:1191] No conv tactic found In compileGraph at optimizer/myelin/codeGenerator.cpp:1425
[12/16/2025-19:43:53] [TRT] [V] {ForeignNode[(Unnamed Layer* 0) [Convolution]]} (Myelin[0x80000023]) profiling completed in 0.283374 seconds. Fastest Tactic: 0xd15ea5edd15ea5ed Time: inf
[12/16/2025-19:43:53] [TRT] [E] IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[(Unnamed Layer* 0) [Convolution]]}. In computeCosts at optimizer/common/tactic/optimizer.cpp:4115)
Traceback (most recent call last):
  File "/root/playground/main_tensorrt.py", line 108, in <module>
    engine, plan = build_trt_engine_from_torch_conv(
  File "/root/playground/main_tensorrt.py", line 85, in build_trt_engine_from_torch_conv
    raise RuntimeError("Failed to build TensorRT engine (serialized network is None).")
RuntimeError: Failed to build TensorRT engine (serialized network is None).

If I use TensorRT-9.2.0.5, the log shows:

[12/16/2025-19:47:16] [TRT] [I] [MemUsageChange] Init CUDA: CPU +65, GPU +0, now: CPU 179, GPU 433 (MiB)
[12/16/2025-19:47:16] [TRT] [V] Trying to load shared library libnvinfer_builder_resource.so.9.2.0
[12/16/2025-19:47:16] [TRT] [V] Loaded shared library libnvinfer_builder_resource.so.9.2.0
[12/16/2025-19:47:18] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +1820, GPU +316, now: CPU 2135, GPU 749 (MiB)
[12/16/2025-19:47:18] [TRT] [V] CUDA lazy loading is enabled.
[12/16/2025-19:47:18] [TRT] [V] Original: 1 layers
[12/16/2025-19:47:18] [TRT] [V] After dead-layer removal: 1 layers
[12/16/2025-19:47:18] [TRT] [V] Graph construction completed in 0.000206954 seconds.
[12/16/2025-19:47:18] [TRT] [V] After Myelin optimization: 1 layers
[12/16/2025-19:47:18] [TRT] [V] Applying ScaleNodes fusions.
[12/16/2025-19:47:18] [TRT] [V] After scale fusion: 1 layers
[12/16/2025-19:47:18] [TRT] [V] After dupe layer removal: 1 layers
[12/16/2025-19:47:18] [TRT] [V] After final dead-layer removal: 1 layers
[12/16/2025-19:47:18] [TRT] [V] After tensor merging: 1 layers
[12/16/2025-19:47:18] [TRT] [V] After vertical fusions: 1 layers
[12/16/2025-19:47:18] [TRT] [V] After dupe layer removal: 1 layers
[12/16/2025-19:47:18] [TRT] [V] After final dead-layer removal: 1 layers
[12/16/2025-19:47:18] [TRT] [V] After tensor merging: 1 layers
[12/16/2025-19:47:18] [TRT] [V] After slice removal: 1 layers
[12/16/2025-19:47:18] [TRT] [V] After concat removal: 1 layers
[12/16/2025-19:47:18] [TRT] [V] Trying to split Reshape and strided tensor
[12/16/2025-19:47:18] [TRT] [V] Graph optimization time: 0.000214471 seconds.
[12/16/2025-19:47:18] [TRT] [V] Building graph using backend strategy 0
[12/16/2025-19:47:18] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[12/16/2025-19:47:18] [TRT] [V] Constructing optimization profile number 0 [1/1].
[12/16/2025-19:47:18] [TRT] [V] Applying generic optimizations to the graph for inference.
[12/16/2025-19:47:18] [TRT] [V] Reserving memory for host IO tensors. Host: 0 bytes
[12/16/2025-19:47:18] [TRT] [V] =============== Computing costs for (Unnamed Layer* 0) [Convolution]
[12/16/2025-19:47:18] [TRT] [V] *************** Autotuning format combination: Float(50176,196,14,1) -> Float(50176,196,14,1) ***************
[12/16/2025-19:47:18] [TRT] [V] --------------- Timing Runner: (Unnamed Layer* 0) [Convolution] (CaskConvolution[0x80000009])
[12/16/2025-19:47:18] [TRT] [V] Tactic Name: ampere_scudnn_winograd_128x128_ldg1_ldg4_relu_tile148t_nt_v1 Tactic: 0x94119b4c514b211a Time: 31.3204
[12/16/2025-19:47:18] [TRT] [V] Tactic: 0x94119b4c514b211a A valid tactic is found. Rest of the tactics are skipped.
[12/16/2025-19:47:18] [TRT] [V] >>>>>>>>>>>>>>> Chose Runner Type: CaskConvolution Tactic: 0x94119b4c514b211a
[12/16/2025-19:47:18] [TRT] [V] *************** Autotuning format combination: Half(50176,196,14,1) -> Half(50176,196,14,1) ***************
[12/16/2025-19:47:18] [TRT] [V] Skipping CaskConvolution: No valid tactics for (Unnamed Layer* 0) [Convolution]
[12/16/2025-19:47:18] [TRT] [V] Skipping CaskFlattenConvolution: No valid tactics for (Unnamed Layer* 0) [Convolution]
[12/16/2025-19:47:18] [TRT] [V] *************** Autotuning format combination: Half(25088,1:2,1792,128) -> Half(25088,1:2,1792,128) ***************
[12/16/2025-19:47:18] [TRT] [V] Skipping CaskConvolution: No valid tactics for (Unnamed Layer* 0) [Convolution]
[12/16/2025-19:47:18] [TRT] [V] Skipping CaskFlattenConvolution: No valid tactics for (Unnamed Layer* 0) [Convolution]
[12/16/2025-19:47:18] [TRT] [V] *************** Autotuning format combination: Half(12544,1:4,896,64) -> Half(12544,1:4,896,64) ***************
[12/16/2025-19:47:18] [TRT] [V] Skipping CaskConvolution: No valid tactics for (Unnamed Layer* 0) [Convolution]
[12/16/2025-19:47:18] [TRT] [V] Skipping CaskFlattenConvolution: No valid tactics for (Unnamed Layer* 0) [Convolution]
[12/16/2025-19:47:18] [TRT] [V] *************** Autotuning format combination: Half(6272,1:8,448,32) -> Float(50176,196,14,1) ***************
[12/16/2025-19:47:18] [TRT] [V] Skipping CaskConvolution: No valid tactics for (Unnamed Layer* 0) [Convolution]
[12/16/2025-19:47:18] [TRT] [V] Skipping CaskFlattenConvolution: No valid tactics for (Unnamed Layer* 0) [Convolution]
[12/16/2025-19:47:18] [TRT] [V] *************** Autotuning format combination: Half(6272,1:8,448,32) -> Half(6272,1:8,448,32) ***************
[12/16/2025-19:47:18] [TRT] [V] Skipping CaskConvolution: No valid tactics for (Unnamed Layer* 0) [Convolution]
[12/16/2025-19:47:18] [TRT] [V] Skipping CaskFlattenConvolution: No valid tactics for (Unnamed Layer* 0) [Convolution]
[12/16/2025-19:47:18] [TRT] [V] =============== Computing reformatting costs for available format set
[12/16/2025-19:47:18] [TRT] [V] =============== Computing reformatting costs: 
[12/16/2025-19:47:18] [TRT] [V] *************** Autotuning Reformat: Half(50176,196,14,1) -> Float(50176,196,14,1) ***************
[12/16/2025-19:47:18] [TRT] [V] --------------- Timing Runner: Optimizer Reformat(input -> <out>) (Reformat[0x80000006])
[12/16/2025-19:47:18] [TRT] [V] Tactic: 0x00000000000003e8 Time: 0.058432
[12/16/2025-19:47:18] [TRT] [V] Tactic: 0x00000000000003e8 A valid tactic is found. Rest of the tactics are skipped.
[12/16/2025-19:47:18] [TRT] [V] =============== Computing reformatting costs for available format set
[12/16/2025-19:47:18] [TRT] [V] Adding reformat layer: Reformatted Input Tensor 0 to (Unnamed Layer* 0) [Convolution] (input) from Half(50176,196,14,1) to Float(50176,196,14,1)
[12/16/2025-19:47:18] [TRT] [V] Formats and tactics selection completed in 0.101433 seconds.
[12/16/2025-19:47:18] [TRT] [V] After reformat layers: 2 layers
[12/16/2025-19:47:18] [TRT] [V] Total number of blocks in pre-optimized block assignment: 2
[12/16/2025-19:47:18] [TRT] [I] Detected 1 inputs and 1 output network tensors.
[12/16/2025-19:47:18] [TRT] [V] Layer: (Unnamed Layer* 0) [Convolution] Host Persistent: 2512 Device Persistent: 0 Scratch Memory: 0
[12/16/2025-19:47:18] [TRT] [V] Skipped printing memory information for 1 layers with 0 memory size i.e. Host Persistent + Device Persistent + Scratch Memory == 0.
[12/16/2025-19:47:18] [TRT] [I] Total Host Persistent Memory: 2512
[12/16/2025-19:47:18] [TRT] [I] Total Device Persistent Memory: 0
[12/16/2025-19:47:18] [TRT] [I] Total Scratch Memory: 0
[12/16/2025-19:47:18] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 1 steps to complete.
[12/16/2025-19:47:18] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 0.008742ms to assign 1 blocks to 1 nodes requiring 8590131200 bytes.
[12/16/2025-19:47:18] [TRT] [V] Total number of blocks in optimized block assignment: 1
[12/16/2025-19:47:18] [TRT] [I] Total Activation Memory: 8590131200
[12/16/2025-19:47:18] [TRT] [I] Total Weights Memory: 4196352
[12/16/2025-19:47:18] [TRT] [V] Finalize: (Unnamed Layer* 0) [Convolution] Set kernel index: 0
[12/16/2025-19:47:18] [TRT] [V] Total number of generated kernels selected for the engine: 1
[12/16/2025-19:47:18] [TRT] [V] Kernel: 0 CASK_STATIC
[12/16/2025-19:47:18] [TRT] [V] Disabling unused tactic source: JIT_CONVOLUTIONS
[12/16/2025-19:47:18] [TRT] [I] Engine generation completed in 0.214741 seconds.
[12/16/2025-19:47:18] [TRT] [V] Engine Layer Information:
Layer(Reformat): Reformatting CopyNode for Input Tensor 0 to (Unnamed Layer* 0) [Convolution], Tactic: 0x00000000000003e8, input (Half[-1,256,14,14]) -> Reformatted Input Tensor 0 to (Unnamed Layer* 0) [Convolution] (Float[-1,256,14,14])
Layer(CaskConvolution): (Unnamed Layer* 0) [Convolution], Tactic: 0x94119b4c514b211a, Reformatted Input Tensor 0 to (Unnamed Layer* 0) [Convolution] (Float[-1,256,14,14]) -> output (Float[-1,256,14,14])
[12/16/2025-19:47:18] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 4 MiB, GPU 16389 MiB
[12/16/2025-19:47:18] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +5, now: CPU 0, GPU 5 (MiB)
[12/16/2025-19:47:18] [TRT] [V] Adding 1 engine(s) to plan file.
[12/16/2025-19:47:18] [TRT] [I] [MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 3623 MiB
[12/16/2025-19:47:18] [TRT] [I] Loaded engine size: 4 MiB
[12/16/2025-19:47:18] [TRT] [V] Deserialization required 3324 microseconds.
[12/16/2025-19:47:18] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in engine deserialization: CPU +0, GPU +4, now: CPU 0, GPU 4 (MiB)
Built engine: Unnamed Network 0

I’m just curious why it works fine with TensorRT 9 but throws an error with TensorRT 10.