_
File size: 3,145 Bytes
da3eeba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Any, Dict, List, Union

import numpy as np
from PIL import Image


def invert_mask(mask: np.ndarray) -> np.ndarray:
    """Invert mask.



    Args:

        mask (np.ndarray): mask



    Returns:

        np.ndarray: inverted mask

    """
    if mask is None or not isinstance(mask, np.ndarray):
        raise ValueError("Invalid mask")

    # return np.logical_not(mask.astype(bool)).astype(np.uint8) * 255
    return np.invert(mask.astype(np.uint8))


def check_inputs_create_mask_image(

        mask: Union[np.ndarray, Image.Image],

        sam_masks: List[Dict[str, Any]],

        ignore_black_chk: bool = True,

) -> None:
    """Check create mask image inputs.



    Args:

        mask (Union[np.ndarray, Image.Image]): mask

        sam_masks (List[Dict[str, Any]]): SAM masks

        ignore_black_chk (bool): ignore black check



    Returns:

        None

    """
    if mask is None or not isinstance(mask, (np.ndarray, Image.Image)):
        raise ValueError("Invalid mask")

    if sam_masks is None or not isinstance(sam_masks, list):
        raise ValueError("Invalid SAM masks")

    if ignore_black_chk is None or not isinstance(ignore_black_chk, bool):
        raise ValueError("Invalid ignore black check")


def convert_mask(mask: Union[np.ndarray, Image.Image]) -> np.ndarray:
    """Convert mask.



    Args:

        mask (Union[np.ndarray, Image.Image]): mask



    Returns:

        np.ndarray: converted mask

    """
    if isinstance(mask, Image.Image):
        mask = np.array(mask)

    if mask.ndim == 2:
        mask = mask[:, :, np.newaxis]

    if mask.shape[2] != 1:
        mask = mask[:, :, 0:1]

    return mask


def create_mask_image(

        mask: Union[np.ndarray, Image.Image],

        sam_masks: List[Dict[str, Any]],

        ignore_black_chk: bool = True,

) -> np.ndarray:
    """Create mask image.



    Args:

        mask (Union[np.ndarray, Image.Image]): mask

        sam_masks (List[Dict[str, Any]]): SAM masks

        ignore_black_chk (bool): ignore black check



    Returns:

        np.ndarray: mask image

    """
    check_inputs_create_mask_image(mask, sam_masks, ignore_black_chk)
    mask = convert_mask(mask)

    canvas_image = np.zeros(mask.shape, dtype=np.uint8)
    mask_region = np.zeros(mask.shape, dtype=np.uint8)
    for seg_dict in sam_masks:
        seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1)
        canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
        if (seg_mask * canvas_mask * mask).astype(bool).any():
            mask_region = mask_region + (seg_mask * canvas_mask)
        seg_color = seg_mask * canvas_mask
        canvas_image = canvas_image + seg_color

    if not ignore_black_chk:
        canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
        if (canvas_mask * mask).astype(bool).any():
            mask_region = mask_region + (canvas_mask)

    mask_region = np.tile(mask_region * 255, (1, 1, 3))

    seg_image = mask_region.astype(np.uint8)

    return seg_image