안녕하세요. Pytorch를 이용해서 MNIST 손글씨 데이터를 GAN 모델에 학습시키고 있는데요. GAN 모델의 생성자, 판별자 손실 함수는 Epoch를 1만 하더라도 이상적으로 잘 감소하는데.. 이상하게 생성자가 생성한 fake 이미지를 막상 출력해보면 너무 터무니 없이 다른데요..? 어떤 부분에서 원인이 될까요? Epoch를 늘려야 할까요..? 늘리기에는 Epoch가 1이 되더라도 판별자, 생성자 손실 함수가 0.000에 수렴해버립니다..
모델 구조, 데이터 로드, 학습 코드 전문은 아래와 같습니다. 참고로 손실함수는 waserstein Loss function을 사용해서 WGAN을 구현하였어요! 원인이 될만한 부분 지적해주시면 감사하겠습니다 (__)
# 모델 구조
import torch
import torch.nn as nn
nz = 100
ngf = 64
nc = 1
ndf = 64
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main1 = nn.Sequential(
nn.ConvTranspose2d(in_channels=nz,
out_channels=ngf*4,
kernel_size=4,
stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(num_features=ngf*4),
nn.ReLU()
)
self.main2 = nn.Sequential(
nn.ConvTranspose2d(in_channels=ngf*4,
out_channels=ngf*2,
kernel_size=4,
stride=2,
padding=1,
bias=False),
nn.BatchNorm2d(num_features=ngf*2),
nn.ReLU()
)
self.main3 = nn.Sequential(
nn.ConvTranspose2d(in_channels=ngf*2,
out_channels=ngf,
kernel_size=2,
stride=2,
padding=1,
bias=False),
nn.BatchNorm2d(num_features=ngf),
nn.ReLU()
)
self.main4 = nn.Sequential(
nn.ConvTranspose2d(in_channels=ngf,
out_channels=nc,
kernel_size=4,
stride=2,
padding=1,
bias=False),
nn.Tanh()
)
def forward(self, x):
x = self.main1(x)
x = self.main2(x)
x = self.main3(x)
x = self.main4(x)
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main1 = nn.Sequential(
nn.Conv2d(in_channels=nc,
out_channels=ndf,
kernel_size=4,
stride=2,
padding=1,
bias=False),
nn.LeakyReLU()
)
self.main2 = nn.Sequential(
nn.Conv2d(in_channels=ndf,
out_channels=ndf*2,
kernel_size=4,
stride=2,
padding=1,
bias=False),
nn.BatchNorm2d(num_features=ndf*2),
nn.LeakyReLU()
)
self.main3 = nn.Sequential(
nn.Conv2d(in_channels=ndf*2,
out_channels=ndf*4,
kernel_size=4,
stride=2,
padding=1,
bias=False),
nn.BatchNorm2d(num_features=ndf*4),
nn.LeakyReLU()
)
self.main4 = nn.Sequential(
nn.Conv2d(in_channels=ndf*4,
out_channels=ndf*8,
kernel_size=4,
stride=2,
padding=1,
bias=False),
)
self.fc = nn.Sequential(
nn.Linear(in_features=512, out_features=128),
nn.LeakyReLU(),
nn.Linear(in_features=128, out_features=16),
nn.LeakyReLU(),
nn.Linear(in_features=16, out_features=1),
nn.Sigmoid()
)
def forward(self, x):
x = self.main1(x)
x = self.main2(x)
x = self.main3(x)
x = self.main4(x)
x = nn.Flatten()(x)
y = self.fc(x)
return y
# 데이터 로드
import torchvision.datasets as dsets
import torchvision.transforms as transforms
# Image Processing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], # 1 for gray scale 만약, RGB channels라면 mean=(0.5, 0.5, 0.5)
std=[0.5])]) # 1 for gray scale 만약, RGB channels라면 std=(0.5, 0.5, 0.5)
# MNIST 데이터셋
train_data = dsets.MNIST(root='data/',
train=True, # 트레인 셋
transform=transform,
download=True)
test_data = dsets.MNIST(root='data/',
train=False,
transform=transform,
download=True)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)
# WGAN
from torch.optim import Adam
from torch.nn import BCELoss
# define model
generator = Generator()
discriminator = Discriminator()
# optimizer
gen_optim = Adam(generator.parameters(), lr=0.00005)
dis_optim = Adam(discriminator.parameters(), lr=0.00005)
# loss
def wasserstein(y_true, y_pred):
return torch.mean(y_true * y_pred)
criterion = wasserstein
# params
EPOCHS = 1
for i in range(EPOCHS):
d_losses = []
g_losses = []
idx = 0
for x, y in train_loader:
real_label = torch.ones(len(x), 1)
fake_label = torch.zeros(len(x), 1)
# train discriminator(1) - based on real image
dis_optim.zero_grad()
real_pred = discriminator(x)
d_real_loss = criterion(real_pred, real_label)
d_real_loss.backward()
# train discriminator(2) - based on fake image
noise = torch.randn(len(x), nz, 1, 1)
fake_x = generator(noise)
fake_pred = discriminator(fake_x.detach().clone())
d_fake_loss = criterion(fake_pred, fake_label)
d_fake_loss.backward()
d_loss = d_real_loss + d_fake_loss
dis_optim.step()
# train generator
gen_optim.zero_grad()
fake_pred = discriminator(fake_x)
g_loss = criterion(fake_pred, real_label)
g_loss.backward()
gen_optim.step()
d_losses.append(d_loss)
g_losses.append(g_loss)
idx += 1
if (idx+1) % 10 == 0:
print("Epoch({e})(iteration n:{n}-> Discriminator Loss {d:.3f} | Generator Loss {g:.3f}".format(e=i+1, n=idx+1, d=d_loss, g=g_loss))
d_loss_avg = sum(d_losses) / len(d_losses)
g_loss_avg = sum(g_losses) / len(g_losses)
print("### {e} Epoch finished! | Discriminator Loss {d:.3f} | Generator Loss {g:.3f}".format(e=i+1, d=d_loss_avg, g=g_loss_avg))