File size: 3,127 Bytes
82d5d16 bcfa8cd 82d5d16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from argparse import ArgumentParser, Namespace
from typing import Dict, List, Tuple
import codecs
import yaml
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, normalize, resize
import gradio as gr
from utils import get_network, colourise_mask
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# state_dict: dict = torch.hub.load_state_dict_from_url(
# "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
# map_location=device # "cuda" if torch.cuda.is_available() else "cpu"
# )["model"]
parser = ArgumentParser("NamedMask demo")
parser.add_argument(
"--config",
type=str,
default="voc_val_n500_cp2_ex.yaml"
)
args: Namespace = parser.parse_args()
base_args = yaml.safe_load(open(f"{args.config}", 'r'))
base_args.pop("dataset_name")
args: dict = vars(args)
args.update(base_args)
args: Namespace = Namespace(**args)
model = get_network().to(device)
# model.load_state_dict(state_dict)
model.eval()
size: int = 384
max_size: int = 512
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
@torch.no_grad()
def main(image: Image):
pil_image: Image.Image = resize(image, size=size, max_size=max_size)
image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W
# logits: b (=1) x n_categories x H x W, torch.float32
logits: torch.Tensor = model(image[None].to(device))
# pred: H x W
pred: torch.Tensor = logits.squeeze(dim=0).argmax(dim=0).cpu().numpy()
coloured_pred: np.ndarray = colourise_mask(mask=pred)
super_imposed_img = cv2.addWeighted(coloured_pred, 0.5, np.array(pil_image), 0.5, 0)
# resize prediction to original resolution
# note: upsampling by 4 and cutting the padded region allows for a better result
# H, W = image.shape[-2:]
#
# # iterate over batch dimension
# pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
#
# pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
#
# attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
# super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
return super_imposed_img
demo = gr.Interface(
fn=main,
inputs=gr.inputs.Image(type="pil", source="upload", tool="editor"),
outputs=gr.outputs.Image(type="numpy", label="prediction"), # "image",
examples=[f"images/{fname}.jpg" for fname in [
"2007_002260",
"2008_002536",
"2008_003499",
"2008_007814",
"2010_001079",
"2010_005063"
]],
examples_per_page=10,
description=codecs.open("description.html", 'r', "utf-8").read(),
title="NamedMask: Distilling Segmenters from Complementary Foundation Models",
allow_flagging="never",
analytics_enabled=False
)
demo.launch(
# share=True
) |