|
|
|
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation |
|
from PIL import Image |
|
import spaces |
|
import torch |
|
from collections import defaultdict |
|
import matplotlib.pyplot as plt |
|
from matplotlib import cm |
|
import matplotlib.patches as mpatches |
|
import os |
|
import numpy as np |
|
import argparse |
|
import matplotlib |
|
import gradio as gr |
|
|
|
|
|
def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512): |
|
if type(image_path) is str: |
|
image = np.array(Image.open(image_path))[:, :, :3] |
|
else: |
|
image = image_path |
|
h, w, c = image.shape |
|
left = min(left, w-1) |
|
right = min(right, w - left - 1) |
|
top = min(top, h - left - 1) |
|
bottom = min(bottom, h - top - 1) |
|
image = image[top:h-bottom, left:w-right] |
|
h, w, c = image.shape |
|
if h < w: |
|
offset = (w - h) // 2 |
|
image = image[:, offset:offset + h] |
|
elif w < h: |
|
offset = (h - w) // 2 |
|
image = image[offset:offset + w] |
|
image = np.array(Image.fromarray(image).resize((size, size))) |
|
return image |
|
|
|
def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, noseg = False, model =None): |
|
if torch.max(segmentation)==torch.min(segmentation)==-1: |
|
print("nothing is detected!") |
|
noseg=True |
|
viridis = matplotlib.colormaps['viridis'].resampled(1) |
|
else: |
|
viridis = matplotlib.colormaps['viridis'].resampled(torch.max(segmentation)-torch.min(segmentation)+1) |
|
fig, ax = plt.subplots() |
|
ax.imshow(segmentation) |
|
instances_counter = defaultdict(int) |
|
handles = [] |
|
label_list = [] |
|
|
|
mask_np_list = [] |
|
|
|
if not noseg: |
|
if torch.min(segmentation) == 0: |
|
mask = segmentation==0 |
|
mask = mask.cpu().detach().numpy() |
|
print(mask.shape) |
|
segment_label = "rest" |
|
color = viridis(0) |
|
label = f"{segment_label}-{0}" |
|
mask_np_list.append(mask) |
|
handles.append(mpatches.Patch(color=color, label=label)) |
|
label_list.append(label) |
|
|
|
for segment in segments_info: |
|
segment_id = segment['id'] |
|
mask = segmentation==segment_id |
|
if torch.min(segmentation) != 0: |
|
segment_id -= 1 |
|
mask = mask.cpu().detach().numpy() |
|
print(mask.shape) |
|
mask_np_list.append(mask) |
|
segment_label = model.config.id2label[segment['label_id']] |
|
instances_counter[segment['label_id']] += 1 |
|
|
|
color = viridis(segment_id) |
|
|
|
label = f"{segment_label}-{segment_id}" |
|
handles.append(mpatches.Patch(color=color, label=label)) |
|
label_list.append(label) |
|
else: |
|
mask = np.full(segmentation.shape, True) |
|
print(mask.shape) |
|
|
|
segment_label = "all" |
|
mask_np_list.append(mask) |
|
color = viridis(0) |
|
label = f"{segment_label}-{0}" |
|
handles.append(mpatches.Patch(color=color, label=label)) |
|
label_list.append(label) |
|
|
|
plt.xticks([]) |
|
plt.yticks([]) |
|
|
|
ax.legend(handles=handles) |
|
plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 ) |
|
print("; ".join(label_list)) |
|
return mask_np_list,label_list |
|
|
|
|
|
@spaces.GPU(duration=10) |
|
def run_segmentation(image, name="example_tmp", size = 512, noseg=False): |
|
|
|
base_folder_path = "." |
|
|
|
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") |
|
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image =Image.fromarray(image) |
|
image = image.resize((size, size)) |
|
os.makedirs(name, exist_ok=True) |
|
|
|
inputs = processor(image, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] |
|
save_folder = os.path.join(base_folder_path, name) |
|
os.makedirs(save_folder, exist_ok=True) |
|
mask_list,label_list = draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model) |
|
print("Finish segment") |
|
|
|
return image,mask_list,label_list |
|
|