|
--- |
|
library_name: stable-baselines3 |
|
tags: |
|
- CartPole-v1 |
|
- deep-reinforcement-learning |
|
- reinforcement-learning |
|
- stable-baselines3 |
|
model-index: |
|
- name: A2C |
|
results: |
|
- task: |
|
type: reinforcement-learning |
|
name: reinforcement-learning |
|
dataset: |
|
name: CartPole-v1 |
|
type: CartPole-v1 |
|
metrics: |
|
- type: mean_reward |
|
value: 9.80 +/- 0.60 |
|
name: mean_reward |
|
verified: false |
|
--- |
|
|
|
# **A2C** Agent playing **CartPole-v1** |
|
This is a trained model of a **A2C** agent playing **CartPole-v1** |
|
using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3). |
|
|
|
## Usage (with Stable-baselines3) |
|
|
|
|
|
```python |
|
import gym |
|
|
|
from stable_baselines3 import A2C |
|
from stable_baselines3.common.env_util import make_vec_env |
|
from huggingface_sb3 import package_to_hub |
|
import wandb |
|
from wandb.integration.sb3 import WandbCallback |
|
|
|
# Parallel environments |
|
env = gym.make("CartPole-v1") |
|
eval_env = gym.make("CartPole-v1") |
|
config = { |
|
"policy_type": "MlpPolicy", |
|
"total_timesteps": 25000, |
|
"env_id": "CartPole-v1", |
|
} |
|
|
|
run = wandb.init( |
|
project="cart_pole", |
|
config=config, |
|
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics |
|
# monitor_gym=True, # auto-upload the videos of agents playing the game |
|
# save_code=True, # optional |
|
) |
|
|
|
|
|
model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}") |
|
model.learn( |
|
total_timesteps=config["total_timesteps"], |
|
callback=WandbCallback( |
|
model_save_path=f"models/{run.id}", |
|
verbose=2, |
|
), |
|
) |
|
run.finish() |
|
|
|
|
|
|
|
model.save("a2c_Cart_Pole") |
|
|
|
|
|
package_to_hub(model=model, |
|
model_name="a2c_Cart_Pole", |
|
model_architecture="A2C", |
|
env_id="CartPole-v1", |
|
eval_env=eval_env, |
|
repo_id="mRoszak/A2C_Cart_Pole", |
|
commit_message="Test commit") |
|
|
|
... |
|
``` |
|
|