Spaces:
Sleeping
Sleeping
| """ | |
| Source url: https://github.com/OPHoperHPO/image-background-remove-tool | |
| Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. | |
| License: Apache License 2.0 | |
| """ | |
| from pathlib import Path | |
| from typing import Union, List, Optional | |
| from PIL import Image | |
| from carvekit.ml.wrap.basnet import BASNET | |
| from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 | |
| from carvekit.ml.wrap.u2net import U2NET | |
| from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 | |
| from carvekit.pipelines.preprocessing import PreprocessingStub | |
| from carvekit.pipelines.postprocessing import MattingMethod | |
| from carvekit.utils.image_utils import load_image | |
| from carvekit.utils.mask_utils import apply_mask | |
| from carvekit.utils.pool_utils import thread_pool_processing | |
| class Interface: | |
| def __init__( | |
| self, | |
| seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7], | |
| pre_pipe: Optional[Union[PreprocessingStub]] = None, | |
| post_pipe: Optional[Union[MattingMethod]] = None, | |
| device="cpu", | |
| ): | |
| """ | |
| Initializes an object for interacting with pipelines and other components of the CarveKit framework. | |
| Args: | |
| pre_pipe: Initialized pre-processing pipeline object | |
| seg_pipe: Initialized segmentation network object | |
| post_pipe: Initialized postprocessing pipeline object | |
| device: The processing device that will be used to apply the masks to the images. | |
| """ | |
| self.device = device | |
| self.preprocessing_pipeline = pre_pipe | |
| self.segmentation_pipeline = seg_pipe | |
| self.postprocessing_pipeline = post_pipe | |
| def __call__( | |
| self, images: List[Union[str, Path, Image.Image]] | |
| ) -> List[Image.Image]: | |
| """ | |
| Removes the background from the specified images. | |
| Args: | |
| images: list of input images | |
| Returns: | |
| List of images without background as PIL.Image.Image instances | |
| """ | |
| images = thread_pool_processing(load_image, images) | |
| if self.preprocessing_pipeline is not None: | |
| masks: List[Image.Image] = self.preprocessing_pipeline( | |
| interface=self, images=images | |
| ) | |
| else: | |
| masks: List[Image.Image] = self.segmentation_pipeline(images=images) | |
| if self.postprocessing_pipeline is not None: | |
| images: List[Image.Image] = self.postprocessing_pipeline( | |
| images=images, masks=masks | |
| ) | |
| else: | |
| images = list( | |
| map( | |
| lambda x: apply_mask( | |
| image=images[x], mask=masks[x], device=self.device | |
| ), | |
| range(len(images)), | |
| ) | |
| ) | |
| return images | |