File size: 3,127 Bytes
82d5d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcfa8cd
82d5d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from argparse import ArgumentParser, Namespace
from typing import Dict, List, Tuple
import codecs
import yaml
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, normalize, resize
import gradio as gr
from utils import get_network, colourise_mask
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# state_dict: dict = torch.hub.load_state_dict_from_url(
#     "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
#     map_location=device  # "cuda" if torch.cuda.is_available() else "cpu"
# )["model"]

parser = ArgumentParser("NamedMask demo")
parser.add_argument(
    "--config",
    type=str,
    default="voc_val_n500_cp2_ex.yaml"
)

args: Namespace = parser.parse_args()
base_args = yaml.safe_load(open(f"{args.config}", 'r'))
base_args.pop("dataset_name")
args: dict = vars(args)
args.update(base_args)
args: Namespace = Namespace(**args)

model = get_network().to(device)
# model.load_state_dict(state_dict)
model.eval()

size: int = 384
max_size: int = 512
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
std: Tuple[float, float, float] = (0.229, 0.224, 0.225)


@torch.no_grad()
def main(image: Image):
    pil_image: Image.Image = resize(image, size=size, max_size=max_size)
    image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std))  # 3 x H x W

    # logits: b (=1) x n_categories x H x W, torch.float32
    logits: torch.Tensor = model(image[None].to(device))

    # pred: H x W
    pred: torch.Tensor = logits.squeeze(dim=0).argmax(dim=0).cpu().numpy()
    coloured_pred: np.ndarray = colourise_mask(mask=pred)
    super_imposed_img = cv2.addWeighted(coloured_pred, 0.5, np.array(pil_image), 0.5, 0)

    # resize prediction to original resolution
    # note: upsampling by 4 and cutting the padded region allows for a better result
    # H, W = image.shape[-2:]
    #
    # # iterate over batch dimension
    # pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
    #
    # pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
    #
    # attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
    # super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
    return super_imposed_img


demo = gr.Interface(
    fn=main,
    inputs=gr.inputs.Image(type="pil", source="upload", tool="editor"),
    outputs=gr.outputs.Image(type="numpy", label="prediction"),  # "image",
    examples=[f"images/{fname}.jpg" for fname in [
        "2007_002260",
        "2008_002536",
        "2008_003499",
        "2008_007814",
        "2010_001079",
        "2010_005063"
    ]],
    examples_per_page=10,
    description=codecs.open("description.html", 'r', "utf-8").read(),
    title="NamedMask: Distilling Segmenters from Complementary Foundation Models",
    allow_flagging="never",
    analytics_enabled=False
)

demo.launch(
    # share=True
)