CM2000112 / carvekit /pipelines /postprocessing.py
jayparmr's picture
Upload folder using huggingface_hub
a3d6c18
raw
history blame
2.69 kB
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
from carvekit.ml.wrap.fba_matting import FBAMatting
from typing import Union, List
from PIL import Image
from pathlib import Path
from carvekit.trimap.cv_gen import CV2TrimapGenerator
from carvekit.trimap.generator import TrimapGenerator
from carvekit.utils.mask_utils import apply_mask
from carvekit.utils.pool_utils import thread_pool_processing
from carvekit.utils.image_utils import load_image, convert_image
__all__ = ["MattingMethod"]
class MattingMethod:
"""
Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap.
Neural network for matting performs accurate object edge detection by using a special map called trimap,
with unknown area that we scan for boundary, already known general object area and the background."""
def __init__(
self,
matting_module: Union[FBAMatting],
trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator],
device="cpu",
):
"""
Initializes Matting Method class.
Args:
matting_module: Initialized matting neural network class
trimap_generator: Initialized trimap generator class
device: Processing device used for applying mask to image
"""
self.device = device
self.matting_module = matting_module
self.trimap_generator = trimap_generator
def __call__(
self,
images: List[Union[str, Path, Image.Image]],
masks: List[Union[str, Path, Image.Image]],
):
"""
Passes data through apply_mask function
Args:
images: list of images
masks: list pf masks
Returns:
list of images
"""
if len(images) != len(masks):
raise ValueError("Images and Masks lists should have same length!")
images = thread_pool_processing(lambda x: convert_image(load_image(x)), images)
masks = thread_pool_processing(
lambda x: convert_image(load_image(x), mode="L"), masks
)
trimaps = thread_pool_processing(
lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]),
range(len(images)),
)
alpha = self.matting_module(images=images, trimaps=trimaps)
return list(
map(
lambda x: apply_mask(
image=images[x], mask=alpha[x], device=self.device
),
range(len(images)),
)
)