Spaces:
Runtime error
Runtime error
from argparse import ArgumentParser, Namespace | |
from typing import Dict, List, Tuple | |
import codecs | |
import yaml | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
from torchvision.transforms.functional import to_tensor, normalize, resize | |
import gradio as gr | |
from utils import get_model | |
from bilateral_solver import bilateral_solver_output | |
import os | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
state_dict: dict = torch.hub.load_state_dict_from_url( | |
"https://www.robots.ox.ac.uk/~vgg/research/selfmask/shared_files/selfmask_nq20.pt", | |
map_location=device # "cuda" if torch.cuda.is_available() else "cpu" | |
) | |
parser = ArgumentParser("SelfMask demo") | |
parser.add_argument( | |
"--config", | |
type=str, | |
default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml" | |
) | |
args: Namespace = parser.parse_args() | |
base_args = yaml.safe_load(open(f"{args.config}", 'r')) | |
base_args.pop("dataset_name") | |
args: dict = vars(args) | |
args.update(base_args) | |
args: Namespace = Namespace(**args) | |
model = get_model(arch="maskformer", configs=args).to(device) | |
model.load_state_dict(state_dict) | |
model.eval() | |
size: int = 384 | |
max_size: int = 512 | |
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406) | |
std: Tuple[float, float, float] = (0.229, 0.224, 0.225) | |
def main(image: Image): | |
pil_image: Image.Image = resize(image, size=size, max_size=max_size) | |
image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W | |
dict_outputs = model(image[None].to(device)) | |
batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"] # [0, 1] | |
batch_objectness: torch.Tensor = dict_outputs.get("objectness", None) # [0, 1] | |
if len(batch_pred_masks.shape) == 5: | |
# b x n_layers x n_queries x h x w -> b x n_queries x h x w | |
batch_pred_masks = batch_pred_masks[:, -1, ...] # extract the output from the last decoder layer | |
if batch_objectness is not None: | |
# b x n_layers x n_queries x 1 -> b x n_queries x 1 | |
batch_objectness = batch_objectness[:, -1, ...] | |
# resize prediction to original resolution | |
# note: upsampling by 4 and cutting the padded region allows for a better result | |
H, W = image.shape[-2:] | |
batch_pred_masks = F.interpolate( | |
batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False | |
)[..., :H, :W] | |
# iterate over batch dimension | |
for batch_index, pred_masks in enumerate(batch_pred_masks): | |
# n_queries x 1 -> n_queries | |
objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1) | |
ranks = torch.argsort(objectness, descending=True) # n_queries | |
pred_mask: torch.Tensor = pred_masks[ranks[0]] # H x W | |
pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255 | |
pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask) # float64 | |
pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8) | |
attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB) | |
super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0) | |
return super_imposed_img | |
# return pred_mask_bi | |
demo = gr.Interface( | |
fn=main, | |
inputs=gr.inputs.Image(type="pil", source="upload", tool="editor"), | |
outputs=gr.outputs.Image(type="numpy", label="saliency map"), # "image", | |
examples=[f"resources/{fname}.jpg" for fname in [ | |
"0053", | |
"0236", | |
"0239", | |
"0403", | |
"0412", | |
"ILSVRC2012_test_00005309", | |
"ILSVRC2012_test_00012622", | |
"ILSVRC2012_test_00022698", | |
"ILSVRC2012_test_00040725", | |
"ILSVRC2012_test_00075738", | |
"ILSVRC2012_test_00080683", | |
"ILSVRC2012_test_00085874", | |
"im052", | |
"sun_ainjbonxmervsvpv", | |
"sun_alfntqzssslakmss", | |
"sun_amnrcxhisjfrliwa", | |
"sun_bvyxpvkouzlfwwod" | |
]], | |
examples_per_page=20, | |
description=codecs.open("description.html", 'r', "utf-8").read(), | |
title="Unsupervised Salient Object Detection with Spectral Cluster Voting", | |
allow_flagging="never", | |
analytics_enabled=False | |
) | |
demo.launch( | |
# share=True | |
) |