First pytorch cuda pass is slower, running in a nvidia/cuda container on an AWS g4 instance with V100

Hi,

after launching a new aws g4dn.xlarge instance with 1xV100 and then launching a container with a nvidia/cuda base image using the nvidia-docker runtime, the first pass of the following sample code is really slow, we can see the memory being loaded very slowly, the pass taking about 1m30s when the next passes take about 7s. If I launch new containers on the same instance, their first pass takes about 7s as well, they do not see this delay.

#!/usr/bin/env python3
import torch
import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
    def forward(self, x):
        x = self.conv1(x)
        output = self.conv2(x)
        return output
if __name__ == "__main__":
    model = Net().cuda()

My instinct is that this is caused by CUDA JIT compilation, the following containers that are spawned on the same machine do not see this delay, which suggests the compiled kernels are shared across containers.

I have tried every permutation of pytorch 1.3.1 through 1.5.1 using cuda 9.2 through 10.1 and I am still seeing the issue and I am unable to find a pip wheel for pytorch that avoids this first pass delay.

Has anyone seem something similar?