pytorch GAN 학습 후 이미지 생성

안녕하세요. Pytorch를 이용해서 MNIST 손글씨 데이터를 GAN 모델에 학습시키고 있는데요. GAN 모델의 생성자, 판별자 손실 함수는 Epoch를 1만 하더라도 이상적으로 잘 감소하는데.. 이상하게 생성자가 생성한 fake 이미지를 막상 출력해보면 너무 터무니 없이 다른데요..? 어떤 부분에서 원인이 될까요? Epoch를 늘려야 할까요..? 늘리기에는 Epoch가 1이 되더라도 판별자, 생성자 손실 함수가 0.000에 수렴해버립니다..

모델 구조, 데이터 로드, 학습 코드 전문은 아래와 같습니다. 참고로 손실함수는 waserstein Loss function을 사용해서 WGAN을 구현하였어요! 원인이 될만한 부분 지적해주시면 감사하겠습니다 (__)

# 모델 구조
import torch
import torch.nn as nn

nz = 100
ngf = 64
nc = 1
ndf = 64

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=nz,
                               out_channels=ngf*4,
                               kernel_size=4,
                               stride=1,
                               padding=0,
                               bias=False),
            nn.BatchNorm2d(num_features=ngf*4),
            nn.ReLU()
        )
        self.main2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=ngf*4,
                               out_channels=ngf*2,
                               kernel_size=4,
                               stride=2,
                               padding=1,
                               bias=False),
            nn.BatchNorm2d(num_features=ngf*2),
            nn.ReLU()
        )
        self.main3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=ngf*2,
                               out_channels=ngf,
                               kernel_size=2,
                               stride=2,
                               padding=1,
                               bias=False),
            nn.BatchNorm2d(num_features=ngf),
            nn.ReLU()
        )
        self.main4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=ngf,
                               out_channels=nc,
                               kernel_size=4,
                               stride=2,
                               padding=1,
                               bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.main1(x)
        x = self.main2(x)
        x = self.main3(x)
        x = self.main4(x)
        return x
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main1 = nn.Sequential(
            nn.Conv2d(in_channels=nc,
                      out_channels=ndf,
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      bias=False),
            nn.LeakyReLU()
        )
        self.main2 = nn.Sequential(
            nn.Conv2d(in_channels=ndf,
                      out_channels=ndf*2,
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(num_features=ndf*2),
            nn.LeakyReLU()
        )
        self.main3 = nn.Sequential(
            nn.Conv2d(in_channels=ndf*2,
                      out_channels=ndf*4,
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(num_features=ndf*4),
            nn.LeakyReLU()
        )
        self.main4 = nn.Sequential(
            nn.Conv2d(in_channels=ndf*4,
                      out_channels=ndf*8,
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      bias=False),
        )
        self.fc = nn.Sequential(
            nn.Linear(in_features=512, out_features=128),
            nn.LeakyReLU(),
            nn.Linear(in_features=128, out_features=16),
            nn.LeakyReLU(),
            nn.Linear(in_features=16, out_features=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.main1(x)
        x = self.main2(x)
        x = self.main3(x)
        x = self.main4(x)
        x = nn.Flatten()(x)
        y = self.fc(x)
        return y
# 데이터 로드
import torchvision.datasets as dsets
import torchvision.transforms as transforms

# Image Processing
transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.5], # 1 for gray scale 만약, RGB channels라면 mean=(0.5, 0.5, 0.5)
                                         std=[0.5])])  # 1 for gray scale 만약, RGB channels라면 std=(0.5, 0.5, 0.5)

# MNIST 데이터셋
train_data = dsets.MNIST(root='data/',
                         train=True, # 트레인 셋
                         transform=transform,
                         download=True)
test_data  = dsets.MNIST(root='data/', 
                          train=False,
                          transform=transform,
                          download=True)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)
# WGAN
from torch.optim import Adam
from torch.nn import BCELoss

# define model
generator = Generator()
discriminator = Discriminator()

# optimizer
gen_optim = Adam(generator.parameters(), lr=0.00005)
dis_optim = Adam(discriminator.parameters(), lr=0.00005)

# loss
def wasserstein(y_true, y_pred):
    return torch.mean(y_true * y_pred)

criterion = wasserstein

# params
EPOCHS = 1

for i in range(EPOCHS):
    d_losses = []
    g_losses = []
    idx = 0
    for x, y in train_loader:
        real_label = torch.ones(len(x), 1)
        fake_label = torch.zeros(len(x), 1)

        # train discriminator(1) - based on real image
        dis_optim.zero_grad()
        real_pred = discriminator(x)
        d_real_loss = criterion(real_pred, real_label)
        d_real_loss.backward()
        # train discriminator(2) - based on fake image
        noise = torch.randn(len(x), nz, 1, 1)
        fake_x = generator(noise)
        fake_pred = discriminator(fake_x.detach().clone())
        d_fake_loss = criterion(fake_pred, fake_label)
        d_fake_loss.backward()

        d_loss = d_real_loss + d_fake_loss
        dis_optim.step()


        # train generator
        gen_optim.zero_grad()
        fake_pred = discriminator(fake_x)
        g_loss = criterion(fake_pred, real_label)
        g_loss.backward()

        gen_optim.step()
        
        d_losses.append(d_loss)
        g_losses.append(g_loss)
        idx += 1
        
        if (idx+1) % 10 == 0:
            print("Epoch({e})(iteration n:{n}-> Discriminator Loss {d:.3f} | Generator Loss {g:.3f}".format(e=i+1, n=idx+1, d=d_loss, g=g_loss))
            
    d_loss_avg = sum(d_losses) / len(d_losses)
    g_loss_avg = sum(g_losses) / len(g_losses)
    print("### {e} Epoch finished! | Discriminator Loss {d:.3f} | Generator Loss {g:.3f}".format(e=i+1, d=d_loss_avg, g=g_loss_avg))
  • 아래 답변은 OpenAI의 GPT-3.5(ChatGPT) 모델로 자동 생성한 것으로, 정확하지 않을 수 있습니다.
  • 잘못된 답변을 보셨다면 댓글로 첨삭 지도 부탁드립니다. :sweat_smile:

안녕하세요! GAN 모델을 사용하여 MNIST 손글씨 데이터를 학습하고 계신다는 것을 알려주셔서 감사합니다. 생성자와 판별자의 손실 함수가 이상적으로 잘 감소하지만 생성된 가짜 이미지가 실제 이미지와 매우 다르게 보인다는 문제가 있습니다. 이 문제의 원인과 해결 방법에 대해 알려드리겠습니다.

가짜 이미지가 실제 이미지와 매우 다르게 보인다는 것은 생성자가 원하는 결과를 생성하지 못하고 있다는 의미입니다. 이는 생성자의 학습이 제대로 이루어지지 않았거나, 생성자와 판별자의 균형이 맞지 않는 것일 수 있습니다.

가장 먼저 확인해야 할 부분은 생성자와 판별자의 구조와 학습 파라미터입니다. 생성자와 판별자의 구조가 적절한지 확인하고, 학습 파라미터의 설정도 적절한지 확인해야 합니다. 또한, 학습률(learning rate)도 적절한 값을 설정해야 합니다. 학습률이 너무 크면 학습이 불안정해지고, 너무 작으면 학습이 느려질 수 있습니다.

또한, WGAN의 경우 Wasserstein 손실 함수를 사용하고 있습니다. Wasserstein 손실 함수는 일반적인 이진 교차 엔트로피 손실 함수와는 다른 형태를 가지고 있습니다. 따라서, 이 손실 함수를 사용할 때에는 적절한 하이퍼파라미터 설정이 필요합니다. Wasserstein 손실 함수의 하이퍼파라미터를 조정하여 학습이 잘 이루어지도록 해야 합니다.

또한, 학습 데이터의 양과 품질도 중요한 요소입니다. MNIST 데이터셋은 비교적 작은 규모의 데이터셋이기 때문에, 더 많은 데이터를 사용하거나 데이터 증강(data augmentation) 기법을 적용하여 학습 데이터의 다양성을 높일 수 있습니다.

마지막으로, 학습을 더 오래 진행해보는 것도 한 가지 방법입니다. 현재 코드에서는 EPOCHS를 1로 설정하고 있습니다. 이는 매우 작은 값이기 때문에, 학습이 충분히 이루어지지 않을 수 있습니다. 더 많은 에폭을 설정하여 학습을 더 오래 진행해보는 것도 시도해 볼만한 방법입니다.

이러한 점들을 고려하여 코드를 수정하고 실험을 진행해보시기 바랍니다. 이상적으로는 생성자가 실제 이미지와 유사한 가짜 이미지를 생성할 수 있도록 학습되어야 합니다. 문제가 해결되지 않는다면, 추가적인 정보를 제공해주시면 더 자세한 도움을 드릴 수 있습니다.

아래는 PyTorch 공식 튜토리얼과 한국어 튜토리얼 링크입니다. 이 링크들을 참고하시면 GAN 모델 학습에 도움이 될 수 있습니다.

감사합니다!


  • 답변이 어떠셨나요? 괜찮았다고 생각하신다면 아래 :heart:를 한 번 눌러주시길 부탁드려요!
  • 아래 답변은 Anthropic의 Claude-2 모델로 자동 생성한 것으로, 정확하지 않을 수 있습니다.
  • 잘못된 답변을 보셨다면 댓글로 첨삭 지도 부탁드립니다. :sweat_smile:

안녕하세요. Pytorch로 GAN 모델을 구현하는 과정에서 어려움이 있다는 것을 잘 이해합니다.

GAN 모델은 훈련이 어려운 모델이죠. 손실 함수가 잘 감소하더라도 생성된 이미지가 좋지 않다면 몇 가지 이유가 있을 수 있습니다.

  1. 먼저, Epoch수가 부족할 가능성이 있습니다. GAN 모델은 일반적으로 수천 개 이상의 Epoch가 필요합니다. 1만 개 정도로 해보시는 것이 좋겠습니다.

  2. 또 한 가지는 모델 구조입니다. 생성자와 판별자의 층이나 유닛수가 적절한지 확인이 필요합니다. 복잡도를 높여야 할 수 있습니다.

  3. 하이퍼파라미터 튜닝도 중요합니다. Learning rate, Optimizer 선택 등이 결과에 영향을 줍니다.

  4. 또 데이터 전처리, 정규화 방법도 중요합니다.

코드는 문제 없어 보이는데, 위의 사항들을 점검해보시기 바랍니다.

필요하다면 제가 도와드리겠습니다. GAN 학습은 인내심이 필요하지만 결국 좋은 결과를 얻을 수 있습니다. 파이팅!


  • 답변이 어떠셨나요? 괜찮았다고 생각하신다면 아래 :heart:를 한 번 눌러주시길 부탁드려요!
  • 아래 답변은 OpenAI의 GPT-4 모델로 자동 생성한 것으로, 정확하지 않을 수 있습니다.
  • 잘못된 답변을 보셨다면 댓글로 첨삭 지도 부탁드립니다. :sweat_smile:

안녕하세요, young-hun-jo님. PyTorch를 이용하여 GAN 모델을 학습시키는 과정에서 발생한 문제에 대해 질문해주셨군요. 질문 내용을 정리해보면, GAN 모델의 생성자와 판별자의 손실 함수는 잘 감소하고 있지만, 생성자가 생성한 가짜 이미지를 출력해보면 원하는 결과와 많이 다르다는 것입니다. 그리고 이에 대한 원인을 찾고자 하시는 것 같습니다.

GAN 모델의 학습은 매우 까다로운 작업입니다. 생성자와 판별자가 서로 경쟁하면서 학습을 진행하는데, 이 과정에서 두 모델이 서로 균형을 이루지 못하면 원하는 결과를 얻기 어렵습니다. 이를 mode collapse라고 부르는데, 이 현상이 발생하면 생성자는 항상 비슷한 이미지만을 생성하게 됩니다.

또한, WGAN을 사용하실 때는 원래의 GAN과는 다르게 클리핑(clipping)이나 그래디언트 패널티(gradient penalty)를 적용해야 합니다. 이는 WGAN의 원리에 따른 것인데, 이를 적용하지 않으면 원하는 결과를 얻기 어렵습니다.

아래는 클리핑을 적용한 코드 예시입니다:

# WGAN with weight clipping
from torch.optim import RMSprop

# define model
generator = Generator()
discriminator = Discriminator()

# optimizer
gen_optim = RMSprop(generator.parameters(), lr=0.00005)
dis_optim = RMSprop(discriminator.parameters(), lr=0.00005)

# loss
def wasserstein(y_true, y_pred):
    return torch.mean(y_true * y_pred)

criterion = wasserstein

# params
EPOCHS = 1
CLIP_VALUE = 0.01  # 클리핑 값

for i in range(EPOCHS):
    # ... (생략) ...

    # Update discriminator weights
    for p in discriminator.parameters():
        p.data.clamp_(-CLIP_VALUE, CLIP_VALUE)

    # ... (생략) ...

이 외에도 학습률(learning rate), 배치 크기(batch size), 초기화 방법(initialization), 활성화 함수(activation function) 등 다양한 요인이 GAN의 학습 결과에 영향을 미칠 수 있습니다. 이러한 요인들을 조절하면서 실험을 진행해보시는 것을 추천드립니다.

더 자세한 내용은 PyTorch 공식 튜토리얼의 GAN 부분을 참고하시면 좋을 것 같습니다. 아래 링크를 참조해주세요.

제가 제시한 해결 방법이 반드시 맞는 것은 아니므로, 다양한 방법을 시도해보시면서 가장 적합한 해결책을 찾아보시기 바랍니다. 추가적인 질문이 있으시다면 언제든지 물어보세요. 도움이 되었으면 좋겠습니다. 감사합니다.


  • 답변이 어떠셨나요? 괜찮았다고 생각하신다면 아래 :heart:를 한 번 눌러주시길 부탁드려요!

여러가지로 살펴보았는데. 결론은 모델의 복잡도가 너무 높은게 원인이였습니다 ㅠ Linear 한 계층으로만 구성하니 이미지 생성이 잘 되네요. 손글씨 데이터에 컨볼루션 계층을 사용하는 것은 너무 배꼽이 큰 모델인 것 같네요... 누군가에게 도움이 될 수 도 있을 것 같아 자문자답 댓글 남깁니다

1개의 좋아요

:+1: 감사합니다, @young-hun-jo 님!

1개의 좋아요