torchvision.transforms에 관해 궁금한게 있습니다

transform 을 compose로 묶을때,
예를들면
transform.compose = [randomaffine, horizontalflip, colorjitter]
이런 compose가 있다고 했을때 해당 트랜스폼을 실행 시 무조건 내부의 세 트랜스폼이 다 실행되는건가요?
아니면 확률적으로 적용되서 일부 트랜스폼만 실행이 되는건가요?

또, Dataset에 transform을 넘겨줄 때, 기존 데이터셋과 변조된 데이터셋을 같이 학습에 활용하고 싶다면
D1 = CustomDataset(transform = None)
D2 = CustomDataset(transform = transform)

D1+D2 이 방법 밖에는 없나요?

transforms.compose는 주어진 변환들을 모두 수행합니다. (확률적으로 적용하시려면 RandomApply를 이용)
https://pytorch.org/vision/stable/generated/torchvision.transforms.RandomApply.html#torchvision.transforms.RandomApply

말씀하신 D1+D2 외에
torch.utils.data.ConcatDataset도 있습니다.
https://pytorch.org/docs/stable/data.html#torch.utils.data.ConcatDataset

또는 커스텀 데이터셋을 작성해서 인덱스에 따라 transform을 적용하는 방법도 좋을 것 같습니다.

import torch
from typing import Callable, Optional
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torchvision.transforms import Compose, RandomApply


class CustomDataset(Dataset):
    """{m} ~ {M} 범위의 정수 반환"""
    def __init__(self, m: int, M: int, transform: Optional[Callable] = None):
        super().__init__()
        self.transform = transform
        self.nums = list(range(m, M+1))

    def __len__(self):
        return len(self.nums)

    def __getitem__(self, idx):
        n = self.nums[idx]
        if self.transform:
            n = self.transform(n)
        return n


class AddDot5:
    """0.5를 더함"""
    def __call__(self, n):
        return n + .5


class Negative:
    """-1을 곱함"""
    def __call__(self, n):
        return -n


if __name__ == '__main__':

    def print_dataset(ds):
        for n in ds:
            print(n, end=' ')
        print()

    ds0 = CustomDataset(0, 3)
    ds1 = CustomDataset(3, 6, transform=AddDot5())

    print('데이터셋 연결\n')

    print('0. ds0 + ds1')
    print_dataset(ds0 + ds1)
    # 0 1 2 3 3.5 4.5 5.5 6.5

    print('\n1. torch.utils.data.ConcatDataset([ds0, ds1])')
    print_dataset(ConcatDataset([ds0, ds1]))
    # 0 1 2 3 3.5 4.5 5.5 6.5


    print('############################################################# \n')

    print('transform 적용')

    print('\nCompose')

    ds2 = CustomDataset(0, 5, transform=Compose([AddDot5()]))
    print_dataset(ds2)
    # 0.5 1.5 2.5 3.5 4.5 5.5 -> 모든 변환이 모든 아이템에 적용

    print('\nCompose + RandomApply(all transforms)')

    ds3 = CustomDataset(0, 10, transform=Compose([RandomApply([AddDot5(), Negative()], p=.5)]))
    print_dataset(ds3)
    # 0 1 -2.5 -3.5 -4.5 -5.5 6 -7.5 8 9 10 -> 모든 변환이 확률적으로 적용

    print('\nCompose + RandomApply(transform)')

    t = Compose([
        RandomApply([AddDot5()], p=.5),
        RandomApply([Negative()], p=.5),
    ])
    ds4 = CustomDataset(0, 10, transform=t)
    print_dataset(ds4)
    # 0 -1 2.5 -3.5 -4.5 -5.5 6 -7 -8.5 9 -10.5 -> 개별 변환이 각각 확률적으로 적용
좋아요 2

감사합니다!!

좋아요 1