I tried to reproduce the above debug_dino_only.py script on the pytorch/pytorch:2.11.0-cuda13.0-cudnn9-devel Docker image and it did not reproduce the segmentation fault after multiple attempts.
However, I managed to reproduce a segmentation fault on the same docker image using a simplified version of my training script, without using a ckpt, with a custom Vision Transformer implementation :
# debug_train_vit.py
import os
import time
import math
import faulthandler
faulthandler.enable()
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from typing import Optional
from typing import Dict, Tuple
from pathlib import Path
class DummyLinearHeadDataset(Dataset):
def __init__(self, num_samples: int, H: int = 768, W: int = 768):
self.num_samples = num_samples
self.H = H
self.W = W
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
color = torch.randn(3, self.H, self.W, dtype=torch.float32)
dist = torch.randn(self.H, self.W, dtype=torch.float32)
valid_mask = torch.ones(self.H, self.W, dtype=torch.bool)
return color, dist, valid_mask
class PositionGetter:
"""Generates and caches 2D spatial positions for patches in a grid.
This class efficiently manages the generation of spatial coordinates for patches
in a 2D grid, caching results to avoid redundant computations.
Attributes:
position_cache: Dictionary storing precomputed position tensors for different
grid dimensions.
"""
def __init__(self):
"""Initializes the position generator with an empty cache."""
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
"""Generates spatial positions for a batch of patches.
Args:
batch_size: Number of samples in the batch.
height: Height of the grid in patches.
width: Width of the grid in patches.
device: Target device for the position tensor.
Returns:
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
for each position in the grid, repeated for each batch item.
"""
if (height, width) not in self.position_cache:
y_coords = torch.arange(height, device=device)
x_coords = torch.arange(width, device=device)
positions = torch.cartesian_prod(y_coords, x_coords)
self.position_cache[height, width] = positions
cached_positions = self.position_cache[height, width]
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
class RotaryPositionEmbedding2D(nn.Module):
"""2D Rotary Position Embedding implementation.
This module applies rotary position embeddings to input tokens based on their
2D spatial positions. It handles the position-dependent rotation of features
separately for vertical and horizontal dimensions.
Args:
frequency: Base frequency for the position embeddings. Default: 100.0
scaling_factor: Scaling factor for frequency computation. Default: 1.0
Attributes:
base_frequency: Base frequency for computing position embeddings.
scaling_factor: Factor to scale the computed frequencies.
frequency_cache: Cache for storing precomputed frequency components.
"""
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
"""Initializes the 2D RoPE module."""
super().__init__()
self.base_frequency = frequency
self.scaling_factor = scaling_factor
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
def _compute_frequency_components(
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes frequency components for rotary embeddings.
Args:
dim: Feature dimension (must be even).
seq_len: Maximum sequence length.
device: Target device for computations.
dtype: Data type for the computed tensors.
Returns:
Tuple of (cosine, sine) tensors for frequency components.
"""
cache_key = (dim, seq_len, device, dtype)
if cache_key not in self.frequency_cache:
# Compute frequency bands
exponents = torch.arange(0, dim, 2, device=device).float() / dim
inv_freq = 1.0 / (self.base_frequency**exponents)
# Generate position-dependent frequencies
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
angles = torch.einsum("i,j->ij", positions, inv_freq)
# Compute and cache frequency components
angles = angles.to(dtype)
angles = torch.cat((angles, angles), dim=-1)
cos_components = angles.cos().to(dtype)
sin_components = angles.sin().to(dtype)
self.frequency_cache[cache_key] = (cos_components, sin_components)
return self.frequency_cache[cache_key]
@staticmethod
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
"""Performs feature rotation by splitting and recombining feature dimensions.
Args:
x: Input tensor to rotate.
Returns:
Rotated feature tensor.
"""
feature_dim = x.shape[-1]
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_1d_rope(
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
) -> torch.Tensor:
"""Applies 1D rotary position embeddings along one dimension.
Args:
tokens: Input token features.
positions: Position indices.
cos_comp: Cosine components for rotation.
sin_comp: Sine components for rotation.
Returns:
Tokens with applied rotary position embeddings.
"""
# Embed positions with frequency components
cos = F.embedding(positions, cos_comp)[:, None, :, :]
sin = F.embedding(positions, sin_comp)[:, None, :, :]
# Apply rotation
return (tokens * cos) + (self._rotate_features(tokens) * sin)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
"""Applies 2D rotary position embeddings to input tokens.
Args:
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
The feature dimension (dim) must be divisible by 4.
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
the y and x coordinates for each token.
Returns:
Tensor of same shape as input with applied 2D rotary position embeddings.
Raises:
AssertionError: If input dimensions are invalid or positions are malformed.
"""
# Validate inputs
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
# Compute feature dimension for each spatial direction
feature_dim = tokens.size(-1) // 2
# Get frequency components
max_position = int(positions.max()) + 1
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
# Split features for vertical and horizontal processing
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
# Apply RoPE separately for each dimension
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
# Combine processed features
return torch.cat((vertical_features, horizontal_features), dim=-1)
class LayerNorm(nn.Module):
def __init__(self, num_channels):
super().__init__()
eps = 1e-6
self.norm = nn.LayerNorm(num_channels, eps=eps)
def forward(self,x):
"""
input : (B,C,H,W)
output : (B,C,H,W)
"""
x = x.permute(0,2,3,1) # (B,H,W,C)
x = self.norm(x) # (B,H,W,C)
x = x.permute(0,3,1,2) # (B,C,H,W)
return x
class LayerScale(nn.Module):
def __init__(self, dim: int, init_values: float = 1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.full((dim,), init_values))
def forward(self, x: Tensor) -> Tensor:
return x * self.gamma
class DropPath(nn.Module):
def __init__(self, drop_prob):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # broadcast to all dims except batch
random_tensor = keep_prob + torch.rand(shape, device=x.device, dtype=x.dtype)
random_tensor.floor_()
return x / keep_prob * random_tensor
class MLP(nn.Module):
def __init__(
self,
input_dim,
hidden_dim: Optional[int] = None,
out_dim: Optional[int] = None,
):
super().__init__()
out_dim = out_dim or input_dim
hidden_dim = hidden_dim or input_dim
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, out_dim),
)
def forward(self, x):
return self.layers(x)
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.rope_freq = 100
self.rope = RotaryPositionEmbedding2D(frequency=self.rope_freq)
self.qkv = nn.Linear(dim, dim * 3)
self.q_norm = nn.LayerNorm(self.head_dim)
self.k_norm = nn.LayerNorm(self.head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
B, N, D = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q).contiguous(), self.k_norm(k).contiguous()
v = v.contiguous()
q = self.rope(q, pos).contiguous()
k = self.rope(k, pos).contiguous()
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=0.,
)
x = x.transpose(1, 2).reshape(B, N, D)
x = self.proj(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, drop_path_prob, mlp_ratio = 4, init_skip=1.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attention = Attention(dim, num_heads)
self.ls1 = LayerScale(dim)
self.norm2 = nn.LayerNorm(dim)
self.mlp_hidden_dim = dim*mlp_ratio
self.mlp = MLP(input_dim=dim, hidden_dim = self.mlp_hidden_dim)
self.ls2 = LayerScale(dim)
self.skip_alpha1 = nn.Parameter(torch.ones(1) * init_skip)
self.skip_alpha2 = nn.Parameter(torch.ones(1) * init_skip)
self.drop_path1 = DropPath(drop_path_prob)
self.drop_path2 = DropPath(drop_path_prob)
def forward(self, x, pos):
x = self.skip_alpha1*x + self.drop_path1(self.ls1(self.attention(self.norm1(x), pos)))
x = self.skip_alpha2*x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class ViT(nn.Module):
def __init__(self, H, W, P, D, max_drop_prob, num_heads, num_blocks):
super().__init__()
self.H = H
self.W = W
self.C_img = 3
self.P = P
self.D = D
self.grid_H = H // self.P
self.grid_W = W // self.P
self.num_blocks = num_blocks
self.img_embedding = nn.Conv2d(self.C_img, D, kernel_size=self.P, stride=self.P)
self.layer_norm = LayerNorm(self.D)
dp_probs = [x.item() for x in torch.linspace(0, max_drop_prob, num_blocks)]
self.blocks = nn.ModuleList([TransformerBlock(D, num_heads, dp) for dp in dp_probs])
self.position_getter = PositionGetter()
def forward(self, img):
x = self.img_embedding(img) # (B,D,H/P,W/P)
x = self.layer_norm(x) # (B,D,H/P,W/P)
x = x.flatten(2).transpose(1,2) # (B,Np,D)
pos = self.position_getter(x.size(0), self.grid_H, self.grid_W, device=x.device) # (B,Np,2)
for block in self.blocks:
x = block(x, pos) # (B,Np,D)
return x
class linear_head(nn.Module):
def __init__(
self,
Hn : int,
Wn : int,
):
super().__init__()
self.Hn = Hn
self.Wn = Wn
self.P = 16
self.D = 1024
assert self.Hn%self.P == 0
assert self.Wn%self.P == 0
self.grid_H = Hn // self.P
self.grid_W = Wn // self.P
self.head = nn.Conv2d(self.D, self.P**2, kernel_size=1)
def forward(self,x):
"""
input x : (B,Np,D) tokens
output : (B,1,Hn,Wn) dist
"""
B = x.size(0)
x = x.reshape(B,self.grid_H, self.grid_W, self.D).permute(0,3,1,2) # (B,D,Hn/P,Wn/P)
x = self.head(x) # (B,P**2,Hn/P,Wn/P)
x = F.pixel_shuffle(x, upscale_factor=self.P) # (B,1,Hn,Wn)
return x
class LinearHeadTrainer:
def __init__(
self,
return_mode : str,
global_batch_size : int,
world_size : int,
num_workers : int,
prefetch_factor : int,
use_torch_compile : bool,
base_lr : float,
weight_decay : float,
warmup_steps : int,
max_epochs : int,
Hn : int,
Wn : int,
run_log_dir : Path,
precision : str,
channels_last : bool,
):
self.max_epochs = max_epochs
self.warmup_steps = warmup_steps
self.Hn = Hn
self.Wn = Wn
self.base_lr = base_lr
self.weight_decay = weight_decay
self.channels_last = channels_last
assert precision in ["fp16", "bf16"]
self.precision = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
}[precision]
self.local_rank = 0
if self.local_rank == 0:
self.writer = SummaryWriter(log_dir=run_log_dir)
else:
self.writer = None
train_dataset = DummyLinearHeadDataset(num_samples=64, H=Hn, W=Wn)
val_dataset = DummyLinearHeadDataset(num_samples=16, H=Hn, W=Wn)
self.train_loader = DataLoader(
train_dataset,
batch_size=global_batch_size // world_size,
shuffle=True,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
persistent_workers=(num_workers > 0),
)
self.val_loader = DataLoader(
val_dataset,
batch_size=global_batch_size // world_size,
shuffle=False,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
persistent_workers=(num_workers > 0),
)
H = 768
W = 768
P = 16
D = 1024
max_drop_prob = 0.4
num_heads = 16
num_blocks = 24
self.vit = ViT(H, W, P, D, max_drop_prob, num_heads, num_blocks)
self.head = linear_head(self.Hn, self.Wn)
self.vit.eval()
for p in self.vit.parameters():
p.requires_grad_(False)
self.device = torch.device(f"cuda:{self.local_rank}")
self.vit = self.vit.to(self.device)
self.head = self.head.to(self.device)
if channels_last:
self.vit = self.vit.to(memory_format=torch.channels_last)
self.head = self.head.to(memory_format=torch.channels_last)
if use_torch_compile:
self.vit = torch.compile(self.vit, mode="default", dynamic=False)
self.head = torch.compile(self.head, mode="default", dynamic=False)
self.optimizer = torch.optim.AdamW(
self.head.parameters(),
lr = self.base_lr,
weight_decay = self.weight_decay,
)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer,
self._lr_lambda,
)
# Cosine annealing with warmup
def _lr_lambda(self, step):
if step < self.warmup_steps:
return step / self.warmup_steps
max_steps = len(self.train_loader) * self.max_epochs
progress = (step - self.warmup_steps) / max(1, (max_steps - self.warmup_steps))
return 0.5 * (1 + math.cos(math.pi * progress))
def loss(self, pred_dist, targ_dist, valid_mask):
"""
pred, target : (B,1,Hn,Wn)
valid_mask : (B,Hn,Wn) boolean of non-padded cells
"""
num_valid_elements = valid_mask.sum()
diff = (pred_dist - targ_dist) * valid_mask[:,None,:,:]
return (torch.abs(diff)).sum() / num_valid_elements
def step(self, batch, split):
img, targ_dist, valid_mask = batch
img = img.to(self.device, non_blocking=True)
if self.channels_last:
img = img.contiguous(memory_format=torch.channels_last)
targ_dist = targ_dist.to(self.device, non_blocking=True)
valid_mask = valid_mask.to(self.device, non_blocking=True)
is_train = split == "train"
with torch.autocast(
device_type=self.device.type,
dtype=self.precision,
enabled=(self.device.type == "cuda")
):
with torch.no_grad():
patch_feats = self.vit(img) # (B,Np,D)
with torch.set_grad_enabled(is_train):
pred_dist = self.head(patch_feats)
loss = self.loss(pred_dist, targ_dist, valid_mask)
if is_train:
self.optimizer.zero_grad(set_to_none=True)
loss.backward()
self.optimizer.step()
self.scheduler.step()
return loss.item()
def train(self):
for e in range(self.max_epochs):
# --- Train ---
self.vit.eval()
self.head.train()
train_loss_sum = 0.0
train_batches = 0
for batch_idx, batch in enumerate(self.train_loader):
img, targ_dist, valid_mask = batch
loss = self.step(batch, "train")
train_loss_sum += loss
train_batches += 1
# --- Val ---
self.vit.eval()
self.head.eval()
val_loss_sum = 0.0
val_batches = 0
for batch_idx, batch in enumerate(self.val_loader):
val_loss_sum += self.step(batch, "val")
val_batches += 1
metrics = torch.tensor(
[train_loss_sum, train_batches, val_loss_sum, val_batches],
device=self.device,
dtype=torch.float32
)
# Compute global averages
avg_train_loss = (metrics[0] / metrics[1]).item()
avg_val_loss = (metrics[2] / metrics[3]).item()
# --- Logging (Rank 0 Only) ---
if self.local_rank == 0:
print(f"epoch : {e}, avg train loss : {avg_train_loss:.4f}")
print(f"epoch : {e}, avg val loss : {avg_val_loss:.4f}")
self.writer.add_scalar("Loss/train", avg_train_loss, e)
self.writer.add_scalar("Loss/val", avg_val_loss, e)
# Force TensorBoard to write to disk immediately
self.writer.flush()
# --- Cleanup ---
if self.local_rank == 0:
self.writer.close()
if __name__ == "__main__":
use_torch_compile = False
return_mode = "linear_head_train_data"
global_batch_size = 8
world_size = int(os.environ.get("WORLD_SIZE",1))
num_workers = 0 #8
prefetch_factor = None #2
base_lr = 0.0003
weight_decay = 0.01
warmup_steps = 100
max_epochs = 100
Hn = 768
Wn = 768
logs_dir = Path("logs")
version_name = "debug"
timestamp = time.strftime("%Y%m%d_%H%M%S")
run_log_dir = logs_dir / f"{version_name}_{timestamp}"
precision = "bf16" # ["fp16", "bf16"]
channels_last = True
try:
mp.set_start_method('spawn', force=True)
except RuntimeError:
pass
trainer = LinearHeadTrainer(
return_mode = return_mode,
global_batch_size = global_batch_size,
world_size = world_size,
num_workers = num_workers,
prefetch_factor = prefetch_factor,
use_torch_compile = use_torch_compile,
base_lr = base_lr,
weight_decay = weight_decay,
warmup_steps = warmup_steps,
max_epochs = max_epochs,
Hn = Hn,
Wn = Wn,
run_log_dir = run_log_dir,
precision = precision,
channels_last = channels_last,
)
trainer.train()
Error message :
root@8583fcbfbd27:/workspace/repo# PYTHONFAULTHANDLER=1 CUDA_LAUNCH_BLOCKING=1 TORCH_SHOW_CPP_STACKTRACES=1 python debug_train_vit.py
epoch : 0, avg train loss : 7.3723
epoch : 0, avg val loss : 7.3655
epoch : 1, avg train loss : 7.3528
epoch : 1, avg val loss : 7.3324
epoch : 2, avg train loss : 7.3108
epoch : 2, avg val loss : 7.2783
epoch : 3, avg train loss : 7.2476
epoch : 3, avg val loss : 7.2051
epoch : 4, avg train loss : 7.1702
epoch : 4, avg val loss : 7.1217
epoch : 5, avg train loss : 7.0820
epoch : 5, avg val loss : 7.0256
epoch : 6, avg train loss : 6.9836
epoch : 6, avg val loss : 6.9295
epoch : 7, avg train loss : 6.8891
epoch : 7, avg val loss : 6.8358
epoch : 8, avg train loss : 6.7969
epoch : 8, avg val loss : 6.7471
epoch : 9, avg train loss : 6.7156
epoch : 9, avg val loss : 6.6714
Fatal Python error: Segmentation fault
Thread 0x00007f717261f6c0 (most recent call first):
<no Python frame>
Thread 0x00007f747da2a6c0 (most recent call first):
File "/usr/lib/python3.12/threading.py", line 359 in wait
File "/usr/lib/python3.12/queue.py", line 180 in get
File "/usr/local/lib/python3.12/dist-packages/tensorboard/summary/writer/event_file_writer.py", line 269 in _run
File "/usr/local/lib/python3.12/dist-packages/tensorboard/summary/writer/event_file_writer.py", line 244 in run
File "/usr/lib/python3.12/threading.py", line 1073 in _bootstrap_inner
File "/usr/lib/python3.12/threading.py", line 1030 in _bootstrap
Current thread 0x00007f75764f7300 (most recent call first):
File "/usr/local/lib/python3.12/dist-packages/torch/nn/functional.py", line 2935 in layer_norm
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/normalization.py", line 229 in forward
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1790 in _call_impl
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1779 in _wrapped_call_impl
File "/workspace/repo/debug_train_vit.py", line 284 in forward
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1790 in _call_impl
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1779 in _wrapped_call_impl
File "/workspace/repo/debug_train_vit.py", line 317 in forward
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1790 in _call_impl
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1779 in _wrapped_call_impl
File "/workspace/repo/debug_train_vit.py", line 344 in forward
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1790 in _call_impl
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1779 in _wrapped_call_impl
File "/workspace/repo/debug_train_vit.py", line 513 in step
File "/workspace/repo/debug_train_vit.py", line 537 in train
File "/workspace/repo/debug_train_vit.py", line 619 in <module>
Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, google._upb._message, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._pcg64, numpy.random._generator, numpy.random._mt19937, numpy.random._philox, numpy.random._sfc64, numpy.random.mtrand (total: 23)
Segmentation fault (core dumped)