Questions to install jax==0.3.25 on Orin

I’m using jetson agx orin, a linux os on nvme. I had jetpack 5.1.5 installed.
I’m trying to install jax to Orin, JAX official site says that nvidia gpu, with linux aarch is good.

JAX provides pre-built CUDA-compatible wheels for Linux x86_64 and Linux aarch64 only

So is jetson orin nvidia gpu? or I should follow command and build from source?

First build and install jaxlib

wget https://github.com/jax-ml/jax/archive/refs/tags/jax-v0.3.25.tar.gz
tar xfz jax-v0.3.25.tar.gz
cd jax-jax-v0.3.25

python3 build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt

#Once jaxlib has been installed, install jax by running:

pip install -e .

THX.

So the thing is that,
remember to pip install wheel

python build/build.py --python_bin_path /home/usr_name/venv/dellaPy38/bin/python --enable_cuda --noenable_mkl_dnn --noenable_rocm --noenable_tpu --noenable_remote_tpu --noenable_nccl --cuda_path=/usr/local/cuda/ --cudnn_path=/usr/lib/aarch64-linux-gnu --cuda_version=11.4 --cudnn_version=8 --cuda_compute_capabilities=8.7 --output_path=./jaxlib_dist build_log.txt 2>&1

it shoudl be good

1 Like