|
from argparse import ArgumentParser, Namespace |
|
from typing import Dict, List, Tuple |
|
import codecs |
|
import yaml |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision.transforms.functional import to_tensor, normalize, resize |
|
import gradio as gr |
|
from utils import get_network, colourise_mask |
|
import os |
|
|
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
parser = ArgumentParser("NamedMask demo") |
|
parser.add_argument( |
|
"--config", |
|
type=str, |
|
default="voc_val_n500_cp2_ex.yaml" |
|
) |
|
|
|
args: Namespace = parser.parse_args() |
|
base_args = yaml.safe_load(open(f"{args.config}", 'r')) |
|
base_args.pop("dataset_name") |
|
args: dict = vars(args) |
|
args.update(base_args) |
|
args: Namespace = Namespace(**args) |
|
|
|
model = get_network().to(device) |
|
|
|
model.eval() |
|
|
|
size: int = 384 |
|
max_size: int = 512 |
|
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406) |
|
std: Tuple[float, float, float] = (0.229, 0.224, 0.225) |
|
|
|
|
|
@torch.no_grad() |
|
def main(image: Image): |
|
pil_image: Image.Image = resize(image, size=size, max_size=max_size) |
|
image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) |
|
|
|
|
|
logits: torch.Tensor = model(image[None].to(device)) |
|
|
|
|
|
pred: torch.Tensor = logits.squeeze(dim=0).argmax(dim=0).cpu().numpy() |
|
coloured_pred: np.ndarray = colourise_mask(mask=pred) |
|
super_imposed_img = cv2.addWeighted(coloured_pred, 0.5, np.array(pil_image), 0.5, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return super_imposed_img |
|
|
|
|
|
demo = gr.Interface( |
|
fn=main, |
|
inputs=gr.inputs.Image(type="pil", source="upload", tool="editor"), |
|
outputs=gr.outputs.Image(type="numpy", label="prediction"), |
|
examples=[f"images/{fname}.jpg" for fname in [ |
|
"2007_002260", |
|
"2008_002536", |
|
"2008_003499", |
|
"2008_007814", |
|
"2010_001079", |
|
"2010_005063" |
|
]], |
|
examples_per_page=10, |
|
description=codecs.open("description.html", 'r', "utf-8").read(), |
|
title="NamedMask: Distilling Segmenters from Complementary Foundation Models", |
|
allow_flagging="never", |
|
analytics_enabled=False |
|
) |
|
|
|
demo.launch( |
|
|
|
) |