Spaces:
Runtime error
Runtime error
import streamlit as st | |
import time | |
from huggingface_sb3 import load_from_hub | |
from stable_baselines3 import PPO | |
from stable_baselines3.common.env_util import make_atari_env | |
from stable_baselines3.common.vec_env import VecFrameStack | |
from stable_baselines3.common.env_util import make_atari_env | |
st.title("Atari Environments Live Model") | |
# @st.cache This is not cachable :( | |
def load_env(env_name): | |
env = make_atari_env(env_name, n_envs=1) | |
env = VecFrameStack(env, n_stack=4) | |
return env | |
# @st.cache This is not cachable :( | |
def load_model(env_name): | |
custom_objects = { | |
"learning_rate": 0.0, | |
"lr_schedule": lambda _: 0.0, | |
"clip_range": lambda _: 0.0, | |
} | |
checkpoint = load_from_hub( | |
f"ThomasSimonini/ppo-{env_name}", | |
f"ppo-{env_name}.zip", | |
) | |
model = PPO.load(checkpoint, custom_objects=custom_objects) | |
return model | |
env_name = st.selectbox( | |
"Select environment", | |
( | |
"SpaceInvadersNoFrameskip-v4", | |
"PongNoFrameskip-v4", | |
"SeaquestNoFrameskip-v4", | |
"QbertNoFrameskip-v4", | |
), | |
) | |
num_episodes = st.slider("Number of Episodes", 1, 20, 5) | |
env = load_env(env_name) | |
model = load_model(env_name) | |
obs = env.reset() | |
with st.empty(): | |
for i in range(num_episodes): | |
obs = env.reset() | |
done = False | |
while not done: | |
frame = env.render(mode="rgb_array") | |
im = st.image(frame, width=400) | |
action, _states = model.predict(obs) | |
obs, reward, done, info = env.step([action]) | |
time.sleep(0.1) | |