DQN으로 Cartpole 문제를 해결하고 싶습니다
Version
Torch : 2.2.2
gymnasium : 0.28.1
cartpole.py
import torch
from torch import nn, optim
from random import sample
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.modules import QValueActor
from torchrl.objectives import DQNLoss
from torchrl.data import BinaryDiscreteTensorSpec
import gymnasium as gym
from matplotlib import pyplot as plt
env = gym.make('CartPole-v1')
model = TensorDictSequential(
TensorDictModule(nn.Linear(4, 10), in_keys=['observation'], out_keys=['in1']),
TensorDictModule(nn.ReLU(), in_keys=['in1'], out_keys=['in2']),
TensorDictModule(nn.Linear(10, 6), in_keys=['in2'], out_keys=['in3']),
TensorDictModule(nn.ReLU(), in_keys=['in3'], out_keys=['in4']),
TensorDictModule(nn.Linear(6, 2), in_keys=['in4'], out_keys=['action_value']))
actor = QValueActor(module=model, spec=BinaryDiscreteTensorSpec(1))
loss_fn = DQNLoss(actor, action_space=BinaryDiscreteTensorSpec(1))
optimizer = optim.Adam(model.parameters())
replay_buffer =
EPOCHS = 1000
MAX_STEPS = 500
FAIL_REWARD = torch.tensor(-1000.0)
GAMMA = 0.95
train_steps =
for i in range(EPOCHS):
model.train()
state = env.reset()
t = MAX_STEPS
with torch.no_grad():
for j in range(MAX_STEPS):
state = torch.tensor(state[0])
action = torch.argmax(actor(TensorDict({'observation' : state})))
next_state, reward, finish = list(map(torch.tensor, list(env.step(state))[:3]))
if finish.item():
reward = FAIL_REWARD
replay_buffer.append([state, action, next_state, finish, reward])
t = j
break
replay_buffer.append([state, action, next_state, finish, reward])
state = next_state
reward_backward = torch.FloatTensor(0.0)
for j in range(MAX_STEPS-1, -1, -1):
reward_backward, replay_buffer[j][4] = (replay_buffer[j][4] + reward_backward) * GAMMA, replay_buffer[j][4] + reward_backward
model.zero_grad()
for j in sample(replay_buffer, len(replay_buffer)//2):
loss = loss_fn(TensorDict({
"observation": j[0],
"action": j[1],
("next", "observation"): j[2],
("next", "done"): j[3],
("next", "reward"): j[4]}),[1,])
loss.backward()
model.eval()
env.reset()
state = env.reset()
t = MAX_STEPS
for j in range(MAX_STEPS):
state = torch.FloatTensor([state])
action = torch.argmax(actor(state))
next_state, reward, finish = list(map(torch.tensor, list(env.step(state))[:3]))
if finish.item():
reward = FAIL_REWARD
replay_buffer.append([state, action, next_state, finish, reward])
t = j
break
train_steps.append(t)
print(i)
plt.plot([i for i in range(EPOCHS)],train_steps)
plt.show()
에러 메세지
Traceback (most recent call last):
File "C:\Users\MSI\PycharmProjects\asdf.venv\cartpole.py", line 39, in
action = torch.argmax(actor(TensorDict({'observation' : state})))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\MSI\anaconda3\envs\cs_ar\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\MSI\anaconda3\envs\cs_ar\Lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\MSI\anaconda3\envs\cs_ar\Lib\site-packages\tensordict\nn\common.py", line 289, in wrapper
return func(_self, tensordict, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\MSI\anaconda3\envs\cs_ar\Lib\site-packages\tensordict_contextlib.py", line 126, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\MSI\anaconda3\envs\cs_ar\Lib\site-packages\tensordict\nn\utils.py", line 261, in wrapper
return func(_self, tensordict, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\MSI\anaconda3\envs\cs_ar\Lib\site-packages\tensordict\nn\sequence.py", line 428, in forward
tensordict = self._run_module(module, tensordict, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\MSI\anaconda3\envs\cs_ar\Lib\site-packages\tensordict\nn\sequence.py", line 409, in _run_module
tensordict = module(tensordict, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\MSI\anaconda3\envs\cs_ar\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\MSI\anaconda3\envs\cs_ar\Lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\MSI\anaconda3\envs\cs_ar\Lib\site-packages\tensordict\nn\common.py", line 289, in wrapper
return func(_self, tensordict, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\MSI\anaconda3\envs\cs_ar\Lib\site-packages\torchrl\modules\tensordict_module\actors.py", line 527, in forward
action = self.action_func_mappingself.action_space
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: QValueModule._binary() missing 1 required positional argument: 'support'