Butanium commited on
Commit
68b09a5
1 Parent(s): a326370

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +98 -2
README.md CHANGED
@@ -3,11 +3,107 @@ pipeline_tag: reinforcement-learning
3
  tags:
4
  - ppo
5
  ---
 
 
 
 
6
 
7
- PPO agents trained in a selfplay settings. The agent were trained on observation as left player only. This repo include checkpoints collected during training for
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  4 experiments:
9
  - Shared weights for actor and critic
10
  - No shared weights
11
  - Resume training for extra steps for both shared and no shared setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- Please check our [wandb report](https://wandb.ai/dumas/SPAR_RL_ELK/) for more details
 
3
  tags:
4
  - ppo
5
  ---
6
+ # Environment
7
+ Multiplayer pong_v3 from PettingZoo with :
8
+ - 4 stacked frame
9
+ - Agent is trained to predict left agent policy (observation is mirrored for right agent)
10
 
11
+ ```py
12
+ def pong_obs_modification(obs, _space, player_id):
13
+ obs[:9, :, :] = 0
14
+ if "second" in player_id:
15
+ # Mirror the image
16
+ obs = obs[:, ::-1, :]
17
+ return obs
18
+
19
+
20
+ def get_env(args, run_name):
21
+ env = importlib.import_module(f"pettingzoo.atari.{args.env_id}").parallel_env()
22
+ env = ss.max_observation_v0(env, 2)
23
+ env = ss.frame_skip_v0(env, 4)
24
+ env = ss.clip_reward_v0(env, lower_bound=-1, upper_bound=1)
25
+ env = ss.color_reduction_v0(env, mode="B")
26
+ env = ss.resize_v1(env, x_size=84, y_size=84)
27
+ env = ss.frame_stack_v1(env, 4)
28
+ # Remove the score from the observation
29
+ if "pong" in args.env_id:
30
+ env = ss.lambda_wrappers.observation_lambda_v0(
31
+ env,
32
+ pong_obs_modification,
33
+ )
34
+ # env = ss.agent_indicator_v0(env, type_only=False)
35
+ env = ss.pettingzoo_env_to_vec_env_v1(env)
36
+ envs = ss.concat_vec_envs_v1(env, args.num_envs // 2, num_cpus=0, base_class="gym")
37
+ envs.single_observation_space = envs.observation_space
38
+ envs.single_action_space = envs.action_space
39
+ envs.is_vector_env = True
40
+ envs = gym.wrappers.RecordEpisodeStatistics(envs)
41
+ if args.capture_video:
42
+ envs = gym.wrappers.RecordVideo(envs, f"videos/{run_name}")
43
+ assert isinstance(
44
+ envs.single_action_space, gym.spaces.Discrete
45
+ ), "only discrete action space is supported"
46
+ return envs
47
+ ```
48
+
49
+ # Experiment
50
+ PPO agents trained in a selfplay settings. This repo includes checkpoints collected during training for
51
  4 experiments:
52
  - Shared weights for actor and critic
53
  - No shared weights
54
  - Resume training for extra steps for both shared and no shared setup
55
+ Please check our [wandb report](https://wandb.ai/dumas/SPAR_RL_ELK/) for more details and the training code on [our GitHub](https://github.com/Butanium/cleanrl/blob/master/multiplayer_pong/ppo_pettingzoo_ma_atari.py)
56
+
57
+ # Model architecture
58
+ ```py
59
+ def atari_network(orth_init=False):
60
+ init = layer_init if orth_init else lambda m: m
61
+ return nn.Sequential(
62
+ init(nn.Conv2d(4, 32, 8, stride=4)),
63
+ nn.ReLU(),
64
+ init(nn.Conv2d(32, 64, 4, stride=2)),
65
+ nn.ReLU(),
66
+ init(nn.Conv2d(64, 64, 3, stride=1)),
67
+ nn.ReLU(),
68
+ nn.Flatten(),
69
+ init(nn.Linear(64 * 7 * 7, 512)),
70
+ nn.ReLU(),
71
+ )
72
+
73
+ class Agent(nn.Module):
74
+ def __init__(self, envs, share_network=False):
75
+ super().__init__()
76
+ self.actor_network = atari_network(orth_init=True)
77
+ self.share_network = share_network
78
+ if share_network:
79
+ self.critic_network = self.actor_network
80
+ else:
81
+ self.critic_network = atari_network(orth_init=True)
82
+ self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01)
83
+ self.critic = layer_init(nn.Linear(512, 1), std=1)
84
+
85
+ def get_value(self, x):
86
+ x = x.clone()
87
+ x[:, :, :, [0, 1, 2, 3]] /= 255.0
88
+ return self.critic(self.critic_network(x.permute((0, 3, 1, 2))))
89
+
90
+ def get_action_and_value(self, x, action=None):
91
+ x = x.clone()
92
+ x[:, :, :, [0, 1, 2, 3]] /= 255.0
93
+ logits = self.actor(self.actor_network(x.permute((0, 3, 1, 2))))
94
+ probs = Categorical(logits=logits)
95
+ if action is None:
96
+ action = probs.sample()
97
+ return (
98
+ action,
99
+ probs.log_prob(action),
100
+ probs.entropy(),
101
+ self.critic(self.critic_network(x.permute((0, 3, 1, 2)))),
102
+ )
103
+
104
+ def load(self, path):
105
+ self.load_state_dict(torch.load(path))
106
+ if self.share_network:
107
+ self.critic_network = self.actor_network
108
+ ```
109