Spaces:
Running
on
T4
Running
on
T4
liuyizhang
commited on
Commit
·
b902809
1
Parent(s):
779c33a
update app.py
Browse files
app.py
CHANGED
@@ -323,10 +323,10 @@ mask_source_segment = "type what to detect below"
|
|
323 |
|
324 |
def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
325 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
|
331 |
file_temp = int(time.time())
|
332 |
logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_[{text_prompt}]_1_')
|
@@ -361,6 +361,9 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
361 |
boxes_filt, pred_phrases = get_grounding_output(
|
362 |
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
|
363 |
)
|
|
|
|
|
|
|
364 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
365 |
|
366 |
pred_dict = {
|
@@ -414,7 +417,7 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
414 |
logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_3_')
|
415 |
if task_type == 'detection' or task_type == 'segment':
|
416 |
logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_9_')
|
417 |
-
return output_images
|
418 |
elif task_type == 'inpainting' or task_type == 'remove':
|
419 |
if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
|
420 |
task_type = 'remove'
|
@@ -488,11 +491,11 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
488 |
os.remove(image_path)
|
489 |
logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_9_')
|
490 |
output_images.append(image_result)
|
491 |
-
return output_images
|
492 |
else:
|
493 |
logger.info(f"task_type:{task_type} error!")
|
494 |
logger.info(f'run_grounded_sam_[{file_temp}]_9_9_')
|
495 |
-
return output_images
|
496 |
|
497 |
def change_radio_display(task_type, mask_source_radio):
|
498 |
text_prompt_visible = True
|
@@ -524,7 +527,7 @@ if __name__ == "__main__":
|
|
524 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
525 |
value=mask_source_segment, label="Mask from",
|
526 |
interactive=True, visible=False)
|
527 |
-
text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.'
|
528 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
529 |
run_button = gr.Button(label="Run")
|
530 |
with gr.Accordion("Advanced options", open=False):
|
@@ -546,11 +549,11 @@ if __name__ == "__main__":
|
|
546 |
|
547 |
with gr.Column():
|
548 |
gallery = gr.Gallery(
|
549 |
-
label="
|
550 |
).style(grid=[2], full_width=True, full_height=True)
|
551 |
|
552 |
run_button.click(fn=run_grounded_sam, inputs=[
|
553 |
-
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend], outputs=[gallery])
|
554 |
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio])
|
555 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio])
|
556 |
|
|
|
323 |
|
324 |
def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
325 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
|
326 |
+
text_prompt = text_prompt.strip()
|
327 |
+
if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
328 |
+
if text_prompt == '':
|
329 |
+
return [], gr.Gallery.update(label='Detection prompt is not found!')
|
330 |
|
331 |
file_temp = int(time.time())
|
332 |
logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_[{text_prompt}]_1_')
|
|
|
361 |
boxes_filt, pred_phrases = get_grounding_output(
|
362 |
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
|
363 |
)
|
364 |
+
if boxes_filt.size(0) == 0:
|
365 |
+
logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_[{text_prompt}]_1_[No objects detected]_')
|
366 |
+
return [], gr.Gallery.update(label='No objects detected, please try others.')
|
367 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
368 |
|
369 |
pred_dict = {
|
|
|
417 |
logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_3_')
|
418 |
if task_type == 'detection' or task_type == 'segment':
|
419 |
logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_9_')
|
420 |
+
return output_images, gr.Gallery.update(label='result images')
|
421 |
elif task_type == 'inpainting' or task_type == 'remove':
|
422 |
if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
|
423 |
task_type = 'remove'
|
|
|
491 |
os.remove(image_path)
|
492 |
logger.info(f'run_grounded_sam_[{file_temp}]_{task_type}_9_')
|
493 |
output_images.append(image_result)
|
494 |
+
return output_images, gr.Gallery.update(label='result images')
|
495 |
else:
|
496 |
logger.info(f"task_type:{task_type} error!")
|
497 |
logger.info(f'run_grounded_sam_[{file_temp}]_9_9_')
|
498 |
+
return output_images, gr.Gallery.update(label='result images')
|
499 |
|
500 |
def change_radio_display(task_type, mask_source_radio):
|
501 |
text_prompt_visible = True
|
|
|
527 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
528 |
value=mask_source_segment, label="Mask from",
|
529 |
interactive=True, visible=False)
|
530 |
+
text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
|
531 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
532 |
run_button = gr.Button(label="Run")
|
533 |
with gr.Accordion("Advanced options", open=False):
|
|
|
549 |
|
550 |
with gr.Column():
|
551 |
gallery = gr.Gallery(
|
552 |
+
label="result images", show_label=True, elem_id="gallery"
|
553 |
).style(grid=[2], full_width=True, full_height=True)
|
554 |
|
555 |
run_button.click(fn=run_grounded_sam, inputs=[
|
556 |
+
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend], outputs=[gallery, gallery])
|
557 |
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio])
|
558 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio])
|
559 |
|