I’m using jetson agx orin, a linux os on nvme. I had jetpack 5.1.5 installed.
I’m trying to install jax to Orin, JAX official site says that nvidia gpu, with linux aarch is good.
JAX provides pre-built CUDA-compatible wheels for Linux x86_64 and Linux aarch64 only
So is jetson orin nvidia gpu? or I should follow command and build from source?
First build and install jaxlib
wget https://github.com/jax-ml/jax/archive/refs/tags/jax-v0.3.25.tar.gz
tar xfz jax-v0.3.25.tar.gz
cd jax-jax-v0.3.25
python3 build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt
#Once jaxlib has been installed, install jax by running:
pip install -e .
THX.
So the thing is that,
remember to pip install wheel
python build/build.py --python_bin_path /home/usr_name/venv/dellaPy38/bin/python --enable_cuda --noenable_mkl_dnn --noenable_rocm --noenable_tpu --noenable_remote_tpu --noenable_nccl --cuda_path=/usr/local/cuda/ --cudnn_path=/usr/lib/aarch64-linux-gnu --cuda_version=11.4 --cudnn_version=8 --cuda_compute_capabilities=8.7 --output_path=./jaxlib_dist build_log.txt 2>&1
it shoudl be good
1 Like