yizhangliu
commited on
Commit
Β·
1ef3fca
1
Parent(s):
26b428d
update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ warnings.filterwarnings('ignore')
|
|
4 |
|
5 |
import subprocess, io, os, sys, time
|
6 |
|
7 |
-
os.system("pip install gradio==3.50.2")
|
8 |
|
9 |
import gradio as gr
|
10 |
from loguru import logger
|
@@ -123,6 +123,8 @@ ram_model = None
|
|
123 |
kosmos_model = None
|
124 |
kosmos_processor = None
|
125 |
|
|
|
|
|
126 |
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
127 |
args = SLConfig.fromfile(model_config_path)
|
128 |
model = build_model(args)
|
@@ -593,6 +595,17 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
593 |
run_task_time = 0
|
594 |
time_cost_str = ''
|
595 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
596 |
|
597 |
if (task_type == 'Kosmos-2'):
|
598 |
global kosmos_model, kosmos_processor
|
@@ -605,20 +618,20 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
605 |
|
606 |
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_input, kosmos_model, kosmos_processor)
|
607 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
608 |
-
return None, None, time_cost_str, kosmos_image, gr.
|
609 |
|
610 |
if (task_type == 'relate anything'):
|
611 |
output_images = relate_anything(input_image['image'], num_relation)
|
612 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
613 |
-
return output_images, gr.
|
614 |
|
615 |
text_prompt = text_prompt.strip()
|
616 |
if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
617 |
if text_prompt == '':
|
618 |
-
return [], gr.
|
619 |
|
620 |
if input_image is None:
|
621 |
-
return [], gr.
|
622 |
|
623 |
file_temp = int(time.time())
|
624 |
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
|
@@ -661,7 +674,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
661 |
)
|
662 |
if boxes_filt.size(0) == 0:
|
663 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
|
664 |
-
return [], gr.
|
665 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
666 |
|
667 |
pred_dict = {
|
@@ -726,7 +739,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
726 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
|
727 |
if task_type == 'detection' or task_type == 'segment':
|
728 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
729 |
-
return output_images, gr.
|
730 |
elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
|
731 |
if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
|
732 |
task_type = 'remove'
|
@@ -804,11 +817,11 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
804 |
output_images.append(image_inpainting)
|
805 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
806 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
807 |
-
return output_images, gr.
|
808 |
else:
|
809 |
logger.info(f"task_type:{task_type} error!")
|
810 |
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
811 |
-
return output_images, gr.
|
812 |
|
813 |
def change_radio_display(task_type, mask_source_radio):
|
814 |
text_prompt_visible = True
|
@@ -839,14 +852,14 @@ def change_radio_display(task_type, mask_source_radio):
|
|
839 |
text_prompt_visible = False
|
840 |
num_relation_visible = True
|
841 |
|
842 |
-
return (gr.
|
843 |
-
gr.
|
844 |
-
gr.
|
845 |
-
gr.
|
846 |
-
gr.
|
847 |
-
gr.
|
848 |
-
gr.
|
849 |
-
gr.
|
850 |
|
851 |
def get_model_device(module):
|
852 |
try:
|
@@ -883,9 +896,12 @@ def main_gradio(args):
|
|
883 |
task_types.append("relate anything")
|
884 |
if kosmos_enable:
|
885 |
task_types.append("Kosmos-2")
|
886 |
-
|
887 |
-
input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload",
|
888 |
-
|
|
|
|
|
|
|
889 |
task_type = gr.Radio(task_types, value="detection",
|
890 |
label='Task type', visible=True)
|
891 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
@@ -897,7 +913,7 @@ def main_gradio(args):
|
|
897 |
|
898 |
kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
|
899 |
|
900 |
-
run_button = gr.Button(
|
901 |
with gr.Accordion("Advanced options", open=False) as advanced_options:
|
902 |
box_threshold = gr.Slider(
|
903 |
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
|
@@ -917,7 +933,7 @@ def main_gradio(args):
|
|
917 |
|
918 |
with gr.Column():
|
919 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
|
920 |
-
)
|
921 |
time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
|
922 |
|
923 |
kosmos_output = gr.Image(type="pil", label="result images", visible=False)
|
@@ -926,9 +942,9 @@ def main_gradio(args):
|
|
926 |
combine_adjacent=False,
|
927 |
show_legend=True,
|
928 |
visible=False,
|
929 |
-
).style(color_map=color_map)
|
930 |
# record which text span (label) is selected
|
931 |
-
selected = gr.Number(-1, show_label=False,
|
932 |
|
933 |
# record the current `entities`
|
934 |
entity_output = gr.Textbox(visible=False)
|
|
|
4 |
|
5 |
import subprocess, io, os, sys, time
|
6 |
|
7 |
+
# os.system("pip install gradio==3.50.2")
|
8 |
|
9 |
import gradio as gr
|
10 |
from loguru import logger
|
|
|
123 |
kosmos_model = None
|
124 |
kosmos_processor = None
|
125 |
|
126 |
+
brush_color = "#00FFFF"
|
127 |
+
|
128 |
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
129 |
args = SLConfig.fromfile(model_config_path)
|
130 |
model = build_model(args)
|
|
|
595 |
run_task_time = 0
|
596 |
time_cost_str = ''
|
597 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
598 |
+
|
599 |
+
logger.info(f"input_image==={input_image}")
|
600 |
+
if 'background' in input_image.keys():
|
601 |
+
input_image['image'] = input_image['background']
|
602 |
+
if len(input_image['layers']) > 0:
|
603 |
+
# input_image['mask'] = input_image['layers'][0] #brush_color
|
604 |
+
img_arr = np.array(input_image['layers'][0].convert("L"))
|
605 |
+
logger.info(f"img_arr==={img_arr.shape}, {img_arr[760][640]}, {img_arr[0][0]}")
|
606 |
+
img_arr = np.where(img_arr > 0, 1, img_arr)
|
607 |
+
# img_arr = 1 - img_arr
|
608 |
+
input_image['mask'] = Image.fromarray(255*img_arr.astype('uint8'))
|
609 |
|
610 |
if (task_type == 'Kosmos-2'):
|
611 |
global kosmos_model, kosmos_processor
|
|
|
618 |
|
619 |
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_input, kosmos_model, kosmos_processor)
|
620 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
621 |
+
return None, None, time_cost_str, kosmos_image, gr.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
|
622 |
|
623 |
if (task_type == 'relate anything'):
|
624 |
output_images = relate_anything(input_image['image'], num_relation)
|
625 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
626 |
+
return output_images, gr.update(label='relate images'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
|
627 |
|
628 |
text_prompt = text_prompt.strip()
|
629 |
if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
630 |
if text_prompt == '':
|
631 |
+
return [], gr.update(label='Detection prompt is not found!ππππ'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
|
632 |
|
633 |
if input_image is None:
|
634 |
+
return [], gr.update(label='Please upload a image!ππππ'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
|
635 |
|
636 |
file_temp = int(time.time())
|
637 |
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
|
|
|
674 |
)
|
675 |
if boxes_filt.size(0) == 0:
|
676 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
|
677 |
+
return [], gr.update(label='No objects detected, please try others.ππππ'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
|
678 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
679 |
|
680 |
pred_dict = {
|
|
|
739 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
|
740 |
if task_type == 'detection' or task_type == 'segment':
|
741 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
742 |
+
return output_images, gr.update(label='result images'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
|
743 |
elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
|
744 |
if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
|
745 |
task_type = 'remove'
|
|
|
817 |
output_images.append(image_inpainting)
|
818 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
819 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
820 |
+
return output_images, gr.update(label='result images'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
|
821 |
else:
|
822 |
logger.info(f"task_type:{task_type} error!")
|
823 |
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
824 |
+
return output_images, gr.update(label='result images'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
|
825 |
|
826 |
def change_radio_display(task_type, mask_source_radio):
|
827 |
text_prompt_visible = True
|
|
|
852 |
text_prompt_visible = False
|
853 |
num_relation_visible = True
|
854 |
|
855 |
+
return (gr.update(visible=text_prompt_visible),
|
856 |
+
gr.update(visible=inpaint_prompt_visible),
|
857 |
+
gr.update(visible=mask_source_radio_visible),
|
858 |
+
gr.update(visible=num_relation_visible),
|
859 |
+
gr.update(visible=image_gallery_visible),
|
860 |
+
gr.update(visible=kosmos_input_visible),
|
861 |
+
gr.update(visible=kosmos_output_visible),
|
862 |
+
gr.update(visible=kosmos_text_output_visible))
|
863 |
|
864 |
def get_model_device(module):
|
865 |
try:
|
|
|
896 |
task_types.append("relate anything")
|
897 |
if kosmos_enable:
|
898 |
task_types.append("Kosmos-2")
|
899 |
+
|
900 |
+
# input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload",
|
901 |
+
# height=512, brush_color='#00FFFF', mask_opacity=0.6)
|
902 |
+
|
903 |
+
input_image = gr.ImageMask(sources='upload', elem_id="image_upload", type='pil', label="Upload",
|
904 |
+
brush=gr.Brush(colors=[brush_color], color_mode="fixed"))
|
905 |
task_type = gr.Radio(task_types, value="detection",
|
906 |
label='Task type', visible=True)
|
907 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
|
|
913 |
|
914 |
kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
|
915 |
|
916 |
+
run_button = gr.Button(value="Run", visible=True)
|
917 |
with gr.Accordion("Advanced options", open=False) as advanced_options:
|
918 |
box_threshold = gr.Slider(
|
919 |
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
|
|
|
933 |
|
934 |
with gr.Column():
|
935 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
|
936 |
+
) #.style(preview=True, columns=[5], object_fit="scale-down", height="auto")
|
937 |
time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
|
938 |
|
939 |
kosmos_output = gr.Image(type="pil", label="result images", visible=False)
|
|
|
942 |
combine_adjacent=False,
|
943 |
show_legend=True,
|
944 |
visible=False,
|
945 |
+
) # .style(color_map=color_map)
|
946 |
# record which text span (label) is selected
|
947 |
+
selected = gr.Number(-1, show_label=False, visible=False)
|
948 |
|
949 |
# record the current `entities`
|
950 |
entity_output = gr.Textbox(visible=False)
|