CM2000112 / carvekit /trimap /generator.py
jayparmr's picture
Upload folder using huggingface_hub
a3d6c18
raw
history blame
1.96 kB
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
from PIL import Image
from carvekit.trimap.cv_gen import CV2TrimapGenerator
from carvekit.trimap.add_ops import prob_filter, prob_as_unknown_area, post_erosion
class TrimapGenerator(CV2TrimapGenerator):
def __init__(
self, prob_threshold: int = 231, kernel_size: int = 30, erosion_iters: int = 5
):
"""
Initialize a TrimapGenerator instance
Args:
prob_threshold: Probability threshold at which the
prob_filter and prob_as_unknown_area operations will be applied
kernel_size: The size of the offset from the object mask
in pixels when an unknown area is detected in the trimap
erosion_iters: The number of iterations of erosion that
the object's mask will be subjected to before forming an unknown area
"""
super().__init__(kernel_size, erosion_iters=0)
self.prob_threshold = prob_threshold
self.__erosion_iters = erosion_iters
def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Image:
"""
Generates trimap based on predicted object mask to refine object mask borders.
Based on cv2 erosion algorithm and additional prob. filters.
Args:
original_image: Original image
mask: Predicted object mask
Returns:
Generated trimap for image.
"""
filter_mask = prob_filter(mask=mask, prob_threshold=self.prob_threshold)
trimap = super(TrimapGenerator, self).__call__(original_image, filter_mask)
new_trimap = prob_as_unknown_area(
trimap=trimap, mask=mask, prob_threshold=self.prob_threshold
)
new_trimap = post_erosion(new_trimap, self.__erosion_iters)
return new_trimap