File size: 1,546 Bytes
4fba7a2
f254011
 
 
 
4fba7a2
 
0547194
 
 
 
f254011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import gradio as gr
import torch
import random
import numpy as np
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation


# preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade")
# model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade")
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-large-coco")
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-large-coco")

def visualize_instance_seg_mask(mask):
    image = np.zeros((mask.shape[0], mask.shape[1], 3))
    labels = np.unique(mask)
    label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
    for i in range(image.shape[0]):
      for j in range(image.shape[1]):
        image[i, j, :] = label2color[mask[i, j]]
    image = image / 255
    return image

def query_image(img):
    target_size = (img.shape[0], img.shape[1])
    inputs = preprocessor(images=img, return_tensors="pt")
    outputs = model(**inputs)
    results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0]
    results = torch.argmax(results, dim=0).numpy()
    results = visualize_instance_seg_mask(results)
    return results


demo = gr.Interface(
    query_image,
    inputs=[gr.Image()],
    outputs="image",
    title="maskformer-swin-tiny-ade results",
    allow_flagging="never",
    analytics_enabled=None
)

demo.launch(show_api=False)