game / app.py
Ivan000's picture
Update app.py
a5cda12 verified
raw
history blame
7.7 kB
# app.py
# =============
# This is a complete app.py file for an Arkanoid game that a neural network will play and learn using reinforcement learning.
# The game is built using pygame, and the neural network is trained using stable-baselines3. Gradio is used for the interface.
import os
import numpy as np
import pygame
import random
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
import gradio as gr
import cv2
import imageio
# Constants
SCREEN_WIDTH = 640
SCREEN_HEIGHT = 480
PADDLE_WIDTH = 100
PADDLE_HEIGHT = 10
BALL_RADIUS = 10
BRICK_WIDTH = 60
BRICK_HEIGHT = 20
BRICK_ROWS = 5
BRICK_COLS = 10
FPS = 60
# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
# Initialize Pygame
pygame.init()
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("Arkanoid")
# Game classes
class Paddle:
def __init__(self):
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT)
def move(self, direction):
if direction == -1:
self.rect.x -= 10
elif direction == 1:
self.rect.x += 10
self.rect.clamp_ip(pygame.Rect(0, 0, SCREEN_WIDTH, SCREEN_HEIGHT))
class Ball:
def __init__(self):
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
self.velocity = [random.choice([-5, 5]), -5]
def move(self):
self.rect.x += self.velocity[0]
self.rect.y += self.velocity[1]
if self.rect.left <= 0 or self.rect.right >= SCREEN_WIDTH:
self.velocity[0] = -self.velocity[0]
if self.rect.top <= 0:
self.velocity[1] = -self.velocity[1]
def reset(self):
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
self.velocity = [random.choice([-5, 5]), -5]
class Brick:
def __init__(self, x, y):
self.rect = pygame.Rect(x, y, BRICK_WIDTH, BRICK_HEIGHT)
class ArkanoidEnv(gym.Env):
def __init__(self):
super(ArkanoidEnv, self).__init__()
self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32)
self.seed_value = None
self.reset()
def reset(self, seed=None, options=None):
if seed is not None:
random.seed(seed)
np.random.seed(seed)
self.seed_value = seed
self.paddle = Paddle()
self.ball = Ball()
self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT) for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
self.done = False
self.score = 0
return self._get_state(), {}
def step(self, action):
if action == 0:
self.paddle.move(0)
elif action == 1:
self.paddle.move(-1)
elif action == 2:
self.paddle.move(1)
self.ball.move()
if self.ball.rect.colliderect(self.paddle.rect):
self.ball.velocity[1] = -self.ball.velocity[1]
for brick in self.bricks[:]:
if self.ball.rect.colliderect(brick.rect):
self.bricks.remove(brick)
self.ball.velocity[1] = -self.ball.velocity[1]
self.score += 1
reward = 1
if not self.bricks:
reward += 10 # Bonus reward for breaking all bricks
self.done = True
truncated = False
return self._get_state(), reward, self.done, truncated, {}
if self.ball.rect.bottom >= SCREEN_HEIGHT:
self.done = True
reward = -1
truncated = False
else:
reward = 0
truncated = False
return self._get_state(), reward, self.done, truncated, {}
def _get_state(self):
state = [
self.paddle.rect.x,
self.ball.rect.x,
self.ball.rect.y,
self.ball.velocity[0],
self.ball.velocity[1]
]
for brick in self.bricks:
state.extend([brick.rect.x, brick.rect.y])
state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks))) # Padding for missing bricks
return np.array(state, dtype=np.float32)
def render(self, mode='human'):
screen.fill(BLACK)
pygame.draw.rect(screen, WHITE, self.paddle.rect)
pygame.draw.ellipse(screen, WHITE, self.ball.rect)
for brick in self.bricks:
pygame.draw.rect(screen, RED, brick.rect)
pygame.display.flip()
pygame.time.Clock().tick(FPS)
def close(self):
pygame.quit()
# Training function
def train_model(env, total_timesteps=10000):
model = DQN('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=total_timesteps)
model.save("arkanoid_model")
return model
# Evaluation function
def evaluate_model(model, env):
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
return mean_reward
# Gradio interface
def play_game():
env = ArkanoidEnv()
model = DQN.load("arkanoid_model")
obs = env.reset()[0]
done = False
frames = []
while not done:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = env.step(action)
env.render()
pygame.image.save(screen, "frame.png")
frames.append(gr.Image(value="frame.png"))
return frames
# Real-time training function
def train_and_play():
env = ArkanoidEnv()
model = DQN('MlpPolicy', env, verbose=1)
total_timesteps = 10000
timesteps_per_update = 1000
frames = []
video_frames = []
for i in range(0, total_timesteps, timesteps_per_update):
model.learn(total_timesteps=timesteps_per_update)
obs = env.reset()[0]
done = False
truncated = False
episode_frames = []
while not done and not truncated:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = env.step(action)
env.render()
# Capture the current frame
frame = pygame.surfarray.array3d(pygame.display.get_surface())
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
video_frames.append(frame)
episode_frames.append(gr.Image(value="frame.png"))
frames.extend(episode_frames)
yield frames
# Save the video
video_path = "arkanoid_training.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(video_path, fourcc, FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
for frame in video_frames:
video_writer.write(frame)
video_writer.release()
# Return the video path
return gr.Video(video_path)
# Main function
def main():
# Gradio interface
iface = gr.Interface(
fn=train_and_play,
inputs=None,
outputs="video",
live=True
)
iface.launch()
if __name__ == "__main__":
main()
# Dependencies
# =============
# The following dependencies are required to run this app:
# - pygame
# - stable-baselines3
# - torch
# - gradio
# - gymnasium
# - opencv-python
# - imageio
#
# You can install these dependencies using pip:
# pip install pygame stable-baselines3 torch gradio gymnasium opencv-python imageio