CM2000112 / carvekit /ml /wrap /fba_matting.py
jayparmr's picture
Upload folder using huggingface_hub
a3d6c18
raw
history blame
8.17 kB
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
import pathlib
from typing import Union, List, Tuple
import PIL
import cv2
import numpy as np
import torch
from PIL import Image
from carvekit.ml.arch.fba_matting.models import FBA
from carvekit.ml.arch.fba_matting.transforms import (
trimap_transform,
groupnorm_normalise_image,
)
from carvekit.ml.files.models_loc import fba_pretrained
from carvekit.utils.image_utils import convert_image, load_image
from carvekit.utils.models_utils import get_precision_autocast, cast_network
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
__all__ = ["FBAMatting"]
class FBAMatting(FBA):
"""
FBA Matting Neural Network to improve edges on image.
"""
def __init__(
self,
device="cpu",
input_tensor_size: Union[List[int], int] = 2048,
batch_size: int = 2,
encoder="resnet50_GN_WS",
load_pretrained: bool = True,
fp16: bool = False,
):
"""
Initialize the FBAMatting model
Args:
device: processing device
input_tensor_size: input image size
batch_size: the number of images that the neural network processes in one run
encoder: neural network encoder head
load_pretrained: loading pretrained model
fp16: use half precision
"""
super(FBAMatting, self).__init__(encoder=encoder)
self.fp16 = fp16
self.device = device
self.batch_size = batch_size
if isinstance(input_tensor_size, list):
self.input_image_size = input_tensor_size[:2]
else:
self.input_image_size = (input_tensor_size, input_tensor_size)
self.to(device)
if load_pretrained:
self.load_state_dict(torch.load(fba_pretrained(), map_location=self.device))
self.eval()
def data_preprocessing(
self, data: Union[PIL.Image.Image, np.ndarray]
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""
Transform input image to suitable data format for neural network
Args:
data: input image
Returns:
input for neural network
"""
resized = data.copy()
if self.batch_size == 1:
resized.thumbnail(self.input_image_size, resample=3)
else:
resized = resized.resize(self.input_image_size, resample=3)
# noinspection PyTypeChecker
image = np.array(resized, dtype=np.float64)
image = image / 255.0 # Normalize image to [0, 1] values range
if resized.mode == "RGB":
image = image[:, :, ::-1]
elif resized.mode == "L":
image2 = np.copy(image)
h, w = image2.shape
image = np.zeros((h, w, 2)) # Transform trimap to binary data format
image[image2 == 1, 1] = 1
image[image2 == 0, 0] = 1
else:
raise ValueError("Incorrect color mode for image")
h, w = image.shape[:2] # Scale input mlt to 8
h1 = int(np.ceil(1.0 * h / 8) * 8)
w1 = int(np.ceil(1.0 * w / 8) * 8)
x_scale = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_LANCZOS4)
image_tensor = torch.from_numpy(x_scale).permute(2, 0, 1)[None, :, :, :].float()
if resized.mode == "RGB":
return image_tensor, groupnorm_normalise_image(
image_tensor.clone(), format="nchw"
)
else:
return (
image_tensor,
torch.from_numpy(trimap_transform(x_scale))
.permute(2, 0, 1)[None, :, :, :]
.float(),
)
@staticmethod
def data_postprocessing(
data: torch.tensor, trimap: PIL.Image.Image
) -> PIL.Image.Image:
"""
Transforms output data from neural network to suitable data
format for using with other components of this framework.
Args:
data: output data from neural network
trimap: Map with the area we need to refine
Returns:
Segmentation mask as PIL Image instance
"""
if trimap.mode != "L":
raise ValueError("Incorrect color mode for trimap")
pred = data.numpy().transpose((1, 2, 0))
pred = cv2.resize(pred, trimap.size, cv2.INTER_LANCZOS4)[:, :, 0]
# noinspection PyTypeChecker
# Clean mask by removing all false predictions outside trimap and already known area
trimap_arr = np.array(trimap.copy())
pred[trimap_arr[:, :] == 0] = 0
# pred[trimap_arr[:, :] == 255] = 1
pred[pred < 0.3] = 0
return Image.fromarray(pred * 255).convert("L")
def __call__(
self,
images: List[Union[str, pathlib.Path, PIL.Image.Image]],
trimaps: List[Union[str, pathlib.Path, PIL.Image.Image]],
) -> List[PIL.Image.Image]:
"""
Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
Args:
images: input images
trimaps: Maps with the areas we need to refine
Returns:
segmentation masks as for input images, as PIL.Image.Image instances
"""
if len(images) != len(trimaps):
raise ValueError(
"Len of specified arrays of images and trimaps should be equal!"
)
collect_masks = []
autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
with autocast:
cast_network(self, dtype)
for idx_batch in batch_generator(range(len(images)), self.batch_size):
inpt_images = thread_pool_processing(
lambda x: convert_image(load_image(images[x])), idx_batch
)
inpt_trimaps = thread_pool_processing(
lambda x: convert_image(load_image(trimaps[x]), mode="L"), idx_batch
)
inpt_img_batches = thread_pool_processing(
self.data_preprocessing, inpt_images
)
inpt_trimaps_batches = thread_pool_processing(
self.data_preprocessing, inpt_trimaps
)
inpt_img_batches_transformed = torch.vstack(
[i[1] for i in inpt_img_batches]
)
inpt_img_batches = torch.vstack([i[0] for i in inpt_img_batches])
inpt_trimaps_transformed = torch.vstack(
[i[1] for i in inpt_trimaps_batches]
)
inpt_trimaps_batches = torch.vstack(
[i[0] for i in inpt_trimaps_batches]
)
with torch.no_grad():
inpt_img_batches = inpt_img_batches.to(self.device)
inpt_trimaps_batches = inpt_trimaps_batches.to(self.device)
inpt_img_batches_transformed = inpt_img_batches_transformed.to(
self.device
)
inpt_trimaps_transformed = inpt_trimaps_transformed.to(self.device)
output = super(FBAMatting, self).__call__(
inpt_img_batches,
inpt_trimaps_batches,
inpt_img_batches_transformed,
inpt_trimaps_transformed,
)
output_cpu = output.cpu()
del (
inpt_img_batches,
inpt_trimaps_batches,
inpt_img_batches_transformed,
inpt_trimaps_transformed,
output,
)
masks = thread_pool_processing(
lambda x: self.data_postprocessing(output_cpu[x], inpt_trimaps[x]),
range(len(inpt_images)),
)
collect_masks += masks
return collect_masks