File size: 8,168 Bytes
a3d6c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
import pathlib
from typing import Union, List, Tuple

import PIL
import cv2
import numpy as np
import torch
from PIL import Image

from carvekit.ml.arch.fba_matting.models import FBA
from carvekit.ml.arch.fba_matting.transforms import (
    trimap_transform,
    groupnorm_normalise_image,
)
from carvekit.ml.files.models_loc import fba_pretrained
from carvekit.utils.image_utils import convert_image, load_image
from carvekit.utils.models_utils import get_precision_autocast, cast_network
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing

__all__ = ["FBAMatting"]


class FBAMatting(FBA):
    """
    FBA Matting Neural Network to improve edges on image.
    """

    def __init__(
        self,
        device="cpu",
        input_tensor_size: Union[List[int], int] = 2048,
        batch_size: int = 2,
        encoder="resnet50_GN_WS",
        load_pretrained: bool = True,
        fp16: bool = False,
    ):
        """
        Initialize the FBAMatting model

        Args:
            device: processing device
            input_tensor_size: input image size
            batch_size: the number of images that the neural network processes in one run
            encoder: neural network encoder head
            load_pretrained: loading pretrained model
            fp16: use half precision

        """
        super(FBAMatting, self).__init__(encoder=encoder)
        self.fp16 = fp16
        self.device = device
        self.batch_size = batch_size
        if isinstance(input_tensor_size, list):
            self.input_image_size = input_tensor_size[:2]
        else:
            self.input_image_size = (input_tensor_size, input_tensor_size)
        self.to(device)
        if load_pretrained:
            self.load_state_dict(torch.load(fba_pretrained(), map_location=self.device))
        self.eval()

    def data_preprocessing(
        self, data: Union[PIL.Image.Image, np.ndarray]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """
        Transform input image to suitable data format for neural network

        Args:
            data: input image

        Returns:
            input for neural network

        """
        resized = data.copy()
        if self.batch_size == 1:
            resized.thumbnail(self.input_image_size, resample=3)
        else:
            resized = resized.resize(self.input_image_size, resample=3)
        # noinspection PyTypeChecker
        image = np.array(resized, dtype=np.float64)
        image = image / 255.0  # Normalize image to [0, 1] values range
        if resized.mode == "RGB":
            image = image[:, :, ::-1]
        elif resized.mode == "L":
            image2 = np.copy(image)
            h, w = image2.shape
            image = np.zeros((h, w, 2))  # Transform trimap to binary data format
            image[image2 == 1, 1] = 1
            image[image2 == 0, 0] = 1
        else:
            raise ValueError("Incorrect color mode for image")
        h, w = image.shape[:2]  # Scale input mlt to 8
        h1 = int(np.ceil(1.0 * h / 8) * 8)
        w1 = int(np.ceil(1.0 * w / 8) * 8)
        x_scale = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_LANCZOS4)
        image_tensor = torch.from_numpy(x_scale).permute(2, 0, 1)[None, :, :, :].float()
        if resized.mode == "RGB":
            return image_tensor, groupnorm_normalise_image(
                image_tensor.clone(), format="nchw"
            )
        else:
            return (
                image_tensor,
                torch.from_numpy(trimap_transform(x_scale))
                .permute(2, 0, 1)[None, :, :, :]
                .float(),
            )

    @staticmethod
    def data_postprocessing(
        data: torch.tensor, trimap: PIL.Image.Image
    ) -> PIL.Image.Image:
        """
        Transforms output data from neural network to suitable data
        format for using with other components of this framework.

        Args:
            data: output data from neural network
            trimap: Map with the area we need to refine

        Returns:
            Segmentation mask as PIL Image instance

        """
        if trimap.mode != "L":
            raise ValueError("Incorrect color mode for trimap")
        pred = data.numpy().transpose((1, 2, 0))
        pred = cv2.resize(pred, trimap.size, cv2.INTER_LANCZOS4)[:, :, 0]
        # noinspection PyTypeChecker
        # Clean mask by removing all false predictions outside trimap and already known area
        trimap_arr = np.array(trimap.copy())
        pred[trimap_arr[:, :] == 0] = 0
        # pred[trimap_arr[:, :] == 255] = 1
        pred[pred < 0.3] = 0
        return Image.fromarray(pred * 255).convert("L")

    def __call__(
        self,
        images: List[Union[str, pathlib.Path, PIL.Image.Image]],
        trimaps: List[Union[str, pathlib.Path, PIL.Image.Image]],
    ) -> List[PIL.Image.Image]:
        """
        Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances

        Args:
            images: input images
            trimaps: Maps with the areas we need to refine

        Returns:
            segmentation masks as for input images, as PIL.Image.Image instances

        """

        if len(images) != len(trimaps):
            raise ValueError(
                "Len of specified arrays of images and trimaps should be equal!"
            )

        collect_masks = []
        autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
        with autocast:
            cast_network(self, dtype)
            for idx_batch in batch_generator(range(len(images)), self.batch_size):
                inpt_images = thread_pool_processing(
                    lambda x: convert_image(load_image(images[x])), idx_batch
                )

                inpt_trimaps = thread_pool_processing(
                    lambda x: convert_image(load_image(trimaps[x]), mode="L"), idx_batch
                )

                inpt_img_batches = thread_pool_processing(
                    self.data_preprocessing, inpt_images
                )
                inpt_trimaps_batches = thread_pool_processing(
                    self.data_preprocessing, inpt_trimaps
                )

                inpt_img_batches_transformed = torch.vstack(
                    [i[1] for i in inpt_img_batches]
                )
                inpt_img_batches = torch.vstack([i[0] for i in inpt_img_batches])

                inpt_trimaps_transformed = torch.vstack(
                    [i[1] for i in inpt_trimaps_batches]
                )
                inpt_trimaps_batches = torch.vstack(
                    [i[0] for i in inpt_trimaps_batches]
                )

                with torch.no_grad():
                    inpt_img_batches = inpt_img_batches.to(self.device)
                    inpt_trimaps_batches = inpt_trimaps_batches.to(self.device)
                    inpt_img_batches_transformed = inpt_img_batches_transformed.to(
                        self.device
                    )
                    inpt_trimaps_transformed = inpt_trimaps_transformed.to(self.device)

                    output = super(FBAMatting, self).__call__(
                        inpt_img_batches,
                        inpt_trimaps_batches,
                        inpt_img_batches_transformed,
                        inpt_trimaps_transformed,
                    )
                    output_cpu = output.cpu()
                    del (
                        inpt_img_batches,
                        inpt_trimaps_batches,
                        inpt_img_batches_transformed,
                        inpt_trimaps_transformed,
                        output,
                    )
                masks = thread_pool_processing(
                    lambda x: self.data_postprocessing(output_cpu[x], inpt_trimaps[x]),
                    range(len(inpt_images)),
                )
                collect_masks += masks
            return collect_masks