Spaces:
Runtime error
Runtime error
File size: 4,617 Bytes
68a69f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import torch
import random
import numpy as np
from PIL import Image
from collections import defaultdict
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from color_palette import ade_palette
from transformers import MaskFormerImageProcessor, Mask2FormerForUniversalSegmentation
def load_model_and_processor(model_ckpt: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
model.eval()
image_preprocessor = MaskFormerImageProcessor.from_pretrained(model_ckpt)
return model, image_preprocessor
def load_default_ckpt(segmentation_task: str):
if segmentation_task == "semantic":
default_pretrained_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
elif segmentation_task == "instance":
default_pretrained_ckpt = "facebook/mask2former-swin-small-coco-instance"
else:
default_pretrained_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
return default_pretrained_ckpt
def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
for res in seg_info:
res['category_id'] = res.pop('label_id')
pred_class = res['category_id']
isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
res['isthing'] = bool(isthing)
visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
out = visualizer.draw_panoptic_seg_predictions(
predicted_segmentation_map.cpu(), seg_info, alpha=0.5
)
output_img = Image.fromarray(out.get_image())
return output_img
def draw_semantic_segmentation(segmentation_map, image, palette):
color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
for label, color in enumerate(palette):
color_segmentation_map[segmentation_map - 1 == label, :] = color
# Convert to BGR
ground_truth_color_seg = color_segmentation_map[..., ::-1]
img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5
img = img.astype(np.uint8)
return img
def visualize_instance_seg_mask(mask):
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
for i in range(image.shape[0]):
for j in range(image.shape[1]):
image[i, j, :] = label2color[mask[i, j]]
image = image / 255
return image
def predict_masks(input_img_path: str, segmentation_task: str):
#load model and image processor
default_pretrained_ckpt = load_default_ckpt(segmentation_task)
model, image_processor = load_model_and_processor(default_pretrained_ckpt)
## pass input image through image processor
image = Image.open(input_img_path)
inputs = image_processor(images=image, return_tensors="pt")
## pass inputs to model for prediction
with torch.no_grad():
outputs = model(**inputs)
# pass outputs to processor for postprocessing
if segmentation_task == "semantic":
result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_segmentation_map = result.cpu().numpy()
palette = ade_palette()
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
elif segmentation_task == "instance":
pass
# result = image_processor.post_process_segmentation(outputs)[0].cpu().detach()
# predicted_segmentation_map = result["segmentation"]
# # predicted_segmentation_map = torch.argmax(result, dim=0).numpy()
# # results = torch.argmax(predicted_segmentation_map, dim=0).numpy()
# print("predicted_segmentation_map:",predicted_segmentation_map)
# print("type predicted_segmentation_map:", type(predicted_segmentation_map))
# output_result = visualize_instance_seg_mask(predicted_segmentation_map)
# # mask = plot_semantic_map(predicted_segmentation_map, image)
else:
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_segmentation_map = result["segmentation"]
seg_info = result['segments_info']
output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
return output_result
|