when I use torchvision.transforms.TrivialAugmentWide while transforming the image, My Orin Nano is 8GB with 16GB SWAP, i saw the memory fills on jtop without even using GUI desktop mode.
I use nvidia container: l4t-ml any issue with the pytorch package because i am trying to follow a guide 04. PyTorch Custom Datasets - Zero to Mastery Learn PyTorch for Deep Learning
def plot_transformed_images(image_paths:list,transform,n=5):
rand_img_paths=random.sample(image_paths,k=n)
for imgpath in rand_img_paths:
with Image.open(imgpath) as f:
fig,ax=plt.subplots(1,2)
ax[0].imshow(f)
ax[0].axis(False)
ax[0].set_title(f"Orginal\nSize:{f.size}")
tfd=transform(f).permute(1,2,0)
ax[1].imshow(tfd)
ax[1].axis(False)
ax[1].set_title(f"Transformed\nSize:{tfd.shape}")
fig.suptitle(f"Class: {imgpath.parent.stem}",fontsize=16)
transform_1=transforms.Compose([transforms.Resize(size=(256*256)),
transforms.TrivialAugmentWide(num_magnitude_bins=1),
transforms.ToTensor()])
plot_transformed_images(image_path_list,transform_1,6) # OOM happens here
surprisingly the transform_1 works while loading datasets
ex
tr=datasets.ImageFolder(root=train_path,transform=transform_1,target_transform=None)
works without error.
but cannot train the model even if the dataset loaded with the transform
error I got:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:1132, in _MultiProcessingDataLoaderIter._try_get_data(self, timeout)
1131 try:
-> 1132 data = self._data_queue.get(timeout=timeout)
1133 return (True, data)
File /usr/lib/python3.10/multiprocessing/queues.py:113, in Queue.get(self, block, timeout)
112 timeout = deadline - time.monotonic()
--> 113 if not self._poll(timeout):
114 raise Empty
File /usr/lib/python3.10/multiprocessing/connection.py:257, in _ConnectionBase.poll(self, timeout)
256 self._check_readable()
--> 257 return self._poll(timeout)
File /usr/lib/python3.10/multiprocessing/connection.py:424, in Connection._poll(self, timeout)
423 def _poll(self, timeout):
--> 424 r = wait([self], timeout)
425 return bool(r)
File /usr/lib/python3.10/multiprocessing/connection.py:931, in wait(object_list, timeout)
930 while True:
--> 931 ready = selector.select(timeout)
932 if ready:
File /usr/lib/python3.10/selectors.py:416, in _PollLikeSelector.select(self, timeout)
415 try:
--> 416 fd_event_list = self._selector.poll(timeout)
417 except InterruptedError:
File /usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/signal_handling.py:66, in _set_SIGCHLD_handler.<locals>.handler(signum, frame)
63 def handler(signum, frame):
64 # This following call uses `waitid` with WNOHANG from C side. Therefore,
65 # Python can still get and update the process status successfully.
---> 66 _error_if_any_worker_fails()
67 if previous_handler is not None:
RuntimeError: DataLoader worker (pid 192) is killed by signal: Killed.
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
Cell In[33], line 4
1 from tqdm.auto import tqdm
3 for epoch in tqdm(range(10)):
----> 4 train(model,tr_dl,loss_fn,optimizer,device)
5 test(model,ts_dl,loss_fn,acc_fn,device)
Cell In[31], line 4, in train(model, dataloader, loss_fn, optimizer, device)
2 model.to(device)
3 model.train()
----> 4 for X,y in dataloader:
5 X,y=X.to(device),y.to(device)
6 y_logits=model(X)
File /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
627 if self._sampler_iter is None:
628 # TODO(https://github.com/pytorch/pytorch/issues/76750)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and \
633 self._IterableDataset_len_called is not None and \
634 self._num_yielded > self._IterableDataset_len_called:
File /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:1328, in _MultiProcessingDataLoaderIter._next_data(self)
1325 return self._process_data(data)
1327 assert not self._shutdown and self._tasks_outstanding > 0
-> 1328 idx, data = self._get_data()
1329 self._tasks_outstanding -= 1
1330 if self._dataset_kind == _DatasetKind.Iterable:
1331 # Check for _IterableDatasetStopIteration
File /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:1294, in _MultiProcessingDataLoaderIter._get_data(self)
1290 # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
1291 # need to call `.task_done()` because we don't use `.join()`.
1292 else:
1293 while True:
-> 1294 success, data = self._try_get_data()
1295 if success:
1296 return data
File /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:1145, in _MultiProcessingDataLoaderIter._try_get_data(self, timeout)
1143 if len(failed_workers) > 0:
1144 pids_str = ', '.join(str(w.pid) for w in failed_workers)
-> 1145 raise RuntimeError(f'DataLoader worker (pid(s) {pids_str}) exited unexpectedly') from e
1146 if isinstance(e, queue.Empty):
1147 return (False, None)
RuntimeError: DataLoader worker (pid(s) 192) exited unexpectedly