|
|
|
|
|
|
|
|
|
|
|
import os |
|
import numpy as np |
|
import pygame |
|
import random |
|
import gymnasium as gym |
|
from stable_baselines3 import DQN |
|
from stable_baselines3.common.evaluation import evaluate_policy |
|
import gradio as gr |
|
import cv2 |
|
import imageio |
|
|
|
|
|
SCREEN_WIDTH = 640 |
|
SCREEN_HEIGHT = 480 |
|
PADDLE_WIDTH = 100 |
|
PADDLE_HEIGHT = 10 |
|
BALL_RADIUS = 10 |
|
BRICK_WIDTH = 60 |
|
BRICK_HEIGHT = 20 |
|
BRICK_ROWS = 5 |
|
BRICK_COLS = 10 |
|
FPS = 60 |
|
|
|
|
|
WHITE = (255, 255, 255) |
|
BLACK = (0, 0, 0) |
|
RED = (255, 0, 0) |
|
|
|
|
|
pygame.init() |
|
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT)) |
|
pygame.display.set_caption("Arkanoid") |
|
|
|
|
|
class Paddle: |
|
def __init__(self): |
|
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT) |
|
|
|
def move(self, direction): |
|
if direction == -1: |
|
self.rect.x -= 10 |
|
elif direction == 1: |
|
self.rect.x += 10 |
|
self.rect.clamp_ip(pygame.Rect(0, 0, SCREEN_WIDTH, SCREEN_HEIGHT)) |
|
|
|
class Ball: |
|
def __init__(self): |
|
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2) |
|
self.velocity = [random.choice([-5, 5]), -5] |
|
|
|
def move(self): |
|
self.rect.x += self.velocity[0] |
|
self.rect.y += self.velocity[1] |
|
|
|
if self.rect.left <= 0 or self.rect.right >= SCREEN_WIDTH: |
|
self.velocity[0] = -self.velocity[0] |
|
if self.rect.top <= 0: |
|
self.velocity[1] = -self.velocity[1] |
|
|
|
def reset(self): |
|
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2) |
|
self.velocity = [random.choice([-5, 5]), -5] |
|
|
|
class Brick: |
|
def __init__(self, x, y): |
|
self.rect = pygame.Rect(x, y, BRICK_WIDTH, BRICK_HEIGHT) |
|
|
|
class ArkanoidEnv(gym.Env): |
|
def __init__(self): |
|
super(ArkanoidEnv, self).__init__() |
|
self.action_space = gym.spaces.Discrete(3) |
|
self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32) |
|
self.seed_value = None |
|
self.reset() |
|
|
|
def reset(self, seed=None, options=None): |
|
if seed is not None: |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
self.seed_value = seed |
|
self.paddle = Paddle() |
|
self.ball = Ball() |
|
self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT) for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)] |
|
self.done = False |
|
self.score = 0 |
|
return self._get_state(), {} |
|
|
|
def step(self, action): |
|
if action == 0: |
|
self.paddle.move(0) |
|
elif action == 1: |
|
self.paddle.move(-1) |
|
elif action == 2: |
|
self.paddle.move(1) |
|
|
|
self.ball.move() |
|
|
|
if self.ball.rect.colliderect(self.paddle.rect): |
|
self.ball.velocity[1] = -self.ball.velocity[1] |
|
|
|
for brick in self.bricks[:]: |
|
if self.ball.rect.colliderect(brick.rect): |
|
self.bricks.remove(brick) |
|
self.ball.velocity[1] = -self.ball.velocity[1] |
|
self.score += 1 |
|
reward = 1 |
|
if not self.bricks: |
|
reward += 10 |
|
self.done = True |
|
truncated = False |
|
return self._get_state(), reward, self.done, truncated, {} |
|
|
|
if self.ball.rect.bottom >= SCREEN_HEIGHT: |
|
self.done = True |
|
reward = -1 |
|
truncated = False |
|
else: |
|
reward = 0 |
|
truncated = False |
|
|
|
return self._get_state(), reward, self.done, truncated, {} |
|
|
|
def _get_state(self): |
|
state = [ |
|
self.paddle.rect.x, |
|
self.ball.rect.x, |
|
self.ball.rect.y, |
|
self.ball.velocity[0], |
|
self.ball.velocity[1] |
|
] |
|
for brick in self.bricks: |
|
state.extend([brick.rect.x, brick.rect.y]) |
|
state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks))) |
|
return np.array(state, dtype=np.float32) |
|
|
|
def render(self, mode='human'): |
|
screen.fill(BLACK) |
|
pygame.draw.rect(screen, WHITE, self.paddle.rect) |
|
pygame.draw.ellipse(screen, WHITE, self.ball.rect) |
|
for brick in self.bricks: |
|
pygame.draw.rect(screen, RED, brick.rect) |
|
pygame.display.flip() |
|
pygame.time.Clock().tick(FPS) |
|
|
|
def close(self): |
|
pygame.quit() |
|
|
|
|
|
def train_model(env, total_timesteps=10000): |
|
model = DQN('MlpPolicy', env, verbose=1) |
|
model.learn(total_timesteps=total_timesteps) |
|
model.save("arkanoid_model") |
|
return model |
|
|
|
|
|
def evaluate_model(model, env): |
|
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False) |
|
return mean_reward |
|
|
|
|
|
def play_game(): |
|
env = ArkanoidEnv() |
|
model = DQN.load("arkanoid_model") |
|
obs = env.reset()[0] |
|
done = False |
|
frames = [] |
|
while not done: |
|
action, _states = model.predict(obs, deterministic=True) |
|
obs, reward, done, truncated, info = env.step(action) |
|
env.render() |
|
pygame.image.save(screen, "frame.png") |
|
frames.append(gr.Image(value="frame.png")) |
|
return frames |
|
|
|
|
|
def train_and_play(): |
|
env = ArkanoidEnv() |
|
model = DQN('MlpPolicy', env, verbose=1) |
|
total_timesteps = 10000 |
|
timesteps_per_update = 1000 |
|
frames = [] |
|
video_frames = [] |
|
|
|
for i in range(0, total_timesteps, timesteps_per_update): |
|
model.learn(total_timesteps=timesteps_per_update) |
|
obs = env.reset()[0] |
|
done = False |
|
truncated = False |
|
episode_frames = [] |
|
while not done and not truncated: |
|
action, _states = model.predict(obs, deterministic=True) |
|
obs, reward, done, truncated, info = env.step(action) |
|
env.render() |
|
|
|
frame = pygame.surfarray.array3d(pygame.display.get_surface()) |
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
video_frames.append(frame) |
|
episode_frames.append(gr.Image(value="frame.png")) |
|
frames.extend(episode_frames) |
|
yield frames |
|
|
|
|
|
video_path = "arkanoid_training.mp4" |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
video_writer = cv2.VideoWriter(video_path, fourcc, FPS, (SCREEN_WIDTH, SCREEN_HEIGHT)) |
|
for frame in video_frames: |
|
video_writer.write(frame) |
|
video_writer.release() |
|
|
|
|
|
return gr.Video(video_path) |
|
|
|
|
|
def main(): |
|
|
|
iface = gr.Interface( |
|
fn=train_and_play, |
|
inputs=None, |
|
outputs="video", |
|
live=True |
|
) |
|
iface.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|