JAX on Jetson Nano

Is there any plan to support Google’s JAX library, which is rapidly gaining traction at the expense of TensorFlow, on the Jetson Nano or other boards?

Building from source on standard Jetpack R32 Rev 5.1 is problematic in two ways: a) Numpy is too old, b) Python 3.69 is not supported. Need 3.7+.

Tried replacing the standard Python with 3.9 but this screwed everything up with pip3.

Hi,

Sorry that we don’t have a prebuilt for it.
Maybe others can share their experience.

Thanks.

Okay that’s fine, but is there a recommended “safe” way to upgrade the Python 3.69 version to something newer? Would also have to have new pip3 and new libraries such as numpy.

Good to know it can work on Jetson.
Thanks for the sharing!

Yeah unfortunately it actually did not work with 3.69. I installed 3.9 and it worked. Going to collate some virtualenv steps so that one can use 3.9 without breaking the rest of jetpack and will update the above post.

I can confirm though, that a 4000x4000 matrix dot product is 410x faster using jax.numpy on the Nano, than it is using jax.numpy on the Raspberry Pi (which doesn’t have a supported GPU). Moreover, it’s also 800x faster than Numpy so we’re talking very serious math speedup using this.

INSTRUCTIONS FOR INSTALLING GOOGLE JAX ML ON 4GB JETSON NANO WITH SSD

SORTING OUT SWAP SPACE

Before you do any compiling of Jax, It’s really important to have plenty of swap space, because the Jetson’s 4GB is not nearly enough. Even 5GB of swap space is not enough, because gcc got killed for out of memory. I’ve gone with 10GB. For this reason you will almost certainly have to be running your Jetson off a USB3.0 SSD rather than the SD-CARD (If you want to do this, JetsonHacks has a good tutorial and helper utilities: Jetson Nano - Boot from USB - JetsonHacks HOWEVER, do NOT take the final bit of advice to remove the SD-CARD. With the latest jetpack this causes SSD boot to fail. You have to keep the SD card in at the same time as the SSD is plugged into USB).

Unfortunately, the Jetson is natively configured to use 2 GB of ZRAM swap, which is almost useless (other than a mild amount of compression). You need to get rid of ZRAM:

sudo systemctl disable nvzramconfig

Probably best to reboot the Nano after this step, just in case (I did). Then use htop to check that you have no swap. Finally, create a 10GB swap file using these instructions:

linuxize.com – 28 Nov 18

How to Add Swap Space on Ubuntu 18.04

Swap is a space on a disk that is used when the amount of physical RAM memory is full. This tutorial covers the steps necessary to add a swap file to Ubuntu 18.04.

Make sure you replace 1G with 10G in the fallocate stage.

Now that you have a 10GB swap file (check it with htop), there’s one more thing I recommend: move the Jetson Nano over to its low power 5 watt mode. That’s because the MAXN 10watt mode will cause compilation on all 4 cores at 100%, for many hours, and especially if you have a fan and and SSD, the power draw might be bigger than your power supply can handle. Certainly was the case for me. So use the Nvidia menu in the taskbar to turn it over to 5 watt mode.

CUDA

Now ensure your CUDA PATH and LD_LIBRARY_PATH are set for CUDA support. Add these to your .bashrc :

export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

And re-source it with source ~/.bashrc .

FIXING PYTHON

Python 3.69 and the related stack on the Jetson will result in a build, but it won’t work. You have to run a later version of Python, and I succeeded using Python 3.9. However note that you don’t want to replace the Jetson’s system Python 3.69 as the default python3 as this will break a lot of stuff. Instead, we’ll use a virtualenv. But before we do that, let’s install Python 3.9 but without making it the default Python:

sudo add-apt-repository ppa:deadsnakes/ppa 
sudo apt update
sudo apt install python3.9 python3.9-dev

This will install python3.9 in your /usr/bin directory. However you’ll notice that when you type python3 you still get the system python 3.69 which is what we want. To use python 3.9 you have to explicitly run python3.9. However we’ll need a bunch of other things for python 3.9 and for that we’ll use a virtual env. To install virtual env we must first install pip3:

sudo apt install python3-pip

Now:

sudo pip3 install virtualenv

Remember to sudo here because we need it for all users.

Next we create a python3.9 virtualenv, here I’ll create a “py39”:

virtualenv -p /usr/bin/python3.9 py39

Activate it:

source ./py39/bin/activate

This is now virgin territory with no numpy scipy or anything (even though they’re installed on the system python). So:

python3 -m pip install numpy scipy six wheel

BUILD

Now you’re ready to install jaxlib from source. Apt install necessary prerequisites:

sudo apt install g++

Clone the jax repository and cd into it:

git clone https://github.com/google/jax
cd jax

You’re finally ready to compile:

python3 build/build.py --enable_cuda

This step will take 12 hours. It’s a humongous codebase. Your Jetson’s swap usage will go above 5GB at times.

Once compiled:

pip3 install dist/*.whl

… but all this was just to get jaxlib up and running, not jax itself. There’s one last step:

pip3 install -e

Done.

I have posted updated instructions above which use a virtual env

ERRATA on my instructions: Please note that the second last line should read

pip3 install -e .

I left out the “.” which makes all the difference. Would edit but these forums seem to have a time limit on editing and erasing/reposting gives me warnings.

This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.