안녕하세요.
Dataloader에서 dataset class별로 transforms 및 train이 가능한가요?
즉, 기존처럼 dataset 정의 후 batch size로 끊어서 순서대로 진행되는게 아닌 dataset을 class단위로 끊어서 transforms & 시각화해보고 싶은데 어디서 손을 대야할지 모르겠네요.
dataset = ImageList(open(args.s_dset_path).readlines(), transform=source_transform)
train_loader = DataLoader(dataset , batch_size=args.batch_size,
shuffle=False, num_workers=args.workers, drop_last=True)
class ImageList(Dataset):
def init(self, image_list, labels=None, transform=None, target_transform=None, mode=‘RGB’):
imgs = make_dataset(image_list, labels)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + “\n”
"Supported image extensions are: " + “,”.join(IMG_EXTENSIONS)))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
if mode == 'RGB':
self.loader = rgb_loader
elif mode == 'L':
self.loader = l_loader
def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, index
def __len__(self):
return len(self.imgs)