안녕하세요,
오랜만에 파이토치를 연습할겸 직접 이미지 데이터셋을 제작하고 DCGAN으로 돌려보는데 훈련 중 제 실력으론 해결할 수 없는 오류가 발생해 질문드립니다.
이미지는 96x96x3이고, 레이블은 없습니다.
코랩에서 찍어본 구동 환경은 다음과 같습니다.
Python 3.7.15
torch 1.12.1
다음은 제가 작성한 코드입니다.
import torch
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, img_path: str):
self.imgs = img_path
def __len__(self) -> int:
return len(self.imgs)
def __getitem__(self, idx: int):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.imgs[idx]
return image
transpose = np.transpose(
np.float32(load_data), # `load_data.shape` => (1133, 96, 96, 3)
(0, 3, 1, 2)
)
dataset = CustomDataset(load_data)
dataloader = DataLoader(
dataset=dataset,
batch_size=BATCH_SIZE,
shuffle=True,
)
다음은 훈련을 돌렸던 GAN 모델의 구조입니다.
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels = 3, out_channels = 96, kernel_size = 4, stride = 2, padding = 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels = 96, out_channels = 96*2, kernel_size = 4, stride = 2, padding = 1, bias=False),
nn.BatchNorm2d(96 * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels = 96*2, out_channels = 96*4, kernel_size = 4, stride = 2, padding = 1, bias=False),
nn.BatchNorm2d(96 * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels = 96*4, out_channels = 96*8, kernel_size = 4, stride = 2, padding = 1, bias=False),
nn.BatchNorm2d(96 * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels = 96*8, out_channels = 3, kernel_size = 2, stride = 2, padding = 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.model(input)
class Generator(nn.Module):
def __init__(self):
# calling constructor of parent class
super().__init__()
self.gen = nn.Sequential(
nn.ConvTranspose2d(in_channels = 100, out_channels = 768, kernel_size = 4, stride = 1, padding = 0, bias = False),
nn.BatchNorm2d(num_features = 768),
nn.LeakyReLU(inplace = True),
nn.ConvTranspose2d(in_channels = 768, out_channels = 384, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(num_features = 384),
nn.LeakyReLU(inplace = True),
nn.ConvTranspose2d(in_channels = 384, out_channels = 192, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(num_features = 192),
nn.LeakyReLU(inplace = True),
nn.ConvTranspose2d(in_channels = 192, out_channels = 96, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(num_features = 96),
nn.LeakyReLU(inplace = True),
nn.ConvTranspose2d(in_channels = 96, out_channels = 3, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.Tanh()
)
def forward(self, input):
return self.gen(input)
추가로 훈련 루프 코드입니다.
for epoch in range(EPOCHS):
for i, b in enumerate(dataloader):
## Discriminator
opt_D.zero_grad()
yhat = netD(b.to(device)).view(-1) # <- 여기서 버그 발생
target = torch.ones(len(b), dtype=torch.float, device=device)
loss_real = loss(yhat, target)
loss_real.backward()
z = torch.randn(len(b), 100, 1, 1, device=device)
fake_img = netG(z)
yhat = netD.cuda()(fake_img.detach()).view(-1)
target = torch.zeros(len(b), dtype=torch.float, device=device)
loss_fake = loss(yhat, target)
loss_fake.backward()
loss_disc = loss_real + loss_fake
opt_D.step()
## Generator
opt_G.zero_grad()
y_hat = netD.cuda()(fake_img).view(-1)
target = torch.ones(len(b), dtype=torch.float, device=device)
loss_gen = loss(yhat, target)
loss_gen.backward()
opt_G.step()
실행 중 "RuntimeError: Given groups=1, weight of size [96, 3, 4, 4], expected input[32, 96, 96, 3] to have 3 channels, but got 96 channels instead"이라는 에러가 발생하는데 permute()
로 조작해도 딱히 통하지는 않더라구요.
아무래도 제가 코드를 이상하게 짠게 확실하지만 도통 감이 안 잡힙니다. 혹시 도움을 요청해도 될까요?