mRoszak commited on
Commit
f74111e
·
1 Parent(s): f4f003f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +48 -3
README.md CHANGED
@@ -26,12 +26,57 @@ This is a trained model of a **A2C** agent playing **CartPole-v1**
26
  using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
27
 
28
  ## Usage (with Stable-baselines3)
29
- TODO: Add your code
30
 
31
 
32
  ```python
33
- from stable_baselines3 import ...
34
- from huggingface_sb3 import load_from_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  ...
37
  ```
 
26
  using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
27
 
28
  ## Usage (with Stable-baselines3)
 
29
 
30
 
31
  ```python
32
+ import gym
33
+
34
+ from stable_baselines3 import A2C
35
+ from stable_baselines3.common.env_util import make_vec_env
36
+ from huggingface_sb3 import package_to_hub
37
+ import wandb
38
+ from wandb.integration.sb3 import WandbCallback
39
+
40
+ # Parallel environments
41
+ env = gym.make("CartPole-v1")
42
+ eval_env = gym.make("CartPole-v1")
43
+ config = {
44
+ "policy_type": "MlpPolicy",
45
+ "total_timesteps": 25000,
46
+ "env_id": "CartPole-v1",
47
+ }
48
+
49
+ run = wandb.init(
50
+ project="cart_pole",
51
+ config=config,
52
+ sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
53
+ # monitor_gym=True, # auto-upload the videos of agents playing the game
54
+ # save_code=True, # optional
55
+ )
56
+
57
+
58
+ model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")
59
+ model.learn(
60
+ total_timesteps=config["total_timesteps"],
61
+ callback=WandbCallback(
62
+ model_save_path=f"models/{run.id}",
63
+ verbose=2,
64
+ ),
65
+ )
66
+ run.finish()
67
+
68
+
69
+
70
+ model.save("a2c_Cart_Pole")
71
+
72
+
73
+ package_to_hub(model=model,
74
+ model_name="a2c_Cart_Pole",
75
+ model_architecture="A2C",
76
+ env_id="CartPole-v1",
77
+ eval_env=eval_env,
78
+ repo_id="mRoszak/A2C_Cart_Pole",
79
+ commit_message="Test commit")
80
 
81
  ...
82
  ```