import gradio as gr
import numpy as np
from PIL import ImageDraw, Image
import torch
import torch.nn.functional as F
# mm libs
from mmdet.registry import MODELS
from mmengine import Config, print_log
from mmengine.structures import InstanceData
from ext.class_names.lvis_list import LVIS_CLASSES
LVIS_NAMES = LVIS_CLASSES
# Description
title = "
Open-Vocabulary SAM"
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
model_cfg = Config.fromfile('app/configs/sam_r50x16_fpn.py')
examples = [
["app/assets/sa_01.jpg"],
["app/assets/sa_224028.jpg"],
["app/assets/sa_227490.jpg"],
["app/assets/sa_228025.jpg"],
["app/assets/sa_234958.jpg"],
["app/assets/sa_235005.jpg"],
["app/assets/sa_235032.jpg"],
["app/assets/sa_235036.jpg"],
["app/assets/sa_235086.jpg"],
["app/assets/sa_235094.jpg"],
["app/assets/sa_235113.jpg"],
["app/assets/sa_235130.jpg"],
]
model = MODELS.build(model_cfg.model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device=device)
model = model.eval()
model.init_weights()
mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None]
std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None]
class IMGState:
def __init__(self):
self.img = None
self.img_feat = None
self.selected_points = []
self.selected_points_labels = []
self.selected_bboxes = []
self.available_to_set = True
def set_img(self, img, img_feat):
self.img = img
self.img_feat = img_feat
self.available_to_set = False
def clear(self):
self.img = None
self.img_feat = None
self.selected_points = []
self.selected_points_labels = []
self.selected_bboxes = []
self.available_to_set = True
def clean(self):
self.selected_points = []
self.selected_points_labels = []
self.selected_bboxes = []
def to_device(self, device=device):
if self.img_feat is not None:
for k in self.img_feat:
if isinstance(self.img_feat[k], torch.Tensor):
self.img_feat[k] = self.img_feat[k].to(device)
elif isinstance(self.img_feat[k], tuple):
self.img_feat[k] = tuple(v.to(device) for v in self.img_feat[k])
@property
def available(self):
return self.available_to_set
IMG_SIZE = 1024
def get_points_with_draw(image, img_state, evt: gr.SelectData):
label = 'Add Mask'
x, y = evt.index[0], evt.index[1]
print_log(f"Point: {x}_{y}", logger='current')
point_radius, point_color = 10, (97, 217, 54) if label == "Add Mask" else (237, 34, 13)
img_state.selected_points.append([x, y])
img_state.selected_points_labels.append(1 if label == "Add Mask" else 0)
draw = ImageDraw.Draw(image)
draw.ellipse(
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
fill=point_color,
)
return img_state, image
def get_bbox_with_draw(image, img_state, evt: gr.SelectData):
x, y = evt.index[0], evt.index[1]
point_radius, point_color, box_outline = 5, (237, 34, 13), 2
box_color = (237, 34, 13)
if len(img_state.selected_bboxes) in [0, 1]:
img_state.selected_bboxes.append([x, y])
elif len(img_state.selected_bboxes) == 2:
img_state.selected_bboxes = [[x, y]]
image = Image.fromarray(img_state.img)
else:
raise ValueError(f"Cannot be {len(img_state.selected_bboxes)}")
print_log(f"box_list: {img_state.selected_bboxes}", logger='current')
draw = ImageDraw.Draw(image)
draw.ellipse(
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
fill=point_color,
)
if len(img_state.selected_bboxes) == 2:
box_points = img_state.selected_bboxes
bbox = (min(box_points[0][0], box_points[1][0]),
min(box_points[0][1], box_points[1][1]),
max(box_points[0][0], box_points[1][0]),
max(box_points[0][1], box_points[1][1]),
)
draw.rectangle(
bbox,
outline=box_color,
width=box_outline
)
return img_state, image
def segment_with_points(
image,
img_state,
):
if img_state.available:
return None, None, "State Error, please try again."
output_img = img_state.img
h, w = output_img.shape[:2]
input_points = torch.tensor(img_state.selected_points, dtype=torch.float32, device=device)
prompts = InstanceData(
point_coords=input_points[None],
)
try:
img_state.to_device()
masks, cls_pred = model.extract_masks(img_state.img_feat, prompts)
img_state.to_device('cpu')
masks = masks[0, 0, :h, :w]
masks = masks > 0.5
cls_pred = cls_pred[0][0]
scores, indices = torch.topk(cls_pred, 1)
scores, indices = scores.tolist(), indices.tolist()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
img_state.clear()
print_log(f"CUDA OOM! please try again later", logger='current')
return None, None, "CUDA OOM, please try again later."
else:
raise
names = []
for ind in indices:
names.append(LVIS_NAMES[ind].replace('_', ' '))
cls_info = ""
for name, score in zip(names, scores):
cls_info += "{} ({:.2f})".format(name, score)
rgb_shape = tuple(list(masks.shape) + [3])
color = np.zeros(rgb_shape, dtype=np.uint8)
color[masks] = np.array([97, 217, 54])
# color[masks] = np.array([217, 90, 54])
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
output_img = Image.fromarray(output_img)
return image, output_img, cls_info
def segment_with_bbox(
image,
img_state
):
if img_state.available:
return None, None, "State Error, please try again."
if len(img_state.selected_bboxes) != 2:
return image, None, ""
output_img = img_state.img
h, w = output_img.shape[:2]
box_points = img_state.selected_bboxes
bbox = (
min(box_points[0][0], box_points[1][0]),
min(box_points[0][1], box_points[1][1]),
max(box_points[0][0], box_points[1][0]),
max(box_points[0][1], box_points[1][1]),
)
input_bbox = torch.tensor(bbox, dtype=torch.float32, device=device)
prompts = InstanceData(
bboxes=input_bbox[None],
)
try:
img_state.to_device()
masks, cls_pred = model.extract_masks(img_state.img_feat, prompts)
img_state.to_device('cpu')
masks = masks[0, 0, :h, :w]
masks = masks > 0.5
cls_pred = cls_pred[0][0]
scores, indices = torch.topk(cls_pred, 1)
scores, indices = scores.tolist(), indices.tolist()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
img_state.clear()
print_log(f"CUDA OOM! please try again later", logger='current')
return None, None, "CUDA OOM, please try again later."
else:
raise
names = []
for ind in indices:
names.append(LVIS_NAMES[ind].replace('_', ' '))
cls_info = ""
for name, score in zip(names, scores):
cls_info += "{} ({:.2f})\n".format(name, score)
rgb_shape = tuple(list(masks.shape) + [3])
color = np.zeros(rgb_shape, dtype=np.uint8)
color[masks] = np.array([97, 217, 54])
# color[masks] = np.array([217, 90, 54])
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
output_img = Image.fromarray(output_img)
return image, output_img, cls_info
def extract_img_feat(img, img_state):
w, h = img.size
scale = IMG_SIZE / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
img = img.resize((new_w, new_h), resample=Image.Resampling.BILINEAR)
img_numpy = np.array(img)
print_log(f"Successfully loaded an image with size {new_w} x {new_h}", logger='current')
try:
img_tensor = torch.tensor(img_numpy, device=device, dtype=torch.float32).permute((2, 0, 1))[None]
img_tensor = (img_tensor - mean) / std
img_tensor = F.pad(img_tensor, (0, IMG_SIZE - new_w, 0, IMG_SIZE - new_h), 'constant', 0)
feat_dict = model.extract_feat(img_tensor)
img_state.set_img(img_numpy, feat_dict)
img_state.to_device('cpu')
print_log(f"Successfully generated the image feats.", logger='current')
except RuntimeError as e:
if "CUDA out of memory" in str(e):
img_state.clear()
print_log(f"CUDA OOM! please try again later", logger='current')
return None, None, "CUDA OOM, please try again later."
else:
raise
return img, None, "Please try to click something."
def clear_everything(img_state):
img_state.clear()
return img_state, None, None, "Please try to click something."
def clean_prompts(img_state):
img_state.clean()
if img_state.img is None:
img_state.clear()
return None, None, "Please try to click something."
return img_state, Image.fromarray(img_state.img), None, "Please try to click something."
def register_point_mode():
img_state_points = gr.State(value=IMGState())
img_state_bbox = gr.State(value=IMGState())
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(title)
# Point mode tab
with gr.Tab("Point mode"):
with gr.Row(variant="panel"):
with gr.Column(scale=1):
cond_img_p = gr.Image(label="Input Image", height=512, type="pil")
with gr.Column(scale=1):
segm_img_p = gr.Image(label="Segment", interactive=False, height=512, type="pil")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
clean_btn_p = gr.Button("Clean Prompts", variant="secondary")
clear_btn_p = gr.Button("Restart", variant="secondary")
with gr.Column():
cls_info = gr.Textbox("", label='Labels')
with gr.Row():
with gr.Column():
gr.Markdown("Try some of the examples below ⬇️")
gr.Examples(
examples=examples,
inputs=[cond_img_p, img_state_points],
outputs=[cond_img_p, segm_img_p, cls_info],
examples_per_page=12,
fn=extract_img_feat,
run_on_click=True,
cache_examples=False,
)
# box mode tab
with gr.Tab("Box mode"):
with gr.Row(variant="panel"):
with gr.Column(scale=1):
cond_img_bbox = gr.Image(label="Input Image", height=512, type="pil")
with gr.Column(scale=1):
segm_img_bbox = gr.Image(label="Segment", interactive=False, height=512, type="pil")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
clean_btn_bbox = gr.Button("Clean Prompts", variant="secondary")
clear_btn_bbox = gr.Button("Restart", variant="secondary")
with gr.Column():
cls_info_bbox = gr.Textbox("", label='Labels')
with gr.Row():
with gr.Column():
gr.Markdown("Try some of the examples below ⬇️")
gr.Examples(
examples=examples,
inputs=[cond_img_bbox, img_state_bbox],
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox],
examples_per_page=12,
fn=extract_img_feat,
run_on_click=True,
cache_examples=False,
)
# extract image feature
cond_img_p.upload(
extract_img_feat,
[cond_img_p, img_state_points],
outputs=[cond_img_p, segm_img_p, cls_info]
)
cond_img_bbox.upload(
extract_img_feat,
[cond_img_bbox, img_state_bbox],
outputs=[cond_img_bbox, segm_img_bbox, cls_info]
)
# get user added points
cond_img_p.select(
get_points_with_draw,
[cond_img_p, img_state_points],
outputs=[img_state_points, cond_img_p]
).then(
segment_with_points,
inputs=[cond_img_p, img_state_points],
outputs=[cond_img_p, segm_img_p, cls_info]
)
cond_img_bbox.select(
get_bbox_with_draw,
[cond_img_bbox, img_state_bbox],
outputs=[img_state_bbox, cond_img_bbox]
).then(
segment_with_bbox,
inputs=[cond_img_bbox, img_state_bbox],
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox]
)
# clean prompts
clean_btn_p.click(
clean_prompts,
inputs=[img_state_points],
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
)
clean_btn_bbox.click(
clean_prompts,
inputs=[img_state_bbox],
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
)
# clear
clear_btn_p.click(
clear_everything,
inputs=[img_state_points],
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
)
cond_img_p.clear(
clear_everything,
inputs=[img_state_points],
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
)
segm_img_p.clear(
clear_everything,
inputs=[img_state_points],
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
)
clear_btn_bbox.click(
clear_everything,
inputs=[img_state_bbox],
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
)
cond_img_bbox.clear(
clear_everything,
inputs=[img_state_bbox],
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
)
segm_img_bbox.clear(
clear_everything,
inputs=[img_state_bbox],
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
)
if __name__ == '__main__':
with gr.Blocks(css=css, title="Open-Vocabulary SAM") as demo:
register_point_mode()
demo.queue()
demo.launch(show_api=False)