A2C Agent playing Pendulum-v1
This is a trained model of a A2C agent playing Pendulum-v1 using the stable-baselines3 library.
Usage (with Stable-baselines3)
from huggingface_sb3 import load_from_hub
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import VecNormalize
# Download checkpoint and stats
env_id = "Pendulum-v1"
checkpoint = load_from_hub(f"araffin/a2c-{env_id}", f"a2c-{env_id}.zip")
vec_normalize_stats = load_from_hub(f"araffin/a2c-{env_id}", f"vec_normalize.pkl")
# Load the model
model = A2C.load(checkpoint)
env = make_vec_env(env_id, n_envs=1)
env = VecNormalize.load(vec_normalize_stats, env)
# do not update them at test time
env.training = False
# reward normalization is not needed at test time
env.norm_reward = False
# Evaluate
print("Evaluating model")
mean_reward, std_reward = evaluate_policy(
model,
env,
n_eval_episodes=20,
deterministic=True,
)
print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
# Start a new episode
obs = env.reset()
try:
while True:
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = env.step(action)
env.render()
except KeyboardInterrupt:
pass
Training Code
from huggingface_sb3 import package_to_hub
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize, sync_envs_normalization
# Create the environment
env_id = "Pendulum-v1"
env = make_vec_env(env_id, n_envs=8)
# Normalize
env = VecNormalize(env, gamma=0.9)
# Create the evaluation env (could be used in `EvalCallback`)
eval_env = make_vec_env(env_id, n_envs=1)
eval_env = VecNormalize(eval_env, gamma=0.9, training=False, norm_reward=False)
# Instantiate the agent
model = A2C(
"MlpPolicy",
env,
n_steps=8,
gamma=0.9,
gae_lambda=0.9,
use_sde=True,
policy_kwargs=dict(log_std_init=-2),
verbose=1,
)
# Train the agent
try:
model.learn(total_timesteps=int(1e6))
except KeyboardInterrupt:
pass
# Synchronize stats (done automatically in `EvalCallback`)
sync_envs_normalization(env, eval_env)
package_to_hub(
model=model,
model_name=f"a2c-{env_id}",
model_architecture="A2C",
env_id=env_id,
eval_env=eval_env,
repo_id=f"araffin/a2c-{env_id}",
commit_message="Initial commit",
)
- Downloads last month
- 1
Evaluation results
- mean_reward on Pendulum-v1self-reported-141.19 +/- 122.27