This modest pytorch program runs happily in my local WSL2 env:
import torch x = torch.tensor([[1.], [4.]]) w = torch.tensor([[2.]]) y = torch.nn.functional.linear(x, w) print(y)
This one gets killed by linux after slurping up all the available memory. No error messages other than “Killed”.
import torch x0 = torch.tensor([[1.], [4.]], device='cuda') w0 = torch.tensor([[2.]], device='cuda') torch.nn.functional.linear(x0, w0) <--- crashes here
I’m running an RTX 3060 Ti on a freshly-installed WSL2/Ubuntu under Win11. I’ve installed the latest NVIDIA Windows drivers (528.49) as well as several previous versions, and I get the same behavior on all of them.