|
""" |
|
Source url: https://github.com/OPHoperHPO/image-background-remove-tool |
|
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. |
|
License: Apache License 2.0 |
|
""" |
|
from carvekit.ml.wrap.fba_matting import FBAMatting |
|
from typing import Union, List |
|
from PIL import Image |
|
from pathlib import Path |
|
from carvekit.trimap.cv_gen import CV2TrimapGenerator |
|
from carvekit.trimap.generator import TrimapGenerator |
|
from carvekit.utils.mask_utils import apply_mask |
|
from carvekit.utils.pool_utils import thread_pool_processing |
|
from carvekit.utils.image_utils import load_image, convert_image |
|
|
|
__all__ = ["MattingMethod"] |
|
|
|
|
|
class MattingMethod: |
|
""" |
|
Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap. |
|
Neural network for matting performs accurate object edge detection by using a special map called trimap, |
|
with unknown area that we scan for boundary, already known general object area and the background.""" |
|
|
|
def __init__( |
|
self, |
|
matting_module: Union[FBAMatting], |
|
trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator], |
|
device="cpu", |
|
): |
|
""" |
|
Initializes Matting Method class. |
|
|
|
Args: |
|
matting_module: Initialized matting neural network class |
|
trimap_generator: Initialized trimap generator class |
|
device: Processing device used for applying mask to image |
|
""" |
|
self.device = device |
|
self.matting_module = matting_module |
|
self.trimap_generator = trimap_generator |
|
|
|
def __call__( |
|
self, |
|
images: List[Union[str, Path, Image.Image]], |
|
masks: List[Union[str, Path, Image.Image]], |
|
): |
|
""" |
|
Passes data through apply_mask function |
|
|
|
Args: |
|
images: list of images |
|
masks: list pf masks |
|
|
|
Returns: |
|
list of images |
|
""" |
|
if len(images) != len(masks): |
|
raise ValueError("Images and Masks lists should have same length!") |
|
images = thread_pool_processing(lambda x: convert_image(load_image(x)), images) |
|
masks = thread_pool_processing( |
|
lambda x: convert_image(load_image(x), mode="L"), masks |
|
) |
|
trimaps = thread_pool_processing( |
|
lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]), |
|
range(len(images)), |
|
) |
|
alpha = self.matting_module(images=images, trimaps=trimaps) |
|
return list( |
|
map( |
|
lambda x: apply_mask( |
|
image=images[x], mask=alpha[x], device=self.device |
|
), |
|
range(len(images)), |
|
) |
|
) |
|
|