|
""" |
|
Source url: https://github.com/OPHoperHPO/image-background-remove-tool |
|
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. |
|
License: Apache License 2.0 |
|
""" |
|
import pathlib |
|
from typing import List, Union |
|
|
|
import PIL.Image |
|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
from torchvision.models.segmentation import deeplabv3_resnet101 |
|
from carvekit.ml.files.models_loc import deeplab_pretrained |
|
from carvekit.utils.image_utils import convert_image, load_image |
|
from carvekit.utils.models_utils import get_precision_autocast, cast_network |
|
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing |
|
|
|
__all__ = ["DeepLabV3"] |
|
|
|
|
|
class DeepLabV3: |
|
def __init__( |
|
self, |
|
device="cpu", |
|
batch_size: int = 10, |
|
input_image_size: Union[List[int], int] = 1024, |
|
load_pretrained: bool = True, |
|
fp16: bool = False, |
|
): |
|
""" |
|
Initialize the DeepLabV3 model |
|
|
|
Args: |
|
device: processing device |
|
input_image_size: input image size |
|
batch_size: the number of images that the neural network processes in one run |
|
load_pretrained: loading pretrained model |
|
fp16: use half precision |
|
|
|
""" |
|
self.device = device |
|
self.batch_size = batch_size |
|
self.network = deeplabv3_resnet101( |
|
pretrained=False, pretrained_backbone=False, aux_loss=True |
|
) |
|
self.network.to(self.device) |
|
if load_pretrained: |
|
self.network.load_state_dict( |
|
torch.load(deeplab_pretrained(), map_location=self.device) |
|
) |
|
if isinstance(input_image_size, list): |
|
self.input_image_size = input_image_size[:2] |
|
else: |
|
self.input_image_size = (input_image_size, input_image_size) |
|
self.network.eval() |
|
self.fp16 = fp16 |
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
] |
|
) |
|
|
|
def to(self, device: str): |
|
""" |
|
Moves neural network to specified processing device |
|
|
|
Args: |
|
device (:class:`torch.device`): the desired device. |
|
Returns: |
|
None |
|
|
|
""" |
|
self.network.to(device) |
|
|
|
def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor: |
|
""" |
|
Transform input image to suitable data format for neural network |
|
|
|
Args: |
|
data: input image |
|
|
|
Returns: |
|
input for neural network |
|
|
|
""" |
|
copy = data.copy() |
|
copy.thumbnail(self.input_image_size, resample=3) |
|
return self.transform(copy) |
|
|
|
@staticmethod |
|
def data_postprocessing( |
|
data: torch.tensor, original_image: PIL.Image.Image |
|
) -> PIL.Image.Image: |
|
""" |
|
Transforms output data from neural network to suitable data |
|
format for using with other components of this framework. |
|
|
|
Args: |
|
data: output data from neural network |
|
original_image: input image which was used for predicted data |
|
|
|
Returns: |
|
Segmentation mask as PIL Image instance |
|
|
|
""" |
|
return ( |
|
Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size) |
|
) |
|
|
|
def __call__( |
|
self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] |
|
) -> List[PIL.Image.Image]: |
|
""" |
|
Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances |
|
|
|
Args: |
|
images: input images |
|
|
|
Returns: |
|
segmentation masks as for input images, as PIL.Image.Image instances |
|
|
|
""" |
|
collect_masks = [] |
|
autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16) |
|
with autocast: |
|
cast_network(self.network, dtype) |
|
for image_batch in batch_generator(images, self.batch_size): |
|
images = thread_pool_processing( |
|
lambda x: convert_image(load_image(x)), image_batch |
|
) |
|
batches = thread_pool_processing(self.data_preprocessing, images) |
|
with torch.no_grad(): |
|
masks = [ |
|
self.network(i.to(self.device).unsqueeze(0))["out"][0] |
|
.argmax(0) |
|
.byte() |
|
.cpu() |
|
for i in batches |
|
] |
|
del batches |
|
masks = thread_pool_processing( |
|
lambda x: self.data_postprocessing(masks[x], images[x]), |
|
range(len(images)), |
|
) |
|
collect_masks += masks |
|
return collect_masks |
|
|