{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import pickle\n", "from atari import AtariEnv\n", "from networks import AtariiDQN, AtariiIQN\n", "from networks import AtariiIQN\n", "\n", "# ------- START TO MODIFY ------- #\n", "IDQN_ALGO = True # if False then i-IQN is evaluated\n", "GAME = \"Alien\"\n", "NETWORK_SEED = 1 # seed in [1, 2, 3, 4, 5]\n", "EVALUATION_SEED = 0\n", "HORIZON = 27000\n", "ENDING_EPS = 0.01\n", "RECORD_VIDEO = False\n", "\n", "### 56 games are available for i-DQN with 5 seeds each:\n", "# Alien, Amidar, Assault, Asterix, Asteroids, Atlantis, \n", "# BankHeist, BattleZone, BeamRider, Berzerk, Bowling, Boxing, Breakout, Centipede, \n", "# ChopperCommand, CrazyClimber, DemonAttack, DoubleDunk, Enduro, FishingDerby, \n", "# Freeway, Frostbite, Gopher, Gravitar, Hero, IceHockey, Jamesbond, Kangaroo, \n", "# Krull, KungFuMaster, MontezumaRevenge, MsPacman, NameThisGame, Phoenix, Pitfall, \n", "# Pong, Pooyan, PrivateEye, Qbert, Riverraid, RoadRunner, Robotank, Seaquest, Skiing, \n", "# Solaris, SpaceInvaders, StarGunner, Tennis, TimePilot, Tutankham, UpNDown, Venture, \n", "# VideoPinball, WizardOfWor, YarsRevenge, Zaxxon\n", "\n", "## 20 games are available for i-IQN with 5 seeds each:\n", "# Alien, Assault, BankHeist, Berzerk, Breakout, Centipede, \n", "# ChopperCommand, DemonAttack, Enduro, Frostbite, Gopher, \n", "# Gravitar, IceHockey, Jamesbond, Krull, KungFuMaster, \n", "# Riverraid, Seaquest, Skiing, StarGunner\n", "# ------- END TO MODIFY ------- #\n", "\n", "\n", "params_path = f\"parameters/{GAME}/{'iDQN' if IDQN_ALGO else 'iIQN'}/{5 if IDQN_ALGO else 3}_Q_{NETWORK_SEED}_best_online_params\"\n", "\n", "env = AtariEnv(GAME)\n", "\n", "if IDQN_ALGO:\n", " q = AtariiDQN(env.n_actions, idx_head=0) # idx_head in [0, 1, 2, 3, 4, 5]\n", "else:\n", " q = AtariiIQN(env.n_actions, idx_head=0) # idx_head in [0, 1, 2, 3]\n", "\n", "with open(params_path, \"rb\") as handle:\n", " q_params = pickle.load(handle)\n", "\n", "reward, absorbing = env.evaluate_one_simulation(\n", " q, q_params, HORIZON, ENDING_EPS, jax.random.PRNGKey(EVALUATION_SEED), params_path if RECORD_VIDEO else None\n", ")\n", "print(\"Undiscounted reward:\", reward)\n", "print(\"N steps\", env.n_steps, \"; Horizon\", HORIZON, \"; Absorbing\", absorbing)" ] } ], "metadata": { "kernelspec": { "display_name": "env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }