|
from typing import Any, Dict, List, Union
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
|
|
def invert_mask(mask: np.ndarray) -> np.ndarray:
|
|
"""Invert mask.
|
|
|
|
Args:
|
|
mask (np.ndarray): mask
|
|
|
|
Returns:
|
|
np.ndarray: inverted mask
|
|
"""
|
|
if mask is None or not isinstance(mask, np.ndarray):
|
|
raise ValueError("Invalid mask")
|
|
|
|
|
|
return np.invert(mask.astype(np.uint8))
|
|
|
|
|
|
def check_inputs_create_mask_image(
|
|
mask: Union[np.ndarray, Image.Image],
|
|
sam_masks: List[Dict[str, Any]],
|
|
ignore_black_chk: bool = True,
|
|
) -> None:
|
|
"""Check create mask image inputs.
|
|
|
|
Args:
|
|
mask (Union[np.ndarray, Image.Image]): mask
|
|
sam_masks (List[Dict[str, Any]]): SAM masks
|
|
ignore_black_chk (bool): ignore black check
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
if mask is None or not isinstance(mask, (np.ndarray, Image.Image)):
|
|
raise ValueError("Invalid mask")
|
|
|
|
if sam_masks is None or not isinstance(sam_masks, list):
|
|
raise ValueError("Invalid SAM masks")
|
|
|
|
if ignore_black_chk is None or not isinstance(ignore_black_chk, bool):
|
|
raise ValueError("Invalid ignore black check")
|
|
|
|
|
|
def convert_mask(mask: Union[np.ndarray, Image.Image]) -> np.ndarray:
|
|
"""Convert mask.
|
|
|
|
Args:
|
|
mask (Union[np.ndarray, Image.Image]): mask
|
|
|
|
Returns:
|
|
np.ndarray: converted mask
|
|
"""
|
|
if isinstance(mask, Image.Image):
|
|
mask = np.array(mask)
|
|
|
|
if mask.ndim == 2:
|
|
mask = mask[:, :, np.newaxis]
|
|
|
|
if mask.shape[2] != 1:
|
|
mask = mask[:, :, 0:1]
|
|
|
|
return mask
|
|
|
|
|
|
def create_mask_image(
|
|
mask: Union[np.ndarray, Image.Image],
|
|
sam_masks: List[Dict[str, Any]],
|
|
ignore_black_chk: bool = True,
|
|
) -> np.ndarray:
|
|
"""Create mask image.
|
|
|
|
Args:
|
|
mask (Union[np.ndarray, Image.Image]): mask
|
|
sam_masks (List[Dict[str, Any]]): SAM masks
|
|
ignore_black_chk (bool): ignore black check
|
|
|
|
Returns:
|
|
np.ndarray: mask image
|
|
"""
|
|
check_inputs_create_mask_image(mask, sam_masks, ignore_black_chk)
|
|
mask = convert_mask(mask)
|
|
|
|
canvas_image = np.zeros(mask.shape, dtype=np.uint8)
|
|
mask_region = np.zeros(mask.shape, dtype=np.uint8)
|
|
for seg_dict in sam_masks:
|
|
seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1)
|
|
canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
|
|
if (seg_mask * canvas_mask * mask).astype(bool).any():
|
|
mask_region = mask_region + (seg_mask * canvas_mask)
|
|
seg_color = seg_mask * canvas_mask
|
|
canvas_image = canvas_image + seg_color
|
|
|
|
if not ignore_black_chk:
|
|
canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
|
|
if (canvas_mask * mask).astype(bool).any():
|
|
mask_region = mask_region + (canvas_mask)
|
|
|
|
mask_region = np.tile(mask_region * 255, (1, 1, 3))
|
|
|
|
seg_image = mask_region.astype(np.uint8)
|
|
|
|
return seg_image
|
|
|