"RuntimeError: Given groups=1, weight of size [96, 3, 4, 4], expected input[32, 96, 96, 3] to have 3 channels, but got 96 channels instead" 오류의 해결방안이 있을까요?

안녕하세요,

오랜만에 파이토치를 연습할겸 직접 이미지 데이터셋을 제작하고 DCGAN으로 돌려보는데 훈련 중 제 실력으론 해결할 수 없는 오류가 발생해 질문드립니다.

이미지는 96x96x3이고, 레이블은 없습니다.

코랩에서 찍어본 구동 환경은 다음과 같습니다.
Python 3.7.15
torch 1.12.1

다음은 제가 작성한 코드입니다.

import torch 
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
  def __init__(self, img_path: str):
    self.imgs = img_path
  
  def __len__(self) -> int:
    return len(self.imgs)
  
  def __getitem__(self, idx: int):
    if torch.is_tensor(idx):
      idx = idx.tolist()
    
    image = self.imgs[idx]

    return image 


transpose = np.transpose(
    np.float32(load_data), # `load_data.shape` => (1133, 96, 96, 3)
    (0, 3, 1, 2)
)

dataset = CustomDataset(load_data)
dataloader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

다음은 훈련을 돌렸던 GAN 모델의 구조입니다.

class Discriminator(nn.Module):
    def __init__(self):
 
        super().__init__()
        self.model = nn.Sequential(
 
            nn.Conv2d(in_channels = 3, out_channels = 96, kernel_size = 4, stride = 2, padding = 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
 
            nn.Conv2d(in_channels = 96, out_channels = 96*2, kernel_size = 4, stride = 2, padding = 1, bias=False),
            nn.BatchNorm2d(96 * 2),
            nn.LeakyReLU(0.2, inplace=True),
 
            nn.Conv2d(in_channels = 96*2, out_channels = 96*4, kernel_size = 4, stride = 2, padding = 1, bias=False),
            nn.BatchNorm2d(96 * 4),
            nn.LeakyReLU(0.2, inplace=True),
 
            nn.Conv2d(in_channels = 96*4, out_channels = 96*8, kernel_size = 4, stride = 2, padding = 1, bias=False),
            nn.BatchNorm2d(96 * 8),
            nn.LeakyReLU(0.2, inplace=True),
 
            nn.Conv2d(in_channels = 96*8, out_channels = 3, kernel_size = 2, stride = 2, padding = 0, bias=False),
            nn.Sigmoid()
        )
     
    def forward(self, input):
        return self.model(input)
class Generator(nn.Module):
    def __init__(self):
 
        # calling constructor of parent class
        super().__init__()
 
        self.gen = nn.Sequential(
          nn.ConvTranspose2d(in_channels = 100, out_channels = 768, kernel_size = 4, stride = 1, padding = 0, bias = False),
          nn.BatchNorm2d(num_features = 768),
          nn.LeakyReLU(inplace = True),
 
          nn.ConvTranspose2d(in_channels = 768, out_channels = 384, kernel_size = 4, stride = 2, padding = 1, bias = False),
          nn.BatchNorm2d(num_features = 384),
          nn.LeakyReLU(inplace = True),
 
          nn.ConvTranspose2d(in_channels = 384, out_channels = 192, kernel_size = 4, stride = 2, padding = 1, bias = False),
          nn.BatchNorm2d(num_features = 192),
          nn.LeakyReLU(inplace = True),
 
          nn.ConvTranspose2d(in_channels = 192, out_channels = 96, kernel_size = 4, stride = 2, padding = 1, bias = False),
          nn.BatchNorm2d(num_features = 96),
          nn.LeakyReLU(inplace = True),

          nn.ConvTranspose2d(in_channels = 96, out_channels = 3, kernel_size = 4, stride = 2, padding = 1, bias = False),
          nn.Tanh()
        )
 
    def forward(self, input):
        return self.gen(input)

추가로 훈련 루프 코드입니다.

for epoch in range(EPOCHS):
  for i, b in enumerate(dataloader):
    ## Discriminator
    opt_D.zero_grad()
    yhat = netD(b.to(device)).view(-1) # <- 여기서 버그 발생 
    target = torch.ones(len(b), dtype=torch.float, device=device)
    loss_real = loss(yhat, target)
    loss_real.backward()

    z = torch.randn(len(b), 100, 1, 1, device=device)
    fake_img = netG(z)

    yhat = netD.cuda()(fake_img.detach()).view(-1)
    target = torch.zeros(len(b), dtype=torch.float, device=device)
    loss_fake = loss(yhat, target)
    loss_fake.backward()
    loss_disc = loss_real + loss_fake
    opt_D.step()

    ## Generator
    opt_G.zero_grad()
    y_hat = netD.cuda()(fake_img).view(-1)
    target = torch.ones(len(b), dtype=torch.float, device=device)
    loss_gen = loss(yhat, target)
    loss_gen.backward()
    opt_G.step()

실행 중 "RuntimeError: Given groups=1, weight of size [96, 3, 4, 4], expected input[32, 96, 96, 3] to have 3 channels, but got 96 channels instead"이라는 에러가 발생하는데 permute()로 조작해도 딱히 통하지는 않더라구요.

아무래도 제가 코드를 이상하게 짠게 확실하지만 도통 감이 안 잡힙니다. 혹시 도움을 요청해도 될까요?

파이토치의 경우 NCHW dimension으로 데이터를 처리하게 되는데요. 지금 데이터로더 쪽에서 NHWC 로 전달이 되는 것 같습니다.(32x96x96x3, 여기서 32는 배치사이즈, 96은 H, W, 3은 채널값)
이걸 32x3x96x96 으로 전달될 수 있게 수정이 필요해 보입니다.

올려주신 코드의 transpose 쪽에서 그런 작업을 하려고 하시는 것 같긴 한데 올바르게 반영되는지 확인 해보시는게 좋을것 같습니다.

1개의 좋아요

확인해보니까 확실히 transpose 쪽에서 반영이 잘 안된거 같네요. 답변 참고해서 한번 수정해봐야겠습니다. 감사합니다.