odiaz1066 commited on
Commit
2f53815
1 Parent(s): 18dd671

Modify to show only 1 model

Browse files
Files changed (1) hide show
  1. app.py +12 -27
app.py CHANGED
@@ -5,7 +5,7 @@ from huggingface_hub.repocard import metadata_load
5
 
6
  app = gr.Blocks()
7
 
8
- def load_agent(model_id_1, model_id_2):
9
  """
10
  This function load the agent's video and results
11
  :return: video_path
@@ -19,16 +19,7 @@ def load_agent(model_id_1, model_id_2):
19
  # Load the video
20
  video_path_1 = hf_hub_download(model_id_1, filename="replay.mp4")
21
 
22
- # Load the metrics
23
- metadata_2 = get_metadata(model_id_2)
24
-
25
- # Get the accuracy
26
- results_2 = parse_metrics_accuracy(metadata_2)
27
-
28
- # Load the video
29
- video_path_2 = hf_hub_download(model_id_2, filename="replay.mp4")
30
-
31
- return model_id_1, video_path_1, results_1, model_id_2, video_path_2, results_2
32
 
33
  def parse_metrics_accuracy(meta):
34
  if "model-index" not in meta:
@@ -58,32 +49,26 @@ def get_metadata(model_id):
58
  with app:
59
  gr.Markdown(
60
  """
61
- # Compare Deep Reinforcement Learning Agents 🤖
62
 
63
- Type two models id you want to compare or check examples below.
64
  """)
65
  with gr.Row():
66
- model1_input = gr.Textbox(label="Model 1")
67
- model2_input = gr.Textbox(label="Model 2")
68
  with gr.Row():
69
- app_button = gr.Button("Compare models")
70
  with gr.Row():
71
  with gr.Column():
72
  model1_name = gr.Markdown()
73
  model1_video_output = gr.Video()
74
- model1_score_output = gr.Textbox(label="Mean Reward +/- Std Reward")
75
- with gr.Column():
76
- model2_name = gr.Markdown()
77
- model2_video_output = gr.Video()
78
- model2_score_output = gr.Textbox(label="Mean Reward +/- Std Reward")
79
 
80
- app_button.click(load_agent, inputs=[model1_input, model2_input], outputs=[model1_name, model1_video_output, model1_score_output, model2_name, model2_video_output, model2_score_output])
81
 
82
- examples = gr.Examples(examples=[["sb3/a2c-AntBulletEnv-v0","sb3/ppo-AntBulletEnv-v0"],
83
- ["ThomasSimonini/a2c-AntBulletEnv-v0", "sb3/a2c-AntBulletEnv-v0"],
84
- ["sb3/dqn-SpaceInvadersNoFrameskip-v4", "sb3/a2c-SpaceInvadersNoFrameskip-v4"],
85
- ["ThomasSimonini/ppo-QbertNoFrameskip-v4","sb3/ppo-QbertNoFrameskip-v4"]],
86
- inputs=[model1_input, model2_input])
87
 
88
 
89
  app.launch()
 
5
 
6
  app = gr.Blocks()
7
 
8
+ def load_agent(model_id_1):
9
  """
10
  This function load the agent's video and results
11
  :return: video_path
 
19
  # Load the video
20
  video_path_1 = hf_hub_download(model_id_1, filename="replay.mp4")
21
 
22
+ return model_id_1, video_path_1, results_1
 
 
 
 
 
 
 
 
 
23
 
24
  def parse_metrics_accuracy(meta):
25
  if "model-index" not in meta:
 
49
  with app:
50
  gr.Markdown(
51
  """
52
+ # Observa los agentes de Lagomorph en acción 🤖
53
 
54
+ Selecciona el modelo a observar.
55
  """)
56
  with gr.Row():
57
+ model1_input = gr.Textbox(label="Modelo")
 
58
  with gr.Row():
59
+ app_button = gr.Button("Mostrar modelo")
60
  with gr.Row():
61
  with gr.Column():
62
  model1_name = gr.Markdown()
63
  model1_video_output = gr.Video()
64
+ model1_score_output = gr.Textbox(label="Recompensa +/- Variación")
 
 
 
 
65
 
66
+ app_button.click(load_agent, inputs=[model1_input], outputs=[model1_name, model1_video_output, model1_score_output])
67
 
68
+ examples = gr.Examples(examples=[["odiaz1066/CartPoleGoal-Lagomorph-seed42"],
69
+ ["odiaz1066/BreakoutNoFrameskip-v4-dqn_atari-seed1"],
70
+ ["odiaz1066/PongNoFrameskip-v4-dqn_atari-seed1"]],
71
+ inputs=[model1_input])
 
72
 
73
 
74
  app.launch()