GPU memory leak when using tensorrt with onnx model

Description

GPU memory keeps increasing when running tensorrt inference in a for loop

Environment

TensorRT Version: 7.0.0.11
GPU Type: 1080Ti
Nvidia Driver Version: 440.33.01
CUDA Version: 10.0
CUDNN Version: 7.6.3
Operating System + Version: Debian9
Python Version (if applicable): 3.7.4
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.3.1
Baremetal or Container (if container which image + tag): N/A

Relevant Files

model file: my_leak.tar.gz

Steps To Reproduce

1.transform pytorch model to onnx

dummy_input = torch.randn(1, 3, 384, 384, device='cuda')
input_names = [ "input" ]
output_names = [ "output" ]
torch.onnx.export(net, dummy_input, "my_leak.onnx", verbose=True, input_names=input_names, output_names=output_names)

2.If I use this onnx file directly in tensorrt, it comes error that indicate the tensorrt just support constant padding,so I do the following change

import onnx
net = onnx.load('my_leak.onnx')
for i in range(len(net.graph.node)):
    if net.graph.node[i].op_type == 'Pad':
        node = net.graph.node[i]
        node.attribute[0].__setattr__("s", b"constant")
onnx.save(net, 'my_leak.onnx')

3.using the onnx generated above, and runnig the code below ,then you will see that gpu memory is keep increaing each time ‘execute_async_v2’ is called, and will nerver decrease untill OOM

[TRT] ../rtSafe/cuda/caskConvolutionRunner.cpp (370) - Cuda Error in execute: 2 (out of memory)
[TRT] FAILED_EXECUTION: std::exception

code

import tensorrt as trt
import numpy as np
import pycuda.autoinit
import pycuda.driver as cuda 
import time
import cv2
from pycuda.tools import DeviceMemoryPool as DMP
from pycuda.tools import PageLockedMemoryPool as PMP

class TrtInfer():
    def __init__(self):
        self.device_pool=DMP()
        self.host_pool = PMP()
        model_path = "my_leak.onnx"
        self.input_size = 384
        self.engine = None

        TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
        explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
        with trt.Builder(TRT_LOGGER) as builder, \
            builder.create_network(explicit_batch) as network, \
            trt.OnnxParser(network, TRT_LOGGER) as parser:
            builder.max_workspace_size = 1<<25
            #config.max_workspace_size = 1<<25
            builder.max_batch_size  =1

            with open(model_path, 'rb') as model:
                if not parser.parse(model.read()):
                    for error in range(parser.num_errors):
                        print(parser.get_error(error))

            last_layer = network.get_layer(network.num_layers - 1)
            if not last_layer.get_output(0):
                network.mark_output(last_layer.get_output(0))
            self.engine = builder.build_cuda_engine(network)
        print('build engine done')

        self.context = self.engine.create_execution_context()
        print('context done')

        # host cpu mem
        h_in_size = trt.volume(self.engine.get_binding_shape(0))
        h_out_size = trt.volume(self.engine.get_binding_shape(1))
        h_in_dtype = trt.nptype(self.engine.get_binding_dtype(0))
        h_out_dtype = trt.nptype(self.engine.get_binding_dtype(1))

        #allocate host mem
        in_cpu = cuda.pagelocked_empty(h_in_size, h_in_dtype)
        out_cpu = cuda.pagelocked_empty(h_out_size, h_out_dtype)
        #in_cpu = host_pool.allocate([h_in_size], h_in_dtype)
        #out_cpu = host_pool.allocate([h_out_size], h_out_dtype)
        
        # allocate gpu mem
        in_gpu = cuda.mem_alloc(in_cpu.nbytes)
        out_gpu = cuda.mem_alloc(out_cpu.nbytes)
        #in_gpu = device_pool.allocate(in_cpu.nbytes)
        #out_gpu = device_pool.allocate(out_cpu.nbytes)
        stream = cuda.Stream()
        print('alloc done')

        self.in_cpu = in_cpu
        self.out_cpu = out_cpu
        self.in_gpu = in_gpu
        self.out_gpu = out_gpu
        self.stream = stream


    def inference(self, inputs):
        # async version
        inputs = inputs.reshape(-1)
        cuda.memcpy_htod_async(self.in_gpu, inputs, self.stream)
        #context.execute_async(1, [int(in_gpu), int(out_gpu)], stream.handle, None)
        self.context.execute_async_v2(bindings=[int(self.in_gpu), int(self.out_gpu)], stream_handle=self.stream.handle)
        cuda.memcpy_dtoh_async(self.out_cpu, self.out_gpu, self.stream)
        self.stream.synchronize()

        '''
        # sync version
        cuda.memcpy_htod(in_gpu, inputs)
        context.execute(1, [int(in_gpu), int(out_gpu)])
        cuda.memcpy_dtoh(out_cpu, out_gpu)
        '''
        return self.out_cpu


if __name__ == "__main__":
    im = cv2.imread('test.png')
    clipped_o = cv2.resize(im, (384, 384))
    clipped = cv2.cvtColor(clipped_o, cv2.COLOR_BGR2RGB)
    clipped = clipped.astype(np.float32)# / 127.5 - 1.0
    clipped = np.stack([clipped])
    clipped = np.transpose(clipped, (0, 3, 1, 2)) / 127.5 - 1.0
    inputs = clipped

    trt_infer = TrtInfer()

    for i in range(1000):
        t1 = time.time()
        res = trt_infer.inference(inputs)
        print("cost time: ", time.time()-t1)

This problem has troubled me for many days and I still have no idea about it. I also have tried to implement this process in C++,but the problem still exist.
I attached model file link above.

1 Like

When I run the origin pytorch model, it works fine. I am not sure whether the memory leak is caused by the model transform, and i want to know what kind of change to the model will cause memory leak.
@SunilJB I see you have applied in similar topics, would you please help figuring out what the problem is ?

Hi @leikang,
Can you please try TRT 7.2.1 release?
There are additional fixes which may resolve the concern.
Thanks!

i am also having this type of issue.
i migrate the sample code into my ros node which process the image data received from other node.
i tried to do the free manually through the pycuda api , but the gpu memory kept growing until got OOM?
any update ? I used TRT 7.2.1 .

fixed the issue by saving the execution_context , and reuse it . thanks
what an effort!

Hi, Request you to share the ONNX model and the script so that we can assist you better.

Alongside you can try validating your model with the below snippet

check_model.py

import sys
import onnx
filename = yourONNXmodel
model = onnx.load(filename)
onnx.checker.check_model(model).

Alternatively, you can try running your model with trtexec command.
https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/trtexec

Thanks!