File size: 3,799 Bytes
4cb4fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
The environment is inspired from https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py
"""

import os
from typing import Tuple, Dict
from gym.wrappers.monitoring import video_recorder
import gym
import numpy as np
import jax
import jax.numpy as jnp
import cv2


class AtariEnv:
    def __init__(
        self,
        name: str,
    ) -> None:
        self.name = name
        self.state_height, self.state_width = (84, 84)
        self.n_stacked_frames = 4
        self.n_skipped_frames = 4

        self.env = gym.make(
            f"ALE/{self.name}-v5",
            full_action_space=False,
            frameskip=1,
            repeat_action_probability=0.25,
            render_mode="rgb_array",
        ).env

        self.n_actions = self.env.action_space.n
        self.original_state_height, self.original_state_width, _ = self.env.observation_space._shape
        self.screen_buffer = [
            np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
            np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
        ]

    @property
    def observation(self) -> np.ndarray:
        return np.copy(self.state_[:, :, -1])

    @property
    def state(self) -> np.ndarray:
        return jnp.array(self.state_, dtype=jnp.float32)

    def reset(self) -> None:
        self.env.reset()

        self.n_steps = 0

        self.env.ale.getScreenGrayscale(self.screen_buffer[0])
        self.screen_buffer[1].fill(0)

        self.state_ = np.zeros((self.state_height, self.state_width, self.n_stacked_frames), dtype=np.uint8)
        self.state_[:, :, -1] = self.resize()

    def step(self, action: jnp.int8) -> Tuple[float, bool, Dict]:
        reward = 0

        for idx_frame in range(self.n_skipped_frames):
            _, reward_, terminal, _ = self.env.step(action)

            reward += reward_

            if idx_frame >= self.n_skipped_frames - 2:
                t = idx_frame - (self.n_skipped_frames - 2)
                self.env.ale.getScreenGrayscale(self.screen_buffer[t])

            if terminal:
                break

        self.state_ = np.roll(self.state_, -1, axis=-1)
        self.state_[:, :, -1] = self.pool_and_resize()

        self.n_steps += 1

        return reward, terminal, _

    def pool_and_resize(self) -> np.ndarray:
        np.maximum(self.screen_buffer[0], self.screen_buffer[1], out=self.screen_buffer[0])

        return self.resize()

    def resize(self):
        return np.asarray(
            cv2.resize(self.screen_buffer[0], (self.state_width, self.state_height), interpolation=cv2.INTER_AREA),
            dtype=np.uint8,
        )

    def evaluate_one_simulation(
        self,
        q,
        q_params: Dict,
        horizon: int,
        eps_eval: float,
        exploration_key: jax.random.PRNGKey,
        video_path: str,
    ) -> float:
        video = video_recorder.VideoRecorder(
            self.env, path=f"{video_path}.mp4", enabled=True if video_path is not None else False
        )
        sun_reward = 0
        terminal = False
        self.reset()

        while not terminal and self.n_steps < horizon:
            self.env.render(mode="rgb_array")
            video.capture_frame()

            exploration_key, key = jax.random.split(exploration_key)
            if jax.random.uniform(key) < eps_eval:
                action = jax.random.choice(key, jnp.arange(self.n_actions)).astype(jnp.int8)
            else:
                action = q.best_action(q_params, self.state, key)

            reward, terminal, _ = self.step(action)

            sun_reward += reward

        video.close()
        if video_path is not None:
            os.remove(f"{video_path}.meta.json")

        return sun_reward, terminal