|
import os |
|
os.environ["XDG_RUNTIME_DIR"] = "/tmp" |
|
import numpy as np |
|
import pygame |
|
import random |
|
import gymnasium as gym |
|
from stable_baselines3 import DQN |
|
from stable_baselines3.common.evaluation import evaluate_policy |
|
import gradio as gr |
|
import cv2 |
|
|
|
|
|
SCREEN_WIDTH = 640 |
|
SCREEN_HEIGHT = 480 |
|
PADDLE_WIDTH = 100 |
|
PADDLE_HEIGHT = 10 |
|
BALL_RADIUS = 10 |
|
BRICK_WIDTH = 60 |
|
BRICK_HEIGHT = 20 |
|
BRICK_ROWS = 5 |
|
BRICK_COLS = 10 |
|
FPS = 40 |
|
|
|
|
|
WHITE = (255, 255, 255) |
|
BLACK = (0, 0, 0) |
|
RED = (255, 0, 0) |
|
|
|
|
|
pygame.init() |
|
|
|
|
|
class Paddle: |
|
def __init__(self): |
|
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT) |
|
|
|
def move(self, direction): |
|
if direction == -1: |
|
self.rect.x -= 10 |
|
elif direction == 1: |
|
self.rect.x += 10 |
|
self.rect.clamp_ip(pygame.Rect(0, 0, SCREEN_WIDTH, SCREEN_HEIGHT)) |
|
|
|
class Ball: |
|
def __init__(self): |
|
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2) |
|
self.velocity = [random.choice([-5, 5]), -5] |
|
|
|
def move(self): |
|
self.rect.x += self.velocity[0] |
|
self.rect.y += self.velocity[1] |
|
|
|
if self.rect.left <= 0 or self.rect.right >= SCREEN_WIDTH: |
|
self.velocity[0] = -self.velocity[0] |
|
if self.rect.top <= 0: |
|
self.velocity[1] = -self.velocity[1] |
|
|
|
def reset(self): |
|
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2) |
|
self.velocity = [random.choice([-5, 5]), -5] |
|
|
|
class Brick: |
|
def __init__(self, x, y): |
|
self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5) |
|
|
|
class ArkanoidEnv(gym.Env): |
|
def __init__(self, reward_size=1, penalty_size=-1, platform_reward=5): |
|
super(ArkanoidEnv, self).__init__() |
|
self.action_space = gym.spaces.Discrete(3) |
|
self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32) |
|
self.reward_size = reward_size |
|
self.penalty_size = penalty_size |
|
self.platform_reward = platform_reward |
|
self.reset() |
|
|
|
def reset(self, seed=None, options=None): |
|
if seed is not None: |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
self.paddle = Paddle() |
|
self.ball = Ball() |
|
self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT) |
|
for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)] |
|
self.done = False |
|
self.score = 0 |
|
return self._get_state(), {} |
|
|
|
def step(self, action): |
|
if action == 0: |
|
self.paddle.move(0) |
|
elif action == 1: |
|
self.paddle.move(-1) |
|
elif action == 2: |
|
self.paddle.move(1) |
|
|
|
self.ball.move() |
|
|
|
if self.ball.rect.colliderect(self.paddle.rect): |
|
self.ball.velocity[1] = -self.ball.velocity[1] |
|
self.score += self.platform_reward |
|
|
|
for brick in self.bricks[:]: |
|
if self.ball.rect.colliderect(brick.rect): |
|
self.bricks.remove(brick) |
|
self.ball.velocity[1] = -self.ball.velocity[1] |
|
self.score += 1 |
|
reward = self.reward_size |
|
if not self.bricks: |
|
reward += self.reward_size * 10 |
|
self.done = True |
|
truncated = False |
|
return self._get_state(), reward, self.done, truncated, {} |
|
|
|
if self.ball.rect.bottom >= SCREEN_HEIGHT: |
|
self.done = True |
|
reward = self.penalty_size |
|
truncated = False |
|
else: |
|
reward = 0 |
|
truncated = False |
|
|
|
return self._get_state(), reward, self.done, truncated, {} |
|
|
|
def _get_state(self): |
|
state = [ |
|
self.paddle.rect.x, |
|
self.ball.rect.x, |
|
self.ball.rect.y, |
|
self.ball.velocity[0], |
|
self.ball.velocity[1] |
|
] |
|
for brick in self.bricks: |
|
state.extend([brick.rect.x, brick.rect.y]) |
|
state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks))) |
|
return np.array(state, dtype=np.float32) |
|
|
|
def render(self, mode='rgb_array'): |
|
surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT)) |
|
surface.fill(BLACK) |
|
pygame.draw.rect(surface, WHITE, self.paddle.rect) |
|
pygame.draw.ellipse(surface, WHITE, self.ball.rect) |
|
for brick in self.bricks: |
|
pygame.draw.rect(surface, RED, brick.rect) |
|
|
|
if mode == 'rgb_array': |
|
return pygame.surfarray.array3d(surface) |
|
elif mode == 'human': |
|
pygame.display.get_surface().blit(surface, (0, 0)) |
|
pygame.display.flip() |
|
|
|
def close(self): |
|
pygame.quit() |
|
|
|
|
|
def train_and_play(reward_size, penalty_size, platform_reward, iterations): |
|
env = ArkanoidEnv(reward_size=reward_size, penalty_size=penalty_size, platform_reward=platform_reward) |
|
model = DQN('MlpPolicy', env, verbose=1) |
|
timesteps_per_update = min(1000, iterations) |
|
video_frames = [] |
|
|
|
completed_iterations = 0 |
|
while completed_iterations < iterations: |
|
steps = min(timesteps_per_update, iterations - completed_iterations) |
|
model.learn(total_timesteps=steps) |
|
completed_iterations += steps |
|
|
|
obs, _ = env.reset() |
|
done = False |
|
while not done: |
|
action, _states = model.predict(obs, deterministic=True) |
|
obs, reward, done, truncated, _ = env.step(action) |
|
|
|
frame = env.render(mode='rgb_array') |
|
frame = np.rot90(frame) |
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
video_frames.append(frame) |
|
|
|
video_path = "arkanoid_training.mp4" |
|
video_writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), FPS, (SCREEN_WIDTH, SCREEN_HEIGHT)) |
|
for frame in video_frames: |
|
video_writer.write(frame) |
|
video_writer.release() |
|
|
|
env.close() |
|
return video_path |
|
|
|
|
|
def main(): |
|
iface = gr.Interface( |
|
fn=train_and_play, |
|
inputs=[ |
|
gr.Number(label="Reward Size", value=1), |
|
gr.Number(label="Penalty Size", value=-1), |
|
gr.Number(label="Platform Reward", value=5), |
|
gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000) |
|
], |
|
outputs="video", |
|
live=False |
|
) |
|
iface.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|