liuyizhang
commited on
Commit
Β·
2a71ebd
1
Parent(s):
5c28041
add time cost by step (ms)
Browse files- app.py +40 -12
- kosmos_utils.py +1 -1
- requirements.txt +1 -1
app.py
CHANGED
@@ -519,24 +519,42 @@ def relate_anything(input_image, k):
|
|
519 |
mask_source_draw = "draw a mask on input image"
|
520 |
mask_source_segment = "type what to detect below"
|
521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
523 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
|
|
|
|
|
|
|
|
|
|
|
524 |
if (task_type == 'Kosmos-2'):
|
525 |
global kosmos_model, kosmos_processor
|
526 |
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(input_image, kosmos_input, kosmos_model, kosmos_processor)
|
527 |
-
|
|
|
528 |
|
529 |
if (task_type == 'relate anything'):
|
530 |
output_images = relate_anything(input_image['image'], num_relation)
|
531 |
-
|
|
|
532 |
|
533 |
text_prompt = text_prompt.strip()
|
534 |
if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
535 |
if text_prompt == '':
|
536 |
-
return [], gr.Gallery.update(label='Detection prompt is not found!ππππ'), None, None, None
|
537 |
|
538 |
if input_image is None:
|
539 |
-
return [], gr.Gallery.update(label='Please upload a image!ππππ'), None, None, None
|
540 |
|
541 |
file_temp = int(time.time())
|
542 |
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
|
@@ -552,10 +570,12 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
552 |
image_pil, image = load_image(input_image['image'].convert("RGB"))
|
553 |
input_img = input_image['image']
|
554 |
output_images.append(input_image['image'])
|
|
|
555 |
else:
|
556 |
image_pil, image = load_image(input_image.convert("RGB"))
|
557 |
input_img = input_image
|
558 |
output_images.append(input_image)
|
|
|
559 |
|
560 |
size = image_pil.size
|
561 |
|
@@ -576,7 +596,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
576 |
)
|
577 |
if boxes_filt.size(0) == 0:
|
578 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
|
579 |
-
return [], gr.Gallery.update(label='No objects detected, please try others.ππππ'), None, None, None
|
580 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
581 |
|
582 |
pred_dict = {
|
@@ -587,6 +607,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
587 |
|
588 |
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
|
589 |
output_images.append(image_with_box)
|
|
|
590 |
|
591 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
592 |
if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
@@ -622,12 +643,13 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
622 |
plt.savefig(image_path, bbox_inches="tight")
|
623 |
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
624 |
os.remove(image_path)
|
625 |
-
output_images.append(segment_image_result)
|
|
|
626 |
|
627 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
|
628 |
if task_type == 'detection' or task_type == 'segment':
|
629 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
630 |
-
return output_images, gr.Gallery.update(label='result images'), None, None, None
|
631 |
elif task_type == 'inpainting' or task_type == 'remove':
|
632 |
if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
|
633 |
task_type = 'remove'
|
@@ -644,6 +666,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
644 |
mask = masks[0][0].cpu().numpy()
|
645 |
mask_pil = Image.fromarray(mask)
|
646 |
output_images.append(mask_pil.convert("RGB"))
|
|
|
647 |
|
648 |
if task_type == 'inpainting':
|
649 |
# inpainting pipeline
|
@@ -682,21 +705,24 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
682 |
extend_pixels=remove_mask_extend, useRectangle=useRectangle)
|
683 |
mask_imgs.append(mask_pil_exp)
|
684 |
mask_pil = mix_masks(mask_imgs)
|
685 |
-
output_images.append(mask_pil.convert("RGB"))
|
|
|
686 |
|
687 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
|
688 |
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
|
689 |
# output_images.append(image_inpainting)
|
|
|
690 |
|
691 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
|
692 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
693 |
output_images.append(image_inpainting)
|
|
|
694 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
695 |
-
return output_images, gr.Gallery.update(label='result images'), None, None, None
|
696 |
else:
|
697 |
logger.info(f"task_type:{task_type} error!")
|
698 |
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
699 |
-
return output_images, gr.Gallery.update(label='result images'), None, None, None
|
700 |
|
701 |
def change_radio_display(task_type, mask_source_radio):
|
702 |
text_prompt_visible = True
|
@@ -828,7 +854,9 @@ if __name__ == "__main__":
|
|
828 |
|
829 |
with gr.Column():
|
830 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
|
831 |
-
).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
|
|
|
|
|
832 |
kosmos_output = gr.Image(type="pil", label="result images", visible=False)
|
833 |
kosmos_text_output = gr.HighlightedText(
|
834 |
label="Generated Description",
|
@@ -860,7 +888,7 @@ if __name__ == "__main__":
|
|
860 |
run_button.click(fn=run_anything_task, inputs=[
|
861 |
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
862 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
|
863 |
-
outputs=[image_gallery, image_gallery, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
|
864 |
|
865 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
|
866 |
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
|
|
|
519 |
mask_source_draw = "draw a mask on input image"
|
520 |
mask_source_segment = "type what to detect below"
|
521 |
|
522 |
+
def get_time_cost(run_task_time, time_cost_str):
|
523 |
+
now_time = int(time.time()*1000)
|
524 |
+
if run_task_time == 0:
|
525 |
+
time_cost_str = 'start'
|
526 |
+
else:
|
527 |
+
if time_cost_str != '':
|
528 |
+
time_cost_str += f'-->'
|
529 |
+
time_cost_str += f'{now_time - run_task_time}'
|
530 |
+
run_task_time = now_time
|
531 |
+
return run_task_time, time_cost_str
|
532 |
+
|
533 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
534 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
|
535 |
+
|
536 |
+
run_task_time = 0
|
537 |
+
time_cost_str = ''
|
538 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
539 |
+
|
540 |
if (task_type == 'Kosmos-2'):
|
541 |
global kosmos_model, kosmos_processor
|
542 |
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(input_image, kosmos_input, kosmos_model, kosmos_processor)
|
543 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
544 |
+
return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
|
545 |
|
546 |
if (task_type == 'relate anything'):
|
547 |
output_images = relate_anything(input_image['image'], num_relation)
|
548 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
549 |
+
return output_images, gr.Gallery.update(label='relate images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
550 |
|
551 |
text_prompt = text_prompt.strip()
|
552 |
if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
553 |
if text_prompt == '':
|
554 |
+
return [], gr.Gallery.update(label='Detection prompt is not found!ππππ'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
555 |
|
556 |
if input_image is None:
|
557 |
+
return [], gr.Gallery.update(label='Please upload a image!ππππ'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
558 |
|
559 |
file_temp = int(time.time())
|
560 |
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
|
|
|
570 |
image_pil, image = load_image(input_image['image'].convert("RGB"))
|
571 |
input_img = input_image['image']
|
572 |
output_images.append(input_image['image'])
|
573 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
574 |
else:
|
575 |
image_pil, image = load_image(input_image.convert("RGB"))
|
576 |
input_img = input_image
|
577 |
output_images.append(input_image)
|
578 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
579 |
|
580 |
size = image_pil.size
|
581 |
|
|
|
596 |
)
|
597 |
if boxes_filt.size(0) == 0:
|
598 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
|
599 |
+
return [], gr.Gallery.update(label='No objects detected, please try others.ππππ'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
600 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
601 |
|
602 |
pred_dict = {
|
|
|
607 |
|
608 |
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
|
609 |
output_images.append(image_with_box)
|
610 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
611 |
|
612 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
613 |
if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
|
|
643 |
plt.savefig(image_path, bbox_inches="tight")
|
644 |
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
645 |
os.remove(image_path)
|
646 |
+
output_images.append(segment_image_result)
|
647 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
648 |
|
649 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
|
650 |
if task_type == 'detection' or task_type == 'segment':
|
651 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
652 |
+
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
653 |
elif task_type == 'inpainting' or task_type == 'remove':
|
654 |
if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
|
655 |
task_type = 'remove'
|
|
|
666 |
mask = masks[0][0].cpu().numpy()
|
667 |
mask_pil = Image.fromarray(mask)
|
668 |
output_images.append(mask_pil.convert("RGB"))
|
669 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
670 |
|
671 |
if task_type == 'inpainting':
|
672 |
# inpainting pipeline
|
|
|
705 |
extend_pixels=remove_mask_extend, useRectangle=useRectangle)
|
706 |
mask_imgs.append(mask_pil_exp)
|
707 |
mask_pil = mix_masks(mask_imgs)
|
708 |
+
output_images.append(mask_pil.convert("RGB"))
|
709 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
710 |
|
711 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
|
712 |
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
|
713 |
# output_images.append(image_inpainting)
|
714 |
+
# run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
715 |
|
716 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
|
717 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
718 |
output_images.append(image_inpainting)
|
719 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
720 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
721 |
+
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
722 |
else:
|
723 |
logger.info(f"task_type:{task_type} error!")
|
724 |
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
725 |
+
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
726 |
|
727 |
def change_radio_display(task_type, mask_source_radio):
|
728 |
text_prompt_visible = True
|
|
|
854 |
|
855 |
with gr.Column():
|
856 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
|
857 |
+
).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
|
858 |
+
time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
|
859 |
+
|
860 |
kosmos_output = gr.Image(type="pil", label="result images", visible=False)
|
861 |
kosmos_text_output = gr.HighlightedText(
|
862 |
label="Generated Description",
|
|
|
888 |
run_button.click(fn=run_anything_task, inputs=[
|
889 |
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
890 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
|
891 |
+
outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
|
892 |
|
893 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
|
894 |
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
|
kosmos_utils.py
CHANGED
@@ -230,4 +230,4 @@ def kosmos_generate_predictions(image_input, text_input, kosmos_model, kosmos_pr
|
|
230 |
if end < len(processed_text):
|
231 |
colored_text.append((processed_text[end:len(processed_text)], None))
|
232 |
|
233 |
-
return annotated_image, colored_text, str(filtered_entities)
|
|
|
230 |
if end < len(processed_text):
|
231 |
colored_text.append((processed_text[end:len(processed_text)], None))
|
232 |
|
233 |
+
return annotated_image, colored_text, str(filtered_entities)
|
requirements.txt
CHANGED
@@ -17,7 +17,7 @@ termcolor
|
|
17 |
timm
|
18 |
torch
|
19 |
torchvision
|
20 |
-
transformers
|
21 |
yapf
|
22 |
numba
|
23 |
scipy
|
|
|
17 |
timm
|
18 |
torch
|
19 |
torchvision
|
20 |
+
transformers==4.27.4
|
21 |
yapf
|
22 |
numba
|
23 |
scipy
|