Jax container create and deploy flask in kubernetes to use gpu

  • Our model required GPU to use Nvidia and integrated with JAX and working fine in VM
  • Now we created container and deployed in K8S
  • Installing jax libery in python base image is more dependency issue getting
  • looking on this docker images https://github.com/NVIDIA/JAX-Toolbox/tree/main for reference
  • Looking some support for how to make container image with jax and nvidia and deploy in k8s for flask to server request

Hi, I think you’re on the right track using Docker image (e.g. ghcr.io/nvidia/jax:jax) from JAX Toolbox.

If your use case is inference, I would also recommend Triton Inference Server. You can use the Python backend with JAX to run inference.

  • Can you direct me dockerfile of Jax image creation repository to understand version
  • Can you provide me sample to create image and deploy k8s in nvidia instance and test

Here are some additional resources:

JAX Dockerfile

JAX Toolbox model server

JAX on GKE tutorial

Thank you for the reference let test in my cluster incase any issue in the doc let me ask you .

've updated my base image and included the necessary steps in my Dockerfile for my application. Here is my current Dockerfile:


FROM ghcr.io/nvidia/jax:jax

RUN apt-get update && apt-get install -y net-tools && apt-get clean && rm -rf /var/lib/apt/lists/*

RUN pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html && \
    pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
    pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

RUN pip install --upgrade transformers flax

# For GPU support with PyTorch
RUN pip3 install torch torchvision torchaudio

COPY requirements.txt .
RUN pip install -r requirements.txt

COPY . /app

CMD ["python", "app.py"]

I need to ensure that my Docker container can access the GPU. Here are my questions:

  1. Running the Image with GPU Access: How should I run my Docker image to ensure it can access the GPU? Do I need to use specific Docker runtime options?
  2. Kubernetes Setup: Do I need to install any NVIDIA drivers or perform additional setup on my Kubernetes worker nodes to enable GPU access for my container?

Any guidance or suggestions would be greatly appreciated. Thank you

OMG, does this represent the mountain I have to clime? I have been in IT my entire career and I do not understand any of this alien gorgon tech talk. What’s in the flask? What gubbernets? JAX? But I do know what Python is.