This modest pytorch program runs happily in my local WSL2 env:
import torch
x = torch.tensor([[1.], [4.]])
w = torch.tensor([[2.]])
y = torch.nn.functional.linear(x, w)
print(y)
tensor([[2.], [8.]])
This one gets killed by linux after slurping up all the available memory. No error messages other than “Killed”.
import torch
x0 = torch.tensor([[1.], [4.]], device='cuda')
w0 = torch.tensor([[2.]], device='cuda')
torch.nn.functional.linear(x0, w0) <--- crashes here
I’m running an RTX 3060 Ti on a freshly-installed WSL2/Ubuntu under Win11. I’ve installed the latest NVIDIA Windows drivers (528.49) as well as several previous versions, and I get the same behavior on all of them.