import torch
import torch.nn as nn
import torch.cuda.amp as amp
import ctypes
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1):
super().__init__()
# Encoder
self.enc1 = DoubleConv(in_channels, 64)
self.enc2 = DoubleConv(64, 128)
self.enc3 = DoubleConv(128, 256)
self.enc4 = DoubleConv(256, 512)
self.enc5 = DoubleConv(512, 1024)
# Decoder
self.dec4 = DoubleConv(1024 + 512, 512)
self.dec3 = DoubleConv(512 + 256, 256)
self.dec2 = DoubleConv(256 + 128, 128)
self.dec1 = DoubleConv(128 + 64, 64)
# Final convolution
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
# Pooling and upsampling
self.pool = nn.MaxPool2d(2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def forward(self, x):
# Encoder
enc1 = self.enc1(x)
x = self.pool(enc1)
enc2 = self.enc2(x)
x = self.pool(enc2)
enc3 = self.enc3(x)
x = self.pool(enc3)
enc4 = self.enc4(x)
x = self.pool(enc4)
# Bridge
x = self.enc5(x)
# Decoder
x = self.up(x)
x = torch.cat([x, enc4], dim=1)
x = self.dec4(x)
x = self.up(x)
x = torch.cat([x, enc3], dim=1)
x = self.dec3(x)
x = self.up(x)
x = torch.cat([x, enc2], dim=1)
x = self.dec2(x)
x = self.up(x)
x = torch.cat([x, enc1], dim=1)
x = self.dec1(x)
return self.final_conv(x)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = UNet().to(device)
# Create dummy input
batch_size = 1
channels = 3
height = 512
width = 512
x = torch.randn(batch_size, channels, height, width).to(device)
# Initialize autocast for mixed precision
scaler = amp.GradScaler()
# Forward pass with mixed precision
with amp.autocast():
output = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Model using FP16: {next(model.parameters()).dtype == torch.float16}")
print(f"Output dtype: {output.dtype}")
If we use cuda_profiler.cu
on this example, the symbolName address is 0x20000000001