namedmask / app.py
noelshin's picture
change sdk_version
41bfd36
from argparse import ArgumentParser, Namespace
from typing import Dict, List, Tuple
import codecs
import yaml
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, normalize, resize
import gradio as gr
from utils import get_network, colourise_mask
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# state_dict: dict = torch.hub.load_state_dict_from_url(
# "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
# map_location=device # "cuda" if torch.cuda.is_available() else "cpu"
# )["model"]
parser = ArgumentParser("NamedMask demo")
parser.add_argument(
"--config",
type=str,
default="voc_val_n500_cp2_ex.yaml"
)
args: Namespace = parser.parse_args()
base_args = yaml.safe_load(open(f"{args.config}", 'r'))
base_args.pop("dataset_name")
args: dict = vars(args)
args.update(base_args)
args: Namespace = Namespace(**args)
model = get_network().to(device)
# model.load_state_dict(state_dict)
model.eval()
size: int = 384
max_size: int = 512
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
@torch.no_grad()
def main(image: Image):
pil_image: Image.Image = resize(image, size=size, max_size=max_size)
image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W
# logits: b (=1) x n_categories x H x W, torch.float32
logits: torch.Tensor = model(image[None].to(device))
# pred: H x W
pred: torch.Tensor = logits.squeeze(dim=0).argmax(dim=0).cpu().numpy()
coloured_pred: np.ndarray = colourise_mask(mask=pred)
super_imposed_img = cv2.addWeighted(coloured_pred, 0.5, np.array(pil_image), 0.5, 0)
# resize prediction to original resolution
# note: upsampling by 4 and cutting the padded region allows for a better result
# H, W = image.shape[-2:]
#
# # iterate over batch dimension
# pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
#
# pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
#
# attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
# super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
return super_imposed_img
demo = gr.Interface(
fn=main,
inputs=gr.inputs.Image(type="pil", source="upload", tool="editor"),
outputs=gr.outputs.Image(type="numpy", label="prediction"), # "image",
examples=[f"images/{fname}.jpg" for fname in [
"2007_002260",
"2008_002536",
"2008_003499",
"2008_007814",
"2010_001079",
"2010_005063"
]],
examples_per_page=10,
description=codecs.open("description.html", 'r', "utf-8").read(),
title="NamedMask: Distilling Segmenters from Complementary Foundation Models",
allow_flagging="never",
analytics_enabled=False
)
demo.launch(
# share=True
)