File size: 2,690 Bytes
a3d6c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
from carvekit.ml.wrap.fba_matting import FBAMatting
from typing import Union, List
from PIL import Image
from pathlib import Path
from carvekit.trimap.cv_gen import CV2TrimapGenerator
from carvekit.trimap.generator import TrimapGenerator
from carvekit.utils.mask_utils import apply_mask
from carvekit.utils.pool_utils import thread_pool_processing
from carvekit.utils.image_utils import load_image, convert_image

__all__ = ["MattingMethod"]


class MattingMethod:
    """
    Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap.
    Neural network for matting performs accurate object edge detection by using a special map called trimap,
    with unknown area that we scan for boundary, already known general object area and the background."""

    def __init__(
        self,
        matting_module: Union[FBAMatting],
        trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator],
        device="cpu",
    ):
        """
        Initializes Matting Method class.

        Args:
            matting_module: Initialized matting neural network class
            trimap_generator: Initialized trimap generator class
            device: Processing device used for applying mask to image
        """
        self.device = device
        self.matting_module = matting_module
        self.trimap_generator = trimap_generator

    def __call__(
        self,
        images: List[Union[str, Path, Image.Image]],
        masks: List[Union[str, Path, Image.Image]],
    ):
        """
        Passes data through apply_mask function

        Args:
            images: list of images
            masks: list pf masks

        Returns:
            list of images
        """
        if len(images) != len(masks):
            raise ValueError("Images and Masks lists should have same length!")
        images = thread_pool_processing(lambda x: convert_image(load_image(x)), images)
        masks = thread_pool_processing(
            lambda x: convert_image(load_image(x), mode="L"), masks
        )
        trimaps = thread_pool_processing(
            lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]),
            range(len(images)),
        )
        alpha = self.matting_module(images=images, trimaps=trimaps)
        return list(
            map(
                lambda x: apply_mask(
                    image=images[x], mask=alpha[x], device=self.device
                ),
                range(len(images)),
            )
        )