File size: 2,639 Bytes
19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import abc
from typing import Dict, List
import numpy as np
import torch
from skimage import color
from skimage.segmentation import mark_boundaries
from . import colors
COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation
class BaseVisualizer:
@abc.abstractmethod
def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
"""
Take a batch, make an image from it and visualize
"""
raise NotImplementedError()
def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str],
last_without_mask=True, rescale_keys=None, mask_only_first=None,
black_mask=False) -> np.ndarray:
mask = images_dict['mask'] > 0.5
result = []
for i, k in enumerate(keys):
img = images_dict[k]
img = np.transpose(img, (1, 2, 0))
if rescale_keys is not None and k in rescale_keys:
img = img - img.min()
img /= img.max() + 1e-5
if len(img.shape) == 2:
img = np.expand_dims(img, 2)
if img.shape[2] == 1:
img = np.repeat(img, 3, axis=2)
elif (img.shape[2] > 3):
img_classes = img.argmax(2)
img = color.label2rgb(img_classes, colors=COLORS)
if mask_only_first:
need_mark_boundaries = i == 0
else:
need_mark_boundaries = i < len(keys) - 1 or not last_without_mask
if need_mark_boundaries:
if black_mask:
img = img * (1 - mask[0][..., None])
img = mark_boundaries(img,
mask[0],
color=(1., 0., 0.),
outline_color=(1., 1., 1.),
mode='thick')
result.append(img)
return np.concatenate(result, axis=1)
def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10,
last_without_mask=True, rescale_keys=None) -> np.ndarray:
batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items()
if k in keys or k == 'mask'}
batch_size = next(iter(batch.values())).shape[0]
items_to_vis = min(batch_size, max_items)
result = []
for i in range(items_to_vis):
cur_dct = {k: tens[i] for k, tens in batch.items()}
result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask,
rescale_keys=rescale_keys))
return np.concatenate(result, axis=0)
|