|
""" |
|
The environment is inspired from https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py |
|
""" |
|
|
|
import os |
|
from typing import Tuple, Dict |
|
from gym.wrappers.monitoring import video_recorder |
|
import gym |
|
import numpy as np |
|
import jax |
|
import jax.numpy as jnp |
|
import cv2 |
|
|
|
|
|
class AtariEnv: |
|
def __init__( |
|
self, |
|
name: str, |
|
) -> None: |
|
self.name = name |
|
self.state_height, self.state_width = (84, 84) |
|
self.n_stacked_frames = 4 |
|
self.n_skipped_frames = 4 |
|
|
|
self.env = gym.make( |
|
f"ALE/{self.name}-v5", |
|
full_action_space=False, |
|
frameskip=1, |
|
repeat_action_probability=0.25, |
|
render_mode="rgb_array", |
|
).env |
|
|
|
self.n_actions = self.env.action_space.n |
|
self.original_state_height, self.original_state_width, _ = self.env.observation_space._shape |
|
self.screen_buffer = [ |
|
np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8), |
|
np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8), |
|
] |
|
|
|
@property |
|
def observation(self) -> np.ndarray: |
|
return np.copy(self.state_[:, :, -1]) |
|
|
|
@property |
|
def state(self) -> np.ndarray: |
|
return jnp.array(self.state_, dtype=jnp.float32) |
|
|
|
def reset(self) -> None: |
|
self.env.reset() |
|
|
|
self.n_steps = 0 |
|
|
|
self.env.ale.getScreenGrayscale(self.screen_buffer[0]) |
|
self.screen_buffer[1].fill(0) |
|
|
|
self.state_ = np.zeros((self.state_height, self.state_width, self.n_stacked_frames), dtype=np.uint8) |
|
self.state_[:, :, -1] = self.resize() |
|
|
|
def step(self, action: jnp.int8) -> Tuple[float, bool, Dict]: |
|
reward = 0 |
|
|
|
for idx_frame in range(self.n_skipped_frames): |
|
_, reward_, terminal, _ = self.env.step(action) |
|
|
|
reward += reward_ |
|
|
|
if idx_frame >= self.n_skipped_frames - 2: |
|
t = idx_frame - (self.n_skipped_frames - 2) |
|
self.env.ale.getScreenGrayscale(self.screen_buffer[t]) |
|
|
|
if terminal: |
|
break |
|
|
|
self.state_ = np.roll(self.state_, -1, axis=-1) |
|
self.state_[:, :, -1] = self.pool_and_resize() |
|
|
|
self.n_steps += 1 |
|
|
|
return reward, terminal, _ |
|
|
|
def pool_and_resize(self) -> np.ndarray: |
|
np.maximum(self.screen_buffer[0], self.screen_buffer[1], out=self.screen_buffer[0]) |
|
|
|
return self.resize() |
|
|
|
def resize(self): |
|
return np.asarray( |
|
cv2.resize(self.screen_buffer[0], (self.state_width, self.state_height), interpolation=cv2.INTER_AREA), |
|
dtype=np.uint8, |
|
) |
|
|
|
def evaluate_one_simulation( |
|
self, |
|
q, |
|
q_params: Dict, |
|
horizon: int, |
|
eps_eval: float, |
|
exploration_key: jax.random.PRNGKey, |
|
video_path: str, |
|
) -> float: |
|
video = video_recorder.VideoRecorder( |
|
self.env, path=f"{video_path}.mp4", enabled=True if video_path is not None else False |
|
) |
|
sun_reward = 0 |
|
terminal = False |
|
self.reset() |
|
|
|
while not terminal and self.n_steps < horizon: |
|
self.env.render(mode="rgb_array") |
|
video.capture_frame() |
|
|
|
exploration_key, key = jax.random.split(exploration_key) |
|
if jax.random.uniform(key) < eps_eval: |
|
action = jax.random.choice(key, jnp.arange(self.n_actions)).astype(jnp.int8) |
|
else: |
|
action = q.best_action(q_params, self.state, key) |
|
|
|
reward, terminal, _ = self.step(action) |
|
|
|
sun_reward += reward |
|
|
|
video.close() |
|
if video_path is not None: |
|
os.remove(f"{video_path}.meta.json") |
|
|
|
return sun_reward, terminal |
|
|