Jetson Nano running out of memory when running pytorch, even with 6G swapfile

I’m attempting to train a model using pytorch transformers with the bert-base-uncased model. Running on the out of the box Jetson nano resulted in the process being killed due to lack of memory.

I thus set up a 6G swap file and attempted to train again. Memory peaked over 99%, hovering between 98.5 and 99.5, while the Swap picked up usage maxing to approximately 30%. Still, the process was killed due to an out of memory error. Any ideas on how to resolve this?

Hi aaronbriel, CUDA device memory isn’t swapped out, so my guess is that the process is using all available physical memory for GPU before it requests more, runs out, and gets killed.

Is it possible to reduce the batch size, model size, or similar to reduce the memory usage? If not, BERT may be too big of model to train onboard Nano.

Is it possible to distribute the memory load through clustering (for example, with https://github.com/NVIDIA/k8s-device-plugin)?

If your PyTorch script is setup for distributed training (see PyTorch imagenet example), it may be possible, however I’m not sure if this would reduce the memory usage of a single instance or not. It may or may not already be at it’s minimum for that particular model.

Note that the PyTorch wheels I built from this thread were not built with NCCL enabled, so you would need to use a different distributed backend.

Although I don’t think it would automatically split up the processing load (i.e. you would still need to use PyTorch distributed training) there have been some resources posted for running K8S on Nano, see here: https://elinux.org/Jetson_Zoo#Kubernetes