Update README.md
Browse files
README.md
CHANGED
@@ -26,12 +26,57 @@ This is a trained model of a **A2C** agent playing **CartPole-v1**
|
|
26 |
using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
|
27 |
|
28 |
## Usage (with Stable-baselines3)
|
29 |
-
TODO: Add your code
|
30 |
|
31 |
|
32 |
```python
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
...
|
37 |
```
|
|
|
26 |
using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
|
27 |
|
28 |
## Usage (with Stable-baselines3)
|
|
|
29 |
|
30 |
|
31 |
```python
|
32 |
+
import gym
|
33 |
+
|
34 |
+
from stable_baselines3 import A2C
|
35 |
+
from stable_baselines3.common.env_util import make_vec_env
|
36 |
+
from huggingface_sb3 import package_to_hub
|
37 |
+
import wandb
|
38 |
+
from wandb.integration.sb3 import WandbCallback
|
39 |
+
|
40 |
+
# Parallel environments
|
41 |
+
env = gym.make("CartPole-v1")
|
42 |
+
eval_env = gym.make("CartPole-v1")
|
43 |
+
config = {
|
44 |
+
"policy_type": "MlpPolicy",
|
45 |
+
"total_timesteps": 25000,
|
46 |
+
"env_id": "CartPole-v1",
|
47 |
+
}
|
48 |
+
|
49 |
+
run = wandb.init(
|
50 |
+
project="cart_pole",
|
51 |
+
config=config,
|
52 |
+
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
|
53 |
+
# monitor_gym=True, # auto-upload the videos of agents playing the game
|
54 |
+
# save_code=True, # optional
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")
|
59 |
+
model.learn(
|
60 |
+
total_timesteps=config["total_timesteps"],
|
61 |
+
callback=WandbCallback(
|
62 |
+
model_save_path=f"models/{run.id}",
|
63 |
+
verbose=2,
|
64 |
+
),
|
65 |
+
)
|
66 |
+
run.finish()
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
model.save("a2c_Cart_Pole")
|
71 |
+
|
72 |
+
|
73 |
+
package_to_hub(model=model,
|
74 |
+
model_name="a2c_Cart_Pole",
|
75 |
+
model_architecture="A2C",
|
76 |
+
env_id="CartPole-v1",
|
77 |
+
eval_env=eval_env,
|
78 |
+
repo_id="mRoszak/A2C_Cart_Pole",
|
79 |
+
commit_message="Test commit")
|
80 |
|
81 |
...
|
82 |
```
|