Ivan000 commited on
Commit
a5cda12
·
verified ·
1 Parent(s): bd793a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -3
app.py CHANGED
@@ -11,6 +11,8 @@ import gymnasium as gym
11
  from stable_baselines3 import DQN
12
  from stable_baselines3.common.evaluation import evaluate_policy
13
  import gradio as gr
 
 
14
 
15
  # Constants
16
  SCREEN_WIDTH = 640
@@ -182,6 +184,7 @@ def train_and_play():
182
  total_timesteps = 10000
183
  timesteps_per_update = 1000
184
  frames = []
 
185
 
186
  for i in range(0, total_timesteps, timesteps_per_update):
187
  model.learn(total_timesteps=timesteps_per_update)
@@ -193,18 +196,32 @@ def train_and_play():
193
  action, _states = model.predict(obs, deterministic=True)
194
  obs, reward, done, truncated, info = env.step(action)
195
  env.render()
196
- pygame.image.save(screen, "frame.png")
 
 
 
197
  episode_frames.append(gr.Image(value="frame.png"))
198
  frames.extend(episode_frames)
199
  yield frames
200
 
 
 
 
 
 
 
 
 
 
 
 
201
  # Main function
202
  def main():
203
  # Gradio interface
204
  iface = gr.Interface(
205
  fn=train_and_play,
206
  inputs=None,
207
- outputs="image",
208
  live=True
209
  )
210
  iface.launch()
@@ -220,6 +237,8 @@ if __name__ == "__main__":
220
  # - torch
221
  # - gradio
222
  # - gymnasium
 
 
223
  #
224
  # You can install these dependencies using pip:
225
- # pip install pygame stable-baselines3 torch gradio gymnasium
 
11
  from stable_baselines3 import DQN
12
  from stable_baselines3.common.evaluation import evaluate_policy
13
  import gradio as gr
14
+ import cv2
15
+ import imageio
16
 
17
  # Constants
18
  SCREEN_WIDTH = 640
 
184
  total_timesteps = 10000
185
  timesteps_per_update = 1000
186
  frames = []
187
+ video_frames = []
188
 
189
  for i in range(0, total_timesteps, timesteps_per_update):
190
  model.learn(total_timesteps=timesteps_per_update)
 
196
  action, _states = model.predict(obs, deterministic=True)
197
  obs, reward, done, truncated, info = env.step(action)
198
  env.render()
199
+ # Capture the current frame
200
+ frame = pygame.surfarray.array3d(pygame.display.get_surface())
201
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
202
+ video_frames.append(frame)
203
  episode_frames.append(gr.Image(value="frame.png"))
204
  frames.extend(episode_frames)
205
  yield frames
206
 
207
+ # Save the video
208
+ video_path = "arkanoid_training.mp4"
209
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
210
+ video_writer = cv2.VideoWriter(video_path, fourcc, FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
211
+ for frame in video_frames:
212
+ video_writer.write(frame)
213
+ video_writer.release()
214
+
215
+ # Return the video path
216
+ return gr.Video(video_path)
217
+
218
  # Main function
219
  def main():
220
  # Gradio interface
221
  iface = gr.Interface(
222
  fn=train_and_play,
223
  inputs=None,
224
+ outputs="video",
225
  live=True
226
  )
227
  iface.launch()
 
237
  # - torch
238
  # - gradio
239
  # - gymnasium
240
+ # - opencv-python
241
+ # - imageio
242
  #
243
  # You can install these dependencies using pip:
244
+ # pip install pygame stable-baselines3 torch gradio gymnasium opencv-python imageio