Koshti10 commited on
Commit
56ba131
·
verified ·
1 Parent(s): 6a98aa1

Upload 10 files

Browse files
Files changed (4) hide show
  1. app.py +76 -10
  2. requirements.txt +2 -1
  3. src/boards.py +7 -7
  4. src/utils.py +12 -0
app.py CHANGED
@@ -1,26 +1,92 @@
1
  import gradio as gr
2
  import numpy as np
3
- import matplotlib.pyplot as plt
4
 
5
- from src.boards import GenerateBoard
6
 
7
  TITLE = """<h1 align="center" id="space-title"> Pento-LLaVA 🤖🎯🎮</h1>"""
8
 
9
- initial_board_image, target_positions, info = GenerateBoard('easy', 18).setup_initial_board()
 
10
 
11
- # Convert initial_board_image to a matplotlib figure
12
- fig, ax = plt.subplots()
13
- ax.imshow(initial_board_image)
14
- ax.axis('off')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- pento_llava_app = gr.Blocks()
17
 
18
  with pento_llava_app:
19
 
20
  gr.HTML(TITLE)
21
- gr.Plot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  pento_llava_app.load()
24
 
25
  pento_llava_app.queue()
26
- pento_llava_app.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
+ import plotly.graph_objects as go
4
 
5
+ from src.utils import select_board
6
 
7
  TITLE = """<h1 align="center" id="space-title"> Pento-LLaVA 🤖🎯🎮</h1>"""
8
 
9
+ pento_llava_app = gr.Blocks()
10
+ fig, targets, info = select_board('easy', 18, 0)
11
 
12
+ # Update the figure to include a boundary/frame
13
+ fig.update_layout(
14
+ xaxis=dict(showline=True, linecolor='black', linewidth=2, showgrid=True, zeroline=True),
15
+ yaxis=dict(showline=True, linecolor='black', linewidth=2, showgrid=True, zeroline=True),
16
+ margin=dict(l=0, r=0, t=0, b=0),
17
+ height=512,
18
+ width=512
19
+ )
20
+
21
+ target_strs = []
22
+ for t in targets:
23
+ target_strs.append(t['target_str'])
24
+
25
+ def gen_new_board(value):
26
+ value = int(value)
27
+ fig, _, _ = select_board('easy', 18, value)
28
+
29
+ # Update the figure to include a boundary/frame
30
+ fig.update_layout(
31
+ xaxis=dict(showline=True, linecolor='black', linewidth=2, showgrid=True, zeroline=True),
32
+ yaxis=dict(showline=True, linecolor='black', linewidth=2, showgrid=True, zeroline=True),
33
+ margin=dict(l=0, r=0, t=0, b=0),
34
+ height=512,
35
+ width=512
36
+ )
37
+
38
+ return fig
39
+
40
+ def gen_info(value):
41
+ value = int(value)
42
+ _, _, info = select_board('easy', 18, value)
43
+
44
+ return info
45
+
46
+ def gen_target_str(value):
47
+ value = int(value)
48
+ _, _, info = select_board('easy', 18, value)
49
+ target_info = info[0]
50
+ target_str = f"<span>Target piece for this episode is <span style='color: {target_info['piece_colour']};'>{target_info['piece_colour']}</span> <span style='color: {target_info['piece_colour']};'>{target_info['piece_shape']}</span> located at <span style='color: {target_info['piece_colour']};'>{target_info['piece_region']}</span></span>"
51
+
52
+ return target_str
53
 
 
54
 
55
  with pento_llava_app:
56
 
57
  gr.HTML(TITLE)
58
+
59
+ with gr.Row():
60
+ with gr.Column():
61
+ main_board = gr.Plot(fig)
62
+
63
+ with gr.Column():
64
+ with gr.Row():
65
+ select_board_items = gr.Dropdown(
66
+ choices=range(512),
67
+ interactive=True
68
+ )
69
+
70
+ with gr.Row():
71
+ display_string = gr.HTML(value=gen_target_str(0))
72
+
73
+
74
+ select_board_items.change(
75
+ fn=gen_new_board,
76
+ inputs=[select_board_items],
77
+ outputs=[main_board],
78
+ queue=True
79
+ )
80
+
81
+ select_board_items.change(
82
+ fn=gen_target_str,
83
+ inputs=[select_board_items],
84
+ outputs=[display_string],
85
+ queue=True
86
+ )
87
+
88
 
89
  pento_llava_app.load()
90
 
91
  pento_llava_app.queue()
92
+ pento_llava_app.launch()
requirements.txt CHANGED
@@ -14,4 +14,5 @@ accelerate==0.31.0
14
  bitsandbytes
15
  datasets==3.0.1
16
  gradio==5.5.0
17
- matplotlib==3.9.2
 
 
14
  bitsandbytes
15
  datasets==3.0.1
16
  gradio==5.5.0
17
+ matplotlib==3.9.2
18
+ plotly==5.24.1
src/boards.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
 
8
  class GenerateBoard():
9
 
10
- def __init__(self, level: str, board_size: int):
11
  self.level = level
12
  self.board_size = board_size
13
 
@@ -16,9 +16,8 @@ class GenerateBoard():
16
  with open(metadata_path, 'r') as f:
17
  metadata = json.load(f)
18
 
19
- num_boards = len(metadata)
20
- random_board_num = np.random.randint(0, num_boards)
21
- self.board_data = metadata[random_board_num]
22
 
23
  def setup_initial_board(self):
24
 
@@ -29,10 +28,11 @@ class GenerateBoard():
29
  info = metadata_obj['info']
30
  target_options = []
31
  for piece in info:
 
32
  target = f"{piece['piece_colour']} {piece['piece_shape']} at {piece['piece_region']}"
33
- target_options.append(target)
34
-
35
-
36
 
37
  env = GridWorldEnv(render_mode="rgb_array", size=self.board_size, grid_info=info, agent_pos=default_start_pos, target_pos=default_target_pos)
38
  env.reset()
 
7
 
8
  class GenerateBoard():
9
 
10
+ def __init__(self, level: str, board_size: int, board_number: int):
11
  self.level = level
12
  self.board_size = board_size
13
 
 
16
  with open(metadata_path, 'r') as f:
17
  metadata = json.load(f)
18
 
19
+ self.board_num = board_number
20
+ self.board_data = metadata[self.board_num]
 
21
 
22
  def setup_initial_board(self):
23
 
 
28
  info = metadata_obj['info']
29
  target_options = []
30
  for piece in info:
31
+ target_info_dict = {}
32
  target = f"{piece['piece_colour']} {piece['piece_shape']} at {piece['piece_region']}"
33
+ target_info_dict['target_str'] = target
34
+ target_info_dict['piece_info'] = piece
35
+ target_options.append(target_info_dict)
36
 
37
  env = GridWorldEnv(render_mode="rgb_array", size=self.board_size, grid_info=info, agent_pos=default_start_pos, target_pos=default_target_pos)
38
  env.reset()
src/utils.py CHANGED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.boards import GenerateBoard
2
+ import plotly.express as px # Import Plotly Express
3
+
4
+ def select_board(level: str = 'easy', size: int = 18, board_number: int = 0):
5
+ initial_board_image, target_positions, info = GenerateBoard(level, size, board_number).setup_initial_board()
6
+
7
+ # Convert initial_board_image to a Plotly figure
8
+ fig = px.imshow(initial_board_image) # Use Plotly's imshow
9
+ fig.update_xaxes(showticklabels=False) # Hide x-axis ticks
10
+ fig.update_yaxes(showticklabels=False) # Hide y-axis ticks
11
+
12
+ return fig, target_positions, info