Pytorch + cuda consumes all linux memory

This modest pytorch program runs happily in my local WSL2 env:

import torch
x = torch.tensor([[1.], [4.]])
w = torch.tensor([[2.]])
y = torch.nn.functional.linear(x, w)
tensor([[2.], [8.]])

This one gets killed by linux after slurping up all the available memory. No error messages other than “Killed”.

import torch
x0 = torch.tensor([[1.], [4.]], device='cuda')
w0 = torch.tensor([[2.]], device='cuda')
torch.nn.functional.linear(x0, w0)     <--- crashes here

I’m running an RTX 3060 Ti on a freshly-installed WSL2/Ubuntu under Win11. I’ve installed the latest NVIDIA Windows drivers (528.49) as well as several previous versions, and I get the same behavior on all of them.