Problem with SDXL model with MaxDiffusion on GCP

Hi, I’m trying to run MaxDiffusion (SDXL) on GCP and it OOMs. Tried with ICI DP=8 or FSDP=8, per_device_batch_size=1, it OOMs on a single H100 node (8xH100). SDXL is a relative small model with around 4B parameters and it should not OOM on a single node.

Here is the repo:

Here is the config & command:
export CHECKPOINTS=gs://{MY_GCS_BUCKET}/maxdiffusion_gpu/config_only/models–stabilityai–stable-diffusion-xl-base-1.0
export CHECKPOINTS_LOCAL_DIR=/tmp/maxdiffusion_gpu/config_only
export CHECKPOINTS_LOCAL=/tmp/maxdiffusion_gpu/config_only/models–stabilityai–stable-diffusion-xl-base-1.0

export LD_LIBRARY_PATH=/usr/local/cuda-12.6/compat:$LD_LIBRARY_PATH && pip install .[training] && mkdir -p $CHECKPOINTS_LOCAL_DIR && gsutil -m cp -R $CHECKPOINTS $CHECKPOINTS_LOCAL_DIR & $CHECKPOINTS_LOCAL/unet && python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name=$RUN_NAME output_dir=gs://$OUTPUT_PATH train_new_unet=true train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=20 ici_fsdp_parallelism=1 pretrained_model_name_or_path=$CHECKPOINTS_LOCAL per_device_batch_size=1

Hi Lance,

Thanks for your report. Filing this bug report on the MaxDiffusion repo might be a good idea. To debug this, some more information would be useful:

  • What’s the exact error message?
  • Did this work for you in any other configurations? Did this work previously?
  • If the error message mentions XLA, could you create an HLO dump? You’ll need to set the XLA_FLAGS environment variable to --xla_dump_to=/tmp/foo --xla_dump_hlo_pass_re=.*

Best,
Johannes