Dataloader에서 class별 분류 가능여부

안녕하세요.

Dataloader에서 dataset class별로 transforms 및 train이 가능한가요?
즉, 기존처럼 dataset 정의 후 batch size로 끊어서 순서대로 진행되는게 아닌 dataset을 class단위로 끊어서 transforms & 시각화해보고 싶은데 어디서 손을 대야할지 모르겠네요.


dataset = ImageList(open(args.s_dset_path).readlines(), transform=source_transform) train_loader = DataLoader(dataset , batch_size=args.batch_size, shuffle=False, num_workers=args.workers, drop_last=True)

class ImageList(Dataset):
def init(self, image_list, labels=None, transform=None, target_transform=None, mode=‘RGB’):
imgs = make_dataset(image_list, labels)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + “\n”
"Supported image extensions are: " + “,”.join(IMG_EXTENSIONS)))

    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    if mode == 'RGB':
        self.loader = rgb_loader
    elif mode == 'L':
        self.loader = l_loader

def __getitem__(self, index):
    path, target = self.imgs[index]
    img = self.loader(path)
    if self.transform is not None:
        img = self.transform(img)
    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target, index

def __len__(self):
    return len(self.imgs)
1개의 좋아요

안녕하세요, @lmg 님.

DataLoader에서 데이터를 가져올 때 Class에 따라서 서로 다른 transform을 적용하고 싶다는 말씀이실까요?

그렇다면 __init__()에 서로 다른 transform들을 만드신 뒤에, __getitem__()에서 데이터를 반환하기 전에 해당 index의 Class를 확인하고 서로 다른 transform을 적용하는 식으로 하면 어떨까요?

제가 맞게 이해했는지 모르겠네요 ^^;
혹시 다른걸 원하셨다면 말씀해주시기를 부탁드립니다~