Crash after too large batch size, leaves gpu stuck 100 utilization

We are running the nvidia Kubernetes device plugin. We appear to have an issue where when we use a large enough batch size (64) we get a crash in our system. That we believe is because we ran out of memory.

RuntimeError: Unable to find a valid cuDNN algorithm to run convolution
THCudaCheck FAIL file=/pytorch/aten/src/THC/THCCachingHostAllocator.cpp line=278 error=700 : an illegal memory access was encountered
Exception in thread Thread-4:
Traceback (most recent call last):
  File "/opt/conda/envs/vtm/lib/python3.9/", line 954, in _bootstrap_inner
  File "/opt/conda/envs/vtm/lib/python3.9/", line 892, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/vtm/lib/python3.9/site-packages/torch/utils/data/_utils/", line 28, in _pin_memory_loop
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/opt/conda/envs/vtm/lib/python3.9/multiprocessing/", line 122, in get
    return _ForkingPickler.loads(res)
  File "/opt/conda/envs/vtm/lib/python3.9/site-packages/torch/multiprocessing/", line 289, in rebuild_storage_fd
    fd = df.detach()
  File "/opt/conda/envs/vtm/lib/python3.9/multiprocessing/", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/opt/conda/envs/vtm/lib/python3.9/multiprocessing/", line 86, in get_connection
    c = Client(address, authkey=process.current_process().authkey)
  File "/opt/conda/envs/vtm/lib/python3.9/multiprocessing/", line 507, in Client
    c = SocketClient(address)
  File "/opt/conda/envs/vtm/lib/python3.9/multiprocessing/", line 635, in SocketClient

After this crash if we look at the gpu it appears as though it is forever stuck at 100 utilization. I have tried to restart the gpu with nvidia-smi but it complains that one of the gpu is being used, and therefore cant reset it.

Any insight on recovering from this with out rebooting the machine?