Spaces:
Runtime error
Runtime error
File size: 4,305 Bytes
35188e4 7b03ec2 35188e4 b3dac8e 35188e4 200320e 35188e4 7b03ec2 35188e4 7b03ec2 35188e4 7b03ec2 35188e4 7b03ec2 35188e4 7b03ec2 35188e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from argparse import ArgumentParser, Namespace
from typing import Dict, List, Tuple
import codecs
import yaml
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, normalize, resize
import gradio as gr
from utils import get_model
from bilateral_solver import bilateral_solver_output
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict: dict = torch.hub.load_state_dict_from_url(
"https://www.robots.ox.ac.uk/~vgg/research/selfmask/shared_files/selfmask_nq20.pt",
map_location=device # "cuda" if torch.cuda.is_available() else "cpu"
)
parser = ArgumentParser("SelfMask demo")
parser.add_argument(
"--config",
type=str,
default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml"
)
args: Namespace = parser.parse_args()
base_args = yaml.safe_load(open(f"{args.config}", 'r'))
base_args.pop("dataset_name")
args: dict = vars(args)
args.update(base_args)
args: Namespace = Namespace(**args)
model = get_model(arch="maskformer", configs=args).to(device)
model.load_state_dict(state_dict)
model.eval()
size: int = 384
max_size: int = 512
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
@torch.no_grad()
def main(image: Image):
pil_image: Image.Image = resize(image, size=size, max_size=max_size)
image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W
dict_outputs = model(image[None].to(device))
batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"] # [0, 1]
batch_objectness: torch.Tensor = dict_outputs.get("objectness", None) # [0, 1]
if len(batch_pred_masks.shape) == 5:
# b x n_layers x n_queries x h x w -> b x n_queries x h x w
batch_pred_masks = batch_pred_masks[:, -1, ...] # extract the output from the last decoder layer
if batch_objectness is not None:
# b x n_layers x n_queries x 1 -> b x n_queries x 1
batch_objectness = batch_objectness[:, -1, ...]
# resize prediction to original resolution
# note: upsampling by 4 and cutting the padded region allows for a better result
H, W = image.shape[-2:]
batch_pred_masks = F.interpolate(
batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False
)[..., :H, :W]
# iterate over batch dimension
for batch_index, pred_masks in enumerate(batch_pred_masks):
# n_queries x 1 -> n_queries
objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1)
ranks = torch.argsort(objectness, descending=True) # n_queries
pred_mask: torch.Tensor = pred_masks[ranks[0]] # H x W
pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask) # float64
pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
return super_imposed_img
# return pred_mask_bi
demo = gr.Interface(
fn=main,
inputs=gr.inputs.Image(type="pil", source="upload", tool="editor"),
outputs=gr.outputs.Image(type="numpy", label="saliency map"), # "image",
examples=[f"resources/{fname}.jpg" for fname in [
"0053",
"0236",
"0239",
"0403",
"0412",
"ILSVRC2012_test_00005309",
"ILSVRC2012_test_00012622",
"ILSVRC2012_test_00022698",
"ILSVRC2012_test_00040725",
"ILSVRC2012_test_00075738",
"ILSVRC2012_test_00080683",
"ILSVRC2012_test_00085874",
"im052",
"sun_ainjbonxmervsvpv",
"sun_alfntqzssslakmss",
"sun_amnrcxhisjfrliwa",
"sun_bvyxpvkouzlfwwod"
]],
examples_per_page=20,
description=codecs.open("description.html", 'r', "utf-8").read(),
title="Unsupervised Salient Object Detection with Spectral Cluster Voting",
allow_flagging="never",
analytics_enabled=False
)
demo.launch(
# share=True
) |