File size: 1,955 Bytes
a3d6c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
from PIL import Image
from carvekit.trimap.cv_gen import CV2TrimapGenerator
from carvekit.trimap.add_ops import prob_filter, prob_as_unknown_area, post_erosion


class TrimapGenerator(CV2TrimapGenerator):
    def __init__(
        self, prob_threshold: int = 231, kernel_size: int = 30, erosion_iters: int = 5
    ):
        """
        Initialize a TrimapGenerator instance

        Args:
            prob_threshold: Probability threshold at which the
            prob_filter and prob_as_unknown_area operations will be applied
            kernel_size: The size of the offset from the object mask
            in pixels when an unknown area is detected in the trimap
            erosion_iters: The number of iterations of erosion that
            the object's mask will be subjected to before forming an unknown area
        """
        super().__init__(kernel_size, erosion_iters=0)
        self.prob_threshold = prob_threshold
        self.__erosion_iters = erosion_iters

    def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Image:
        """
        Generates trimap based on predicted object mask to refine object mask borders.
        Based on cv2 erosion algorithm and additional prob. filters.
        Args:
            original_image: Original image
            mask: Predicted object mask

        Returns:
            Generated trimap for image.
        """
        filter_mask = prob_filter(mask=mask, prob_threshold=self.prob_threshold)
        trimap = super(TrimapGenerator, self).__call__(original_image, filter_mask)
        new_trimap = prob_as_unknown_area(
            trimap=trimap, mask=mask, prob_threshold=self.prob_threshold
        )
        new_trimap = post_erosion(new_trimap, self.__erosion_iters)
        return new_trimap