Hi,
I am looking to train a semantic segmentation model with Fully-Convolutional Network model in PyTorch.
I found this semantic segmentation train.py from the dusty-nv repo
#
# Note -- this training script is tweaked from the original version at:
#
# https://github.com/pytorch/vision/tree/v0.3.0/references/segmentation
#
# It's also meant to be used against this fork of torchvision, which includes
# some patches for exporting to ONNX and adds fcn_resnet18 and fcn_resnet34:
#
# https://github.com/dusty-nv/vision/tree/v0.3.0
#
import argparse
import datetime
import time
import math
import os
import shutil
import torch
import torch.utils.data
from torch import nn
This file has been truncated. show original
I went through train.py and found that the train_one_epoch function accepts image and target from the data loader. I was wondering what the format and shape of the image and the target is accepted by the function.
I want to define a dataset class more like PennFudanDataset class from the following link
https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
Thank You.
Hi @harsha.tejas2002 , this code is original from torchvision example: vision/references/segmentation at main · pytorch/vision · GitHub
The input is RGB image and the target is a single-channel class ID mask, where each pixel corresponds to the class ID.
Hey,
Thanks for the reply.
class PennFundanPedDataset(torch.utils.data.Dataset):
def __init__(self, root, transforms = None):
self.root = root
self.transforms = transforms
self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))
def __getitem__(self, idx):
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
img = Image.open(img_path).convert("RGB")
target = Image.open(mask_path)
target = np.array(target)
target = torch.as_tensor(target, dtype=torch.int64)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.imgs)
Please let me know if I should make any changes in the code.
I am not completely sure about the shape of the target.
On the surface it looks okay, but I’m not familiar with this dataset so hard for me to say for sure at a glance. Recommend that you consult the other segmentation dataloaders to compare and aid in your debugging.
system
Closed
September 19, 2021, 5:46am
10
This topic was automatically closed 60 days after the last reply. New replies are no longer allowed.