현재 gym 라이브러리의 FrozenLake-v1 게임을 강화학습으로 해결하기 위해 Q-Network를 pytorch로 구현하고 있습니다. 학습을 진행하면서 Loss 값이 감소하고 게임 종료가 끝났는지 여부를 반환하는 done
변수도 True
로 잘 받아오면서 종료가 되는데, 이상하게 Reward 값이 계속 0.0으로 출력이 됩니다. 코드 전문은 아래와 같습니다.
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
# 모델 정의
class Network(nn.Module):
def __init__(self, n_states, n_actions):
super(Network, self).__init__()
self.n_states = n_states
self.n_actions = n_actions
self.linear = nn.Linear(in_features=self.n_states, out_features=self.n_actions)
def forward(self, x):
x = torch.Tensor([x])
x = F.one_hot(x.to(torch.int64), num_classes=self.n_states)
x = x.to(torch.float32)
y = self.linear(x)
return y
# setting environment
env = gym.make("FrozenLake-v1")
# params
n_states = env.observation_space.n
n_actions = env.action_space.n
num_episodes = 1
exploration_prob = 0.5
gamma = 0.99
lr = 0.1
# model, criterion, optimizer
network = Network(n_states, n_actions)
criterion = nn.MSELoss()
optimizer = SGD(params=network.parameters(), lr=lr)
# train
for i in range(num_episodes):
done = False
all_reward = 0
# init state
current_state, current_prob = env.reset()
while not done:
# init gradients in Q-network
optimizer.zero_grad()
# create label from Q-network using `current_state`
Qs = network(current_state)
# exploration & exploitation
if torch.rand(1) < exploration_prob:
action = env.action_space.sample()
else:
action = torch.argmax(Qs).item()
# step to next-state
next_state, reward, done, _, _ = env.step(action)
if done:
Qs[0, action] = reward
else:
# prediction from Q-network using `next_state`
Qs_pred = network(next_state)
Qs[0, action] = reward + gamma * torch.max(Qs_pred).item()
# Loss
loss = criterion(Qs, Qs_pred)
loss.backward()
# update params
optimizer.step()
# update next_state to current_state
current_state = next_state
all_reward += reward
print("Loss:", loss.item(), "Reward:", reward, "done:", done)
위 코드의 마지막 줄에 print
문을 보면 보상을 출력시키도록 했는데, 코드를 실행시키면 게임 종료가 끝날 때까지 계속 보상은 0.0 으로 출력됩니다. 위 코드가 제대로 된 코드인 걸까요? 왜 보상이 계속 0.0이 나오는지 잘 모르겠습니다. 원인이 무엇인지 답변 부탁드립니다!