File size: 3,799 Bytes
4cb4fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
"""
The environment is inspired from https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py
"""
import os
from typing import Tuple, Dict
from gym.wrappers.monitoring import video_recorder
import gym
import numpy as np
import jax
import jax.numpy as jnp
import cv2
class AtariEnv:
def __init__(
self,
name: str,
) -> None:
self.name = name
self.state_height, self.state_width = (84, 84)
self.n_stacked_frames = 4
self.n_skipped_frames = 4
self.env = gym.make(
f"ALE/{self.name}-v5",
full_action_space=False,
frameskip=1,
repeat_action_probability=0.25,
render_mode="rgb_array",
).env
self.n_actions = self.env.action_space.n
self.original_state_height, self.original_state_width, _ = self.env.observation_space._shape
self.screen_buffer = [
np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
]
@property
def observation(self) -> np.ndarray:
return np.copy(self.state_[:, :, -1])
@property
def state(self) -> np.ndarray:
return jnp.array(self.state_, dtype=jnp.float32)
def reset(self) -> None:
self.env.reset()
self.n_steps = 0
self.env.ale.getScreenGrayscale(self.screen_buffer[0])
self.screen_buffer[1].fill(0)
self.state_ = np.zeros((self.state_height, self.state_width, self.n_stacked_frames), dtype=np.uint8)
self.state_[:, :, -1] = self.resize()
def step(self, action: jnp.int8) -> Tuple[float, bool, Dict]:
reward = 0
for idx_frame in range(self.n_skipped_frames):
_, reward_, terminal, _ = self.env.step(action)
reward += reward_
if idx_frame >= self.n_skipped_frames - 2:
t = idx_frame - (self.n_skipped_frames - 2)
self.env.ale.getScreenGrayscale(self.screen_buffer[t])
if terminal:
break
self.state_ = np.roll(self.state_, -1, axis=-1)
self.state_[:, :, -1] = self.pool_and_resize()
self.n_steps += 1
return reward, terminal, _
def pool_and_resize(self) -> np.ndarray:
np.maximum(self.screen_buffer[0], self.screen_buffer[1], out=self.screen_buffer[0])
return self.resize()
def resize(self):
return np.asarray(
cv2.resize(self.screen_buffer[0], (self.state_width, self.state_height), interpolation=cv2.INTER_AREA),
dtype=np.uint8,
)
def evaluate_one_simulation(
self,
q,
q_params: Dict,
horizon: int,
eps_eval: float,
exploration_key: jax.random.PRNGKey,
video_path: str,
) -> float:
video = video_recorder.VideoRecorder(
self.env, path=f"{video_path}.mp4", enabled=True if video_path is not None else False
)
sun_reward = 0
terminal = False
self.reset()
while not terminal and self.n_steps < horizon:
self.env.render(mode="rgb_array")
video.capture_frame()
exploration_key, key = jax.random.split(exploration_key)
if jax.random.uniform(key) < eps_eval:
action = jax.random.choice(key, jnp.arange(self.n_actions)).astype(jnp.int8)
else:
action = q.best_action(q_params, self.state, key)
reward, terminal, _ = self.step(action)
sun_reward += reward
video.close()
if video_path is not None:
os.remove(f"{video_path}.meta.json")
return sun_reward, terminal
|