SEAGULL / demo /mask_utils.py
Zevin2023's picture
add online demo
8fa1f84
raw
history blame
5.13 kB
import cv2
from PIL import Image
import numpy as np
import torch
import gradio as gr
class ImageSketcher(gr.Image):
"""
Code is from https://github.com/jshilong/GPT4RoI/blob/7c157b5f33914f21cfbc804fb301d3ce06324193/gpt4roi/app.py#L365
Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
"""
is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
def __init__(self, **kwargs):
super().__init__(tool='boxes', **kwargs)
def preprocess(self, x):
if x is None:
return x
if self.tool == 'boxes' and self.source in ['upload', 'webcam']:
if isinstance(x, str):
x = {'image': x, 'boxes': []}
else:
assert isinstance(x, dict)
assert isinstance(x['image'], str)
assert isinstance(x['boxes'], list)
x = super().preprocess(x)
return x
def process_mask_to_show(mask):
'''
Process the mask to show on the gradio.Image
'''
mask = np.array(mask > 0.1, dtype=np.uint8) * 255
mask_stacked = np.stack([mask] * 3, axis=-1)
return mask_stacked
def img_add_masks(img_, colored_mask, mask, linewidth=2):
if type(img_) is np.ndarray:
img = Image.fromarray(img_, mode='RGB').convert('RGBA')
else:
img = img_.copy()
h, w = img.height, img.width
# contour
temp = np.zeros((h, w, 1))
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(temp, contours, -1, (255, 255, 255), linewidth)
color = np.array([1, 1, 1, 1])
contour_mask = temp * color.reshape(1, 1, -1)
overlay_inner = Image.fromarray(colored_mask.astype(np.uint8), 'RGBA')
img.paste(overlay_inner, (0, 0), overlay_inner)
overlay_contour = Image.fromarray(contour_mask.astype(np.uint8), 'RGBA')
img.paste(overlay_contour, (0, 0), overlay_contour)
return img
def gen_colored_masks(
annotation,
random_color=False,
):
"""
Code is largely based on https://github.com/CASIA-IVA-Lab/FastSAM/blob/4d153e909f0ad9c8ecd7632566e5a24e21cf0071/utils/tools_gradio.py#L130
"""
device = annotation.device
mask_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
areas = torch.sum(annotation, dim=(1, 2))
sorted_indices = torch.argsort(areas, descending=False)
annotation = annotation[sorted_indices]
index = (annotation != 0).to(torch.long).argmax(dim=0)
if random_color:
color = torch.rand((mask_sum, 1, 1, 3)).to(device)
else:
color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
[30 / 255, 144 / 255, 255 / 255]
).to(device)
transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
visual = torch.cat([color, transparency], dim=-1)
mask_image = torch.unsqueeze(annotation, -1) * visual
mask = torch.zeros((height, weight, 4)).to(device)
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
mask[h_indices, w_indices, :] = mask_image[indices]
mask_cpu = mask.cpu().numpy()
return mask_cpu, sorted_indices
def mask_foreground(mask, trans=0.6, random_color=True):
if random_color:
color = np.concatenate([np.random.random(3) * 255, np.array([trans * 255])], axis=0)
else:
color = np.array([30, 144, 255, trans * 255])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
return mask_image
def mask_background(mask, trans=0.5):
h, w = mask.shape[-2:]
mask_image = (1 - mask.reshape(h, w, 1)) * np.array([0, 0, 0, trans * 255])
return mask_image
def mask_select_point(all_masks, output_mask_2_raw, mask_order, evt: gr.SelectData):
h, w = output_mask_2_raw.height, output_mask_2_raw.width
pointed_mask = None
for i in range(len(mask_order)):
idx = mask_order[i]
msk = all_masks[idx]
if msk[evt.index[1], evt.index[0]] == 1:
pointed_mask = msk.copy()
break
if pointed_mask is not None:
contours, hierarchy = cv2.findContours(pointed_mask.astype("uint8"), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
ret = output_mask_2_raw.copy()
temp = np.zeros((h, w, 1))
contours, _ = cv2.findContours(msk.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(temp, contours, -1, (255, 255, 255), 3)
color = np.array([1, 1, 1, 1])
contour_mask = temp * color.reshape(1, 1, -1)
colored_mask = mask_background(pointed_mask)
overlay_inner = Image.fromarray(colored_mask.astype(np.uint8), 'RGBA')
ret.paste(overlay_inner, (0, 0), overlay_inner)
overlay_contour = Image.fromarray(contour_mask.astype(np.uint8), 'RGBA')
ret.paste(overlay_contour, (0, 0), overlay_contour)
return ret, pointed_mask
else:
return output_mask_2_raw, None