|
|
|
|
|
|
|
import os |
|
os.system('pip install git+https://github.com/SysCV/transfiner.git') |
|
|
|
from matplotlib.pyplot import axis |
|
import gradio as gr |
|
import requests |
|
import numpy as np |
|
from torch import nn |
|
import requests |
|
|
|
import torch |
|
|
|
from detectron2 import model_zoo |
|
from detectron2.engine import DefaultPredictor |
|
from detectron2.config import get_cfg |
|
from detectron2.utils.visualizer import Visualizer |
|
from detectron2.data import MetadataCatalog |
|
|
|
|
|
model_name='./configs/transfiner/mask_rcnn_R_101_FPN_3x_deform.yaml' |
|
|
|
|
|
cfg = get_cfg() |
|
|
|
cfg.merge_from_file(model_name) |
|
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 |
|
cfg.VIS_PERIOD = 100 |
|
|
|
|
|
cfg.MODEL.WEIGHTS = './output_3x_transfiner_r101_deform.pth' |
|
|
|
if not torch.cuda.is_available(): |
|
cfg.MODEL.DEVICE='cpu' |
|
|
|
predictor = DefaultPredictor(cfg) |
|
|
|
|
|
def inference(image): |
|
width, height = image.size |
|
if width > 1300: |
|
ratio = float(height) / float(width) |
|
width = 1300 |
|
height = int(ratio * width) |
|
image = image.resize((width, height)) |
|
|
|
img = np.asarray(image) |
|
|
|
|
|
outputs = predictor(img) |
|
|
|
v = Visualizer(img, MetadataCatalog.get(cfg.DATASETS.TRAIN[0])) |
|
out = v.draw_instance_predictions(outputs["instances"].to("cpu")) |
|
|
|
return out.get_image() |
|
|
|
|
|
|
|
title = "Mask Transfiner [CVPR, 2022]" |
|
description = "Demo for <a target='_blank' href='https://arxiv.org/abs/2111.13673'>Mask Transfiner for High-Quality Instance Segmentation, CVPR 2022</a> based on R101-FPN. To use it, simply upload your image, or click one of the examples to load them. Note that it runs in <b>CPU environment</b> provided by Hugging Face so the processing speed may be slow." |
|
article = "<p style='text-align: center'><a target='_blank' href='https://arxiv.org/abs/2111.13673'>Mask Transfiner for High-Quality Instance Segmentation, CVPR 2022</a> | <a target='_blank' href='https://github.com/SysCV/transfiner'>Mask Transfiner Github Code</a></p>" |
|
|
|
gr.Interface( |
|
inference, |
|
[gr.inputs.Image(type="pil", label="Input")], |
|
gr.outputs.Image(type="numpy", label="Output"), |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=[ |
|
["demo/sample_imgs/000000131444.jpg"], |
|
["demo/sample_imgs/000000157365.jpg"], |
|
["demo/sample_imgs/000000176037.jpg"], |
|
["demo/sample_imgs/000000018737.jpg"], |
|
["demo/sample_imgs/000000224200.jpg"], |
|
["demo/sample_imgs/000000558073.jpg"], |
|
["demo/sample_imgs/000000404922.jpg"], |
|
["demo/sample_imgs/000000252776.jpg"], |
|
["demo/sample_imgs/000000482477.jpg"], |
|
["demo/sample_imgs/000000344909.jpg"] |
|
]).launch() |
|
|
|
|