ydshieh HF staff commited on
Commit
1ff1aaf
1 Parent(s): ad630c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -14
app.py CHANGED
@@ -8,6 +8,7 @@ import torchvision.transforms as T
8
  from PIL import Image
9
  from transformers import AutoProcessor, AutoModelForVision2Seq
10
  import cv2
 
11
 
12
  colors = [
13
  (0, 255, 0),
@@ -39,7 +40,7 @@ def is_overlapping(rect1, rect2):
39
  return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
40
 
41
 
42
- def draw_entity_boxes_on_image(image, entities, show=False, save_path=None):
43
  """_summary_
44
  Args:
45
  image (_type_): image or image path
@@ -69,10 +70,14 @@ def draw_entity_boxes_on_image(image, entities, show=False, save_path=None):
69
  image = np.array(pil_img)[:, :, [2, 1, 0]]
70
  else:
71
  raise ValueError(f"invaild image format, {type(image)} for {image}")
72
-
73
  if len(entities) == 0:
74
  return image
75
 
 
 
 
 
76
  # Not to show too many bboxes
77
  entities = entities[:len(color_map)]
78
 
@@ -92,11 +97,13 @@ def draw_entity_boxes_on_image(image, entities, show=False, save_path=None):
92
  used_colors = colors # random.sample(colors, k=num_bboxes)
93
 
94
  color_id = -1
95
- for entity_name, (start, end), bboxes in entities:
96
  color_id += 1
 
 
97
  for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
98
- if start is None and bbox_id > 0:
99
- color_id += 1
100
  orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)
101
 
102
  # draw bbox
@@ -199,13 +206,17 @@ def main():
199
 
200
  color_id = -1
201
  entity_info = []
202
- for entity_name, (start, end), bboxes in entities:
 
 
 
 
203
  color_id += 1
204
- for bbox_id, _ in enumerate(bboxes):
205
- if start is None and bbox_id > 0:
206
- color_id += 1
207
- if start is not None:
208
- entity_info.append(((start, end), color_id))
209
 
210
  colored_text = []
211
  prev_start = 0
@@ -219,7 +230,7 @@ def main():
219
  if end < len(processed_text):
220
  colored_text.append((processed_text[end:len(processed_text)], None))
221
 
222
- return annotated_image, colored_text
223
 
224
  term_of_use = """
225
  ### Terms of use
@@ -271,12 +282,33 @@ def main():
271
  ], inputs=[image_input, text_input, do_sample])
272
  gr.Markdown(term_of_use)
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  run_button.click(fn=generate_predictions,
275
  inputs=[image_input, text_input, do_sample, sampling_topp, sampling_temperature],
276
- outputs=[image_output, text_output1],
277
  show_progress=True, queue=True)
278
 
279
- demo.launch()
280
 
281
 
282
  if __name__ == "__main__":
 
8
  from PIL import Image
9
  from transformers import AutoProcessor, AutoModelForVision2Seq
10
  import cv2
11
+ import ast
12
 
13
  colors = [
14
  (0, 255, 0),
 
40
  return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
41
 
42
 
43
+ def draw_entity_boxes_on_image(image, entities, show=False, save_path=None, entity_index=-1):
44
  """_summary_
45
  Args:
46
  image (_type_): image or image path
 
70
  image = np.array(pil_img)[:, :, [2, 1, 0]]
71
  else:
72
  raise ValueError(f"invaild image format, {type(image)} for {image}")
73
+
74
  if len(entities) == 0:
75
  return image
76
 
77
+ indices = list(range(len(entities)))
78
+ if entity_index >= 0:
79
+ indices = [entity_index]
80
+
81
  # Not to show too many bboxes
82
  entities = entities[:len(color_map)]
83
 
 
97
  used_colors = colors # random.sample(colors, k=num_bboxes)
98
 
99
  color_id = -1
100
+ for entity_idx, (entity_name, (start, end), bboxes) in enumerate(entities):
101
  color_id += 1
102
+ if entity_idx not in indices:
103
+ continue
104
  for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
105
+ # if start is None and bbox_id > 0:
106
+ # color_id += 1
107
  orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)
108
 
109
  # draw bbox
 
206
 
207
  color_id = -1
208
  entity_info = []
209
+ filtered_entities = []
210
+ for entity in entities:
211
+ entity_name, (start, end), bboxes = entity
212
+ if start is None:
213
+ continue
214
  color_id += 1
215
+ # for bbox_id, _ in enumerate(bboxes):
216
+ # if start is None and bbox_id > 0:
217
+ # color_id += 1
218
+ entity_info.append(((start, end), color_id))
219
+ filtered_entities.append(entity)
220
 
221
  colored_text = []
222
  prev_start = 0
 
230
  if end < len(processed_text):
231
  colored_text.append((processed_text[end:len(processed_text)], None))
232
 
233
+ return annotated_image, colored_text, str(filtered_entities)
234
 
235
  term_of_use = """
236
  ### Terms of use
 
282
  ], inputs=[image_input, text_input, do_sample])
283
  gr.Markdown(term_of_use)
284
 
285
+ # record which text span (label) is selected
286
+ selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
287
+
288
+ # record the current `entities`
289
+ entity_output = gr.Textbox(visible=False)
290
+
291
+ # get the current selected span label
292
+ def get_text_span_label(evt: gr.SelectData):
293
+ if evt.value[-1] is None:
294
+ return -1
295
+ return int(evt.value[-1])
296
+ # and set this information to `selected`
297
+ text_output1.select(get_text_span_label, None, selected)
298
+
299
+ # update output image when we change the span (enity) selection
300
+ def update_output_image(img_input, image_output, entities, idx):
301
+ entities = ast.literal_eval(entities)
302
+ updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
303
+ return updated_image
304
+ selected.change(update_output_image, [image_input, image_output, entity_output, selected], [image_output])
305
+
306
  run_button.click(fn=generate_predictions,
307
  inputs=[image_input, text_input, do_sample, sampling_topp, sampling_temperature],
308
+ outputs=[image_output, text_output1, entity_output],
309
  show_progress=True, queue=True)
310
 
311
+ demo.launch(share=True)
312
 
313
 
314
  if __name__ == "__main__":