OOM when using TrivialAugmentWide

when I use torchvision.transforms.TrivialAugmentWide while transforming the image, My Orin Nano is 8GB with 16GB SWAP, i saw the memory fills on jtop without even using GUI desktop mode.
I use nvidia container: l4t-ml any issue with the pytorch package because i am trying to follow a guide 04. PyTorch Custom Datasets - Zero to Mastery Learn PyTorch for Deep Learning

def plot_transformed_images(image_paths:list,transform,n=5):
    rand_img_paths=random.sample(image_paths,k=n)
    for imgpath in rand_img_paths:
        with Image.open(imgpath) as f:
            fig,ax=plt.subplots(1,2)
            ax[0].imshow(f)
            ax[0].axis(False)
            ax[0].set_title(f"Orginal\nSize:{f.size}")

            tfd=transform(f).permute(1,2,0)
            ax[1].imshow(tfd)
            ax[1].axis(False)
            ax[1].set_title(f"Transformed\nSize:{tfd.shape}")

            fig.suptitle(f"Class: {imgpath.parent.stem}",fontsize=16)

transform_1=transforms.Compose([transforms.Resize(size=(256*256)),
                               transforms.TrivialAugmentWide(num_magnitude_bins=1),
                               transforms.ToTensor()])

plot_transformed_images(image_path_list,transform_1,6) # OOM happens here

surprisingly the transform_1 works while loading datasets
ex
tr=datasets.ImageFolder(root=train_path,transform=transform_1,target_transform=None)
works without error.

but cannot train the model even if the dataset loaded with the transform

error I got:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:1132, in _MultiProcessingDataLoaderIter._try_get_data(self, timeout)
   1131 try:
-> 1132     data = self._data_queue.get(timeout=timeout)
   1133     return (True, data)

File /usr/lib/python3.10/multiprocessing/queues.py:113, in Queue.get(self, block, timeout)
    112 timeout = deadline - time.monotonic()
--> 113 if not self._poll(timeout):
    114     raise Empty

File /usr/lib/python3.10/multiprocessing/connection.py:257, in _ConnectionBase.poll(self, timeout)
    256 self._check_readable()
--> 257 return self._poll(timeout)

File /usr/lib/python3.10/multiprocessing/connection.py:424, in Connection._poll(self, timeout)
    423 def _poll(self, timeout):
--> 424     r = wait([self], timeout)
    425     return bool(r)

File /usr/lib/python3.10/multiprocessing/connection.py:931, in wait(object_list, timeout)
    930 while True:
--> 931     ready = selector.select(timeout)
    932     if ready:

File /usr/lib/python3.10/selectors.py:416, in _PollLikeSelector.select(self, timeout)
    415 try:
--> 416     fd_event_list = self._selector.poll(timeout)
    417 except InterruptedError:

File /usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/signal_handling.py:66, in _set_SIGCHLD_handler.<locals>.handler(signum, frame)
     63 def handler(signum, frame):
     64     # This following call uses `waitid` with WNOHANG from C side. Therefore,
     65     # Python can still get and update the process status successfully.
---> 66     _error_if_any_worker_fails()
     67     if previous_handler is not None:

RuntimeError: DataLoader worker (pid 192) is killed by signal: Killed. 

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[33], line 4
      1 from tqdm.auto import tqdm
      3 for epoch in tqdm(range(10)):
----> 4     train(model,tr_dl,loss_fn,optimizer,device)
      5     test(model,ts_dl,loss_fn,acc_fn,device)

Cell In[31], line 4, in train(model, dataloader, loss_fn, optimizer, device)
      2 model.to(device)
      3 model.train()
----> 4 for X,y in dataloader:
      5     X,y=X.to(device),y.to(device)
      6     y_logits=model(X)

File /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
    627 if self._sampler_iter is None:
    628     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    629     self._reset()  # type: ignore[call-arg]
--> 630 data = self._next_data()
    631 self._num_yielded += 1
    632 if self._dataset_kind == _DatasetKind.Iterable and \
    633         self._IterableDataset_len_called is not None and \
    634         self._num_yielded > self._IterableDataset_len_called:

File /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:1328, in _MultiProcessingDataLoaderIter._next_data(self)
   1325     return self._process_data(data)
   1327 assert not self._shutdown and self._tasks_outstanding > 0
-> 1328 idx, data = self._get_data()
   1329 self._tasks_outstanding -= 1
   1330 if self._dataset_kind == _DatasetKind.Iterable:
   1331     # Check for _IterableDatasetStopIteration

File /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:1294, in _MultiProcessingDataLoaderIter._get_data(self)
   1290     # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
   1291     # need to call `.task_done()` because we don't use `.join()`.
   1292 else:
   1293     while True:
-> 1294         success, data = self._try_get_data()
   1295         if success:
   1296             return data

File /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:1145, in _MultiProcessingDataLoaderIter._try_get_data(self, timeout)
   1143 if len(failed_workers) > 0:
   1144     pids_str = ', '.join(str(w.pid) for w in failed_workers)
-> 1145     raise RuntimeError(f'DataLoader worker (pid(s) {pids_str}) exited unexpectedly') from e
   1146 if isinstance(e, queue.Empty):
   1147     return (False, None)

RuntimeError: DataLoader worker (pid(s) 192) exited unexpectedly

Hi,

DataLoader worker (pid 192) is killed by signal: Killed. 

Killed is usually caused by the out of memory.
Could you lower the batchsize to see if it helps?

For example:

# Setup batch size and number of workers 
BATCH_SIZE = 8
NUM_WORKERS = 1
print(f"Creating DataLoader's with batch size {BATCH_SIZE} and {NUM_WORKERS} workers.")

# Create DataLoader's
train_dataloader_simple = DataLoader(train_data_simple, 
                                     batch_size=BATCH_SIZE, 
                                     shuffle=True, 
                                     num_workers=NUM_WORKERS)
...

Thanks.

Due to low (<100) number of custom samples for each class available for training, the BATCH_SIZE is already 1
Thanks.

Sorry its a silly mistake in code Mods: please close this

My mistake

Code completion switched (,) with (*)

Okay~ Thanks for updating.