from __future__ import absolute_import, division, print_function, unicode_literals from car_dqn import CarRacingDQN import os import tensorflow as tf import gym import _thread import re import sys import numpy as np #Ensure its running og GPU print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU'))) load_checkpoint = True checkpoint_path = "data/checkpoints/train24" train_episodes = 15000 save_freq_episodes = train_episodes/100 ###############333 finished = False opendir = checkpoint_path + '.txt' text_results = open(opendir, "w") render = False frame_skip = 3 #frame_skip number n. model is trained n to n times only model_config = dict( min_epsilon=0.05, max_negative_rewards=8, min_experience_size=int(100), #######################################33 experience_capacity=int(150000), num_frame_stack=frame_skip, frame_skip=frame_skip, train_freq=frame_skip, batchsize=64, epsilon_decay_steps=int(100000), target_network_update_freq=int(1000), #Updates the target network every 10000 global steps by copying them from the prediction network to the target network gamma=0.95, render=False, ) dqn_scores = [] eps_history = [] avg_score_all = [0] env = gym.make('CarRacing-v0', verbose=False) tf.compat.v1.reset_default_graph dqn_agent = CarRacingDQN(env=env, **model_config) dqn_agent.build_graph() sess = tf.InteractiveSession() dqn_agent.session = sess #Initialize save checkpoints saver = tf.train.Saver(max_to_keep=1000) #max number of checkpoints = 500 #Choice to load checkpoints if load_checkpoint: train_episodes = 150 save_freq_episodes = 0 print("loading the latest checkpoint from %s" % checkpoint_path) ckpt = tf.train.get_checkpoint_state(checkpoint_path) assert ckpt, "checkpoint path %s not found" % checkpoint_path global_counter = int(re.findall("-(\d+)$", ckpt.model_checkpoint_path)[0]) saver.restore(sess, ckpt.model_checkpoint_path) dqn_agent.global_counter = global_counter render = True else: if checkpoint_path is not None: assert not os.path.exists(checkpoint_path), \ "checkpoint path already exists but load_checkpoint is false" tf.global_variables_initializer().run() def save_checkpoint(): if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) p = os.path.join(checkpoint_path, "m.ckpt") saver.save(sess, p, dqn_agent.global_counter) print("saved to %s - %d" % (p, dqn_agent.global_counter)) def one_episode(eps_history,dqn_scores,avg_score_all,render,load_checkpoint): score, reward, frames, epsilon = dqn_agent.play_episode(render, load_checkpoint) eps_history.append(epsilon) dqn_scores.append(score) i = dqn_agent.episode_counter avg_score = np.mean(dqn_scores[max(0, i - 100):(i + 1)]) avg_score_all.append(avg_score) max_avg_score = max(avg_score_all) if avg_score >= max_avg_score: new_max = ' => New HighScore! <= ' highscore = True else: new_max = '' highscore = False strm = ("#> episode: %i | score: %.2f | total steps: %i | epsilon: %.5f | average 100 score: %.2f" % (i, score, dqn_agent.global_counter, epsilon, avg_score)) print(strm + new_max) text_results = open(opendir, "a") text_results.write(strm + new_max + '\n') text_results.close() if not load_checkpoint: save_cond = ( dqn_agent.episode_counter % save_freq_episodes == 0 and checkpoint_path is not None and dqn_agent.do_training ) if save_cond or (highscore and dqn_agent.episode_counter > 100): save_checkpoint() return eps_history,dqn_scores,avg_score_all def input_thread(list): input("...enter to stop after current episode\n") list.append("OK") def main_loop(eps_history,dqn_scores,avg_score_all,render,load_checkpoint): #call training loop list = [] _thread.start_new_thread(input_thread, (list,)) while True: if list: break if dqn_agent.do_training and dqn_agent.episode_counter >= train_episodes: break eps_history,dqn_scores,avg_score_all = one_episode(eps_history,dqn_scores,avg_score_all,render,load_checkpoint) print("done") text_results.close() exit() return eps_history,dqn_scores,avg_score_all if train_episodes > 0 and dqn_agent.episode_counter < train_episodes and not load_checkpoint : print("now training... you can early stop with enter...") print("##########") sys.stdout.flush() main_loop(eps_history,dqn_scores,avg_score_all,render,load_checkpoint) save_checkpoint() print("ok training done") else: print("now just playing...") sys.stdout.flush() main_loop(eps_history,dqn_scores,avg_score_all,render,load_checkpoint)