Edward Beeching commited on
Commit
b481772
1 Parent(s): bf95e48

created live model eval example

Browse files
Files changed (2) hide show
  1. app.py +65 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+
4
+ from huggingface_sb3 import load_from_hub
5
+
6
+ from stable_baselines3 import PPO
7
+ from stable_baselines3.common.env_util import make_atari_env
8
+ from stable_baselines3.common.vec_env import VecFrameStack
9
+
10
+ from stable_baselines3.common.env_util import make_atari_env
11
+
12
+ st.title("Atari Environments Live Model")
13
+
14
+ # @st.cache This is not cachable :(
15
+ def load_env(env_name):
16
+ env = make_atari_env(env_name, n_envs=1)
17
+ env = VecFrameStack(env, n_stack=4)
18
+ return env
19
+
20
+
21
+ # @st.cache This is not cachable :(
22
+ def load_model(env_name):
23
+ custom_objects = {
24
+ "learning_rate": 0.0,
25
+ "lr_schedule": lambda _: 0.0,
26
+ "clip_range": lambda _: 0.0,
27
+ }
28
+
29
+ checkpoint = load_from_hub(
30
+ f"ThomasSimonini/ppo-{env_name}",
31
+ f"ppo-{env_name}.zip",
32
+ )
33
+
34
+ model = PPO.load(checkpoint, custom_objects=custom_objects)
35
+
36
+ return model
37
+
38
+
39
+ env_name = st.selectbox(
40
+ "Select environment",
41
+ (
42
+ "SpaceInvadersNoFrameskip-v4",
43
+ "PongNoFrameskip-v4",
44
+ "SeaquestNoFrameskip-v4",
45
+ "QbertNoFrameskip-v4",
46
+ ),
47
+ )
48
+
49
+ num_episodes = st.slider("Number of Episodes", 1, 20, 5)
50
+ env = load_env(env_name)
51
+ model = load_model(env_name)
52
+
53
+ obs = env.reset()
54
+
55
+ with st.empty():
56
+ for i in range(num_episodes):
57
+ obs = env.reset()
58
+ done = False
59
+ while not done:
60
+ frame = env.render(mode="rgb_array")
61
+ im = st.image(frame, width=400)
62
+ action, _states = model.predict(obs)
63
+ obs, reward, done, info = env.step([action])
64
+
65
+ time.sleep(0.1)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ huggingface_sb3
2
+ gym
3
+ stable-baselines3[extra]