diego2554 commited on
Commit
af295ba
·
1 Parent(s): b0e8dee

Update rembg/bg.py

Browse files
Files changed (1) hide show
  1. rembg/bg.py +180 -35
rembg/bg.py CHANGED
@@ -1,26 +1,142 @@
1
  import io
2
- from typing import Any, Union
3
- import numpy as np
4
- from PIL import Image
5
  from enum import Enum
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- from rembg import new_session, remove
8
- from rembg.sessions.base import BaseSession
9
- from rembg.util.util import fix_image_orientation
10
 
11
  class ReturnType(Enum):
12
  BYTES = 0
13
  PILLOW = 1
14
  NDARRAY = 2
15
 
16
- def remove_background(
17
- data: Union[bytes, Image.Image, np.ndarray],
18
- alpha_influence: float = 0.5,
19
- segmentation_strength: float = 0.5,
20
- smoothing: float = 0.5,
21
- model: str = "u2net",
22
- ) -> Union[bytes, Image.Image, np.ndarray]:
23
- if isinstance(data, Image.Image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  return_type = ReturnType.PILLOW
25
  img = data
26
  elif isinstance(data, bytes):
@@ -30,27 +146,56 @@ def remove_background(
30
  return_type = ReturnType.NDARRAY
31
  img = Image.fromarray(data)
32
  else:
33
- raise ValueError("Input type {} is not supported.".format(type(data))
34
 
 
35
  img = fix_image_orientation(img)
36
- session = new_session(model)
37
- output = remove(
38
- img,
39
- alpha_matting=True,
40
- alpha_matting_foreground_threshold=alpha_influence * 255,
41
- alpha_matting_background_threshold=(1 - alpha_influence) * 255,
42
- alpha_matting_erode_size=int(segmentation_strength * 20),
43
- alpha_matting_smoothing=smoothing,
44
- session=session
45
- )
46
-
47
- if return_type == ReturnType.PILLOW:
48
- return output
49
- elif return_type == ReturnType.NDARRAY:
50
- return np.array(output)
51
- else:
52
- bio = io.BytesIO()
53
- output.save(bio, "PNG")
54
- bio.seek(0)
55
- return bio.read()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import io
 
 
 
2
  from enum import Enum
3
+ from typing import Any, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ from cv2 import (
7
+ BORDER_DEFAULT,
8
+ MORPH_ELLIPSE,
9
+ MORPH_OPEN,
10
+ GaussianBlur,
11
+ getStructuringElement,
12
+ morphologyEx,
13
+ )
14
+ from PIL import Image, ImageOps
15
+ from PIL.Image import Image as PILImage
16
+ from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
17
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
18
+ from pymatting.util.util import stack_images
19
+ from scipy.ndimage import binary_erosion
20
+
21
+ from .session_factory import new_session
22
+ from .sessions import sessions_class
23
+ from .sessions.base import BaseSession
24
+
25
+ kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
26
 
 
 
 
27
 
28
  class ReturnType(Enum):
29
  BYTES = 0
30
  PILLOW = 1
31
  NDARRAY = 2
32
 
33
+
34
+ def alpha_matting_cutout(
35
+ img: PILImage,
36
+ mask: PILImage,
37
+ foreground_threshold: int,
38
+ background_threshold: int,
39
+ erode_structure_size: int,
40
+ ) -> PILImage:
41
+ if img.mode == "RGBA" or img.mode == "CMYK":
42
+ img = img.convert("RGB")
43
+
44
+ img = np.asarray(img)
45
+ mask = np.asarray(mask)
46
+
47
+ is_foreground = mask > foreground_threshold
48
+ is_background = mask < background_threshold
49
+
50
+ structure = None
51
+ if erode_structure_size > 0:
52
+ structure = np.ones(
53
+ (erode_structure_size, erode_structure_size), dtype=np.uint8
54
+ )
55
+
56
+ is_foreground = binary_erosion(is_foreground, structure=structure)
57
+ is_background = binary_erosion(is_background, structure=structure, border_value=1)
58
+
59
+ trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
60
+ trimap[is_foreground] = 255
61
+ trimap[is_background] = 0
62
+
63
+ img_normalized = img / 255.0
64
+ trimap_normalized = trimap / 255.0
65
+
66
+ alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
67
+ foreground = estimate_foreground_ml(img_normalized, alpha)
68
+ cutout = stack_images(foreground, alpha)
69
+
70
+ cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
71
+ cutout = Image.fromarray(cutout)
72
+
73
+ return cutout
74
+
75
+
76
+ def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
77
+ empty = Image.new("RGBA", (img.size), 0)
78
+ cutout = Image.composite(img, empty, mask)
79
+ return cutout
80
+
81
+
82
+ def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
83
+ pivot = imgs.pop(0)
84
+ for im in imgs:
85
+ pivot = get_concat_v(pivot, im)
86
+ return pivot
87
+
88
+
89
+ def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
90
+ dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
91
+ dst.paste(img1, (0, 0))
92
+ dst.paste(img2, (0, img1.height))
93
+ return dst
94
+
95
+
96
+ def post_process(mask: np.ndarray) -> np.ndarray:
97
+ """
98
+ Post Process the mask for a smooth boundary by applying Morphological Operations
99
+ Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757
100
+ args:
101
+ mask: Binary Numpy Mask
102
+ """
103
+ mask = morphologyEx(mask, MORPH_OPEN, kernel)
104
+ mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT)
105
+ mask = np.where(mask < 127, 0, 255).astype(np.uint8) # convert again to binary
106
+ return mask
107
+
108
+
109
+ def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
110
+ r, g, b, a = color
111
+ colored_image = Image.new("RGBA", img.size, (r, g, b, a))
112
+ colored_image.paste(img, mask=img)
113
+
114
+ return colored_image
115
+
116
+
117
+ def fix_image_orientation(img: PILImage) -> PILImage:
118
+ return ImageOps.exif_transpose(img)
119
+
120
+
121
+ def download_models() -> None:
122
+ for session in sessions_class:
123
+ session.download_models()
124
+
125
+
126
+ def remove(
127
+ data: Union[bytes, PILImage, np.ndarray],
128
+ alpha_matting: bool = False,
129
+ alpha_matting_foreground_threshold: int = 240,
130
+ alpha_matting_background_threshold: int = 10,
131
+ alpha_matting_erode_size: int = 10,
132
+ session: Optional[BaseSession] = None,
133
+ only_mask: bool = False,
134
+ post_process_mask: bool = False,
135
+ bgcolor: Optional[Tuple[int, int, int, int]] = None,
136
+ *args: Optional[Any],
137
+ **kwargs: Optional[Any]
138
+ ) -> Union[bytes, PILImage, np.ndarray]:
139
+ if isinstance(data, PILImage):
140
  return_type = ReturnType.PILLOW
141
  img = data
142
  elif isinstance(data, bytes):
 
146
  return_type = ReturnType.NDARRAY
147
  img = Image.fromarray(data)
148
  else:
149
+ raise ValueError("Input type {} is not supported.".format(type(data)))
150
 
151
+ # Fix image orientation
152
  img = fix_image_orientation(img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ if session is None:
155
+ session = new_session("u2net", *args, **kwargs)
156
+
157
+ masks = session.predict(img, *args, **kwargs)
158
+ cutouts = []
159
+
160
+ for mask in masks:
161
+ if post_process_mask:
162
+ mask = Image.fromarray(post_process(np.array(mask)))
163
+
164
+ if only_mask:
165
+ cutout = mask
166
+
167
+ elif alpha_matting:
168
+ try:
169
+ cutout = alpha_matting_cutout(
170
+ img,
171
+ mask,
172
+ alpha_matting_foreground_threshold,
173
+ alpha_matting_background_threshold,
174
+ alpha_matting_erode_size,
175
+ )
176
+ except ValueError:
177
+ cutout = naive_cutout(img, mask)
178
+
179
+ else:
180
+ cutout = naive_cutout(img, mask)
181
+
182
+ cutouts.append(cutout)
183
+
184
+ cutout = img
185
+ if len(cutouts) > 0:
186
+ cutout = get_concat_v_multi(cutouts)
187
+
188
+ if bgcolor is not None and not only_mask:
189
+ cutout = apply_background_color(cutout, bgcolor)
190
+
191
+ if ReturnType.PILLOW == return_type:
192
+ return cutout
193
+
194
+ if ReturnType.NDARRAY == return_type:
195
+ return np.asarray(cutout)
196
+
197
+ bio = io.BytesIO()
198
+ cutout.save(bio, "PNG")
199
+ bio.seek(0)
200
+
201
+ return bio.read()