Spaces:
Runtime error
Runtime error
import os | |
import time | |
import torch | |
import numpy as np | |
import gradio as gr | |
from segment_anything import build_sam, SamAutomaticMaskGenerator | |
from segment_anything.utils.amg import ( | |
batch_iterator, | |
MaskData, | |
calculate_stability_score, | |
batched_mask_to_box, | |
is_box_near_crop_edge, | |
) | |
os.system(r'python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth') | |
hourglass_args = { | |
"baseline": {}, | |
"1.2x faster": { | |
"use_hourglass": True, | |
"hourglass_clustering_location": 14, | |
"hourglass_num_cluster": 100, | |
}, | |
"1.5x faster": { | |
"use_hourglass": True, | |
"hourglass_clustering_location": 6, | |
"hourglass_num_cluster": 81, | |
}, | |
} | |
def generate_mask(image, generator: SamAutomaticMaskGenerator): | |
generator.predictor.set_image(image) | |
image_size = image.shape[:2] | |
points_scale = np.array(image_size)[None, ::-1] | |
points_for_image = generator.point_grids[0] * points_scale | |
for (points,) in batch_iterator(generator.points_per_batch, points_for_image): | |
transformed_points = generator.predictor.transform.apply_coords(points, image_size) | |
in_points = torch.as_tensor(transformed_points, device=generator.predictor.device) | |
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) | |
masks, iou_preds, _ = generator.predictor.predict_torch( | |
in_points[:, None, :], | |
in_labels[:, None], | |
multimask_output=True, | |
return_logits=True, | |
) | |
# Serialize predictions and store in MaskData | |
data = MaskData( | |
masks=masks.flatten(0, 1), | |
iou_preds=iou_preds.flatten(0, 1), | |
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), | |
) | |
del masks | |
# Filter by predicted IoU | |
if generator.pred_iou_thresh > 0.0: | |
keep_mask = data["iou_preds"] > generator.pred_iou_thresh | |
data.filter(keep_mask) | |
# Calculate stability score | |
data["stability_score"] = calculate_stability_score( | |
data["masks"], generator.predictor.model.mask_threshold, generator.stability_score_offset | |
) | |
if generator.stability_score_thresh > 0.0: | |
keep_mask = data["stability_score"] >= generator.stability_score_thresh | |
data.filter(keep_mask) | |
# Threshold masks and calculate boxes | |
data["masks"] = data["masks"] > generator.predictor.model.mask_threshold | |
# Write mask records | |
curr_anns = [] | |
for idx in range(len(data["masks"])): | |
ann = { | |
"segmentation": data["masks"][idx].numpy(), | |
"area": data["masks"][idx].sum().item(), | |
} | |
curr_anns.append(ann) | |
return curr_anns | |
def predict(image, speed_mode, points_per_side): | |
points_per_side = int(points_per_side) | |
mask_generator = SamAutomaticMaskGenerator( | |
build_sam(checkpoint="sam_vit_h_4b8939.pth", **hourglass_args[speed_mode]), | |
points_per_side=points_per_side, | |
points_per_batch=64 if points_per_side > 12 else points_per_side * points_per_side | |
) | |
start = time.perf_counter() | |
with torch.no_grad(): | |
# masks = mask_generator.generate(image) | |
masks = generate_mask(image, mask_generator) | |
eta = time.perf_counter() - start | |
eta_text = f"Time of generation: {eta:.2f} seconds" | |
if len(masks) == 0: | |
return image | |
sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True) | |
img = np.ones(image.shape) | |
for mask in sorted_masks: | |
m = mask['segmentation'] | |
color_mask = np.random.random((1, 1, 3)) | |
img = img * (1 - m[..., None]) + color_mask * m[..., None] | |
image = ((image + img * 255) / 2).astype(np.uint8) | |
return image, eta_text | |
description = """ | |
# <center>Expedit-SAM (Expedite Segment Anything Model without any training)</center> | |
Github link: [Link](https://github.com/Expedit-LargeScale-Vision-Transformer/Expedit-SAM) | |
You can select the speed mode you want to use from the "Speed Mode" dropdown menu and click "Run" to segment the image you uploaded to the "Input Image" box. | |
""" | |
if (SPACE_ID := os.getenv('SPACE_ID')) is not None: | |
description += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>' | |
def main(): | |
with gr.Blocks() as demo: | |
gr.Markdown(description) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image") | |
points_per_side = gr.Dropdown( | |
choices=[4, 6, 8, 12, 16, 32], | |
value=12, | |
label="Points per Side", | |
) | |
speed_mode = gr.Dropdown( | |
choices=list(hourglass_args.keys()), | |
value="baseline", | |
label="Speed Mode", | |
multiselect=False, | |
) | |
with gr.Row(): | |
run_btn = gr.Button(label="Run", id="run", value="Run") | |
clear_btn = gr.Button(label="Clear", id="clear", value="Clear") | |
with gr.Column(): | |
output_image = gr.Image(label="Output Image") | |
eta_label = gr.Label(label="ETA") | |
gr.Examples( | |
examples=[ | |
["./notebooks/images/dog.jpg"], | |
["notebooks/images/groceries.jpg"], | |
["notebooks/images/truck.jpg"], | |
], | |
inputs=[input_image], | |
outputs=[output_image], | |
fn=predict, | |
) | |
run_btn.click( | |
fn=predict, | |
inputs=[input_image, speed_mode, points_per_side], | |
outputs=[output_image, eta_label] | |
) | |
clear_btn.click( | |
fn=lambda: [None, None], | |
inputs=None, | |
outputs=[input_image, output_image], | |
queue=False, | |
) | |
demo.queue() | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |