안녕하세요? TorchRL PPO을 공부하고 있는 배움지기입니다.
Inverted pendulum 환경에서 Tutorial로 구성된 Reinforcement Learning (PPO) with TorchRL Tutorial을 CartPole_v1 환경으로 변경하여 학습하고자 하는데,
obseration states와 action간에 일관성이 없고, states가 한방향으로만 작동하고 있습니다.
혹시 어디가 문제인지 문의드립니다.
*************************************************
!pip3 install torchrl
!pip3 install "gymnasium[classic-control]"
!pip3 install tqdm
import warnings
warnings.filterwarnings("ignore")
from torch import multiprocessing
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
from torch.distributions import Categorical
import tensordict
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
import torch.nn.functional as F
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter, TransformedEnv)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.ppo import GAE
from tqdm import tqdm
from torchrl.envs.transforms import AutoResetTransform
from torchrl.envs import AutoResetEnv
from tensordict import TensorDict
#from torchrl.collectors import SyncDataCollector
from torchrl.envs.transforms import Transform
from torchrl.data import DiscreteTensorSpec
class ActionTracer(Transform):
def \_call(self, td):
print("POLICY action:", td\["action"\], td\["action"\].shape, td\["action"\].dtype)
return td
class EnvActionTracer(Transform):
def \_step(self, td):
print("ENV step got:", td\["action"\], td\["action"\].shape, td\["action"\].dtype)
return td
class DebugPolicy(torch.nn.Module):
def \__init_\_(self, policy):
super().\__init_\_()
self.policy = policy
def forward(self, td):
td = self.policy(td)
print(">>> policy action:", td\["action"\])
return td
class DebugGymEnv(GymEnv):
def \_step(self, tensordict):
print(">>> raw action to gym:", tensordict\["action"\])
return super().\_step(tensordict)
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
num_cells = 256 # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0
frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M
total_frames = 50_000
sub_batch_size = 64 # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10 # optimization steps per batch of data collected
clip_epsilon = (
0.2 # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
#entropy_eps = 1e-4
entropy_eps = 0.02
#base_env = GymEnv("CartPole-v1",device="cpu")
base_env = DebugGymEnv("CartPole-v1", device="cpu")
env = TransformedEnv(
base_env,
Compose(
\# normalize observations
#ObservationNorm(in_keys=\["observation"\]),
DoubleToFloat(),
StepCounter(),
#EnvActionTracer(), # 반드시 StepCounter 뒤
#EnvActionTracer(), # env 입력
),
)
print(base_env._env.spec.max_episode_steps)
#env.transform[0].init_stats(num_iter=1200, reduce_dim=0, cat_dim=0)
#print("normalization constant shape:", env.transform[0].loc.shape)
env.action_spec = DiscreteTensorSpec(
n=2,
shape=torch.Size(\[\]),
dtype=torch.int64,
device="cpu",
)
check_env_specs(env)
print("env.action_spec: ", env.action_spec)
print("observation_spec:", env.observation_spec)
print("reward_spec:", env.reward_spec)
print("input_spec:", env.input_spec)
print("action_spec (as defined by input_spec):", env.action_spec)
#class SaveAppliedAction(Transform):
# def _call(self, tensordict):
# tensordict.set(
# "applied_action",
# tensordict["action"].clone()
# )
# return tensordict
#env.append_transform(SaveAppliedAction())
rollout = env.rollout(3)
print("rollout of three steps:", rollout)
print("Shape of the rollout TensorDict:", rollout.batch_size)
actor_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(env.action_space.n, device=device),
)
actor_module = TensorDictModule(
actor_net, in_keys=\["observation"\], out_keys=\["logits"\],
)
policy_module = ProbabilisticActor(
module=actor_module,
#spec=env.action_spec,
spec=None,
in_keys=\["logits"\],
out_keys=\["action"\],
distribution_class=Categorical,
return_log_prob=True,
#cache_dist=True,
\# we'll need the log-prob for the numerator of the importance weights
)
value_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(1, device=device),
)
value_module = ValueOperator(
module=value_net, in_keys=\["observation"\],
)
print("Running policy:", policy_module(env.reset()))
print("Running value function:", value_module(env.reset()))
policy = DebugPolicy(policy_module)
collector = SyncDataCollector(
env,
#policy_module,
policy,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
split_trajs=False,
device=device,
)
replay_buffer = ReplayBuffer(
storage = LazyTensorStorage(max_size=frames_per_batch),
sampler = SamplerWithoutReplacement(),
)
advantage_module = GAE(
gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True, device=device,
)
loss_module = ClipPPOLoss(
actor_network=policy_module,
critic_network=value_module,
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
entropy_coef=entropy_eps,
\# these keys match by default but we set this for completeness
critic_coef=1.0,
loss_critic_type="smooth_l1",
)
optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optim, total_frames // frames_per_batch, 0.0
)
logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""
for i, tensordict_data in enumerate(collector):
#############################
print("i: ", i)
print(tensordict_data.keys())
print(tensordict_data\["logits"\].std())
print("env max steps:", env.base_env.\_max_episode_steps)
print("Check_reward sum: ", tensordict_data\["next", "reward"\].sum().item())
print("DDstep_count:", tensordict_data\['next', 'step_count'\].max().item())
\# tensordict_data에서 필요한 텐서 가져오기
with torch.no_grad():
values = value_module(tensordict_data\['observation'\])
with torch.no_grad():
values1 = value_module(tensordict_data\['next', 'observation'\])
actions = tensordict_data\['action'\]
states = tensordict_data\['next', 'observation'\]
rewards = tensordict_data\['next', 'reward'\].squeeze(-1)
terminated = tensordict_data\['next', 'terminated'\]
truncated = tensordict_data\['next', 'truncated'\]
\# reshape (1차원으로 flatten)
states_flat = states.reshape(-1, states.shape\[-1\])
actions_flat = actions.reshape(-1)
rewards_flat = rewards.reshape(-1)
terminated_flat = terminated.reshape(-1)
truncated_flat = truncated.reshape(-1)
values_flat = values.reshape(-1)
\# 각 step별 출력
#print(type(actions_flat))
#print(tensordict_data\["action"\])
print("Action 0 : to Left, 1 : to right")
print("Step | Pos | Pos_dot | Theta | Theta_dot | Action | Reward | Done | Truncated | V(s)")
print("-"\*60)
for j in range(states_flat.shape\[0\]):
theta = states_flat\[j, 0\].item() # pole angle
theta_dot = states_flat\[j, 1\].item() # pole angular velocity
pos = states_flat\[j, 2\].item() # pole pos
pos_dot = states_flat\[j, 3\].item() # pole velocity
action = actions_flat\[j\].item()
reward = rewards_flat\[j\].item()
done = terminated_flat\[j\].item()
trunc = truncated_flat\[j\].item()
value = values_flat\[j\].item()
print(f"{j:4d} | {theta:+.3f} | {theta_dot:+.3f} | {pos:+.3f} | {pos_dot:+.3f} | {action:2d} | "
f"{reward:+.3f} | {done} | {trunc} | {value:+.3f}")
\# 정책 확률 및 엔트로피 확인
logits = tensordict_data\['logits'\].reshape(-1, env.action_space.n)
probs = torch.softmax(logits, dim=-1)
dist = Categorical(logits=logits)
entropy = dist.entropy().mean()
print("\\nPolicy check:")
print("Action probabilities mean:", probs.mean(dim=0))
print("Policy entropy:", entropy.item())
\# Advantage 확인
if 'advantage' in tensordict_data.keys():
adv = tensordict_data\['advantage'\].reshape(-1)
print("Advantage mean:", adv.mean().item())
print("Advantage std :", adv.std().item())
#############################
for \_ in range(num_epochs):
advantage_module(tensordict_data)
data_view = tensordict_data.reshape(-1)
replay_buffer.extend(data_view.cpu())
for \_ in range(frames_per_batch // sub_batch_size):
subdata = replay_buffer.sample(sub_batch_size)
loss_vals = loss_module(subdata.to(device))
loss_value = (
loss_vals\["loss_objective"\]
+ loss_vals\["loss_critic"\]
+ loss_vals\["loss_entropy"\]
)
#print("loss_value: ", loss_value)
\# Optimization: backward, grad clipping and optimization step
loss_value.backward()
\# this is not strictly mandatory but it's good practice to keep
\# your gradient norm bounded
torch.nn.utils.clip_grad_norm\_(loss_module.parameters(), max_grad_norm)
optim.step()
optim.zero_grad()
logs\["reward"\].append(tensordict_data\["next", "reward"\].mean().item())
pbar.update(tensordict_data.numel())
cum_reward_str = (
f"average reward={logs\['reward'\]\[-1\]: 4.4f} (init={logs\['reward'\]\[0\]: 4.4f})"
)
logs\["step_count"\].append(tensordict_data\["step_count"\].max().item())
stepcount_str = f"step count (max): {logs\['step_count'\]\[-1\]}"
logs\["lr"\].append(optim.param_groups\[0\]\["lr"\])
lr_str = f"lr policy: {logs\['lr'\]\[-1\]: 4.4f}"
if i % 10 == 0:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_rollout = env.rollout(1000, policy_module)
logs\["eval reward"\].append(eval_rollout\["next", "reward"\].mean().item())
logs\["eval reward (sum)"\].append(
eval_rollout\["next", "reward"\].sum().item()
)
logs\["eval step_count"\].append(eval_rollout\["step_count"\].max().item())
eval_str = (
f"eval cumulative reward: {logs\['eval reward (sum)'\]\[-1\]: 4.4f} "
f"(init: {logs\['eval reward (sum)'\]\[0\]: 4.4f}), "
f"eval step-count: {logs\['eval step_count'\]\[-1\]}"
)
del eval_rollout
pbar.set_description(", ".join(\[eval_str, cum_reward_str, stepcount_str, lr_str\]))
scheduler.step()