Pytorch and Python 3.8 on Jetson NX

I built a new Docker container based on nvcr.io/nvidia/l4t-ml:r32.5.0-py3 and copied there the yolov5 files where I tried to upgrade the python version to 3.8.
It works fine, but it’s only using the CPU which heavily slows down everything. Unfortunately, I couldn’t figure out how I can make pytorch using the GPU on the Jetson NX.

FROM nvcr.io/nvidia/l4t-ml:r32.5.0-py3

RUN apt-get update && apt-get install -y python3.8 python3.8-dev python3.8-venv curl python3-tk

RUN python3.8 -m pip install -U pip setuptools

RUN python3.8 -m pip install --ignore-installed PyYAML

COPY requirements.txt requirements.txt

RUN python3.8 -m pip install -r requirements.txt

The requirements.txt is slightly modified from yolo.

# base ----------------------------------------
matplotlib>=3.2.2
numpy>=1.18.5
opencv-python>=4.1.2
Pillow
#PyYAML>=5.3.1
scipy>=1.4.1
torch>=1.7.0
torchvision>=0.8.1
tqdm>=4.41.0

# logging -------------------------------------
tensorboard>=2.4.1
# wandb

# plotting ------------------------------------
seaborn>=0.11.0
pandas

Does anyone of you have hints for me to accomplish that?

Hi @ppn, if you are installing PyTorch from pip, it won’t be built with CUDA support (it will be CPU only). We have pre-built PyTorch wheels for Python 3.6 (with GPU support) in this thread, but for Python 3.8 you will need to build PyTorch from source.

See this thread for further info about building PyTorch for Python 3.8:

Yeah, I found all that stuff for 3.6 and I was somehow hoping I don’t need to compile it myself. At least it’s just PyTorch to compile. Thanks for the answer.

Hm, I compiled it, but it’s still the CPU. What could I have done wrong at the container setup? I am running it with docker run -it --runtime nvidia image /bin/bash. Do I need to install drivers?

Did you compile the container or the PyTorch wheel? You will need to build the PyTorch wheel itself with Python 3.8. Check the config soon after it starts building to make sure it found CUDA.

I am bit confused I guess. Compile the container? I used a container to compile pytorch in it since I did not want to install python 3.8 on the host machine to keep it clean.

FROM nvcr.io/nvidia/l4t-ml:r32.5.0-py3

RUN apt-get update && apt-get install -y python3.8 python3.8-dev python3.8-venv curl python3-tk

RUN python3.8 -m pip install -U pip setuptools

and in that container then what was written in your linked thread

git clone --recursive --branch v1.9.0 http://github.com/pytorch/pytorch
cd pytorch
python3.8 -m pip install -r requirements.txt
python3.8 setup.py install

How do I check the CUDA thing?
I tested /usr/local/cuda/bin/nvcc --version on the host and the container and it both returns the same:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_21:14:42_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89

I feel, I somewhere missed a switch/trigger or sth like that to enable CUDA.

I see, ok gotcha - that makes sense.

Have you set your Docker daemon’s default-runtime to nvidia? You may need this in order for PyTorch to detect CUDA inside your container during docker build operations. See here in order to set it:
https://github.com/dusty-nv/jetson-containers#docker-default-runtime (remember to reboot or restart your Docker service after making this change in order for it to take effect)

Also, are you setting the environment variables from the thread I linked to?

$ export USE_NCCL=0
$ export USE_DISTRIBUTED=0                # skip setting this if you want to enable OpenMPI backend
$ export USE_QNNPACK=0
$ export USE_PYTORCH_QNNPACK=0
$ export TORCH_CUDA_ARCH_LIST="5.3;6.2;7.2"

A minute or two after you start the build, you should see a build summary - make sure it prints out USE_CUDA = ON and your CUDA/cuDNN paths:

-- ******** Summary ********
-- General:
--   CMake version         : 3.10.2
--   CMake command         : /usr/bin/cmake
--   System                : Linux
--   C++ compiler          : /usr/bin/c++
--   C++ compiler id       : GNU
--   C++ compiler version  : 7.5.0
--   Using ccache if found : ON
--   Found ccache          : CCACHE_PROGRAM-NOTFOUND
--   CXX flags             :  -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -DMISSING_ARM_VST1 -DMISSING_ARM_VLD1 -Wno-stringop-overflow
--   Build type            : Release
--   Compile definitions   : ONNX_ML=1;ONNXIFI_ENABLE_EXT=1;ONNX_NAMESPACE=onnx_torch;HAVE_MMAP=1;_FILE_OFFSET_BITS=64;HAVE_SHM_OPEN=1;HAVE_SHM_UNLINK=1;HAVE_MALLOC_USABLE_SIZE=1;USE_EXTERNAL_MZCRC;MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS
--   CMAKE_PREFIX_PATH     : /usr/lib/python3.6/site-packages;/usr/local/cuda-10.2
--   CMAKE_INSTALL_PREFIX  : /media/nvidia/NVME/pytorch/pytorch-v1.9.0/torch
--   USE_GOLD_LINKER       : OFF
-- 
--   TORCH_VERSION         : 1.9.0
--   CAFFE2_VERSION        : 1.9.0
--   BUILD_CAFFE2          : ON
--   BUILD_CAFFE2_OPS      : ON
--   BUILD_CAFFE2_MOBILE   : OFF
--   BUILD_STATIC_RUNTIME_BENCHMARK: OFF
--   BUILD_TENSOREXPR_BENCHMARK: OFF
--   BUILD_BINARY          : OFF
--   BUILD_CUSTOM_PROTOBUF : ON
--     Link local protobuf : ON
--   BUILD_DOCS            : OFF
--   BUILD_PYTHON          : True
--     Python version      : 3.6.9
--     Python executable   : /usr/bin/python3
--     Pythonlibs version  : 3.6.9
--     Python library      : /usr/lib/libpython3.6m.so.1.0
--     Python includes     : /usr/include/python3.6m
--     Python site-packages: lib/python3.6/site-packages
--   BUILD_SHARED_LIBS     : ON
--   CAFFE2_USE_MSVC_STATIC_RUNTIME     : OFF
--   BUILD_TEST            : True
--   BUILD_JNI             : OFF
--   BUILD_MOBILE_AUTOGRAD : OFF
--   BUILD_LITE_INTERPRETER: OFF
--   INTERN_BUILD_MOBILE   : 
--   USE_BLAS              : 1
--     BLAS                : open
--   USE_LAPACK            : 1
--     LAPACK              : open
--   USE_ASAN              : OFF
--   USE_CPP_CODE_COVERAGE : OFF
--   USE_CUDA              : ON
--     Split CUDA          : OFF
--     CUDA static link    : OFF
--     USE_CUDNN           : ON
--     CUDA version        : 10.2
--     cuDNN version       : 8.0.0
--     CUDA root directory : /usr/local/cuda-10.2
--     CUDA library        : /usr/local/cuda-10.2/lib64/stubs/libcuda.so
--     cudart library      : /usr/local/cuda-10.2/lib64/libcudart.so
--     cublas library      : /usr/lib/aarch64-linux-gnu/libcublas.so
--     cufft library       : /usr/local/cuda-10.2/lib64/libcufft.so
--     curand library      : /usr/local/cuda-10.2/lib64/libcurand.so
--     cuDNN library       : /usr/lib/aarch64-linux-gnu/libcudnn.so
--     nvrtc               : /usr/local/cuda-10.2/lib64/libnvrtc.so
--     CUDA include path   : /usr/local/cuda-10.2/include
--     NVCC executable     : /usr/local/cuda-10.2/bin/nvcc
--     NVCC flags          : -Xfatbin;-compress-all;-DONNX_NAMESPACE=onnx_torch;-gencode;arch=compute_53,code=sm_53;-gencode;arch=compute_62,code=sm_62;-gencode;arch=compute_72,code=sm_72;-Xcudafe;--diag_suppress=cc_clobber_ignored,--diag_suppress=integer_sign_change,--diag_suppress=useless_using_declaration,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=implicit_return_from_non_void_function,--diag_suppress=unsigned_compare_with_zero,--diag_suppress=declared_but_not_referenced,--diag_suppress=bad_friend_decl;-std=c++14;-Xcompiler;-fPIC;--expt-relaxed-constexpr;--expt-extended-lambda;-Wno-deprecated-gpu-targets;--expt-extended-lambda;-Xcompiler;-fPIC;-DCUDA_HAS_FP16=1;-D__CUDA_NO_HALF_OPERATORS__;-D__CUDA_NO_HALF_CONVERSIONS__;-D__CUDA_NO_BFLOAT16_CONVERSIONS__;-D__CUDA_NO_HALF2_OPERATORS__
--     CUDA host compiler  : /usr/bin/cc
--     NVCC --device-c     : OFF
--     USE_TENSORRT        : OFF
--   USE_ROCM              : OFF
--   USE_EIGEN_FOR_BLAS    : ON
--   USE_FBGEMM            : OFF
--     USE_FAKELOWP          : OFF
--   USE_KINETO            : ON
--   USE_FFMPEG            : OFF
--   USE_GFLAGS            : OFF
--   USE_GLOG              : OFF
--   USE_LEVELDB           : OFF
--   USE_LITE_PROTO        : OFF
--   USE_LMDB              : OFF
--   USE_METAL             : OFF
--   USE_PYTORCH_METAL     : OFF
--   USE_FFTW              : OFF
--   USE_MKL               : OFF
--   USE_MKLDNN            : OFF
--   USE_NCCL              : 0
--   USE_NNPACK            : ON
--   USE_NUMPY             : ON
--   USE_OBSERVERS         : ON
--   USE_OPENCL            : OFF
--   USE_OPENCV            : OFF
--   USE_OPENMP            : ON
--   USE_TBB               : OFF
--   USE_VULKAN            : OFF
--   USE_PROF              : OFF
--   USE_QNNPACK           : 0
--   USE_PYTORCH_QNNPACK   : 0
--   USE_REDIS             : OFF
--   USE_ROCKSDB           : OFF
--   USE_ZMQ               : OFF
--   USE_DISTRIBUTED       : ON
--     USE_MPI             : ON
--     USE_GLOO            : ON
--     USE_TENSORPIPE      : ON
--   USE_DEPLOY           : OFF
--   Public Dependencies  : Threads::Threads
--   Private Dependencies : 

You can terminate the build if you don’t see this, because if CUDA isn’t detected at the beginning PyTorch won’t be built with it.

Unfortunately, it’s disabled even after setting your the ENVs. That’s what I gathered from the logs:

--   TORCH_VERSION         : 1.9.0
--   CAFFE2_VERSION        : 1.9.0
--   BUILD_CAFFE2          : ON
--   BUILD_CAFFE2_OPS      : ON
--   BUILD_CAFFE2_MOBILE   : OFF
--   BUILD_STATIC_RUNTIME_BENCHMARK: OFF
--   BUILD_TENSOREXPR_BENCHMARK: OFF
--   BUILD_BINARY          : OFF
--   BUILD_CUSTOM_PROTOBUF : ON
--     Link local protobuf : ON
--   BUILD_DOCS            : OFF
--   BUILD_PYTHON          : True
--     Python version      : 3.8
--     Python executable   : /usr/bin/python3.8
--     Pythonlibs version  : 3.8.0
--     Python library      : /usr/lib/libpython3.8.so.1.0
--     Python includes     : /usr/include/python3.8
--     Python site-packages: lib/python3.8/site-packages
--   BUILD_SHARED_LIBS     : ON
--   CAFFE2_USE_MSVC_STATIC_RUNTIME     : OFF
--   BUILD_TEST            : True
--   BUILD_JNI             : OFF
--   BUILD_MOBILE_AUTOGRAD : OFF
--   BUILD_LITE_INTERPRETER: OFF
--   INTERN_BUILD_MOBILE   :
--   USE_BLAS              : 1
--     BLAS                : open
--   USE_LAPACK            : 1
--     LAPACK              : open
--   USE_ASAN              : OFF
--   USE_CPP_CODE_COVERAGE : OFF
--   USE_CUDA              : OFF
-- Could NOT find CUDA (missing: CUDA_CUDART_LIBRARY) (found version "10.2")
CMake Warning at cmake/public/cuda.cmake:31 (message):
  Caffe2: CUDA cannot be found.  Depending on whether you are building Caffe2
  or a Caffe2 dependent library, the next warning / error will give you more
  info.

CMake Warning at cmake/Dependencies.cmake:1178 (message):
  Not compiling with CUDA.  Suppress this warning with -DUSE_CUDA=OFF.
Call Stack (most recent call first):
  CMakeLists.txt:621 (include)
-- Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor

disabling CUDA because NOT USE_CUDA is set
-- USE_CUDNN is set to 0. Compiling without cuDNN support

Hmm…have you set your Docker default-runtime to nvidia and rebooted?
https://github.com/dusty-nv/jetson-containers#docker-default-runtime

It seems it is having trouble finding CUDA in container. Can you also set these in your Dockerfile near the top?

ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"

If the issue persists, I would recommend to try building the PyTorch wheel for Python 3.8 outside of the container, to narrow down if it’s a problem with the PyTorch build script discovering CUDA inside the container.

1 Like

seems like we made it:

--   USE_CUDA              : ON

I will let you know if it works with the compiled version

1 Like
2021-7-4 torch 1.9.0a0+gitd69c22d CUDA:0 (Xavier, 7773.5546875MB)

@dusty_nv thanks so much. It’s working now

OK great, glad to hear that you got it built with CUDA on!

This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.