Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict, List, Optional | |
import cv2 | |
import mmcv | |
import numpy as np | |
import torch | |
from mmengine.dist import master_only | |
from mmengine.structures import PixelData | |
from mmengine.visualization import Visualizer | |
from mmseg.registry import VISUALIZERS | |
from mmseg.structures import SegDataSample | |
from mmseg.utils import get_classes, get_palette | |
class SegLocalVisualizer(Visualizer): | |
"""Local Visualizer. | |
Args: | |
name (str): Name of the instance. Defaults to 'visualizer'. | |
image (np.ndarray, optional): the origin image to draw. The format | |
should be RGB. Defaults to None. | |
vis_backends (list, optional): Visual backend config list. | |
Defaults to None. | |
save_dir (str, optional): Save file dir for all storage backends. | |
If it is None, the backend storage will not save any data. | |
classes (list, optional): Input classes for result rendering, as the | |
prediction of segmentation model is a segment map with label | |
indices, `classes` is a list which includes items responding to the | |
label indices. If classes is not defined, visualizer will take | |
`cityscapes` classes by default. Defaults to None. | |
palette (list, optional): Input palette for result rendering, which is | |
a list of color palette responding to the classes. Defaults to None. | |
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317>`_ | |
visulizer will use the meta information of the dataset i.e. classes | |
and palette, but the `classes` and `palette` have higher priority. | |
Defaults to None. | |
alpha (int, float): The transparency of segmentation mask. | |
Defaults to 0.8. | |
Examples: | |
>>> import numpy as np | |
>>> import torch | |
>>> from mmengine.structures import PixelData | |
>>> from mmseg.structures import SegDataSample | |
>>> from mmseg.visualization import SegLocalVisualizer | |
>>> seg_local_visualizer = SegLocalVisualizer() | |
>>> image = np.random.randint(0, 256, | |
... size=(10, 12, 3)).astype('uint8') | |
>>> gt_sem_seg_data = dict(data=torch.randint(0, 2, (1, 10, 12))) | |
>>> gt_sem_seg = PixelData(**gt_sem_seg_data) | |
>>> gt_seg_data_sample = SegDataSample() | |
>>> gt_seg_data_sample.gt_sem_seg = gt_sem_seg | |
>>> seg_local_visualizer.dataset_meta = dict( | |
>>> classes=('background', 'foreground'), | |
>>> palette=[[120, 120, 120], [6, 230, 230]]) | |
>>> seg_local_visualizer.add_datasample('visualizer_example', | |
... image, gt_seg_data_sample) | |
>>> seg_local_visualizer.add_datasample( | |
... 'visualizer_example', image, | |
... gt_seg_data_sample, show=True) | |
""" # noqa | |
def __init__(self, | |
name: str = 'visualizer', | |
image: Optional[np.ndarray] = None, | |
vis_backends: Optional[Dict] = None, | |
save_dir: Optional[str] = None, | |
classes: Optional[List] = None, | |
palette: Optional[List] = None, | |
dataset_name: Optional[str] = None, | |
alpha: float = 0.8, | |
**kwargs): | |
super().__init__(name, image, vis_backends, save_dir, **kwargs) | |
self.alpha: float = alpha | |
self.set_dataset_meta(palette, classes, dataset_name) | |
def _get_center_loc(self, mask: np.ndarray) -> np.ndarray: | |
"""Get semantic seg center coordinate. | |
Args: | |
mask: np.ndarray: get from sem_seg | |
""" | |
loc = np.argwhere(mask == 1) | |
loc_sort = np.array( | |
sorted(loc.tolist(), key=lambda row: (row[0], row[1]))) | |
y_list = loc_sort[:, 0] | |
unique, indices, counts = np.unique( | |
y_list, return_index=True, return_counts=True) | |
y_loc = unique[counts.argmax()] | |
y_most_freq_loc = loc[loc_sort[:, 0] == y_loc] | |
center_num = len(y_most_freq_loc) // 2 | |
x = y_most_freq_loc[center_num][1] | |
y = y_most_freq_loc[center_num][0] | |
return np.array([x, y]) | |
def _draw_sem_seg(self, | |
image: np.ndarray, | |
sem_seg: PixelData, | |
classes: Optional[List], | |
palette: Optional[List], | |
with_labels: Optional[bool] = True) -> np.ndarray: | |
"""Draw semantic seg of GT or prediction. | |
Args: | |
image (np.ndarray): The image to draw. | |
sem_seg (:obj:`PixelData`): Data structure for pixel-level | |
annotations or predictions. | |
classes (list, optional): Input classes for result rendering, as | |
the prediction of segmentation model is a segment map with | |
label indices, `classes` is a list which includes items | |
responding to the label indices. If classes is not defined, | |
visualizer will take `cityscapes` classes by default. | |
Defaults to None. | |
palette (list, optional): Input palette for result rendering, which | |
is a list of color palette responding to the classes. | |
Defaults to None. | |
with_labels(bool, optional): Add semantic labels in visualization | |
result, Default to True. | |
Returns: | |
np.ndarray: the drawn image which channel is RGB. | |
""" | |
num_classes = len(classes) | |
sem_seg = sem_seg.cpu().data | |
ids = np.unique(sem_seg)[::-1] | |
legal_indices = ids < num_classes | |
ids = ids[legal_indices] | |
labels = np.array(ids, dtype=np.int64) | |
colors = [palette[label] for label in labels] | |
mask = np.zeros_like(image, dtype=np.uint8) | |
for label, color in zip(labels, colors): | |
mask[sem_seg[0] == label, :] = color | |
if with_labels: | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
# (0,1] to change the size of the text relative to the image | |
scale = 0.05 | |
fontScale = min(image.shape[0], image.shape[1]) / (25 / scale) | |
fontColor = (255, 255, 255) | |
if image.shape[0] < 300 or image.shape[1] < 300: | |
thickness = 1 | |
rectangleThickness = 1 | |
else: | |
thickness = 2 | |
rectangleThickness = 2 | |
lineType = 2 | |
if isinstance(sem_seg[0], torch.Tensor): | |
masks = sem_seg[0].numpy() == labels[:, None, None] | |
else: | |
masks = sem_seg[0] == labels[:, None, None] | |
masks = masks.astype(np.uint8) | |
for mask_num in range(len(labels)): | |
classes_id = labels[mask_num] | |
classes_color = colors[mask_num] | |
loc = self._get_center_loc(masks[mask_num]) | |
text = classes[classes_id] | |
(label_width, label_height), baseline = cv2.getTextSize( | |
text, font, fontScale, thickness) | |
mask = cv2.rectangle(mask, loc, | |
(loc[0] + label_width + baseline, | |
loc[1] + label_height + baseline), | |
classes_color, -1) | |
mask = cv2.rectangle(mask, loc, | |
(loc[0] + label_width + baseline, | |
loc[1] + label_height + baseline), | |
(0, 0, 0), rectangleThickness) | |
mask = cv2.putText(mask, text, (loc[0], loc[1] + label_height), | |
font, fontScale, fontColor, thickness, | |
lineType) | |
color_seg = (image * (1 - self.alpha) + mask * self.alpha).astype( | |
np.uint8) | |
self.set_image(color_seg) | |
return color_seg | |
def _draw_depth_map(self, image: np.ndarray, | |
depth_map: PixelData) -> np.ndarray: | |
"""Draws a depth map on a given image. | |
This function takes an image and a depth map as input, | |
renders the depth map, and concatenates it with the original image. | |
Finally, it updates the internal image state of the visualizer with | |
the concatenated result. | |
Args: | |
image (np.ndarray): The original image where the depth map will | |
be drawn. The array should be in the format HxWx3 where H is | |
the height, W is the width. | |
depth_map (PixelData): Depth map to be drawn. The depth map | |
should be in the form of a PixelData object. It will be | |
converted to a torch tensor if it is a numpy array. | |
Returns: | |
np.ndarray: The concatenated image with the depth map drawn. | |
Example: | |
>>> depth_map_data = PixelData(data=torch.rand(1, 10, 10)) | |
>>> image = np.random.randint(0, 256, | |
>>> size=(10, 10, 3)).astype('uint8') | |
>>> visualizer = SegLocalVisualizer() | |
>>> visualizer._draw_depth_map(image, depth_map_data) | |
""" | |
depth_map = depth_map.cpu().data | |
if isinstance(depth_map, np.ndarray): | |
depth_map = torch.from_numpy(depth_map) | |
if depth_map.ndim == 2: | |
depth_map = depth_map[None] | |
depth_map = self.draw_featmap(depth_map, resize_shape=image.shape[:2]) | |
out_image = np.concatenate((image, depth_map), axis=0) | |
self.set_image(out_image) | |
return out_image | |
def set_dataset_meta(self, | |
classes: Optional[List] = None, | |
palette: Optional[List] = None, | |
dataset_name: Optional[str] = None) -> None: | |
"""Set meta information to visualizer. | |
Args: | |
classes (list, optional): Input classes for result rendering, as | |
the prediction of segmentation model is a segment map with | |
label indices, `classes` is a list which includes items | |
responding to the label indices. If classes is not defined, | |
visualizer will take `cityscapes` classes by default. | |
Defaults to None. | |
palette (list, optional): Input palette for result rendering, which | |
is a list of color palette responding to the classes. | |
Defaults to None. | |
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317>`_ | |
visulizer will use the meta information of the dataset i.e. | |
classes and palette, but the `classes` and `palette` have | |
higher priority. Defaults to None. | |
""" # noqa | |
# Set default value. When calling | |
# `SegLocalVisualizer().dataset_meta=xxx`, | |
# it will override the default value. | |
if dataset_name is None: | |
dataset_name = 'cityscapes' | |
classes = classes if classes else get_classes(dataset_name) | |
palette = palette if palette else get_palette(dataset_name) | |
assert len(classes) == len( | |
palette), 'The length of classes should be equal to palette' | |
self.dataset_meta: dict = {'classes': classes, 'palette': palette} | |
def add_datasample( | |
self, | |
name: str, | |
image: np.ndarray, | |
data_sample: Optional[SegDataSample] = None, | |
draw_gt: bool = True, | |
draw_pred: bool = True, | |
show: bool = False, | |
wait_time: float = 0, | |
# TODO: Supported in mmengine's Viusalizer. | |
out_file: Optional[str] = None, | |
step: int = 0, | |
with_labels: Optional[bool] = True) -> None: | |
"""Draw datasample and save to all backends. | |
- If GT and prediction are plotted at the same time, they are | |
displayed in a stitched image where the left image is the | |
ground truth and the right image is the prediction. | |
- If ``show`` is True, all storage backends are ignored, and | |
the images will be displayed in a local window. | |
- If ``out_file`` is specified, the drawn image will be | |
saved to ``out_file``. it is usually used when the display | |
is not available. | |
Args: | |
name (str): The image identifier. | |
image (np.ndarray): The image to draw. | |
gt_sample (:obj:`SegDataSample`, optional): GT SegDataSample. | |
Defaults to None. | |
pred_sample (:obj:`SegDataSample`, optional): Prediction | |
SegDataSample. Defaults to None. | |
draw_gt (bool): Whether to draw GT SegDataSample. Default to True. | |
draw_pred (bool): Whether to draw Prediction SegDataSample. | |
Defaults to True. | |
show (bool): Whether to display the drawn image. Default to False. | |
wait_time (float): The interval of show (s). Defaults to 0. | |
out_file (str): Path to output file. Defaults to None. | |
step (int): Global step value to record. Defaults to 0. | |
with_labels(bool, optional): Add semantic labels in visualization | |
result, Defaults to True. | |
""" | |
classes = self.dataset_meta.get('classes', None) | |
palette = self.dataset_meta.get('palette', None) | |
gt_img_data = None | |
pred_img_data = None | |
if draw_gt and data_sample is not None: | |
if 'gt_sem_seg' in data_sample: | |
assert classes is not None, 'class information is ' \ | |
'not provided when ' \ | |
'visualizing semantic ' \ | |
'segmentation results.' | |
gt_img_data = self._draw_sem_seg(image, data_sample.gt_sem_seg, | |
classes, palette, with_labels) | |
if 'gt_depth_map' in data_sample: | |
gt_img_data = gt_img_data if gt_img_data is not None else image | |
gt_img_data = self._draw_depth_map(gt_img_data, | |
data_sample.gt_depth_map) | |
if draw_pred and data_sample is not None: | |
if 'pred_sem_seg' in data_sample: | |
assert classes is not None, 'class information is ' \ | |
'not provided when ' \ | |
'visualizing semantic ' \ | |
'segmentation results.' | |
pred_img_data = self._draw_sem_seg(image, | |
data_sample.pred_sem_seg, | |
classes, palette, | |
with_labels) | |
if 'pred_depth_map' in data_sample: | |
pred_img_data = pred_img_data if pred_img_data is not None \ | |
else image | |
pred_img_data = self._draw_depth_map( | |
pred_img_data, data_sample.pred_depth_map) | |
if gt_img_data is not None and pred_img_data is not None: | |
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) | |
elif gt_img_data is not None: | |
drawn_img = gt_img_data | |
else: | |
drawn_img = pred_img_data | |
if show: | |
self.show(drawn_img, win_name=name, wait_time=wait_time) | |
if out_file is not None: | |
mmcv.imwrite(mmcv.rgb2bgr(drawn_img), out_file) | |
else: | |
self.add_image(name, drawn_img, step) | |