INSTRUCTIONS FOR INSTALLING GOOGLE JAX ML ON 4GB JETSON NANO WITH SSD
SORTING OUT SWAP SPACE
Before you do any compiling of Jax, It’s really important to have plenty of swap space, because the Jetson’s 4GB is not nearly enough. Even 5GB of swap space is not enough, because gcc got killed for out of memory. I’ve gone with 10GB. For this reason you will almost certainly have to be running your Jetson off a USB3.0 SSD rather than the SD-CARD (If you want to do this, JetsonHacks has a good tutorial and helper utilities: Jetson Nano - Boot from USB - JetsonHacks HOWEVER, do NOT take the final bit of advice to remove the SD-CARD. With the latest jetpack this causes SSD boot to fail. You have to keep the SD card in at the same time as the SSD is plugged into USB).
Unfortunately, the Jetson is natively configured to use 2 GB of ZRAM swap, which is almost useless (other than a mild amount of compression). You need to get rid of ZRAM:
sudo systemctl disable nvzramconfig
Probably best to reboot the Nano after this step, just in case (I did). Then use htop to check that you have no swap. Finally, create a 10GB swap file using these instructions:
linuxize.com – 28 Nov 18
Swap is a space on a disk that is used when the amount of physical RAM memory is full. This tutorial covers the steps necessary to add a swap file to Ubuntu 18.04.
Make sure you replace
10G in the fallocate stage.
Now that you have a 10GB swap file (check it with htop), there’s one more thing I recommend: move the Jetson Nano over to its low power 5 watt mode. That’s because the MAXN 10watt mode will cause compilation on all 4 cores at 100%, for many hours, and especially if you have a fan and and SSD, the power draw might be bigger than your power supply can handle. Certainly was the case for me. So use the Nvidia menu in the taskbar to turn it over to 5 watt mode.
Now ensure your CUDA PATH and LD_LIBRARY_PATH are set for CUDA support. Add these to your
And re-source it with
source ~/.bashrc .
Python 3.69 and the related stack on the Jetson will result in a build, but it won’t work. You have to run a later version of Python, and I succeeded using Python 3.9. However note that you don’t want to replace the Jetson’s system Python 3.69 as the default python3 as this will break a lot of stuff. Instead, we’ll use a virtualenv. But before we do that, let’s install Python 3.9 but without making it the default Python:
sudo add-apt-repository ppa:deadsnakes/ppa
sudo apt update
sudo apt install python3.9 python3.9-dev
This will install
python3.9 in your /usr/bin directory. However you’ll notice that when you type
python3 you still get the system python 3.69 which is what we want. To use python 3.9 you have to explicitly run
python3.9. However we’ll need a bunch of other things for python 3.9 and for that we’ll use a virtual env. To install virtual env we must first install pip3:
sudo apt install python3-pip
sudo pip3 install virtualenv
Remember to sudo here because we need it for all users.
Next we create a python3.9 virtualenv, here I’ll create a “py39”:
virtualenv -p /usr/bin/python3.9 py39
This is now virgin territory with no numpy scipy or anything (even though they’re installed on the system python). So:
python3 -m pip install numpy scipy six wheel
Now you’re ready to install
jaxlib from source. Apt install necessary prerequisites:
sudo apt install g++
Clone the jax repository and cd into it:
git clone https://github.com/google/jax
You’re finally ready to compile:
python3 build/build.py --enable_cuda
This step will take 12 hours. It’s a humongous codebase. Your Jetson’s swap usage will go above 5GB at times.
pip3 install dist/*.whl
… but all this was just to get jaxlib up and running, not jax itself. There’s one last step:
pip3 install -e