import gradio as gr import numpy as np import torch import os from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry from PIL import ImageDraw from utils.tools import box_prompt, format_results, point_prompt from utils.tools_gradio import fast_process # Most of our demo code is from [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM). Huge thanks for AN-619. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the pre-trained model sam_checkpoint = "./mobile_sam.pt" model_type = "vit_t" mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) mobile_sam = mobile_sam.to(device=device) mobile_sam.eval() mask_generator = SamAutomaticMaskGenerator(mobile_sam) predictor = SamPredictor(mobile_sam) # Description title = "
Faster Segment Anything(MobileSAM)
" description_e = """This is a demo of [Faster Segment Anything(MobileSAM) Model](https://github.com/ChaoningZhang/MobileSAM). We will provide box mode soon. Enjoy! """ description_p = """ # Instructions for points mode You can use your own image or sample images. """ examples = [ ["assets/picture1.jpg"], ["assets/picture2.jpg"], ] default_example = examples[0] css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" @torch.no_grad() def segment_everything( image, input_size=1024, better_quality=False, withContours=True, use_retina=True, mask_random_color=True, ): global mask_generator input_size = int(input_size) w, h = image.size scale = input_size / max(w, h) new_w = int(w * scale) new_h = int(h * scale) image = image.resize((new_w, new_h)) nd_image = np.array(image) annotations = mask_generator.generate(nd_image) fig = fast_process( annotations=annotations, image=image, device=device, scale=(1024 // input_size), better_quality=better_quality, mask_random_color=mask_random_color, bbox=None, use_retina=use_retina, withContours=withContours, ) return fig def segment_with_points( image, input_size=1024, better_quality=False, withContours=True, use_retina=True, mask_random_color=True, ): global global_points global global_point_label input_size = int(input_size) w, h = image.size scale = input_size / max(w, h) new_w = int(w * scale) new_h = int(h * scale) image = image.resize((new_w, new_h)) scaled_points = np.array([[int(x * scale) for x in point] for point in global_points]) global_point_label = np.array(global_point_label) nd_image = np.array(image) predictor.set_image(nd_image) masks, scores, logits = predictor.predict( point_coords=scaled_points, point_labels=global_point_label, multimask_output=True, ) results = format_results(masks, scores, logits, 0) annotations, _ = point_prompt( results, scaled_points, global_point_label, new_h, new_w ) annotations = np.array([annotations]) fig = fast_process( annotations=annotations, image=image, device=device, scale=(1024 // input_size), better_quality=better_quality, mask_random_color=mask_random_color, bbox=None, use_retina=use_retina, withContours=withContours, ) global_points = [] global_point_label = [] # return fig, None return fig, image def get_points_with_draw(image, label, evt: gr.SelectData): global global_points global global_point_label x, y = evt.index[0], evt.index[1] point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else ( 255, 0, 255, ) global_points.append([x, y]) global_point_label.append(1 if label == "Add Mask" else 0) print(x, y, label == "Add Mask") # 创建一个可以在图像上绘图的对象 draw = ImageDraw.Draw(image) draw.ellipse( [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color, ) return image cond_img_e = gr.Image(label="Input", value=default_example[0], type="pil") cond_img_p = gr.Image(label="Input with points", value=default_example[0], type="pil") segm_img_e = gr.Image(label="Segmented Image", interactive=False, type="pil") segm_img_p = gr.Image( label="Segmented Image with points", interactive=False, type="pil" ) global_points = [] global_point_label = [] input_size_slider = gr.components.Slider( minimum=512, maximum=1024, value=1024, step=64, label="Input_size", info="Our model was trained on a size of 1024", ) with gr.Blocks(css=css, title="Faster Segment Anything(MobileSAM)") as demo: with gr.Row(): with gr.Column(scale=1): # Title gr.Markdown(title) # with gr.Tab("Everything mode"): # # Images # with gr.Row(variant="panel"): # with gr.Column(scale=1): # cond_img_e.render() # # with gr.Column(scale=1): # segm_img_e.render() # # # Submit & Clear # with gr.Row(): # with gr.Column(): # input_size_slider.render() # # with gr.Row(): # contour_check = gr.Checkbox( # value=True, # label="withContours", # info="draw the edges of the masks", # ) # # with gr.Column(): # segment_btn_e = gr.Button( # "Segment Everything", variant="primary" # ) # clear_btn_e = gr.Button("Clear", variant="secondary") # # gr.Markdown("Try some of the examples below ⬇️") # gr.Examples( # examples=examples, # inputs=[cond_img_e], # outputs=segm_img_e, # fn=segment_everything, # cache_examples=True, # examples_per_page=4, # ) # # with gr.Column(): # with gr.Accordion("Advanced options", open=False): # # text_box = gr.Textbox(label="text prompt") # with gr.Row(): # mor_check = gr.Checkbox( # value=False, # label="better_visual_quality", # info="better quality using morphologyEx", # ) # with gr.Column(): # retina_check = gr.Checkbox( # value=True, # label="use_retina", # info="draw high-resolution segmentation masks", # ) # # Description # gr.Markdown(description_e) # with gr.Tab("Points mode"): # Images with gr.Row(variant="panel"): with gr.Column(scale=1): cond_img_p.render() with gr.Column(scale=1): segm_img_p.render() # Submit & Clear with gr.Row(): with gr.Column(): with gr.Row(): add_or_remove = gr.Radio( ["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)", ) with gr.Column(): segment_btn_p = gr.Button( "Segment with points prompt", variant="primary" ) clear_btn_p = gr.Button("Clear points", variant="secondary") gr.Markdown("Try some of the examples below ⬇️") gr.Examples( examples=examples, inputs=[cond_img_p], # outputs=segm_img_p, # fn=segment_with_points, # cache_examples=True, examples_per_page=4, ) with gr.Column(): # Description gr.Markdown(description_p) cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p) # segment_btn_e.click( # segment_everything, # inputs=[ # cond_img_e, # input_size_slider, # mor_check, # contour_check, # retina_check, # ], # outputs=segm_img_e, # ) segment_btn_p.click( segment_with_points, inputs=[cond_img_p], outputs=[segm_img_p, cond_img_p] ) def clear(): return None, None def clear_text(): return None, None, None # clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e]) clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p]) demo.queue() demo.launch()