Segmentation fault on RTX 5090 with CUDA 13 during repeated PyTorch CUDA forward passes

I am seeing a reproducible but nondeterministic segmentation fault when running repeated neural network forward passes on an RTX 5090 GPU.

The crash happens after a random number of iterations. In the example below, it crashed at iteration 537, but on reruns it crashes at different iterations.

Environment :

  • OS: Arch Linux (up to date)

  • GPU: NVIDIA RTX 5090

  • NVIDIA driver: 595.71.05

  • System CUDA driver/runtime: 13.2

  • Python environment: micromamba

    • PyTorch: 2.11.0

    • PyTorch CUDA toolkit: 13.0

Here is an example of basic script that triggers the error, using the DinoV3 encoder :

# debug_dino_only.py
import torch

REPO_DIR = <DINO_REPO_DIR>
CKPT_PATH = <LOCAL DINO VITL16 CKPT PATH>
device = "cuda:0"

net = torch.hub.load(
   REPO_DIR,
   'dinov3_vitl16',
   source='local',
   pretrained=False
)
ckpt = torch.load(CKPT_PATH, map_location="cpu")
state_dict = ckpt.get("model", ckpt)

msg = net.load_state_dict(state_dict, strict=False)
print("missing keys:", len(msg.missing_keys), msg.missing_keys[:30], flush=True)
print("unexpected keys:", len(msg.unexpected_keys), msg.unexpected_keys[:30], flush=True)

net = net.to(device).train()
for i in range(10_000):
   print("iter", i, flush=True)
   x = torch.randn(2, 3, 768, 768, device=device)
   y = net.forward_features(x)["x_norm_patchtokens"]

Segmentation fault         (core dumped) python debug_dino_only.py

I tried to reproduce this, but wasn’t successful.
What CPU are you using? And what is the kernel version of your arch linux distro?

Ideally, if you have a docker container that reproduces this and can share it (or just the docker file to build and run), it would be great. Plus, if you are able to reproduce it without the checkpoint (which requires a license), it would help. Thanks !

Thanks for your reply,

I’m using a Intel(R) Core™ Ultra 9 285K CPU.

The linux kernel version is 7.0.3-arch1-2

Sure, I’ll try to create a docker/no License version.

I tried to reproduce the above debug_dino_only.py script on the pytorch/pytorch:2.11.0-cuda13.0-cudnn9-devel Docker image and it did not reproduce the segmentation fault after multiple attempts.

However, I managed to reproduce a segmentation fault on the same docker image using a simplified version of my training script, without using a ckpt, with a custom Vision Transformer implementation :

# debug_train_vit.py

import os
import time
import math

import faulthandler
faulthandler.enable()

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader

from typing import Optional
from typing import Dict, Tuple

from pathlib import Path

class DummyLinearHeadDataset(Dataset):
    def __init__(self, num_samples: int, H: int = 768, W: int = 768):
        self.num_samples = num_samples
        self.H = H
        self.W = W

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        color = torch.randn(3, self.H, self.W, dtype=torch.float32)
        dist = torch.randn(self.H, self.W, dtype=torch.float32)
        valid_mask = torch.ones(self.H, self.W, dtype=torch.bool)
        return color, dist, valid_mask


class PositionGetter:
    """Generates and caches 2D spatial positions for patches in a grid.

    This class efficiently manages the generation of spatial coordinates for patches
    in a 2D grid, caching results to avoid redundant computations.

    Attributes:
        position_cache: Dictionary storing precomputed position tensors for different
            grid dimensions.
    """

    def __init__(self):
        """Initializes the position generator with an empty cache."""
        self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}

    def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
        """Generates spatial positions for a batch of patches.

        Args:
            batch_size: Number of samples in the batch.
            height: Height of the grid in patches.
            width: Width of the grid in patches.
            device: Target device for the position tensor.

        Returns:
            Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
            for each position in the grid, repeated for each batch item.
        """
        if (height, width) not in self.position_cache:
            y_coords = torch.arange(height, device=device)
            x_coords = torch.arange(width, device=device)
            positions = torch.cartesian_prod(y_coords, x_coords)
            self.position_cache[height, width] = positions

        cached_positions = self.position_cache[height, width]
        return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()


class RotaryPositionEmbedding2D(nn.Module):
    """2D Rotary Position Embedding implementation.

    This module applies rotary position embeddings to input tokens based on their
    2D spatial positions. It handles the position-dependent rotation of features
    separately for vertical and horizontal dimensions.

    Args:
        frequency: Base frequency for the position embeddings. Default: 100.0
        scaling_factor: Scaling factor for frequency computation. Default: 1.0

    Attributes:
        base_frequency: Base frequency for computing position embeddings.
        scaling_factor: Factor to scale the computed frequencies.
        frequency_cache: Cache for storing precomputed frequency components.
    """

    def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
        """Initializes the 2D RoPE module."""
        super().__init__()
        self.base_frequency = frequency
        self.scaling_factor = scaling_factor
        self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}

    def _compute_frequency_components(
        self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Computes frequency components for rotary embeddings.

        Args:
            dim: Feature dimension (must be even).
            seq_len: Maximum sequence length.
            device: Target device for computations.
            dtype: Data type for the computed tensors.

        Returns:
            Tuple of (cosine, sine) tensors for frequency components.
        """
        cache_key = (dim, seq_len, device, dtype)
        if cache_key not in self.frequency_cache:
            # Compute frequency bands
            exponents = torch.arange(0, dim, 2, device=device).float() / dim
            inv_freq = 1.0 / (self.base_frequency**exponents)

            # Generate position-dependent frequencies
            positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
            angles = torch.einsum("i,j->ij", positions, inv_freq)

            # Compute and cache frequency components
            angles = angles.to(dtype)
            angles = torch.cat((angles, angles), dim=-1)
            cos_components = angles.cos().to(dtype)
            sin_components = angles.sin().to(dtype)
            self.frequency_cache[cache_key] = (cos_components, sin_components)

        return self.frequency_cache[cache_key]

    @staticmethod
    def _rotate_features(x: torch.Tensor) -> torch.Tensor:
        """Performs feature rotation by splitting and recombining feature dimensions.

        Args:
            x: Input tensor to rotate.

        Returns:
            Rotated feature tensor.
        """
        feature_dim = x.shape[-1]
        x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def _apply_1d_rope(
        self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
    ) -> torch.Tensor:
        """Applies 1D rotary position embeddings along one dimension.

        Args:
            tokens: Input token features.
            positions: Position indices.
            cos_comp: Cosine components for rotation.
            sin_comp: Sine components for rotation.

        Returns:
            Tokens with applied rotary position embeddings.
        """
        # Embed positions with frequency components
        cos = F.embedding(positions, cos_comp)[:, None, :, :]
        sin = F.embedding(positions, sin_comp)[:, None, :, :]

        # Apply rotation
        return (tokens * cos) + (self._rotate_features(tokens) * sin)

    def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
        """Applies 2D rotary position embeddings to input tokens.

        Args:
            tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
                   The feature dimension (dim) must be divisible by 4.
            positions: Position tensor of shape (batch_size, n_tokens, 2) containing
                      the y and x coordinates for each token.

        Returns:
            Tensor of same shape as input with applied 2D rotary position embeddings.

        Raises:
            AssertionError: If input dimensions are invalid or positions are malformed.
        """
        # Validate inputs
        assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
        assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"

        # Compute feature dimension for each spatial direction
        feature_dim = tokens.size(-1) // 2

        # Get frequency components
        max_position = int(positions.max()) + 1
        cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)

        # Split features for vertical and horizontal processing
        vertical_features, horizontal_features = tokens.chunk(2, dim=-1)

        # Apply RoPE separately for each dimension
        vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
        horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)

        # Combine processed features
        return torch.cat((vertical_features, horizontal_features), dim=-1)

class LayerNorm(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        eps = 1e-6
        self.norm = nn.LayerNorm(num_channels, eps=eps)

    def forward(self,x):
        """
        input  : (B,C,H,W)
        output : (B,C,H,W)
        """
        x = x.permute(0,2,3,1) # (B,H,W,C)
        x = self.norm(x)       # (B,H,W,C)
        x = x.permute(0,3,1,2) # (B,C,H,W)
        return x
    
class LayerScale(nn.Module):
    def __init__(self, dim: int, init_values: float = 1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.full((dim,), init_values))

    def forward(self, x: Tensor) -> Tensor:
        return x * self.gamma

class DropPath(nn.Module):
    def __init__(self, drop_prob):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # broadcast to all dims except batch
        random_tensor = keep_prob + torch.rand(shape, device=x.device, dtype=x.dtype)
        random_tensor.floor_()
        return x / keep_prob * random_tensor

class MLP(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim: Optional[int] = None,
        out_dim: Optional[int] = None,
    ):
        super().__init__()
        out_dim = out_dim or input_dim
        hidden_dim = hidden_dim or input_dim
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim),        
        )

    def forward(self, x):
        return self.layers(x)

class Attention(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.rope_freq = 100
        self.rope = RotaryPositionEmbedding2D(frequency=self.rope_freq)

        self.qkv = nn.Linear(dim, dim * 3)
        self.q_norm = nn.LayerNorm(self.head_dim)
        self.k_norm = nn.LayerNorm(self.head_dim)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
        B, N, D = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q).contiguous(), self.k_norm(k).contiguous()
        v = v.contiguous()

        q = self.rope(q, pos).contiguous()
        k = self.rope(k, pos).contiguous()

        x = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=0.,
        )

        x = x.transpose(1, 2).reshape(B, N, D)
        x = self.proj(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, drop_path_prob, mlp_ratio = 4, init_skip=1.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attention = Attention(dim, num_heads)
        self.ls1 = LayerScale(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp_hidden_dim = dim*mlp_ratio
        self.mlp = MLP(input_dim=dim, hidden_dim = self.mlp_hidden_dim)
        self.ls2 = LayerScale(dim)

        self.skip_alpha1 = nn.Parameter(torch.ones(1) * init_skip)
        self.skip_alpha2 = nn.Parameter(torch.ones(1) * init_skip)

        self.drop_path1 = DropPath(drop_path_prob)
        self.drop_path2 = DropPath(drop_path_prob)

    def forward(self, x, pos):
        x = self.skip_alpha1*x + self.drop_path1(self.ls1(self.attention(self.norm1(x), pos)))
        x = self.skip_alpha2*x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x

class ViT(nn.Module):
    def __init__(self, H, W, P, D, max_drop_prob, num_heads, num_blocks):
        super().__init__()
        self.H = H
        self.W = W
        self.C_img = 3
        self.P = P
        self.D = D
        self.grid_H = H // self.P
        self.grid_W = W // self.P
        self.num_blocks = num_blocks
        self.img_embedding = nn.Conv2d(self.C_img, D, kernel_size=self.P, stride=self.P)
        self.layer_norm = LayerNorm(self.D)
        dp_probs = [x.item() for x in torch.linspace(0, max_drop_prob, num_blocks)]
        self.blocks = nn.ModuleList([TransformerBlock(D, num_heads, dp) for dp in dp_probs])
        self.position_getter = PositionGetter()
    
    def forward(self, img):
        x = self.img_embedding(img) # (B,D,H/P,W/P)
        x = self.layer_norm(x)                            # (B,D,H/P,W/P)
        x = x.flatten(2).transpose(1,2)                   # (B,Np,D)
        pos = self.position_getter(x.size(0), self.grid_H, self.grid_W, device=x.device) # (B,Np,2)
        for block in self.blocks:
            x = block(x, pos)                             # (B,Np,D)
        return x

class linear_head(nn.Module):
    def __init__(
            self,
            Hn : int,
            Wn : int,
    ):
        super().__init__()
        self.Hn = Hn
        self.Wn = Wn
        self.P = 16
        self.D = 1024
        assert self.Hn%self.P == 0
        assert self.Wn%self.P == 0
        self.grid_H = Hn // self.P
        self.grid_W = Wn // self.P
        self.head = nn.Conv2d(self.D, self.P**2, kernel_size=1)

    def forward(self,x):
        """
        input x : (B,Np,D) tokens
        output  : (B,1,Hn,Wn) dist
        """
        B = x.size(0)
        x = x.reshape(B,self.grid_H, self.grid_W, self.D).permute(0,3,1,2) # (B,D,Hn/P,Wn/P)
        x = self.head(x) # (B,P**2,Hn/P,Wn/P)
        x = F.pixel_shuffle(x, upscale_factor=self.P) # (B,1,Hn,Wn)
        return x

class LinearHeadTrainer:
    def __init__(
            self,
            return_mode : str,
            global_batch_size : int,
            world_size : int,
            num_workers : int,
            prefetch_factor : int,
            use_torch_compile   : bool,
            base_lr : float,
            weight_decay : float,
            warmup_steps : int,
            max_epochs : int,
            Hn : int,
            Wn : int,
            run_log_dir : Path,
            precision : str,
            channels_last : bool,
    ):
        self.max_epochs = max_epochs
        self.warmup_steps = warmup_steps
        self.Hn = Hn
        self.Wn = Wn
        self.base_lr = base_lr
        self.weight_decay = weight_decay
        self.channels_last = channels_last
        assert precision in ["fp16", "bf16"]
        self.precision = {
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
        }[precision]

        self.local_rank = 0
        
        if self.local_rank == 0:
            self.writer = SummaryWriter(log_dir=run_log_dir)
        else:
            self.writer = None

        train_dataset = DummyLinearHeadDataset(num_samples=64, H=Hn, W=Wn)
        val_dataset = DummyLinearHeadDataset(num_samples=16, H=Hn, W=Wn)

        self.train_loader = DataLoader(
            train_dataset,
            batch_size=global_batch_size // world_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=(num_workers > 0),
        )

        self.val_loader = DataLoader(
            val_dataset,
            batch_size=global_batch_size // world_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=(num_workers > 0),
        )

        H = 768
        W = 768
        P = 16
        D = 1024
        max_drop_prob = 0.4
        num_heads = 16
        num_blocks = 24
        self.vit  = ViT(H, W, P, D, max_drop_prob, num_heads, num_blocks)
        self.head = linear_head(self.Hn, self.Wn)

        self.vit.eval()
        for p in self.vit.parameters():
            p.requires_grad_(False)


        self.device = torch.device(f"cuda:{self.local_rank}")
        
        self.vit = self.vit.to(self.device)
        self.head = self.head.to(self.device)

        if channels_last:
            self.vit = self.vit.to(memory_format=torch.channels_last)
            self.head = self.head.to(memory_format=torch.channels_last)

        if use_torch_compile:
            self.vit = torch.compile(self.vit, mode="default", dynamic=False)
            self.head = torch.compile(self.head, mode="default", dynamic=False)

        self.optimizer = torch.optim.AdamW(
            self.head.parameters(),
            lr = self.base_lr,
            weight_decay = self.weight_decay,
        )

        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer, 
            self._lr_lambda,
        )


    # Cosine annealing with warmup
    def _lr_lambda(self, step):

        if step < self.warmup_steps:
            return step / self.warmup_steps

        max_steps = len(self.train_loader) * self.max_epochs
        progress = (step - self.warmup_steps) / max(1, (max_steps - self.warmup_steps))

        return 0.5 * (1 + math.cos(math.pi * progress))

    def loss(self, pred_dist, targ_dist, valid_mask):
        """
        pred, target : (B,1,Hn,Wn)
        valid_mask   : (B,Hn,Wn) boolean of non-padded cells  
        """
        num_valid_elements = valid_mask.sum()
        diff = (pred_dist - targ_dist) * valid_mask[:,None,:,:]
        return (torch.abs(diff)).sum() / num_valid_elements

    def step(self, batch, split):
        img, targ_dist, valid_mask = batch
        
        img = img.to(self.device, non_blocking=True)

        if self.channels_last:
            img = img.contiguous(memory_format=torch.channels_last)

        targ_dist = targ_dist.to(self.device, non_blocking=True)
        valid_mask = valid_mask.to(self.device, non_blocking=True)

        is_train = split == "train"
        with torch.autocast(
            device_type=self.device.type,
            dtype=self.precision,
            enabled=(self.device.type == "cuda")
        ):
            with torch.no_grad():
                patch_feats = self.vit(img) # (B,Np,D)

            with torch.set_grad_enabled(is_train):
                    pred_dist = self.head(patch_feats)
                    loss = self.loss(pred_dist, targ_dist, valid_mask)
        
        if is_train:
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            
        return loss.item()
    
    def train(self):
        for e in range(self.max_epochs):
            
            # --- Train ---
            self.vit.eval()
            self.head.train()
            train_loss_sum = 0.0
            train_batches = 0
            for batch_idx, batch in enumerate(self.train_loader):
                img, targ_dist, valid_mask = batch
                loss = self.step(batch, "train")
                train_loss_sum += loss
                train_batches += 1

            # --- Val ---
            self.vit.eval()
            self.head.eval()
            val_loss_sum = 0.0
            val_batches = 0
            for batch_idx, batch in enumerate(self.val_loader):
                val_loss_sum += self.step(batch, "val")
                val_batches += 1

            metrics = torch.tensor(
                [train_loss_sum, train_batches, val_loss_sum, val_batches], 
                device=self.device,
                dtype=torch.float32
            )
            
            # Compute global averages
            avg_train_loss = (metrics[0] / metrics[1]).item()
            avg_val_loss   = (metrics[2] / metrics[3]).item()

            # --- Logging (Rank 0 Only) ---
            if self.local_rank == 0:
                print(f"epoch : {e}, avg train loss : {avg_train_loss:.4f}")
                print(f"epoch : {e}, avg val loss : {avg_val_loss:.4f}")
                
                self.writer.add_scalar("Loss/train", avg_train_loss, e)
                self.writer.add_scalar("Loss/val", avg_val_loss, e)
                
                # Force TensorBoard to write to disk immediately
                self.writer.flush()
                
        # --- Cleanup ---
        if self.local_rank == 0:
            self.writer.close()

if __name__ == "__main__":
    
    use_torch_compile = False
    return_mode = "linear_head_train_data"
    global_batch_size = 8
    world_size = int(os.environ.get("WORLD_SIZE",1))
    num_workers = 0 #8
    prefetch_factor = None #2
    base_lr = 0.0003
    weight_decay = 0.01
    warmup_steps = 100
    max_epochs = 100
    Hn = 768
    Wn = 768
    logs_dir = Path("logs")
    version_name = "debug"
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    run_log_dir = logs_dir / f"{version_name}_{timestamp}"
    precision = "bf16" # ["fp16", "bf16"]
    channels_last = True

    try:
        mp.set_start_method('spawn', force=True)
    except RuntimeError:
        pass

    trainer = LinearHeadTrainer(
        return_mode = return_mode,
        global_batch_size = global_batch_size,
        world_size = world_size,
        num_workers = num_workers,
        prefetch_factor = prefetch_factor,
        use_torch_compile = use_torch_compile,
        base_lr = base_lr,
        weight_decay = weight_decay,
        warmup_steps = warmup_steps,
        max_epochs = max_epochs,
        Hn = Hn,
        Wn = Wn,
        run_log_dir = run_log_dir,
        precision = precision,
        channels_last = channels_last,
    )

    trainer.train()

Error message :

root@8583fcbfbd27:/workspace/repo# PYTHONFAULTHANDLER=1 CUDA_LAUNCH_BLOCKING=1 TORCH_SHOW_CPP_STACKTRACES=1 python debug_train_vit.py
epoch : 0, avg train loss : 7.3723
epoch : 0, avg val loss : 7.3655
epoch : 1, avg train loss : 7.3528
epoch : 1, avg val loss : 7.3324
epoch : 2, avg train loss : 7.3108
epoch : 2, avg val loss : 7.2783
epoch : 3, avg train loss : 7.2476
epoch : 3, avg val loss : 7.2051
epoch : 4, avg train loss : 7.1702
epoch : 4, avg val loss : 7.1217
epoch : 5, avg train loss : 7.0820
epoch : 5, avg val loss : 7.0256
epoch : 6, avg train loss : 6.9836
epoch : 6, avg val loss : 6.9295
epoch : 7, avg train loss : 6.8891
epoch : 7, avg val loss : 6.8358
epoch : 8, avg train loss : 6.7969
epoch : 8, avg val loss : 6.7471
epoch : 9, avg train loss : 6.7156
epoch : 9, avg val loss : 6.6714
Fatal Python error: Segmentation fault

Thread 0x00007f717261f6c0 (most recent call first):
  <no Python frame>

Thread 0x00007f747da2a6c0 (most recent call first):
  File "/usr/lib/python3.12/threading.py", line 359 in wait
  File "/usr/lib/python3.12/queue.py", line 180 in get
  File "/usr/local/lib/python3.12/dist-packages/tensorboard/summary/writer/event_file_writer.py", line 269 in _run
  File "/usr/local/lib/python3.12/dist-packages/tensorboard/summary/writer/event_file_writer.py", line 244 in run
  File "/usr/lib/python3.12/threading.py", line 1073 in _bootstrap_inner
  File "/usr/lib/python3.12/threading.py", line 1030 in _bootstrap

Current thread 0x00007f75764f7300 (most recent call first):
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/functional.py", line 2935 in layer_norm
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/normalization.py", line 229 in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1790 in _call_impl
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1779 in _wrapped_call_impl
  File "/workspace/repo/debug_train_vit.py", line 284 in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1790 in _call_impl
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1779 in _wrapped_call_impl
  File "/workspace/repo/debug_train_vit.py", line 317 in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1790 in _call_impl
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1779 in _wrapped_call_impl
  File "/workspace/repo/debug_train_vit.py", line 344 in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1790 in _call_impl
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1779 in _wrapped_call_impl
  File "/workspace/repo/debug_train_vit.py", line 513 in step
  File "/workspace/repo/debug_train_vit.py", line 537 in train
  File "/workspace/repo/debug_train_vit.py", line 619 in <module>

Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, google._upb._message, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._pcg64, numpy.random._generator, numpy.random._mt19937, numpy.random._philox, numpy.random._sfc64, numpy.random.mtrand (total: 23)
Segmentation fault (core dumped)

Thank you very much for the simple reproducer !
Unfortunately, I wasn’t able to reproduce your segmentation fault. However, I only have access to machines with AMD CPUs + 5070/5080/5090 or RTX 6000 GPUs, plus the kernel is a bit older.
So my guess is that your combination of CPU + kernel plays some role here.

Are you able to run this on a different system, or the same system but using a slightly older Linux kernel (e.g. I’m using 6.11) ?
I know this might not be easily possible, but it will be impossible to debug if we cannot reproduce it. For example, I cannot verify whether this is even related to PyTorch and/or CUDA, it might be a more generic bug with python threading and/or tensorboard.

Thanks !

Thanks for your reply !
I managed to reproduce the bug on the Linux-LTS kernel (6.18.26-2-lts) (I did not succeed in installing Linux kernel 6.11), with a version of the script where I removed tensorboard:

# debug_train_vit.py

import os
import math

import faulthandler
faulthandler.enable()

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

from typing import Optional
from typing import Dict, Tuple


class DummyLinearHeadDataset(Dataset):
    def __init__(self, num_samples: int, H: int = 768, W: int = 768):
        self.num_samples = num_samples
        self.H = H
        self.W = W

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        color = torch.randn(3, self.H, self.W, dtype=torch.float32)
        dist = torch.randn(self.H, self.W, dtype=torch.float32)
        valid_mask = torch.ones(self.H, self.W, dtype=torch.bool)
        return color, dist, valid_mask


class PositionGetter:
    """Generates and caches 2D spatial positions for patches in a grid.

    This class efficiently manages the generation of spatial coordinates for patches
    in a 2D grid, caching results to avoid redundant computations.

    Attributes:
        position_cache: Dictionary storing precomputed position tensors for different
            grid dimensions.
    """

    def __init__(self):
        """Initializes the position generator with an empty cache."""
        self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}

    def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
        """Generates spatial positions for a batch of patches.

        Args:
            batch_size: Number of samples in the batch.
            height: Height of the grid in patches.
            width: Width of the grid in patches.
            device: Target device for the position tensor.

        Returns:
            Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
            for each position in the grid, repeated for each batch item.
        """
        if (height, width) not in self.position_cache:
            y_coords = torch.arange(height, device=device)
            x_coords = torch.arange(width, device=device)
            positions = torch.cartesian_prod(y_coords, x_coords)
            self.position_cache[height, width] = positions

        cached_positions = self.position_cache[height, width]
        return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()


class RotaryPositionEmbedding2D(nn.Module):
    """2D Rotary Position Embedding implementation.

    This module applies rotary position embeddings to input tokens based on their
    2D spatial positions. It handles the position-dependent rotation of features
    separately for vertical and horizontal dimensions.

    Args:
        frequency: Base frequency for the position embeddings. Default: 100.0
        scaling_factor: Scaling factor for frequency computation. Default: 1.0

    Attributes:
        base_frequency: Base frequency for computing position embeddings.
        scaling_factor: Factor to scale the computed frequencies.
        frequency_cache: Cache for storing precomputed frequency components.
    """

    def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
        """Initializes the 2D RoPE module."""
        super().__init__()
        self.base_frequency = frequency
        self.scaling_factor = scaling_factor
        self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}

    def _compute_frequency_components(
        self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Computes frequency components for rotary embeddings.

        Args:
            dim: Feature dimension (must be even).
            seq_len: Maximum sequence length.
            device: Target device for computations.
            dtype: Data type for the computed tensors.

        Returns:
            Tuple of (cosine, sine) tensors for frequency components.
        """
        cache_key = (dim, seq_len, device, dtype)
        if cache_key not in self.frequency_cache:
            # Compute frequency bands
            exponents = torch.arange(0, dim, 2, device=device).float() / dim
            inv_freq = 1.0 / (self.base_frequency**exponents)

            # Generate position-dependent frequencies
            positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
            angles = torch.einsum("i,j->ij", positions, inv_freq)

            # Compute and cache frequency components
            angles = angles.to(dtype)
            angles = torch.cat((angles, angles), dim=-1)
            cos_components = angles.cos().to(dtype)
            sin_components = angles.sin().to(dtype)
            self.frequency_cache[cache_key] = (cos_components, sin_components)

        return self.frequency_cache[cache_key]

    @staticmethod
    def _rotate_features(x: torch.Tensor) -> torch.Tensor:
        """Performs feature rotation by splitting and recombining feature dimensions.

        Args:
            x: Input tensor to rotate.

        Returns:
            Rotated feature tensor.
        """
        feature_dim = x.shape[-1]
        x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def _apply_1d_rope(
        self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
    ) -> torch.Tensor:
        """Applies 1D rotary position embeddings along one dimension.

        Args:
            tokens: Input token features.
            positions: Position indices.
            cos_comp: Cosine components for rotation.
            sin_comp: Sine components for rotation.

        Returns:
            Tokens with applied rotary position embeddings.
        """
        # Embed positions with frequency components
        cos = F.embedding(positions, cos_comp)[:, None, :, :]
        sin = F.embedding(positions, sin_comp)[:, None, :, :]

        # Apply rotation
        return (tokens * cos) + (self._rotate_features(tokens) * sin)

    def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
        """Applies 2D rotary position embeddings to input tokens.

        Args:
            tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
                   The feature dimension (dim) must be divisible by 4.
            positions: Position tensor of shape (batch_size, n_tokens, 2) containing
                      the y and x coordinates for each token.

        Returns:
            Tensor of same shape as input with applied 2D rotary position embeddings.

        Raises:
            AssertionError: If input dimensions are invalid or positions are malformed.
        """
        # Validate inputs
        assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
        assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"

        # Compute feature dimension for each spatial direction
        feature_dim = tokens.size(-1) // 2

        # Get frequency components
        max_position = int(positions.max()) + 1
        cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)

        # Split features for vertical and horizontal processing
        vertical_features, horizontal_features = tokens.chunk(2, dim=-1)

        # Apply RoPE separately for each dimension
        vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
        horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)

        # Combine processed features
        return torch.cat((vertical_features, horizontal_features), dim=-1)

class LayerNorm(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        eps = 1e-6
        self.norm = nn.LayerNorm(num_channels, eps=eps)

    def forward(self,x):
        """
        input  : (B,C,H,W)
        output : (B,C,H,W)
        """
        x = x.permute(0,2,3,1) # (B,H,W,C)
        x = self.norm(x)       # (B,H,W,C)
        x = x.permute(0,3,1,2) # (B,C,H,W)
        return x
    
class LayerScale(nn.Module):
    def __init__(self, dim: int, init_values: float = 1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.full((dim,), init_values))

    def forward(self, x: Tensor) -> Tensor:
        return x * self.gamma

class DropPath(nn.Module):
    def __init__(self, drop_prob):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # broadcast to all dims except batch
        random_tensor = keep_prob + torch.rand(shape, device=x.device, dtype=x.dtype)
        random_tensor.floor_()
        return x / keep_prob * random_tensor

class MLP(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim: Optional[int] = None,
        out_dim: Optional[int] = None,
    ):
        super().__init__()
        out_dim = out_dim or input_dim
        hidden_dim = hidden_dim or input_dim
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim),        
        )

    def forward(self, x):
        return self.layers(x)

class Attention(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.rope_freq = 100
        self.rope = RotaryPositionEmbedding2D(frequency=self.rope_freq)

        self.qkv = nn.Linear(dim, dim * 3)
        self.q_norm = nn.LayerNorm(self.head_dim)
        self.k_norm = nn.LayerNorm(self.head_dim)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
        B, N, D = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q).contiguous(), self.k_norm(k).contiguous()
        v = v.contiguous()

        q = self.rope(q, pos).contiguous()
        k = self.rope(k, pos).contiguous()

        x = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=0.,
        )

        x = x.transpose(1, 2).reshape(B, N, D)
        x = self.proj(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, drop_path_prob, mlp_ratio = 4, init_skip=1.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attention = Attention(dim, num_heads)
        self.ls1 = LayerScale(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp_hidden_dim = dim*mlp_ratio
        self.mlp = MLP(input_dim=dim, hidden_dim = self.mlp_hidden_dim)
        self.ls2 = LayerScale(dim)

        self.skip_alpha1 = nn.Parameter(torch.ones(1) * init_skip)
        self.skip_alpha2 = nn.Parameter(torch.ones(1) * init_skip)

        self.drop_path1 = DropPath(drop_path_prob)
        self.drop_path2 = DropPath(drop_path_prob)

    def forward(self, x, pos):
        x = self.skip_alpha1*x + self.drop_path1(self.ls1(self.attention(self.norm1(x), pos)))
        x = self.skip_alpha2*x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x

class ViT(nn.Module):
    def __init__(self, H, W, P, D, max_drop_prob, num_heads, num_blocks):
        super().__init__()
        self.H = H
        self.W = W
        self.C_img = 3
        self.P = P
        self.D = D
        self.grid_H = H // self.P
        self.grid_W = W // self.P
        self.num_blocks = num_blocks
        self.img_embedding = nn.Conv2d(self.C_img, D, kernel_size=self.P, stride=self.P)
        self.layer_norm = LayerNorm(self.D)
        dp_probs = [x.item() for x in torch.linspace(0, max_drop_prob, num_blocks)]
        self.blocks = nn.ModuleList([TransformerBlock(D, num_heads, dp) for dp in dp_probs])
        self.position_getter = PositionGetter()
    
    def forward(self, img):
        x = self.img_embedding(img) # (B,D,H/P,W/P)
        x = self.layer_norm(x)                            # (B,D,H/P,W/P)
        x = x.flatten(2).transpose(1,2)                   # (B,Np,D)
        pos = self.position_getter(x.size(0), self.grid_H, self.grid_W, device=x.device) # (B,Np,2)
        for block in self.blocks:
            x = block(x, pos)                             # (B,Np,D)
        return x

class linear_head(nn.Module):
    def __init__(
            self,
            Hn : int,
            Wn : int,
    ):
        super().__init__()
        self.Hn = Hn
        self.Wn = Wn
        self.P = 16
        self.D = 1024
        assert self.Hn%self.P == 0
        assert self.Wn%self.P == 0
        self.grid_H = Hn // self.P
        self.grid_W = Wn // self.P
        self.head = nn.Conv2d(self.D, self.P**2, kernel_size=1)

    def forward(self,x):
        """
        input x : (B,Np,D) tokens
        output  : (B,1,Hn,Wn) dist
        """
        B = x.size(0)
        x = x.reshape(B,self.grid_H, self.grid_W, self.D).permute(0,3,1,2) # (B,D,Hn/P,Wn/P)
        x = self.head(x) # (B,P**2,Hn/P,Wn/P)
        x = F.pixel_shuffle(x, upscale_factor=self.P) # (B,1,Hn,Wn)
        return x

class LinearHeadTrainer:
    def __init__(
            self,
            return_mode : str,
            global_batch_size : int,
            world_size : int,
            num_workers : int,
            prefetch_factor : int,
            use_torch_compile   : bool,
            base_lr : float,
            weight_decay : float,
            warmup_steps : int,
            max_epochs : int,
            Hn : int,
            Wn : int,
            precision : str,
            channels_last : bool,
    ):
        self.max_epochs = max_epochs
        self.warmup_steps = warmup_steps
        self.Hn = Hn
        self.Wn = Wn
        self.base_lr = base_lr
        self.weight_decay = weight_decay
        self.channels_last = channels_last
        assert precision in ["fp16", "bf16"]
        self.precision = {
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
        }[precision]

        self.local_rank = 0
        
        train_dataset = DummyLinearHeadDataset(num_samples=64, H=Hn, W=Wn)
        val_dataset = DummyLinearHeadDataset(num_samples=16, H=Hn, W=Wn)

        self.train_loader = DataLoader(
            train_dataset,
            batch_size=global_batch_size // world_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=(num_workers > 0),
        )

        self.val_loader = DataLoader(
            val_dataset,
            batch_size=global_batch_size // world_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=(num_workers > 0),
        )

        H = 768
        W = 768
        P = 16
        D = 1024
        max_drop_prob = 0.4
        num_heads = 16
        num_blocks = 24
        self.vit  = ViT(H, W, P, D, max_drop_prob, num_heads, num_blocks)
        self.head = linear_head(self.Hn, self.Wn)

        self.vit.eval()
        for p in self.vit.parameters():
            p.requires_grad_(False)


        self.device = torch.device(f"cuda:{self.local_rank}")
        
        self.vit = self.vit.to(self.device)
        self.head = self.head.to(self.device)

        if channels_last:
            self.vit = self.vit.to(memory_format=torch.channels_last)
            self.head = self.head.to(memory_format=torch.channels_last)

        if use_torch_compile:
            self.vit = torch.compile(self.vit, mode="default", dynamic=False)
            self.head = torch.compile(self.head, mode="default", dynamic=False)

        self.optimizer = torch.optim.AdamW(
            self.head.parameters(),
            lr = self.base_lr,
            weight_decay = self.weight_decay,
        )

        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer, 
            self._lr_lambda,
        )


    # Cosine annealing with warmup
    def _lr_lambda(self, step):

        if step < self.warmup_steps:
            return step / self.warmup_steps

        max_steps = len(self.train_loader) * self.max_epochs
        progress = (step - self.warmup_steps) / max(1, (max_steps - self.warmup_steps))

        return 0.5 * (1 + math.cos(math.pi * progress))

    def loss(self, pred_dist, targ_dist, valid_mask):
        """
        pred, target : (B,1,Hn,Wn)
        valid_mask   : (B,Hn,Wn) boolean of non-padded cells  
        """
        num_valid_elements = valid_mask.sum()
        diff = (pred_dist - targ_dist) * valid_mask[:,None,:,:]
        return (torch.abs(diff)).sum() / num_valid_elements

    def step(self, batch, split):
        img, targ_dist, valid_mask = batch
        
        img = img.to(self.device, non_blocking=True)

        if self.channels_last:
            img = img.contiguous(memory_format=torch.channels_last)

        targ_dist = targ_dist.to(self.device, non_blocking=True)
        valid_mask = valid_mask.to(self.device, non_blocking=True)

        is_train = split == "train"
        with torch.autocast(
            device_type=self.device.type,
            dtype=self.precision,
            enabled=(self.device.type == "cuda")
        ):
            with torch.no_grad():
                patch_feats = self.vit(img) # (B,Np,D)

            with torch.set_grad_enabled(is_train):
                    pred_dist = self.head(patch_feats)
                    loss = self.loss(pred_dist, targ_dist, valid_mask)
        
        if is_train:
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            
        return loss.item()
    
    def train(self):
        for e in range(self.max_epochs):
            
            # --- Train ---
            self.vit.eval()
            self.head.train()
            train_loss_sum = 0.0
            train_batches = 0
            for batch_idx, batch in enumerate(self.train_loader):
                img, targ_dist, valid_mask = batch
                loss = self.step(batch, "train")
                train_loss_sum += loss
                train_batches += 1

            # --- Val ---
            self.vit.eval()
            self.head.eval()
            val_loss_sum = 0.0
            val_batches = 0
            for batch_idx, batch in enumerate(self.val_loader):
                val_loss_sum += self.step(batch, "val")
                val_batches += 1

            metrics = torch.tensor(
                [train_loss_sum, train_batches, val_loss_sum, val_batches], 
                device=self.device,
                dtype=torch.float32
            )
            
            # Compute global averages
            avg_train_loss = (metrics[0] / metrics[1]).item()
            avg_val_loss   = (metrics[2] / metrics[3]).item()

            # --- Logging (Rank 0 Only) ---
            if self.local_rank == 0:
                print(f"epoch : {e}, avg train loss : {avg_train_loss:.4f}")
                print(f"epoch : {e}, avg val loss : {avg_val_loss:.4f}")
                            

if __name__ == "__main__":
    
    use_torch_compile = False
    return_mode = "linear_head_train_data"
    global_batch_size = 8
    world_size = int(os.environ.get("WORLD_SIZE",1))
    num_workers = 0 #8
    prefetch_factor = None #2
    base_lr = 0.0003
    weight_decay = 0.01
    warmup_steps = 100
    max_epochs = 100
    Hn = 768
    Wn = 768
    precision = "bf16" # ["fp16", "bf16"]
    channels_last = True

    try:
        mp.set_start_method('spawn', force=True)
    except RuntimeError:
        pass

    trainer = LinearHeadTrainer(
        return_mode = return_mode,
        global_batch_size = global_batch_size,
        world_size = world_size,
        num_workers = num_workers,
        prefetch_factor = prefetch_factor,
        use_torch_compile = use_torch_compile,
        base_lr = base_lr,
        weight_decay = weight_decay,
        warmup_steps = warmup_steps,
        max_epochs = max_epochs,
        Hn = Hn,
        Wn = Wn,
        precision = precision,
        channels_last = channels_last,
    )

    trainer.train()
docker run --gpus all --rm -it \
  --ipc=host \
  -v "$PWD":/workspace/repo \
  -w /workspace/repo \
  pytorch/pytorch:2.11.0-cuda13.0-cudnn9-devel \
  bash
root@0a8f2f9a1d40:/workspace/repo# python - <<'PY'
import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))
PY

python debug_train_vit.py
2.11.0+cu130
True
NVIDIA GeForce RTX 5090
epoch : 0, avg train loss : 7.3662
epoch : 0, avg val loss : 7.3611
epoch : 1, avg train loss : 7.3469
epoch : 1, avg val loss : 7.3278
epoch : 2, avg train loss : 7.3052
epoch : 2, avg val loss : 7.2706
epoch : 3, avg train loss : 7.2425
epoch : 3, avg val loss : 7.2000
epoch : 4, avg train loss : 7.1651
epoch : 4, avg val loss : 7.1158
epoch : 5, avg train loss : 7.0763
epoch : 5, avg val loss : 7.0230
Fatal Python error: Segmentation fault

Thread 0x00007f461d3ff6c0 (most recent call first):
  <no Python frame>

Current thread 0x00007f49e8c96300 (most recent call first):
  File "/workspace/repo/debug_train_vit.py", line 189 in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1790 in _call_impl
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1779 in _wrapped_call_impl
  File "/workspace/repo/debug_train_vit.py", line 284 in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1790 in _call_impl
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1779 in _wrapped_call_impl
  File "/workspace/repo/debug_train_vit.py", line 314 in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1790 in _call_impl
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1779 in _wrapped_call_impl
  File "/workspace/repo/debug_train_vit.py", line 341 in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1790 in _call_impl
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1779 in _wrapped_call_impl
  File "/workspace/repo/debug_train_vit.py", line 504 in step
  File "/workspace/repo/debug_train_vit.py", line 538 in train
  File "/workspace/repo/debug_train_vit.py", line 596 in <module>

Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._pcg64, numpy.random._generator, numpy.random._mt19937, numpy.random._philox, numpy.random._sfc64, numpy.random.mtrand (total: 22)
Segmentation fault (core dumped)

Hello, here is a potentially interesting update : I downgraded the nvidia drivers to version 590.48.01 and I no longer observe the seg fault (after hours of trials) with this new configuration :

  • pytorch docker image (same as before)
  • Linux-LTS kernel (6.18.26-2-lts)
  • Nvidia drivers 590.48.01

thank you, that is very interesting !
I can try with the exact driver you used, hopefully tomorrow

Hello @mjoux, here is an update which I also sent to the issue opened for the same topic on the pytorch github. It could potentially provide much more information about the bug !

I upgraded nvidia drivers back to the recent versions and the seg fault starts happening again as I run debug_train_vit.py, which further suggests that the issue is in relation with the recent nvidia drivers. Please find below and attached :

  • the config of my arch system

  • commands I ran

  • exact python script : debug_train_vit.py

  • gdb backtrace : gdb_backtrace.txt

Arch System :

  • Nvidia Driver 595.71.05

  • nvidia-open-dkms 595.71.05-2

  • nvidia-utils 595.71.05-2

  • Kernel 6.18.30-1-lts

  • CUDA Version: 13.2

  • RTX 5090

docker run --gpus all --rm -it \
  --ipc=host \
  -v "$PWD":/workspace/repo \
  -w /workspace/repo \
  pytorch/pytorch:2.11.0-cuda13.0-cudnn9-devel \
  bash

root@0be2eaa33379:/workspace/repo# gdb --args python debug_train_vit.py
GNU gdb (Ubuntu 15.1-1ubuntu1~24.04.1) 15.1
Copyright (C) 2024 Free Software Foundation, Inc.
License GPLv3+: GNU GPL version 3 or later <http://gnu.org/licenses/gpl.html>
This is free software: you are free to change and redistribute it.
There is NO WARRANTY, to the extent permitted by law.
Type "show copying" and "show warranty" for details.
This GDB was configured as "x86_64-linux-gnu".
Type "show configuration" for configuration details.
For bug reporting instructions, please see:
<https://www.gnu.org/software/gdb/bugs/>.
Find the GDB manual and other documentation resources online at:
    <http://www.gnu.org/software/gdb/documentation/>.

For help, type "help".
Type "apropos word" to search for commands related to "word"...
Reading symbols from python...
(No debugging symbols found in python)
(gdb) run
Starting program: /usr/bin/python debug_train_vit.py
warning: Error disabling address space randomization: Operation not permitted
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
[New Thread 0x7fc0b1dff6c0 (LWP 480)]
[New Thread 0x7fc0b15fe6c0 (LWP 481)]
[New Thread 0x7fc0b0dfd6c0 (LWP 482)]
[New Thread 0x7fc0b05fc6c0 (LWP 483)]
[New Thread 0x7fc0afdfb6c0 (LWP 484)]
[New Thread 0x7fc0af5fa6c0 (LWP 485)]
[New Thread 0x7fc0aedf96c0 (LWP 486)]
[New Thread 0x7fc0ae5f86c0 (LWP 487)]
[New Thread 0x7fc0addf76c0 (LWP 488)]
[New Thread 0x7fc0ad5f66c0 (LWP 489)]
[New Thread 0x7fc0acdf56c0 (LWP 490)]
[New Thread 0x7fc0ac5f46c0 (LWP 491)]
[New Thread 0x7fc0abdf36c0 (LWP 492)]
[New Thread 0x7fc0ab5f26c0 (LWP 493)]
[New Thread 0x7fc0aadf16c0 (LWP 494)]
[New Thread 0x7fc0aa5f06c0 (LWP 495)]
[New Thread 0x7fc0a9def6c0 (LWP 496)]
[New Thread 0x7fc0a95ee6c0 (LWP 497)]
[New Thread 0x7fc0a8ded6c0 (LWP 498)]
[New Thread 0x7fc0a85ec6c0 (LWP 499)]
[New Thread 0x7fc0a7deb6c0 (LWP 500)]
[New Thread 0x7fc0a75ea6c0 (LWP 501)]
[New Thread 0x7fc0a6de96c0 (LWP 502)]
torch: 2.11.0+cu130
[New Thread 0x7fc0a192a6c0 (LWP 503)]
gpu: True
device: NVIDIA GeForce RTX 5090
torch num threads: 1
torch interop threads: 1
[New Thread 0x7fbe4e3a16c0 (LWP 504)]
[New Thread 0x7fbe4dba06c0 (LWP 505)]
[New Thread 0x7fbe455436c0 (LWP 506)]
[New Thread 0x7fbe4f7ff6c0 (LWP 507)]
epoch : 0, avg train loss : 7.3689
epoch : 0, avg val loss : 7.3609
epoch : 1, avg train loss : 7.3495
epoch : 1, avg val loss : 7.3314
epoch : 2, avg train loss : 7.3080
epoch : 2, avg val loss : 7.2755
epoch : 3, avg train loss : 7.2456
epoch : 3, avg val loss : 7.2042
epoch : 4, avg train loss : 7.1669
epoch : 4, avg val loss : 7.1203
epoch : 5, avg train loss : 7.0772
epoch : 5, avg val loss : 7.0249
epoch : 6, avg train loss : 6.9826
epoch : 6, avg val loss : 6.9301
epoch : 7, avg train loss : 6.8872
epoch : 7, avg val loss : 6.8354
epoch : 8, avg train loss : 6.7959
epoch : 8, avg val loss : 6.7506
epoch : 9, avg train loss : 6.7146
epoch : 9, avg val loss : 6.6702
epoch : 10, avg train loss : 6.6422
epoch : 10, avg val loss : 6.6042
epoch : 11, avg train loss : 6.5814
epoch : 11, avg val loss : 6.5542
epoch : 12, avg train loss : 6.5331
epoch : 12, avg val loss : 6.5111
epoch : 13, avg train loss : 6.4963
epoch : 13, avg val loss : 6.4787
epoch : 14, avg train loss : 6.4689
epoch : 14, avg val loss : 6.4564
epoch : 15, avg train loss : 6.4521
epoch : 15, avg val loss : 6.4425
epoch : 16, avg train loss : 6.4360
epoch : 16, avg val loss : 6.4284
epoch : 17, avg train loss : 6.4261
epoch : 17, avg val loss : 6.4212
epoch : 18, avg train loss : 6.4190
epoch : 18, avg val loss : 6.4137
epoch : 19, avg train loss : 6.4134
epoch : 19, avg val loss : 6.4112
epoch : 20, avg train loss : 6.4088
epoch : 20, avg val loss : 6.4066
epoch : 21, avg train loss : 6.4050
epoch : 21, avg val loss : 6.3987
epoch : 22, avg train loss : 6.4018
epoch : 22, avg val loss : 6.3985
epoch : 23, avg train loss : 6.4000
epoch : 23, avg val loss : 6.3974
epoch : 24, avg train loss : 6.3982
epoch : 24, avg val loss : 6.3956
epoch : 25, avg train loss : 6.3954
epoch : 25, avg val loss : 6.3939

Thread 1 "python" received signal SIGSEGV, Segmentation fault.
0x00007fc16de15c48 in ?? () from /lib/libcuda.so.1


# debug_train_vit.py

import os
import math

import faulthandler
faulthandler.enable()

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

from typing import Optional
from typing import Dict, Tuple


class DummyLinearHeadDataset(Dataset):
    def __init__(self, num_samples: int, H: int = 768, W: int = 768):
        self.num_samples = num_samples
        self.H = H
        self.W = W

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        color = torch.randn(3, self.H, self.W, dtype=torch.float32)
        dist = torch.randn(self.H, self.W, dtype=torch.float32)
        valid_mask = torch.ones(self.H, self.W, dtype=torch.bool)
        return color, dist, valid_mask


class PositionGetter:
    """Generates and caches 2D spatial positions for patches in a grid.

    This class efficiently manages the generation of spatial coordinates for patches
    in a 2D grid, caching results to avoid redundant computations.

    Attributes:
        position_cache: Dictionary storing precomputed position tensors for different
            grid dimensions.
    """

    def __init__(self):
        """Initializes the position generator with an empty cache."""
        self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}

    def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
        """Generates spatial positions for a batch of patches.

        Args:
            batch_size: Number of samples in the batch.
            height: Height of the grid in patches.
            width: Width of the grid in patches.
            device: Target device for the position tensor.

        Returns:
            Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
            for each position in the grid, repeated for each batch item.
        """
        if (height, width) not in self.position_cache:
            y_coords = torch.arange(height, device=device)
            x_coords = torch.arange(width, device=device)
            positions = torch.cartesian_prod(y_coords, x_coords)
            self.position_cache[height, width] = positions

        cached_positions = self.position_cache[height, width]
        return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()


class RotaryPositionEmbedding2D(nn.Module):
    """2D Rotary Position Embedding implementation.

    This module applies rotary position embeddings to input tokens based on their
    2D spatial positions. It handles the position-dependent rotation of features
    separately for vertical and horizontal dimensions.

    Args:
        frequency: Base frequency for the position embeddings. Default: 100.0
        scaling_factor: Scaling factor for frequency computation. Default: 1.0

    Attributes:
        base_frequency: Base frequency for computing position embeddings.
        scaling_factor: Factor to scale the computed frequencies.
        frequency_cache: Cache for storing precomputed frequency components.
    """

    def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
        """Initializes the 2D RoPE module."""
        super().__init__()
        self.base_frequency = frequency
        self.scaling_factor = scaling_factor
        self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}

    def _compute_frequency_components(
        self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Computes frequency components for rotary embeddings.

        Args:
            dim: Feature dimension (must be even).
            seq_len: Maximum sequence length.
            device: Target device for computations.
            dtype: Data type for the computed tensors.

        Returns:
            Tuple of (cosine, sine) tensors for frequency components.
        """
        cache_key = (dim, seq_len, device, dtype)
        if cache_key not in self.frequency_cache:
            # Compute frequency bands
            exponents = torch.arange(0, dim, 2, device=device).float() / dim
            inv_freq = 1.0 / (self.base_frequency**exponents)

            # Generate position-dependent frequencies
            positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
            angles = torch.einsum("i,j->ij", positions, inv_freq)

            # Compute and cache frequency components
            angles = angles.to(dtype)
            angles = torch.cat((angles, angles), dim=-1)
            cos_components = angles.cos().to(dtype)
            sin_components = angles.sin().to(dtype)
            self.frequency_cache[cache_key] = (cos_components, sin_components)

        return self.frequency_cache[cache_key]

    @staticmethod
    def _rotate_features(x: torch.Tensor) -> torch.Tensor:
        """Performs feature rotation by splitting and recombining feature dimensions.

        Args:
            x: Input tensor to rotate.

        Returns:
            Rotated feature tensor.
        """
        feature_dim = x.shape[-1]
        x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def _apply_1d_rope(
        self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
    ) -> torch.Tensor:
        """Applies 1D rotary position embeddings along one dimension.

        Args:
            tokens: Input token features.
            positions: Position indices.
            cos_comp: Cosine components for rotation.
            sin_comp: Sine components for rotation.

        Returns:
            Tokens with applied rotary position embeddings.
        """
        # Embed positions with frequency components
        cos = F.embedding(positions, cos_comp)[:, None, :, :]
        sin = F.embedding(positions, sin_comp)[:, None, :, :]

        # Apply rotation
        return (tokens * cos) + (self._rotate_features(tokens) * sin)

    def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
        """Applies 2D rotary position embeddings to input tokens.

        Args:
            tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
                   The feature dimension (dim) must be divisible by 4.
            positions: Position tensor of shape (batch_size, n_tokens, 2) containing
                      the y and x coordinates for each token.

        Returns:
            Tensor of same shape as input with applied 2D rotary position embeddings.

        Raises:
            AssertionError: If input dimensions are invalid or positions are malformed.
        """
        # Validate inputs
        assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
        assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"

        # Compute feature dimension for each spatial direction
        feature_dim = tokens.size(-1) // 2

        # Get frequency components
        max_position = int(positions.max()) + 1
        cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)

        # Split features for vertical and horizontal processing
        vertical_features, horizontal_features = tokens.chunk(2, dim=-1)

        # Apply RoPE separately for each dimension
        vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
        horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)

        # Combine processed features
        return torch.cat((vertical_features, horizontal_features), dim=-1)

class LayerNorm(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        eps = 1e-6
        self.norm = nn.LayerNorm(num_channels, eps=eps)

    def forward(self,x):
        """
        input  : (B,C,H,W)
        output : (B,C,H,W)
        """
        x = x.permute(0,2,3,1) # (B,H,W,C)
        x = self.norm(x)       # (B,H,W,C)
        x = x.permute(0,3,1,2) # (B,C,H,W)
        return x
    
class LayerScale(nn.Module):
    def __init__(self, dim: int, init_values: float = 1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.full((dim,), init_values))

    def forward(self, x: Tensor) -> Tensor:
        return x * self.gamma

class DropPath(nn.Module):
    def __init__(self, drop_prob):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # broadcast to all dims except batch
        random_tensor = keep_prob + torch.rand(shape, device=x.device, dtype=x.dtype)
        random_tensor.floor_()
        return x / keep_prob * random_tensor

class MLP(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim: Optional[int] = None,
        out_dim: Optional[int] = None,
    ):
        super().__init__()
        out_dim = out_dim or input_dim
        hidden_dim = hidden_dim or input_dim
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim),        
        )

    def forward(self, x):
        return self.layers(x)

class Attention(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.rope_freq = 100
        self.rope = RotaryPositionEmbedding2D(frequency=self.rope_freq)

        self.qkv = nn.Linear(dim, dim * 3)
        self.q_norm = nn.LayerNorm(self.head_dim)
        self.k_norm = nn.LayerNorm(self.head_dim)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
        B, N, D = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q).contiguous(), self.k_norm(k).contiguous()
        v = v.contiguous()

        q = self.rope(q, pos).contiguous()
        k = self.rope(k, pos).contiguous()

        x = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=0.,
        )

        x = x.transpose(1, 2).reshape(B, N, D)
        x = self.proj(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, drop_path_prob, mlp_ratio = 4, init_skip=1.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attention = Attention(dim, num_heads)
        self.ls1 = LayerScale(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp_hidden_dim = dim*mlp_ratio
        self.mlp = MLP(input_dim=dim, hidden_dim = self.mlp_hidden_dim)
        self.ls2 = LayerScale(dim)

        self.skip_alpha1 = nn.Parameter(torch.ones(1) * init_skip)
        self.skip_alpha2 = nn.Parameter(torch.ones(1) * init_skip)

        self.drop_path1 = DropPath(drop_path_prob)
        self.drop_path2 = DropPath(drop_path_prob)

    def forward(self, x, pos):
        x = self.skip_alpha1*x + self.drop_path1(self.ls1(self.attention(self.norm1(x), pos)))
        x = self.skip_alpha2*x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x

class ViT(nn.Module):
    def __init__(self, H, W, P, D, max_drop_prob, num_heads, num_blocks):
        super().__init__()
        self.H = H
        self.W = W
        self.C_img = 3
        self.P = P
        self.D = D
        self.grid_H = H // self.P
        self.grid_W = W // self.P
        self.num_blocks = num_blocks
        self.img_embedding = nn.Conv2d(self.C_img, D, kernel_size=self.P, stride=self.P)
        self.layer_norm = LayerNorm(self.D)
        dp_probs = [x.item() for x in torch.linspace(0, max_drop_prob, num_blocks)]
        self.blocks = nn.ModuleList([TransformerBlock(D, num_heads, dp) for dp in dp_probs])
        self.position_getter = PositionGetter()
    
    def forward(self, img):
        x = self.img_embedding(img) # (B,D,H/P,W/P)
        x = self.layer_norm(x)                            # (B,D,H/P,W/P)
        x = x.flatten(2).transpose(1,2)                   # (B,Np,D)
        pos = self.position_getter(x.size(0), self.grid_H, self.grid_W, device=x.device) # (B,Np,2)
        for block in self.blocks:
            x = block(x, pos)                             # (B,Np,D)
        return x

class linear_head(nn.Module):
    def __init__(
            self,
            Hn : int,
            Wn : int,
            D  : int,
    ):
        super().__init__()
        self.Hn = Hn
        self.Wn = Wn
        self.P = 16
        self.D = D
        assert self.Hn%self.P == 0
        assert self.Wn%self.P == 0
        self.grid_H = Hn // self.P
        self.grid_W = Wn // self.P
        self.head = nn.Conv2d(self.D, self.P**2, kernel_size=1)

    def forward(self,x):
        """
        input x : (B,Np,D) tokens
        output  : (B,1,Hn,Wn) dist
        """
        B = x.size(0)
        x = x.reshape(B,self.grid_H, self.grid_W, self.D).permute(0,3,1,2) # (B,D,Hn/P,Wn/P)
        x = self.head(x) # (B,P**2,Hn/P,Wn/P)
        x = F.pixel_shuffle(x, upscale_factor=self.P) # (B,1,Hn,Wn)
        return x

class LinearHeadTrainer:
    def __init__(
            self,
            return_mode : str,
            global_batch_size : int,
            world_size : int,
            num_workers : int,
            prefetch_factor : int,
            use_torch_compile   : bool,
            base_lr : float,
            weight_decay : float,
            warmup_steps : int,
            max_epochs : int,
            Hn : int,
            Wn : int,
            D  : int,
            precision : str,
            channels_last : bool,
            accelerator : str,
    ):
        self.max_epochs = max_epochs
        self.warmup_steps = warmup_steps
        self.Hn = Hn
        self.Wn = Wn
        self.D  = D
        self.base_lr = base_lr
        self.weight_decay = weight_decay
        self.channels_last = channels_last
        assert precision in ["fp16", "bf16"]
        self.precision = {
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
        }[precision]

        self.local_rank = 0
        
        train_dataset = DummyLinearHeadDataset(num_samples=64, H=Hn, W=Wn)
        val_dataset = DummyLinearHeadDataset(num_samples=16, H=Hn, W=Wn)

        use_cuda = accelerator == "gpu" and torch.cuda.is_available()

        self.train_loader = DataLoader(
            train_dataset,
            batch_size=global_batch_size // world_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=use_cuda,
            persistent_workers=(num_workers > 0),
        )

        self.val_loader = DataLoader(
            val_dataset,
            batch_size=global_batch_size // world_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=use_cuda,
            persistent_workers=(num_workers > 0),
        )

        P = 16
        max_drop_prob = 0.4
        num_heads = 16
        num_blocks = 24
        self.vit  = ViT(self.Hn, self.Wn, P, self.D, max_drop_prob, num_heads, num_blocks)
        self.head = linear_head(self.Hn, self.Wn, self.D)

        self.vit.eval()
        for p in self.vit.parameters():
            p.requires_grad_(False)

        if accelerator == "gpu":
            self.device = torch.device(f"cuda:{self.local_rank}")
        elif accelerator == "cpu":
            self.device = torch.device("cpu")
        else:
            raise ValueError(f"Unknown device: {accelerator}")
        
        self.vit = self.vit.to(self.device)
        self.head = self.head.to(self.device)

        if channels_last:
            self.vit = self.vit.to(memory_format=torch.channels_last)
            self.head = self.head.to(memory_format=torch.channels_last)

        if use_torch_compile:
            self.vit = torch.compile(self.vit, mode="default", dynamic=False)
            self.head = torch.compile(self.head, mode="default", dynamic=False)

        self.optimizer = torch.optim.AdamW(
            self.head.parameters(),
            lr = self.base_lr,
            weight_decay = self.weight_decay,
        )

        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer, 
            self._lr_lambda,
        )


    # Cosine annealing with warmup
    def _lr_lambda(self, step):

        if step < self.warmup_steps:
            return step / self.warmup_steps

        max_steps = len(self.train_loader) * self.max_epochs
        progress = (step - self.warmup_steps) / max(1, (max_steps - self.warmup_steps))

        return 0.5 * (1 + math.cos(math.pi * progress))

    def loss(self, pred_dist, targ_dist, valid_mask):
        """
        pred, target : (B,1,Hn,Wn)
        valid_mask   : (B,Hn,Wn) boolean of non-padded cells  
        """
        num_valid_elements = valid_mask.sum()
        diff = (pred_dist - targ_dist) * valid_mask[:,None,:,:]
        return (torch.abs(diff)).sum() / num_valid_elements

    def step(self, batch, split):
        img, targ_dist, valid_mask = batch
        
        img = img.to(self.device, non_blocking=True)

        if self.channels_last:
            img = img.contiguous(memory_format=torch.channels_last)

        targ_dist = targ_dist.to(self.device, non_blocking=True)
        valid_mask = valid_mask.to(self.device, non_blocking=True)

        is_train = split == "train"
        with torch.autocast(
            device_type=self.device.type,
            dtype=self.precision,
            enabled=(self.device.type == "cuda")
        ):
            with torch.no_grad():
                patch_feats = self.vit(img) # (B,Np,D)

            with torch.set_grad_enabled(is_train):
                    pred_dist = self.head(patch_feats)
                    loss = self.loss(pred_dist, targ_dist, valid_mask)
        
        if is_train:
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            
        return loss.item()
    
    def train(self):
        for e in range(self.max_epochs):
            
            # --- Train ---
            self.vit.eval()
            self.head.train()
            train_loss_sum = 0.0
            train_batches = 0
            for batch_idx, batch in enumerate(self.train_loader):
                loss = self.step(batch, "train")
                train_loss_sum += loss
                train_batches += 1

            # --- Val ---
            self.vit.eval()
            self.head.eval()
            val_loss_sum = 0.0
            val_batches = 0
            for batch_idx, batch in enumerate(self.val_loader):
                val_loss_sum += self.step(batch, "val")
                val_batches += 1

            metrics = torch.tensor(
                [train_loss_sum, train_batches, val_loss_sum, val_batches], 
                device=self.device,
                dtype=torch.float32
            )
            
            # Compute global averages
            avg_train_loss = (metrics[0] / metrics[1]).item()
            avg_val_loss   = (metrics[2] / metrics[3]).item()

            # --- Logging (Rank 0 Only) ---
            if self.local_rank == 0:
                print(f"epoch : {e}, avg train loss : {avg_train_loss:.4f}")
                print(f"epoch : {e}, avg val loss : {avg_val_loss:.4f}")
                            

if __name__ == "__main__":
    
    use_torch_compile = False
    return_mode = "linear_head_train_data"
    global_batch_size = 8
    world_size = int(os.environ.get("WORLD_SIZE",1))
    num_workers = 0 #8
    prefetch_factor = None #2
    base_lr = 0.0003
    weight_decay = 0.01
    warmup_steps = 100
    max_epochs = 1000
    Hn = 768 #768
    Wn = 768 #768
    D = 1024 #1024
    precision = "bf16" # ["fp16", "bf16"]
    channels_last = True
    accelerator = "gpu" # ["gpu", "cpu"]

    try:
        mp.set_start_method('spawn', force=True)
    except RuntimeError:
        pass

    torch.set_num_threads(1)
    torch.set_num_interop_threads(1)

    print("torch:", torch.__version__)
    print("gpu:", torch.cuda.is_available())
    if accelerator == "gpu" and torch.cuda.is_available():
        print("device:", torch.cuda.get_device_name(0))
    else:
        print("device: CPU")
    print("torch num threads:", torch.get_num_threads())
    print("torch interop threads:", torch.get_num_interop_threads())

    trainer = LinearHeadTrainer(
        return_mode = return_mode,
        global_batch_size = global_batch_size,
        world_size = world_size,
        num_workers = num_workers,
        prefetch_factor = prefetch_factor,
        use_torch_compile = use_torch_compile,
        base_lr = base_lr,
        weight_decay = weight_decay,
        warmup_steps = warmup_steps,
        max_epochs = max_epochs,
        Hn = Hn,
        Wn = Wn,
        D  = D,
        precision = precision,
        channels_last = channels_last,
        accelerator = accelerator,
    )

    trainer.train()

gdb_backtrace.txt (49.6 KB)

Now here is a cuda-gbd backtrace for another run launched with
CUDA_LAUNCH_BLOCKING=1 cuda-gdb --args python debug_train_vit.py

Interestingly, cuda-gdb repeatedly reports:
“CUDA Driver error detected: No CUDA context is current to the calling thread”
“Returning 201 (CUDA_ERROR_INVALID_CONTEXT) from cuCtxGetDevice_v2”

Then the program still segfaults in /lib/libcuda.so.1

cuda_gdb_backtrace.txt (26.6 KB)