Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,8 @@ import gymnasium as gym
|
|
11 |
from stable_baselines3 import DQN
|
12 |
from stable_baselines3.common.evaluation import evaluate_policy
|
13 |
import gradio as gr
|
|
|
|
|
14 |
|
15 |
# Constants
|
16 |
SCREEN_WIDTH = 640
|
@@ -182,6 +184,7 @@ def train_and_play():
|
|
182 |
total_timesteps = 10000
|
183 |
timesteps_per_update = 1000
|
184 |
frames = []
|
|
|
185 |
|
186 |
for i in range(0, total_timesteps, timesteps_per_update):
|
187 |
model.learn(total_timesteps=timesteps_per_update)
|
@@ -193,18 +196,32 @@ def train_and_play():
|
|
193 |
action, _states = model.predict(obs, deterministic=True)
|
194 |
obs, reward, done, truncated, info = env.step(action)
|
195 |
env.render()
|
196 |
-
|
|
|
|
|
|
|
197 |
episode_frames.append(gr.Image(value="frame.png"))
|
198 |
frames.extend(episode_frames)
|
199 |
yield frames
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
# Main function
|
202 |
def main():
|
203 |
# Gradio interface
|
204 |
iface = gr.Interface(
|
205 |
fn=train_and_play,
|
206 |
inputs=None,
|
207 |
-
outputs="
|
208 |
live=True
|
209 |
)
|
210 |
iface.launch()
|
@@ -220,6 +237,8 @@ if __name__ == "__main__":
|
|
220 |
# - torch
|
221 |
# - gradio
|
222 |
# - gymnasium
|
|
|
|
|
223 |
#
|
224 |
# You can install these dependencies using pip:
|
225 |
-
# pip install pygame stable-baselines3 torch gradio gymnasium
|
|
|
11 |
from stable_baselines3 import DQN
|
12 |
from stable_baselines3.common.evaluation import evaluate_policy
|
13 |
import gradio as gr
|
14 |
+
import cv2
|
15 |
+
import imageio
|
16 |
|
17 |
# Constants
|
18 |
SCREEN_WIDTH = 640
|
|
|
184 |
total_timesteps = 10000
|
185 |
timesteps_per_update = 1000
|
186 |
frames = []
|
187 |
+
video_frames = []
|
188 |
|
189 |
for i in range(0, total_timesteps, timesteps_per_update):
|
190 |
model.learn(total_timesteps=timesteps_per_update)
|
|
|
196 |
action, _states = model.predict(obs, deterministic=True)
|
197 |
obs, reward, done, truncated, info = env.step(action)
|
198 |
env.render()
|
199 |
+
# Capture the current frame
|
200 |
+
frame = pygame.surfarray.array3d(pygame.display.get_surface())
|
201 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
202 |
+
video_frames.append(frame)
|
203 |
episode_frames.append(gr.Image(value="frame.png"))
|
204 |
frames.extend(episode_frames)
|
205 |
yield frames
|
206 |
|
207 |
+
# Save the video
|
208 |
+
video_path = "arkanoid_training.mp4"
|
209 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
210 |
+
video_writer = cv2.VideoWriter(video_path, fourcc, FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
|
211 |
+
for frame in video_frames:
|
212 |
+
video_writer.write(frame)
|
213 |
+
video_writer.release()
|
214 |
+
|
215 |
+
# Return the video path
|
216 |
+
return gr.Video(video_path)
|
217 |
+
|
218 |
# Main function
|
219 |
def main():
|
220 |
# Gradio interface
|
221 |
iface = gr.Interface(
|
222 |
fn=train_and_play,
|
223 |
inputs=None,
|
224 |
+
outputs="video",
|
225 |
live=True
|
226 |
)
|
227 |
iface.launch()
|
|
|
237 |
# - torch
|
238 |
# - gradio
|
239 |
# - gymnasium
|
240 |
+
# - opencv-python
|
241 |
+
# - imageio
|
242 |
#
|
243 |
# You can install these dependencies using pip:
|
244 |
+
# pip install pygame stable-baselines3 torch gradio gymnasium opencv-python imageio
|