PPO playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
af2eea2
import logging | |
from dataclasses import asdict, dataclass | |
from time import perf_counter | |
from typing import List, NamedTuple, Optional, TypeVar | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.optim import Adam | |
from torch.utils.tensorboard.writer import SummaryWriter | |
from rl_algo_impls.shared.algorithm import Algorithm | |
from rl_algo_impls.shared.callbacks import Callback | |
from rl_algo_impls.shared.gae import compute_advantages | |
from rl_algo_impls.shared.policy.actor_critic import ActorCritic | |
from rl_algo_impls.shared.schedule import ( | |
constant_schedule, | |
linear_schedule, | |
schedule, | |
update_learning_rate, | |
) | |
from rl_algo_impls.shared.stats import log_scalars | |
from rl_algo_impls.wrappers.vectorable_wrapper import ( | |
VecEnv, | |
single_action_space, | |
single_observation_space, | |
) | |
class TrainStepStats(NamedTuple): | |
loss: float | |
pi_loss: float | |
v_loss: float | |
entropy_loss: float | |
approx_kl: float | |
clipped_frac: float | |
val_clipped_frac: float | |
class TrainStats: | |
loss: float | |
pi_loss: float | |
v_loss: float | |
entropy_loss: float | |
approx_kl: float | |
clipped_frac: float | |
val_clipped_frac: float | |
explained_var: float | |
def __init__(self, step_stats: List[TrainStepStats], explained_var: float) -> None: | |
self.loss = np.mean([s.loss for s in step_stats]).item() | |
self.pi_loss = np.mean([s.pi_loss for s in step_stats]).item() | |
self.v_loss = np.mean([s.v_loss for s in step_stats]).item() | |
self.entropy_loss = np.mean([s.entropy_loss for s in step_stats]).item() | |
self.approx_kl = np.mean([s.approx_kl for s in step_stats]).item() | |
self.clipped_frac = np.mean([s.clipped_frac for s in step_stats]).item() | |
self.val_clipped_frac = np.mean([s.val_clipped_frac for s in step_stats]).item() | |
self.explained_var = explained_var | |
def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None: | |
for name, value in asdict(self).items(): | |
tb_writer.add_scalar(f"losses/{name}", value, global_step=global_step) | |
def __repr__(self) -> str: | |
return " | ".join( | |
[ | |
f"Loss: {round(self.loss, 2)}", | |
f"Pi L: {round(self.pi_loss, 2)}", | |
f"V L: {round(self.v_loss, 2)}", | |
f"E L: {round(self.entropy_loss, 2)}", | |
f"Apx KL Div: {round(self.approx_kl, 2)}", | |
f"Clip Frac: {round(self.clipped_frac, 2)}", | |
f"Val Clip Frac: {round(self.val_clipped_frac, 2)}", | |
] | |
) | |
PPOSelf = TypeVar("PPOSelf", bound="PPO") | |
class PPO(Algorithm): | |
def __init__( | |
self, | |
policy: ActorCritic, | |
env: VecEnv, | |
device: torch.device, | |
tb_writer: SummaryWriter, | |
learning_rate: float = 3e-4, | |
learning_rate_decay: str = "none", | |
n_steps: int = 2048, | |
batch_size: int = 64, | |
n_epochs: int = 10, | |
gamma: float = 0.99, | |
gae_lambda: float = 0.95, | |
clip_range: float = 0.2, | |
clip_range_decay: str = "none", | |
clip_range_vf: Optional[float] = None, | |
clip_range_vf_decay: str = "none", | |
normalize_advantage: bool = True, | |
ent_coef: float = 0.0, | |
ent_coef_decay: str = "none", | |
vf_coef: float = 0.5, | |
ppo2_vf_coef_halving: bool = False, | |
max_grad_norm: float = 0.5, | |
sde_sample_freq: int = -1, | |
update_advantage_between_epochs: bool = True, | |
update_returns_between_epochs: bool = False, | |
gamma_end: Optional[float] = None, | |
) -> None: | |
super().__init__(policy, env, device, tb_writer) | |
self.policy = policy | |
self.get_action_mask = getattr(env, "get_action_mask", None) | |
self.gamma_schedule = ( | |
linear_schedule(gamma, gamma_end) | |
if gamma_end is not None | |
else constant_schedule(gamma) | |
) | |
self.gae_lambda = gae_lambda | |
self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7) | |
self.lr_schedule = schedule(learning_rate_decay, learning_rate) | |
self.max_grad_norm = max_grad_norm | |
self.clip_range_schedule = schedule(clip_range_decay, clip_range) | |
self.clip_range_vf_schedule = None | |
if clip_range_vf: | |
self.clip_range_vf_schedule = schedule(clip_range_vf_decay, clip_range_vf) | |
if normalize_advantage: | |
assert ( | |
env.num_envs * n_steps > 1 and batch_size > 1 | |
), f"Each minibatch must be larger than 1 to support normalization" | |
self.normalize_advantage = normalize_advantage | |
self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef) | |
self.vf_coef = vf_coef | |
self.ppo2_vf_coef_halving = ppo2_vf_coef_halving | |
self.n_steps = n_steps | |
self.batch_size = batch_size | |
self.n_epochs = n_epochs | |
self.sde_sample_freq = sde_sample_freq | |
self.update_advantage_between_epochs = update_advantage_between_epochs | |
self.update_returns_between_epochs = update_returns_between_epochs | |
def learn( | |
self: PPOSelf, | |
train_timesteps: int, | |
callbacks: Optional[List[Callback]] = None, | |
total_timesteps: Optional[int] = None, | |
start_timesteps: int = 0, | |
) -> PPOSelf: | |
if total_timesteps is None: | |
total_timesteps = train_timesteps | |
assert start_timesteps + train_timesteps <= total_timesteps | |
epoch_dim = (self.n_steps, self.env.num_envs) | |
step_dim = (self.env.num_envs,) | |
obs_space = single_observation_space(self.env) | |
act_space = single_action_space(self.env) | |
act_shape = self.policy.action_shape | |
next_obs = self.env.reset() | |
next_action_masks = self.get_action_mask() if self.get_action_mask else None | |
next_episode_starts = np.full(step_dim, True, dtype=np.bool_) | |
obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore | |
actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype) # type: ignore | |
rewards = np.zeros(epoch_dim, dtype=np.float32) | |
episode_starts = np.zeros(epoch_dim, dtype=np.bool_) | |
values = np.zeros(epoch_dim, dtype=np.float32) | |
logprobs = np.zeros(epoch_dim, dtype=np.float32) | |
action_masks = ( | |
np.zeros( | |
(self.n_steps,) + next_action_masks.shape, dtype=next_action_masks.dtype | |
) | |
if next_action_masks is not None | |
else None | |
) | |
timesteps_elapsed = start_timesteps | |
while timesteps_elapsed < start_timesteps + train_timesteps: | |
start_time = perf_counter() | |
progress = timesteps_elapsed / total_timesteps | |
ent_coef = self.ent_coef_schedule(progress) | |
learning_rate = self.lr_schedule(progress) | |
update_learning_rate(self.optimizer, learning_rate) | |
pi_clip = self.clip_range_schedule(progress) | |
gamma = self.gamma_schedule(progress) | |
chart_scalars = { | |
"learning_rate": self.optimizer.param_groups[0]["lr"], | |
"ent_coef": ent_coef, | |
"pi_clip": pi_clip, | |
"gamma": gamma, | |
} | |
if self.clip_range_vf_schedule: | |
v_clip = self.clip_range_vf_schedule(progress) | |
chart_scalars["v_clip"] = v_clip | |
else: | |
v_clip = None | |
log_scalars(self.tb_writer, "charts", chart_scalars, timesteps_elapsed) | |
self.policy.eval() | |
self.policy.reset_noise() | |
for s in range(self.n_steps): | |
timesteps_elapsed += self.env.num_envs | |
if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0: | |
self.policy.reset_noise() | |
obs[s] = next_obs | |
episode_starts[s] = next_episode_starts | |
if action_masks is not None: | |
action_masks[s] = next_action_masks | |
( | |
actions[s], | |
values[s], | |
logprobs[s], | |
clamped_action, | |
) = self.policy.step(next_obs, action_masks=next_action_masks) | |
next_obs, rewards[s], next_episode_starts, _ = self.env.step( | |
clamped_action | |
) | |
next_action_masks = ( | |
self.get_action_mask() if self.get_action_mask else None | |
) | |
self.policy.train() | |
b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device) # type: ignore | |
b_actions = torch.tensor(actions.reshape((-1,) + act_shape)).to( # type: ignore | |
self.device | |
) | |
b_logprobs = torch.tensor(logprobs.reshape(-1)).to(self.device) | |
b_action_masks = ( | |
torch.tensor(action_masks.reshape((-1,) + next_action_masks.shape[1:])).to( # type: ignore | |
self.device | |
) | |
if action_masks is not None | |
else None | |
) | |
y_pred = values.reshape(-1) | |
b_values = torch.tensor(y_pred).to(self.device) | |
step_stats = [] | |
# Define variables that will definitely be set through the first epoch | |
advantages: np.ndarray = None # type: ignore | |
b_advantages: torch.Tensor = None # type: ignore | |
y_true: np.ndarray = None # type: ignore | |
b_returns: torch.Tensor = None # type: ignore | |
for e in range(self.n_epochs): | |
if e == 0 or self.update_advantage_between_epochs: | |
advantages = compute_advantages( | |
rewards, | |
values, | |
episode_starts, | |
next_episode_starts, | |
next_obs, | |
self.policy, | |
gamma, | |
self.gae_lambda, | |
) | |
b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device) | |
if e == 0 or self.update_returns_between_epochs: | |
returns = advantages + values | |
y_true = returns.reshape(-1) | |
b_returns = torch.tensor(y_true).to(self.device) | |
b_idxs = torch.randperm(len(b_obs)) | |
# Only record last epoch's stats | |
step_stats.clear() | |
for i in range(0, len(b_obs), self.batch_size): | |
self.policy.reset_noise(self.batch_size) | |
mb_idxs = b_idxs[i : i + self.batch_size] | |
mb_obs = b_obs[mb_idxs] | |
mb_actions = b_actions[mb_idxs] | |
mb_values = b_values[mb_idxs] | |
mb_logprobs = b_logprobs[mb_idxs] | |
mb_action_masks = ( | |
b_action_masks[mb_idxs] if b_action_masks is not None else None | |
) | |
mb_adv = b_advantages[mb_idxs] | |
if self.normalize_advantage: | |
mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8) | |
mb_returns = b_returns[mb_idxs] | |
new_logprobs, entropy, new_values = self.policy( | |
mb_obs, mb_actions, action_masks=mb_action_masks | |
) | |
logratio = new_logprobs - mb_logprobs | |
ratio = torch.exp(logratio) | |
clipped_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip) | |
pi_loss = torch.max(-ratio * mb_adv, -clipped_ratio * mb_adv).mean() | |
v_loss_unclipped = (new_values - mb_returns) ** 2 | |
if v_clip: | |
v_loss_clipped = ( | |
mb_values | |
+ torch.clamp(new_values - mb_values, -v_clip, v_clip) | |
- mb_returns | |
) ** 2 | |
v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean() | |
else: | |
v_loss = v_loss_unclipped.mean() | |
if self.ppo2_vf_coef_halving: | |
v_loss *= 0.5 | |
entropy_loss = -entropy.mean() | |
loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss | |
self.optimizer.zero_grad() | |
loss.backward() | |
nn.utils.clip_grad_norm_( | |
self.policy.parameters(), self.max_grad_norm | |
) | |
self.optimizer.step() | |
with torch.no_grad(): | |
approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item() | |
clipped_frac = ( | |
((ratio - 1).abs() > pi_clip) | |
.float() | |
.mean() | |
.cpu() | |
.numpy() | |
.item() | |
) | |
val_clipped_frac = ( | |
((new_values - mb_values).abs() > v_clip) | |
.float() | |
.mean() | |
.cpu() | |
.numpy() | |
.item() | |
if v_clip | |
else 0 | |
) | |
step_stats.append( | |
TrainStepStats( | |
loss.item(), | |
pi_loss.item(), | |
v_loss.item(), | |
entropy_loss.item(), | |
approx_kl, | |
clipped_frac, | |
val_clipped_frac, | |
) | |
) | |
var_y = np.var(y_true).item() | |
explained_var = ( | |
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y | |
) | |
TrainStats(step_stats, explained_var).write_to_tensorboard( | |
self.tb_writer, timesteps_elapsed | |
) | |
end_time = perf_counter() | |
rollout_steps = self.n_steps * self.env.num_envs | |
self.tb_writer.add_scalar( | |
"train/steps_per_second", | |
rollout_steps / (end_time - start_time), | |
timesteps_elapsed, | |
) | |
if callbacks: | |
if not all( | |
c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks | |
): | |
logging.info( | |
f"Callback terminated training at {timesteps_elapsed} timesteps" | |
) | |
break | |
return self | |