Thiago Hersan
maskformer-swin-large-coco
0547194
raw
history blame
1.55 kB
import gradio as gr
import torch
import random
import numpy as np
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
# preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade")
# model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade")
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-large-coco")
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-large-coco")
def visualize_instance_seg_mask(mask):
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
for i in range(image.shape[0]):
for j in range(image.shape[1]):
image[i, j, :] = label2color[mask[i, j]]
image = image / 255
return image
def query_image(img):
target_size = (img.shape[0], img.shape[1])
inputs = preprocessor(images=img, return_tensors="pt")
outputs = model(**inputs)
results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0]
results = torch.argmax(results, dim=0).numpy()
results = visualize_instance_seg_mask(results)
return results
demo = gr.Interface(
query_image,
inputs=[gr.Image()],
outputs="image",
title="maskformer-swin-tiny-ade results",
allow_flagging="never",
analytics_enabled=None
)
demo.launch(show_api=False)