|
import torch |
|
import streamlit as st |
|
import pygame |
|
import os |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import pandas as pd |
|
import numpy as np |
|
from collections import deque |
|
import random |
|
from typing import List |
|
from argparse import Action |
|
import random |
|
import sys |
|
from sqlalchemy import asc |
|
import math |
|
import time |
|
from tqdm import tqdm |
|
from datetime import datetime |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
SCREEN_HEIGHT = 600 |
|
SCREEN_WIDTH = 1100 |
|
|
|
INIT_GAME_SPEED = 14 |
|
X_POS_BG_INIT = 0 |
|
Y_POS_BG = 380 |
|
|
|
INIT_REPLAY_MEM_SIZE = 5_000 |
|
REPLAY_MEMORY_SIZE = 45_000 |
|
MODEL_NAME = "DINO" |
|
MIN_REPLAY_MEMORY_SIZE = 1_000 |
|
MINIBATCH_SIZE = 64 |
|
DISCOUNT = 0.95 |
|
UPDATE_TARGET_THRESH = 5 |
|
|
|
EPSILON_INIT = 0.25 |
|
|
|
EPSILON_DECAY = 0.75 |
|
NUM_EPISODES = 100 |
|
MIN_EPSILON = 0.05 |
|
|
|
RUNNING = [pygame.image.load(os.path.join("Assets/Dino", "DinoRun1.png")), |
|
pygame.image.load(os.path.join("Assets/Dino", "DinoRun2.png"))] |
|
|
|
DUCKING = [pygame.image.load(os.path.join("Assets/Dino", "DinoDuck1.png")), |
|
pygame.image.load(os.path.join("Assets/Dino", "DinoDuck2.png"))] |
|
|
|
|
|
JUMPING = pygame.image.load(os.path.join("Assets/Dino", "DinoJump.png")) |
|
|
|
SMALL_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus1.png")), |
|
pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus2.png")), |
|
pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus3.png"))] |
|
|
|
|
|
LARGE_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus1.png")), |
|
pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus2.png")), |
|
pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus3.png"))] |
|
|
|
BIRD = [pygame.image.load(os.path.join("Assets/Bird", "Bird1.png")), pygame.image.load(os.path.join("Assets/Bird", "Bird2.png"))] |
|
|
|
CLOUD = pygame.image.load(os.path.join("Assets/Other", "Cloud.png")) |
|
|
|
BACKGROUND = pygame.image.load(os.path.join("Assets/Other", "Track.png")) |
|
|
|
RUNNING = [pygame.image.load(os.path.join("Assets/Dino", "DinoRun1.png")), |
|
pygame.image.load(os.path.join("Assets/Dino", "DinoRun2.png"))] |
|
|
|
DUCKING = [pygame.image.load(os.path.join("Assets/Dino", "DinoDuck1.png")), |
|
pygame.image.load(os.path.join("Assets/Dino", "DinoDuck2.png"))] |
|
|
|
|
|
JUMPING = pygame.image.load(os.path.join("Assets/Dino", "DinoJump.png")) |
|
|
|
SMALL_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus1.png")), |
|
pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus2.png")), |
|
pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus3.png"))] |
|
|
|
|
|
LARGE_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus1.png")), |
|
pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus2.png")), |
|
pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus3.png"))] |
|
|
|
BIRD = [pygame.image.load(os.path.join("Assets/Bird", "Bird1.png")), pygame.image.load(os.path.join("Assets/Bird", "Bird2.png"))] |
|
|
|
CLOUD = pygame.image.load(os.path.join("Assets/Other", "Cloud.png")) |
|
|
|
BACKGROUND = pygame.image.load(os.path.join("Assets/Other", "Track.png")) |
|
|
|
class NeuralNetwork(nn.Module): |
|
def __init__(self): |
|
super(NeuralNetwork, self).__init__() |
|
self.fc1 = nn.Linear(7, 4) |
|
self.fc2 = nn.Linear(4, 3) |
|
|
|
def forward(self, x): |
|
x = torch.relu(self.fc1(x)) |
|
x = self.fc2(x) |
|
return x |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
class DQNAgent: |
|
def __init__(self): |
|
self.model = NeuralNetwork().to(device) |
|
self.target_model = NeuralNetwork().to(device) |
|
self.target_model.load_state_dict(self.model.state_dict()) |
|
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) |
|
self.loss_function = nn.MSELoss() |
|
|
|
self.init_replay_memory = deque(maxlen=INIT_REPLAY_MEM_SIZE) |
|
self.late_replay_memory = deque(maxlen=REPLAY_MEMORY_SIZE) |
|
self.target_update_counter = 0 |
|
|
|
def update_replay_memory(self, transition): |
|
|
|
|
|
if len(self.init_replay_memory) < INIT_REPLAY_MEM_SIZE: |
|
self.init_replay_memory.append(transition) |
|
else: |
|
self.late_replay_memory.append(transition) |
|
|
|
|
|
def get_qs(self, state): |
|
state_tensor = torch.Tensor(state).to(device) |
|
with torch.no_grad(): |
|
return self.model(state_tensor).cpu().numpy() |
|
|
|
def train(self, terminal_state, step): |
|
if len(self.init_replay_memory) < MIN_REPLAY_MEMORY_SIZE: |
|
return |
|
|
|
total_mem = list(self.init_replay_memory) |
|
total_mem.extend(self.late_replay_memory) |
|
minibatch = random.sample(total_mem, MINIBATCH_SIZE) |
|
|
|
|
|
current_states = torch.Tensor([transition[0] for transition in minibatch]).to(device) |
|
current_qs_list = self.model(current_states) |
|
new_current_states = torch.Tensor([transition[3] for transition in minibatch]).to(device) |
|
future_qs_list = self.target_model(new_current_states) |
|
|
|
X = [] |
|
y = [] |
|
|
|
for index, (current_state, action, reward, new_current_state, done) in enumerate(minibatch): |
|
if not done: |
|
max_future_q = torch.max(future_qs_list[index]) |
|
new_q = reward + DISCOUNT * max_future_q |
|
else: |
|
new_q = reward |
|
|
|
current_qs = current_qs_list[index] |
|
current_qs[action] = new_q |
|
|
|
X.append(current_state) |
|
y.append(current_qs) |
|
|
|
X = torch.tensor(np.array(X, dtype=np.float32)).to(device) |
|
y = torch.tensor(np.array([y_item.detach().cpu().numpy() if isinstance(y_item, torch.Tensor) else y_item for y_item in y], dtype=np.float32)).to(device) |
|
|
|
self.optimizer.zero_grad() |
|
output = self.model(X) |
|
loss = self.loss_function(output, y) |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
if terminal_state: |
|
self.target_update_counter += 1 |
|
|
|
if self.target_update_counter > UPDATE_TARGET_THRESH: |
|
self.target_model.load_state_dict(self.model.state_dict()) |
|
self.target_update_counter = 0 |
|
|
|
|
|
class Obstacle: |
|
def __init__(self, image: List[pygame.Surface], type: int) -> None: |
|
self.image = image |
|
self.type = type |
|
self.rect = self.image[self.type].get_rect() |
|
self.rect.x = SCREEN_WIDTH |
|
|
|
def update(self, obstacles: list, game_speed: int): |
|
self.rect.x -= game_speed |
|
if self.rect.x < -self.rect.width: |
|
obstacles.pop() |
|
|
|
def draw(self, SCREEN: pygame.Surface): |
|
SCREEN.blit(self.image[self.type], self.rect) |
|
|
|
class Dino(DQNAgent): |
|
X_POS = 80 |
|
Y_POS = 310 |
|
Y_DUCK_POS = 340 |
|
JUMP_VEL = 8.5 |
|
|
|
def __init__(self) -> None: |
|
|
|
self.duck_img = DUCKING |
|
self.run_img = RUNNING |
|
self.jump_img = JUMPING |
|
|
|
|
|
|
|
self.dino_duck = False |
|
self.dino_run = True |
|
self.dino_jump = False |
|
|
|
self.step_index = 0 |
|
self.jump_vel = self.JUMP_VEL |
|
self.image = self.run_img[0] |
|
self.dino_rect = self.image.get_rect() |
|
|
|
self.dino_rect.x = self.X_POS |
|
self.dino_rect.y = self.Y_POS |
|
|
|
self.score = 0 |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
def update(self, move: pygame.key.ScancodeWrapper): |
|
if self.dino_duck: |
|
self.duck() |
|
|
|
if self.dino_jump: |
|
self.jump() |
|
|
|
if self.dino_run: |
|
self.run() |
|
|
|
if self.step_index >= 20: |
|
self.step_index = 0 |
|
|
|
|
|
if move[pygame.K_UP] and not self.dino_jump: |
|
self.dino_jump = True |
|
self.dino_run = False |
|
self.dino_duck = False |
|
|
|
elif move[pygame.K_DOWN] and not self.dino_jump: |
|
self.dino_duck = True |
|
self.dino_run = False |
|
self.dino_jump = False |
|
|
|
elif not(self.dino_jump or move[pygame.K_DOWN]): |
|
self.dino_run = True |
|
self.dino_jump = False |
|
self.dino_duck = False |
|
|
|
def update_auto(self, move): |
|
if self.dino_duck == True: |
|
self.duck() |
|
|
|
if self.dino_jump == True: |
|
self.jump() |
|
|
|
if self.dino_run == True: |
|
self.run() |
|
|
|
if self.step_index >= 20: |
|
self.step_index = 0 |
|
|
|
if move == 0 and not self.dino_jump: |
|
self.dino_jump = True |
|
self.dino_run = False |
|
self.dino_duck = False |
|
|
|
elif move == 1 and not self.dino_jump: |
|
self.dino_duck = True |
|
self.dino_run = False |
|
self.dino_jump = False |
|
|
|
elif not(self.dino_jump or move == 1): |
|
self.dino_run = True |
|
self.dino_jump = False |
|
self.dino_duck = False |
|
|
|
def duck(self) -> None: |
|
self.image = self.duck_img[self.step_index // 10] |
|
self.dino_rect = self.image.get_rect() |
|
self.dino_rect.x = self.X_POS |
|
self.dino_rect.y = self.Y_DUCK_POS |
|
self.step_index += 1 |
|
|
|
def run(self) -> None: |
|
self.image = self.run_img[self.step_index // 10] |
|
self.dino_rect = self.image.get_rect() |
|
self.dino_rect.x = self.X_POS |
|
self.dino_rect.y = self.Y_POS |
|
self.step_index += 1 |
|
|
|
|
|
def jump(self) -> None: |
|
self.image = self.jump_img |
|
if self.dino_jump: |
|
self.dino_rect.y -= self.jump_vel * 3 |
|
self.jump_vel -= 0.6 |
|
|
|
if self.jump_vel < -self.JUMP_VEL: |
|
self.dino_jump = False |
|
self.dino_run = True |
|
self.jump_vel = self.JUMP_VEL |
|
|
|
def draw(self, SCREEN: pygame.Surface): |
|
SCREEN.blit(self.image, (self.dino_rect.x, self.dino_rect.y)) |
|
|
|
class LargeCactus(Obstacle): |
|
def __init__(self, image: List[pygame.Surface]) -> None: |
|
self.type = random.randint(0, 2) |
|
super().__init__(image, self.type) |
|
self.rect.y = 300 |
|
|
|
|
|
class SmallCactus(Obstacle): |
|
def __init__(self, image: List[pygame.Surface]) -> None: |
|
self.type = random.randint(0, 2) |
|
super().__init__(image, self.type) |
|
self.rect.y = 325 |
|
|
|
class Bird(Obstacle): |
|
def __init__(self, image: List[pygame.Surface]) -> None: |
|
self.type = 0 |
|
super().__init__(image, self.type) |
|
self.rect.y = SCREEN_HEIGHT - 340 |
|
self.index = 0 |
|
|
|
def draw(self, SCREEN: pygame.Surface): |
|
if self.index >= 19: |
|
self.index = 0 |
|
|
|
SCREEN.blit(self.image[self.index // 10], self.rect) |
|
self.index += 1 |
|
|
|
class Cloud: |
|
def __init__(self) -> None: |
|
self.x = SCREEN_WIDTH + random.randint(800, 1000) |
|
self.y = random.randint(50, 100) |
|
self.image = CLOUD |
|
self.width = self.image.get_width() |
|
|
|
def update(self, game_speed: int): |
|
self.x -= game_speed |
|
if self.x < -self.width: |
|
self.x = SCREEN_WIDTH + random.randint(800, 1000) |
|
self.y = random.randint(50, 100) |
|
|
|
|
|
def draw(self, SCREEN: pygame.Surface): |
|
SCREEN.blit(self.image, (self.x, self.y)) |
|
|
|
class Game: |
|
def __init__(self, epsilon, load_model=False, model_path=None): |
|
os.environ["SDL_VIDEODRIVER"] = "dummy" |
|
pygame.init() |
|
self.SCREEN = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT)) |
|
|
|
self.obstacles = [] |
|
|
|
self.run = True |
|
|
|
self.clock = pygame.time.Clock() |
|
|
|
self.cloud = Cloud() |
|
|
|
self.game_speed = INIT_GAME_SPEED |
|
|
|
self.font = pygame.font.Font("freesansbold.ttf", 20) |
|
|
|
self.dino = Dino() |
|
|
|
|
|
if load_model and model_path: |
|
self.dino.model.load_state_dict(torch.load(model_path, map_location=device)) |
|
|
|
self.x_pos_bg = X_POS_BG_INIT |
|
|
|
self.points = 0 |
|
|
|
self.epsilon = epsilon |
|
|
|
self.ep_rewards = [-200] |
|
|
|
self.high_score = 0 |
|
|
|
self.best_score = 0 |
|
|
|
def reset(self): |
|
self.game_speed = INIT_GAME_SPEED |
|
old_dino = self.dino |
|
self.dino = Dino() |
|
self.dino.init_replay_memory = old_dino.init_replay_memory |
|
self.dino.late_replay_memory = old_dino.late_replay_memory |
|
self.dino.target_update_counter = old_dino.target_update_counter |
|
|
|
self.dino.model.load_state_dict(old_dino.model.state_dict()) |
|
self.dino.target_model.load_state_dict(old_dino.target_model.state_dict()) |
|
|
|
self.x_pos_bg = X_POS_BG_INIT |
|
self.points = 0 |
|
self.SCREEN = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT)) |
|
self.clock = pygame.time.Clock() |
|
|
|
def get_dist(self, pos_a: tuple, pos_b:tuple): |
|
dx = pos_a[0] - pos_b[0] |
|
dy = pos_a[1] - pos_b[1] |
|
|
|
return math.sqrt(dx**2 + dy**2) |
|
|
|
def update_background(self): |
|
image_width = BACKGROUND.get_width() |
|
|
|
self.SCREEN.blit(BACKGROUND, (self.x_pos_bg, Y_POS_BG)) |
|
self.SCREEN.blit(BACKGROUND, (self.x_pos_bg + image_width, Y_POS_BG)) |
|
|
|
if self.x_pos_bg <= -image_width: |
|
self.SCREEN.blit(BACKGROUND, (self.x_pos_bg + image_width, Y_POS_BG)) |
|
self.x_pos_bg = 0 |
|
|
|
self.x_pos_bg -= self.game_speed |
|
return self.x_pos_bg |
|
|
|
def get_state(self): |
|
state = [] |
|
state.append(self.dino.dino_rect.y / self.dino.Y_DUCK_POS + 10) |
|
pos_a = (self.dino.dino_rect.x, self.dino.dino_rect.y) |
|
bird = 0 |
|
cactus = 0 |
|
if len(self.obstacles) == 0: |
|
dist = self.get_dist(pos_a, tuple([SCREEN_WIDTH + 10, self.dino.Y_POS])) / math.sqrt(SCREEN_HEIGHT**2 + SCREEN_WIDTH**2) |
|
obs_height = 0 |
|
obj_width = 0 |
|
else: |
|
dist = self.get_dist(pos_a, (self.obstacles[0].rect.midtop)) / math.sqrt(SCREEN_HEIGHT**2 + SCREEN_WIDTH**2) |
|
obs_height = self.obstacles[0].rect.midtop[1] / self.dino.Y_DUCK_POS |
|
obj_width = self.obstacles[0].rect.width / SMALL_CACTUS[2].get_rect().width |
|
if self.obstacles[0].__class__ == SmallCactus(SMALL_CACTUS).__class__ or \ |
|
self.obstacles[0].__class__ == LargeCactus(LARGE_CACTUS).__class__: |
|
cactus = 1 |
|
else: |
|
bird = 1 |
|
|
|
state.append(dist) |
|
state.append(obs_height) |
|
state.append(self.game_speed / 24) |
|
state.append(obj_width) |
|
state.append(cactus) |
|
state.append(bird) |
|
|
|
return state |
|
|
|
|
|
def update_score(self): |
|
self.points += 1 |
|
if self.points % 200 == 0: |
|
self.game_speed += 1 |
|
|
|
if self.points > self.high_score: |
|
self.high_score = self.points |
|
|
|
text = self.font.render(f"Points: {self.points} Highscore: {self.high_score}", True, (0, 0, 0)) |
|
textRect = text.get_rect() |
|
textRect.center = (SCREEN_WIDTH - textRect.width // 2 - 10, 40) |
|
self.SCREEN.blit(text, textRect) |
|
|
|
|
|
def create_obstacle(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
obstacle_prob = random.randint(0, 50) |
|
if obstacle_prob == 0: |
|
self.obstacles.append(SmallCactus(SMALL_CACTUS)) |
|
elif obstacle_prob == 1: |
|
self.obstacles.append(LargeCactus(LARGE_CACTUS)) |
|
elif obstacle_prob == 2 and self.points > 300: |
|
self.obstacles.append(Bird(BIRD)) |
|
|
|
def update_game(self, moves, user_input=None): |
|
self.dino.draw(self.SCREEN) |
|
if user_input is not None: |
|
self.dino.update(user_input) |
|
else: |
|
self.dino.update_auto(moves) |
|
|
|
self.update_background() |
|
|
|
self.cloud.draw(self.SCREEN) |
|
|
|
self.cloud.update(self.game_speed) |
|
|
|
self.update_score() |
|
|
|
self.clock.tick(30) |
|
|
|
|
|
|
|
def play_manual(self): |
|
|
|
while self.run is True: |
|
for event in pygame.event.get(): |
|
if event.type == pygame.QUIT: |
|
sys.exit() |
|
|
|
self.SCREEN.fill((255, 255, 255)) |
|
user_input = pygame.key.get_pressed() |
|
|
|
|
|
if len(self.obstacles) == 0: |
|
self.create_obstacle() |
|
|
|
for obstacle in self.obstacles: |
|
obstacle.draw(SCREEN=self.SCREEN) |
|
obstacle.update(self.obstacles, self.game_speed) |
|
if self.dino.dino_rect.colliderect(obstacle.rect): |
|
self.dino.score = self.points |
|
pygame.quit() |
|
self.obstacles.pop() |
|
print("Game over!") |
|
return |
|
|
|
self.update_game(user_input=user_input, moves=2) |
|
pygame.display.update() |
|
|
|
|
|
def play_auto(self,episode_info): |
|
try: |
|
points_label = 0 |
|
for episode in tqdm(range(1, NUM_EPISODES + 1), ascii=True, unit='episodes'): |
|
episode_reward = 0 |
|
step = 1 |
|
current_state = self.get_state() |
|
self.run = True |
|
|
|
episode_info.text(f'Escenario: {episode}, Puntuación actual: {self.points}, Recompensa del episodio: {episode_reward}') |
|
while self.run is True: |
|
|
|
for event in pygame.event.get(): |
|
if event.type == pygame.QUIT: |
|
sys.exit() |
|
|
|
self.SCREEN.fill((255, 255, 255)) |
|
|
|
if len(self.obstacles) == 0: |
|
self.create_obstacle() |
|
|
|
|
|
|
|
|
|
|
|
|
|
if np.random.random() > self.epsilon: |
|
action = self.dino.get_qs(torch.Tensor(current_state)) |
|
|
|
action = np.argmax(action) |
|
|
|
else: |
|
num = np.random.randint(0, 10) |
|
if num == 0: |
|
|
|
action = num |
|
elif num <= 3: |
|
action = 1 |
|
else: |
|
action = 2 |
|
|
|
self.update_game(moves=action) |
|
|
|
next_state = self.get_state() |
|
reward = 0 |
|
|
|
for obstacle in self.obstacles: |
|
obstacle.draw(SCREEN=self.SCREEN) |
|
obstacle.update(self.obstacles, self.game_speed) |
|
next_state = self.get_state() |
|
if self.dino.dino_rect.x > obstacle.rect.x + obstacle.rect.width: |
|
reward = 3 |
|
|
|
if action == 0 and obstacle.rect.x > SCREEN_WIDTH // 2: |
|
reward = -1 |
|
|
|
if self.dino.dino_rect.colliderect(obstacle.rect): |
|
self.dino.score = self.points |
|
|
|
self.obstacles.pop() |
|
points_label = self.points |
|
self.reset() |
|
reward = -10 |
|
|
|
self.run = False |
|
break |
|
|
|
|
|
|
|
episode_reward += reward |
|
|
|
self.dino.update_replay_memory(tuple([current_state, action, reward, next_state, self.run])) |
|
|
|
self.dino.train( not self.run, step=step) |
|
|
|
current_state = next_state |
|
|
|
step += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.points > self.best_score: |
|
self.best_score = self.points |
|
|
|
self.best_model_filename = 'models/highscore/BestScore_model.pth' |
|
torch.save(self.dino.model.state_dict(), self.best_model_filename) |
|
|
|
pygame.display.update() |
|
|
|
|
|
self.ep_rewards.append(episode_reward) |
|
|
|
|
|
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') |
|
|
|
|
|
if episode % 50 == 0: |
|
filename = f'models/episodes/{points_label}_Points,Episode_{episode}_Date_{current_time}_model.pth' |
|
torch.save(self.dino.model.state_dict(), filename) |
|
|
|
|
|
if self.epsilon > MIN_EPSILON: |
|
self.epsilon *= EPSILON_DECAY |
|
if self.epsilon < MIN_EPSILON: |
|
self.epsilon = 0 |
|
|
|
else: |
|
self.epsilon = max(MIN_EPSILON, self.epsilon) |
|
|
|
|
|
finally: |
|
|
|
|
|
if hasattr(self, 'best_model_filename'): |
|
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') |
|
final_model_filename = f'models/highscore/{self.best_score}_BestScore_Final_{current_time}_model.pth' |
|
import shutil |
|
shutil.copy(self.best_model_filename, final_model_filename) |
|
print(f"Modelo duplicado guardado como: {final_model_filename}") |
|
|
|
|
|
def plot_rewards(ep_rewards): |
|
plt.figure(figsize=(10, 6)) |
|
plt.plot(ep_rewards) |
|
plt.title("Recompensas por Episodio") |
|
plt.xlabel("Episodio") |
|
plt.ylabel("Recompensa") |
|
st.pyplot(plt) |
|
|
|
|
|
|
|
def streamlit_ui(): |
|
st.title('Juego del Dinosaurio con IA') |
|
|
|
|
|
with st.sidebar: |
|
st.header("Configuraciones") |
|
epsilon_init = st.slider("Epsilon Inicial", 0.025, 0.975, EPSILON_INIT) |
|
epsilon_decay = st.slider("Epsilon Decay", 0.025, 0.975, EPSILON_DECAY) |
|
num_episodes = st.slider("Número de Episodios", 1, 500, NUM_EPISODES) |
|
|
|
|
|
model_directory = 'models/highscore/' |
|
model_files = os.listdir(model_directory) |
|
selected_model_file = st.selectbox('Elige un modelo para cargar', model_files) |
|
|
|
|
|
score_col, highscore_col = st.columns(2) |
|
with score_col: |
|
score = st.empty() |
|
with highscore_col: |
|
high_score = st.empty() |
|
|
|
episode_info = st.empty() |
|
|
|
|
|
if st.button('Iniciar Juego con IA'): |
|
model_path = os.path.join(model_directory, selected_model_file) |
|
game = Game(EPSILON_INIT, load_model=True, model_path=model_path) |
|
game.play_auto(episode_info) |
|
|
|
|
|
if len(game.ep_rewards) > 0: |
|
plot_rewards(game.ep_rewards) |
|
|
|
|
|
streamlit_ui() |