Spaces:
Sleeping
Sleeping
- app.py +11 -11
- rembg/_version.py +2 -2
- rembg/bg.py +18 -6
- rembg/cli.py +3 -429
- rembg/commands/__init__.py +13 -0
- rembg/commands/i_command.py +93 -0
- rembg/commands/p_command.py +181 -0
- rembg/commands/s_command.py +238 -0
- rembg/session_factory.py +10 -57
- rembg/sessions/__init__.py +22 -0
- rembg/sessions/base.py +63 -0
- rembg/sessions/dis.py +47 -0
- rembg/sessions/sam.py +165 -0
- rembg/sessions/silueta.py +49 -0
- rembg/sessions/u2net.py +49 -0
- rembg/sessions/u2net_cloth_seg.py +110 -0
- rembg/sessions/u2net_human_seg.py +49 -0
- rembg/sessions/u2netp.py +49 -0
app.py
CHANGED
@@ -9,9 +9,7 @@ def inference(file, af, mask, model):
|
|
9 |
im = cv2.imread(file, cv2.IMREAD_COLOR)
|
10 |
cv2.imwrite(os.path.join("input.png"), im)
|
11 |
|
12 |
-
from rembg import remove
|
13 |
-
from rembg.session_base import BaseSession
|
14 |
-
from rembg.session_factory import new_session
|
15 |
|
16 |
input_path = 'input.png'
|
17 |
output_path = 'output.png'
|
@@ -19,15 +17,15 @@ def inference(file, af, mask, model):
|
|
19 |
with open(input_path, 'rb') as i:
|
20 |
with open(output_path, 'wb') as o:
|
21 |
input = i.read()
|
22 |
-
sessions: dict[str, BaseSession] = {}
|
23 |
output = remove(
|
24 |
input,
|
25 |
-
session=
|
26 |
-
model, new_session(model)
|
27 |
-
),
|
28 |
alpha_matting_erode_size = af,
|
29 |
only_mask = (True if mask == "Mask only" else False)
|
30 |
-
)
|
|
|
|
|
|
|
31 |
o.write(output)
|
32 |
return os.path.join("output.png")
|
33 |
|
@@ -40,7 +38,7 @@ gr.Interface(
|
|
40 |
inference,
|
41 |
[
|
42 |
gr.inputs.Image(type="filepath", label="Input"),
|
43 |
-
gr.inputs.Slider(10, 25, default=10, label="Alpha matting"),
|
44 |
gr.inputs.Radio(
|
45 |
[
|
46 |
"Default",
|
@@ -55,14 +53,16 @@ gr.Interface(
|
|
55 |
"u2netp",
|
56 |
"u2net_human_seg",
|
57 |
"u2net_cloth_seg",
|
58 |
-
"silueta"
|
|
|
|
|
59 |
],
|
60 |
type="value",
|
61 |
default="u2net",
|
62 |
label="Models"
|
63 |
),
|
64 |
],
|
65 |
-
gr.outputs.Image(type="
|
66 |
title=title,
|
67 |
description=description,
|
68 |
article=article,
|
|
|
9 |
im = cv2.imread(file, cv2.IMREAD_COLOR)
|
10 |
cv2.imwrite(os.path.join("input.png"), im)
|
11 |
|
12 |
+
from rembg import new_session, remove
|
|
|
|
|
13 |
|
14 |
input_path = 'input.png'
|
15 |
output_path = 'output.png'
|
|
|
17 |
with open(input_path, 'rb') as i:
|
18 |
with open(output_path, 'wb') as o:
|
19 |
input = i.read()
|
|
|
20 |
output = remove(
|
21 |
input,
|
22 |
+
session = new_session(model),
|
|
|
|
|
23 |
alpha_matting_erode_size = af,
|
24 |
only_mask = (True if mask == "Mask only" else False)
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
o.write(output)
|
30 |
return os.path.join("output.png")
|
31 |
|
|
|
38 |
inference,
|
39 |
[
|
40 |
gr.inputs.Image(type="filepath", label="Input"),
|
41 |
+
gr.inputs.Slider(10, 25, default=10, label="Alpha matting erode size"),
|
42 |
gr.inputs.Radio(
|
43 |
[
|
44 |
"Default",
|
|
|
53 |
"u2netp",
|
54 |
"u2net_human_seg",
|
55 |
"u2net_cloth_seg",
|
56 |
+
"silueta",
|
57 |
+
"isnet-general-use",
|
58 |
+
"sam",
|
59 |
],
|
60 |
type="value",
|
61 |
default="u2net",
|
62 |
label="Models"
|
63 |
),
|
64 |
],
|
65 |
+
gr.outputs.Image(type="filepath", label="Output"),
|
66 |
title=title,
|
67 |
description=description,
|
68 |
article=article,
|
rembg/_version.py
CHANGED
@@ -24,8 +24,8 @@ def get_keywords():
|
|
24 |
# each be defined on a line of their own. _version.py will just call
|
25 |
# get_keywords().
|
26 |
git_refnames = " (HEAD -> main)"
|
27 |
-
git_full = "
|
28 |
-
git_date = "
|
29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
30 |
return keywords
|
31 |
|
|
|
24 |
# each be defined on a line of their own. _version.py will just call
|
25 |
# get_keywords().
|
26 |
git_refnames = " (HEAD -> main)"
|
27 |
+
git_full = "e47b2a0ed405a5a30f42bacb142b107f7a4b6536"
|
28 |
+
git_date = "2023-04-26 20:40:21 -0300"
|
29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
30 |
return keywords
|
31 |
|
rembg/bg.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import io
|
2 |
from enum import Enum
|
3 |
-
from typing import List, Optional, Union
|
4 |
|
5 |
import numpy as np
|
6 |
from cv2 import (
|
@@ -18,8 +18,8 @@ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
|
18 |
from pymatting.util.util import stack_images
|
19 |
from scipy.ndimage import binary_erosion
|
20 |
|
21 |
-
from .session_base import BaseSession
|
22 |
from .session_factory import new_session
|
|
|
23 |
|
24 |
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
25 |
|
@@ -37,7 +37,6 @@ def alpha_matting_cutout(
|
|
37 |
background_threshold: int,
|
38 |
erode_structure_size: int,
|
39 |
) -> PILImage:
|
40 |
-
|
41 |
if img.mode == "RGBA" or img.mode == "CMYK":
|
42 |
img = img.convert("RGB")
|
43 |
|
@@ -106,6 +105,14 @@ def post_process(mask: np.ndarray) -> np.ndarray:
|
|
106 |
return mask
|
107 |
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
def remove(
|
110 |
data: Union[bytes, PILImage, np.ndarray],
|
111 |
alpha_matting: bool = False,
|
@@ -115,8 +122,10 @@ def remove(
|
|
115 |
session: Optional[BaseSession] = None,
|
116 |
only_mask: bool = False,
|
117 |
post_process_mask: bool = False,
|
|
|
|
|
|
|
118 |
) -> Union[bytes, PILImage, np.ndarray]:
|
119 |
-
|
120 |
if isinstance(data, PILImage):
|
121 |
return_type = ReturnType.PILLOW
|
122 |
img = data
|
@@ -130,9 +139,9 @@ def remove(
|
|
130 |
raise ValueError("Input type {} is not supported.".format(type(data)))
|
131 |
|
132 |
if session is None:
|
133 |
-
session = new_session("u2net")
|
134 |
|
135 |
-
masks = session.predict(img)
|
136 |
cutouts = []
|
137 |
|
138 |
for mask in masks:
|
@@ -163,6 +172,9 @@ def remove(
|
|
163 |
if len(cutouts) > 0:
|
164 |
cutout = get_concat_v_multi(cutouts)
|
165 |
|
|
|
|
|
|
|
166 |
if ReturnType.PILLOW == return_type:
|
167 |
return cutout
|
168 |
|
|
|
1 |
import io
|
2 |
from enum import Enum
|
3 |
+
from typing import Any, List, Optional, Tuple, Union
|
4 |
|
5 |
import numpy as np
|
6 |
from cv2 import (
|
|
|
18 |
from pymatting.util.util import stack_images
|
19 |
from scipy.ndimage import binary_erosion
|
20 |
|
|
|
21 |
from .session_factory import new_session
|
22 |
+
from .sessions.base import BaseSession
|
23 |
|
24 |
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
25 |
|
|
|
37 |
background_threshold: int,
|
38 |
erode_structure_size: int,
|
39 |
) -> PILImage:
|
|
|
40 |
if img.mode == "RGBA" or img.mode == "CMYK":
|
41 |
img = img.convert("RGB")
|
42 |
|
|
|
105 |
return mask
|
106 |
|
107 |
|
108 |
+
def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
|
109 |
+
r, g, b, a = color
|
110 |
+
colored_image = Image.new("RGBA", img.size, (r, g, b, a))
|
111 |
+
colored_image.paste(img, mask=img)
|
112 |
+
|
113 |
+
return colored_image
|
114 |
+
|
115 |
+
|
116 |
def remove(
|
117 |
data: Union[bytes, PILImage, np.ndarray],
|
118 |
alpha_matting: bool = False,
|
|
|
122 |
session: Optional[BaseSession] = None,
|
123 |
only_mask: bool = False,
|
124 |
post_process_mask: bool = False,
|
125 |
+
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
126 |
+
*args: Optional[Any],
|
127 |
+
**kwargs: Optional[Any]
|
128 |
) -> Union[bytes, PILImage, np.ndarray]:
|
|
|
129 |
if isinstance(data, PILImage):
|
130 |
return_type = ReturnType.PILLOW
|
131 |
img = data
|
|
|
139 |
raise ValueError("Input type {} is not supported.".format(type(data)))
|
140 |
|
141 |
if session is None:
|
142 |
+
session = new_session("u2net", *args, **kwargs)
|
143 |
|
144 |
+
masks = session.predict(img, *args, **kwargs)
|
145 |
cutouts = []
|
146 |
|
147 |
for mask in masks:
|
|
|
172 |
if len(cutouts) > 0:
|
173 |
cutout = get_concat_v_multi(cutouts)
|
174 |
|
175 |
+
if bgcolor is not None and not only_mask:
|
176 |
+
cutout = apply_background_color(cutout, bgcolor)
|
177 |
+
|
178 |
if ReturnType.PILLOW == return_type:
|
179 |
return cutout
|
180 |
|
rembg/cli.py
CHANGED
@@ -1,25 +1,7 @@
|
|
1 |
-
import pathlib
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
from enum import Enum
|
5 |
-
from typing import IO, cast
|
6 |
-
|
7 |
-
import aiohttp
|
8 |
import click
|
9 |
-
import filetype
|
10 |
-
import uvicorn
|
11 |
-
from asyncer import asyncify
|
12 |
-
from fastapi import Depends, FastAPI, File, Form, Query
|
13 |
-
from fastapi.middleware.cors import CORSMiddleware
|
14 |
-
from starlette.responses import Response
|
15 |
-
from tqdm import tqdm
|
16 |
-
from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
17 |
-
from watchdog.observers import Observer
|
18 |
|
19 |
from . import _version
|
20 |
-
from .
|
21 |
-
from .session_base import BaseSession
|
22 |
-
from .session_factory import new_session
|
23 |
|
24 |
|
25 |
@click.group()
|
@@ -28,413 +10,5 @@ def main() -> None:
|
|
28 |
pass
|
29 |
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
"-m",
|
34 |
-
"--model",
|
35 |
-
default="u2net",
|
36 |
-
type=click.Choice(
|
37 |
-
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
|
38 |
-
),
|
39 |
-
show_default=True,
|
40 |
-
show_choices=True,
|
41 |
-
help="model name",
|
42 |
-
)
|
43 |
-
@click.option(
|
44 |
-
"-a",
|
45 |
-
"--alpha-matting",
|
46 |
-
is_flag=True,
|
47 |
-
show_default=True,
|
48 |
-
help="use alpha matting",
|
49 |
-
)
|
50 |
-
@click.option(
|
51 |
-
"-af",
|
52 |
-
"--alpha-matting-foreground-threshold",
|
53 |
-
default=240,
|
54 |
-
type=int,
|
55 |
-
show_default=True,
|
56 |
-
help="trimap fg threshold",
|
57 |
-
)
|
58 |
-
@click.option(
|
59 |
-
"-ab",
|
60 |
-
"--alpha-matting-background-threshold",
|
61 |
-
default=10,
|
62 |
-
type=int,
|
63 |
-
show_default=True,
|
64 |
-
help="trimap bg threshold",
|
65 |
-
)
|
66 |
-
@click.option(
|
67 |
-
"-ae",
|
68 |
-
"--alpha-matting-erode-size",
|
69 |
-
default=10,
|
70 |
-
type=int,
|
71 |
-
show_default=True,
|
72 |
-
help="erode size",
|
73 |
-
)
|
74 |
-
@click.option(
|
75 |
-
"-om",
|
76 |
-
"--only-mask",
|
77 |
-
is_flag=True,
|
78 |
-
show_default=True,
|
79 |
-
help="output only the mask",
|
80 |
-
)
|
81 |
-
@click.option(
|
82 |
-
"-ppm",
|
83 |
-
"--post-process-mask",
|
84 |
-
is_flag=True,
|
85 |
-
show_default=True,
|
86 |
-
help="post process the mask",
|
87 |
-
)
|
88 |
-
@click.argument(
|
89 |
-
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
90 |
-
)
|
91 |
-
@click.argument(
|
92 |
-
"output",
|
93 |
-
default=(None if sys.stdin.isatty() else "-"),
|
94 |
-
type=click.File("wb", lazy=True),
|
95 |
-
)
|
96 |
-
def i(model: str, input: IO, output: IO, **kwargs) -> None:
|
97 |
-
output.write(remove(input.read(), session=new_session(model), **kwargs))
|
98 |
-
|
99 |
-
|
100 |
-
@main.command(help="for a folder as input")
|
101 |
-
@click.option(
|
102 |
-
"-m",
|
103 |
-
"--model",
|
104 |
-
default="u2net",
|
105 |
-
type=click.Choice(
|
106 |
-
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
|
107 |
-
),
|
108 |
-
show_default=True,
|
109 |
-
show_choices=True,
|
110 |
-
help="model name",
|
111 |
-
)
|
112 |
-
@click.option(
|
113 |
-
"-a",
|
114 |
-
"--alpha-matting",
|
115 |
-
is_flag=True,
|
116 |
-
show_default=True,
|
117 |
-
help="use alpha matting",
|
118 |
-
)
|
119 |
-
@click.option(
|
120 |
-
"-af",
|
121 |
-
"--alpha-matting-foreground-threshold",
|
122 |
-
default=240,
|
123 |
-
type=int,
|
124 |
-
show_default=True,
|
125 |
-
help="trimap fg threshold",
|
126 |
-
)
|
127 |
-
@click.option(
|
128 |
-
"-ab",
|
129 |
-
"--alpha-matting-background-threshold",
|
130 |
-
default=10,
|
131 |
-
type=int,
|
132 |
-
show_default=True,
|
133 |
-
help="trimap bg threshold",
|
134 |
-
)
|
135 |
-
@click.option(
|
136 |
-
"-ae",
|
137 |
-
"--alpha-matting-erode-size",
|
138 |
-
default=10,
|
139 |
-
type=int,
|
140 |
-
show_default=True,
|
141 |
-
help="erode size",
|
142 |
-
)
|
143 |
-
@click.option(
|
144 |
-
"-om",
|
145 |
-
"--only-mask",
|
146 |
-
is_flag=True,
|
147 |
-
show_default=True,
|
148 |
-
help="output only the mask",
|
149 |
-
)
|
150 |
-
@click.option(
|
151 |
-
"-ppm",
|
152 |
-
"--post-process-mask",
|
153 |
-
is_flag=True,
|
154 |
-
show_default=True,
|
155 |
-
help="post process the mask",
|
156 |
-
)
|
157 |
-
@click.option(
|
158 |
-
"-w",
|
159 |
-
"--watch",
|
160 |
-
default=False,
|
161 |
-
is_flag=True,
|
162 |
-
show_default=True,
|
163 |
-
help="watches a folder for changes",
|
164 |
-
)
|
165 |
-
@click.argument(
|
166 |
-
"input",
|
167 |
-
type=click.Path(
|
168 |
-
exists=True,
|
169 |
-
path_type=pathlib.Path,
|
170 |
-
file_okay=False,
|
171 |
-
dir_okay=True,
|
172 |
-
readable=True,
|
173 |
-
),
|
174 |
-
)
|
175 |
-
@click.argument(
|
176 |
-
"output",
|
177 |
-
type=click.Path(
|
178 |
-
exists=False,
|
179 |
-
path_type=pathlib.Path,
|
180 |
-
file_okay=False,
|
181 |
-
dir_okay=True,
|
182 |
-
writable=True,
|
183 |
-
),
|
184 |
-
)
|
185 |
-
def p(
|
186 |
-
model: str, input: pathlib.Path, output: pathlib.Path, watch: bool, **kwargs
|
187 |
-
) -> None:
|
188 |
-
session = new_session(model)
|
189 |
-
|
190 |
-
def process(each_input: pathlib.Path) -> None:
|
191 |
-
try:
|
192 |
-
mimetype = filetype.guess(each_input)
|
193 |
-
if mimetype is None:
|
194 |
-
return
|
195 |
-
if mimetype.mime.find("image") < 0:
|
196 |
-
return
|
197 |
-
|
198 |
-
each_output = (output / each_input.name).with_suffix(".png")
|
199 |
-
each_output.parents[0].mkdir(parents=True, exist_ok=True)
|
200 |
-
|
201 |
-
if not each_output.exists():
|
202 |
-
each_output.write_bytes(
|
203 |
-
cast(
|
204 |
-
bytes,
|
205 |
-
remove(each_input.read_bytes(), session=session, **kwargs),
|
206 |
-
)
|
207 |
-
)
|
208 |
-
|
209 |
-
if watch:
|
210 |
-
print(
|
211 |
-
f"processed: {each_input.absolute()} -> {each_output.absolute()}"
|
212 |
-
)
|
213 |
-
except Exception as e:
|
214 |
-
print(e)
|
215 |
-
|
216 |
-
inputs = list(input.glob("**/*"))
|
217 |
-
if not watch:
|
218 |
-
inputs = tqdm(inputs)
|
219 |
-
|
220 |
-
for each_input in inputs:
|
221 |
-
if not each_input.is_dir():
|
222 |
-
process(each_input)
|
223 |
-
|
224 |
-
if watch:
|
225 |
-
observer = Observer()
|
226 |
-
|
227 |
-
class EventHandler(FileSystemEventHandler):
|
228 |
-
def on_any_event(self, event: FileSystemEvent) -> None:
|
229 |
-
if not (
|
230 |
-
event.is_directory or event.event_type in ["deleted", "closed"]
|
231 |
-
):
|
232 |
-
process(pathlib.Path(event.src_path))
|
233 |
-
|
234 |
-
event_handler = EventHandler()
|
235 |
-
observer.schedule(event_handler, input, recursive=False)
|
236 |
-
observer.start()
|
237 |
-
|
238 |
-
try:
|
239 |
-
while True:
|
240 |
-
time.sleep(1)
|
241 |
-
|
242 |
-
finally:
|
243 |
-
observer.stop()
|
244 |
-
observer.join()
|
245 |
-
|
246 |
-
|
247 |
-
@main.command(help="for a http server")
|
248 |
-
@click.option(
|
249 |
-
"-p",
|
250 |
-
"--port",
|
251 |
-
default=5000,
|
252 |
-
type=int,
|
253 |
-
show_default=True,
|
254 |
-
help="port",
|
255 |
-
)
|
256 |
-
@click.option(
|
257 |
-
"-l",
|
258 |
-
"--log_level",
|
259 |
-
default="info",
|
260 |
-
type=str,
|
261 |
-
show_default=True,
|
262 |
-
help="log level",
|
263 |
-
)
|
264 |
-
@click.option(
|
265 |
-
"-t",
|
266 |
-
"--threads",
|
267 |
-
default=None,
|
268 |
-
type=int,
|
269 |
-
show_default=True,
|
270 |
-
help="number of worker threads",
|
271 |
-
)
|
272 |
-
def s(port: int, log_level: str, threads: int) -> None:
|
273 |
-
sessions: dict[str, BaseSession] = {}
|
274 |
-
tags_metadata = [
|
275 |
-
{
|
276 |
-
"name": "Background Removal",
|
277 |
-
"description": "Endpoints that perform background removal with different image sources.",
|
278 |
-
"externalDocs": {
|
279 |
-
"description": "GitHub Source",
|
280 |
-
"url": "https://github.com/danielgatis/rembg",
|
281 |
-
},
|
282 |
-
},
|
283 |
-
]
|
284 |
-
app = FastAPI(
|
285 |
-
title="Rembg",
|
286 |
-
description="Rembg is a tool to remove images background. That is it.",
|
287 |
-
version=_version.get_versions()["version"],
|
288 |
-
contact={
|
289 |
-
"name": "Daniel Gatis",
|
290 |
-
"url": "https://github.com/danielgatis",
|
291 |
-
"email": "[email protected]",
|
292 |
-
},
|
293 |
-
license_info={
|
294 |
-
"name": "MIT License",
|
295 |
-
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
|
296 |
-
},
|
297 |
-
openapi_tags=tags_metadata,
|
298 |
-
)
|
299 |
-
|
300 |
-
app.add_middleware(
|
301 |
-
CORSMiddleware,
|
302 |
-
allow_credentials=True,
|
303 |
-
allow_origins=["*"],
|
304 |
-
allow_methods=["*"],
|
305 |
-
allow_headers=["*"],
|
306 |
-
)
|
307 |
-
|
308 |
-
class ModelType(str, Enum):
|
309 |
-
u2net = "u2net"
|
310 |
-
u2netp = "u2netp"
|
311 |
-
u2net_human_seg = "u2net_human_seg"
|
312 |
-
u2net_cloth_seg = "u2net_cloth_seg"
|
313 |
-
silueta = "silueta"
|
314 |
-
|
315 |
-
class CommonQueryParams:
|
316 |
-
def __init__(
|
317 |
-
self,
|
318 |
-
model: ModelType = Query(
|
319 |
-
default=ModelType.u2net,
|
320 |
-
description="Model to use when processing image",
|
321 |
-
),
|
322 |
-
a: bool = Query(default=False, description="Enable Alpha Matting"),
|
323 |
-
af: int = Query(
|
324 |
-
default=240,
|
325 |
-
ge=0,
|
326 |
-
le=255,
|
327 |
-
description="Alpha Matting (Foreground Threshold)",
|
328 |
-
),
|
329 |
-
ab: int = Query(
|
330 |
-
default=10,
|
331 |
-
ge=0,
|
332 |
-
le=255,
|
333 |
-
description="Alpha Matting (Background Threshold)",
|
334 |
-
),
|
335 |
-
ae: int = Query(
|
336 |
-
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
337 |
-
),
|
338 |
-
om: bool = Query(default=False, description="Only Mask"),
|
339 |
-
ppm: bool = Query(default=False, description="Post Process Mask"),
|
340 |
-
):
|
341 |
-
self.model = model
|
342 |
-
self.a = a
|
343 |
-
self.af = af
|
344 |
-
self.ab = ab
|
345 |
-
self.ae = ae
|
346 |
-
self.om = om
|
347 |
-
self.ppm = ppm
|
348 |
-
|
349 |
-
class CommonQueryPostParams:
|
350 |
-
def __init__(
|
351 |
-
self,
|
352 |
-
model: ModelType = Form(
|
353 |
-
default=ModelType.u2net,
|
354 |
-
description="Model to use when processing image",
|
355 |
-
),
|
356 |
-
a: bool = Form(default=False, description="Enable Alpha Matting"),
|
357 |
-
af: int = Form(
|
358 |
-
default=240,
|
359 |
-
ge=0,
|
360 |
-
le=255,
|
361 |
-
description="Alpha Matting (Foreground Threshold)",
|
362 |
-
),
|
363 |
-
ab: int = Form(
|
364 |
-
default=10,
|
365 |
-
ge=0,
|
366 |
-
le=255,
|
367 |
-
description="Alpha Matting (Background Threshold)",
|
368 |
-
),
|
369 |
-
ae: int = Form(
|
370 |
-
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
371 |
-
),
|
372 |
-
om: bool = Form(default=False, description="Only Mask"),
|
373 |
-
ppm: bool = Form(default=False, description="Post Process Mask"),
|
374 |
-
):
|
375 |
-
self.model = model
|
376 |
-
self.a = a
|
377 |
-
self.af = af
|
378 |
-
self.ab = ab
|
379 |
-
self.ae = ae
|
380 |
-
self.om = om
|
381 |
-
self.ppm = ppm
|
382 |
-
|
383 |
-
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
384 |
-
return Response(
|
385 |
-
remove(
|
386 |
-
content,
|
387 |
-
session=sessions.setdefault(
|
388 |
-
commons.model.value, new_session(commons.model.value)
|
389 |
-
),
|
390 |
-
alpha_matting=commons.a,
|
391 |
-
alpha_matting_foreground_threshold=commons.af,
|
392 |
-
alpha_matting_background_threshold=commons.ab,
|
393 |
-
alpha_matting_erode_size=commons.ae,
|
394 |
-
only_mask=commons.om,
|
395 |
-
post_process_mask=commons.ppm,
|
396 |
-
),
|
397 |
-
media_type="image/png",
|
398 |
-
)
|
399 |
-
|
400 |
-
@app.on_event("startup")
|
401 |
-
def startup():
|
402 |
-
if threads is not None:
|
403 |
-
from anyio import CapacityLimiter
|
404 |
-
from anyio.lowlevel import RunVar
|
405 |
-
|
406 |
-
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
|
407 |
-
|
408 |
-
@app.get(
|
409 |
-
path="/",
|
410 |
-
tags=["Background Removal"],
|
411 |
-
summary="Remove from URL",
|
412 |
-
description="Removes the background from an image obtained by retrieving an URL.",
|
413 |
-
)
|
414 |
-
async def get_index(
|
415 |
-
url: str = Query(
|
416 |
-
default=..., description="URL of the image that has to be processed."
|
417 |
-
),
|
418 |
-
commons: CommonQueryParams = Depends(),
|
419 |
-
):
|
420 |
-
async with aiohttp.ClientSession() as session:
|
421 |
-
async with session.get(url) as response:
|
422 |
-
file = await response.read()
|
423 |
-
return await asyncify(im_without_bg)(file, commons)
|
424 |
-
|
425 |
-
@app.post(
|
426 |
-
path="/",
|
427 |
-
tags=["Background Removal"],
|
428 |
-
summary="Remove from Stream",
|
429 |
-
description="Removes the background from an image sent within the request itself.",
|
430 |
-
)
|
431 |
-
async def post_index(
|
432 |
-
file: bytes = File(
|
433 |
-
default=...,
|
434 |
-
description="Image file (byte stream) that has to be processed.",
|
435 |
-
),
|
436 |
-
commons: CommonQueryPostParams = Depends(),
|
437 |
-
):
|
438 |
-
return await asyncify(im_without_bg)(file, commons)
|
439 |
-
|
440 |
-
uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import click
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
from . import _version
|
4 |
+
from .commands import command_functions
|
|
|
|
|
5 |
|
6 |
|
7 |
@click.group()
|
|
|
10 |
pass
|
11 |
|
12 |
|
13 |
+
for command in command_functions:
|
14 |
+
main.add_command(command)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rembg/commands/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import import_module
|
2 |
+
from pathlib import Path
|
3 |
+
from pkgutil import iter_modules
|
4 |
+
|
5 |
+
command_functions = []
|
6 |
+
|
7 |
+
package_dir = Path(__file__).resolve().parent
|
8 |
+
for _b, module_name, _p in iter_modules([str(package_dir)]):
|
9 |
+
module = import_module(f"{__name__}.{module_name}")
|
10 |
+
for attribute_name in dir(module):
|
11 |
+
attribute = getattr(module, attribute_name)
|
12 |
+
if attribute_name.endswith("_command"):
|
13 |
+
command_functions.append(attribute)
|
rembg/commands/i_command.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import sys
|
3 |
+
from typing import IO
|
4 |
+
|
5 |
+
import click
|
6 |
+
|
7 |
+
from ..bg import remove
|
8 |
+
from ..session_factory import new_session
|
9 |
+
from ..sessions import sessions_names
|
10 |
+
|
11 |
+
|
12 |
+
@click.command(
|
13 |
+
name="i",
|
14 |
+
help="for a file as input",
|
15 |
+
)
|
16 |
+
@click.option(
|
17 |
+
"-m",
|
18 |
+
"--model",
|
19 |
+
default="u2net",
|
20 |
+
type=click.Choice(sessions_names),
|
21 |
+
show_default=True,
|
22 |
+
show_choices=True,
|
23 |
+
help="model name",
|
24 |
+
)
|
25 |
+
@click.option(
|
26 |
+
"-a",
|
27 |
+
"--alpha-matting",
|
28 |
+
is_flag=True,
|
29 |
+
show_default=True,
|
30 |
+
help="use alpha matting",
|
31 |
+
)
|
32 |
+
@click.option(
|
33 |
+
"-af",
|
34 |
+
"--alpha-matting-foreground-threshold",
|
35 |
+
default=240,
|
36 |
+
type=int,
|
37 |
+
show_default=True,
|
38 |
+
help="trimap fg threshold",
|
39 |
+
)
|
40 |
+
@click.option(
|
41 |
+
"-ab",
|
42 |
+
"--alpha-matting-background-threshold",
|
43 |
+
default=10,
|
44 |
+
type=int,
|
45 |
+
show_default=True,
|
46 |
+
help="trimap bg threshold",
|
47 |
+
)
|
48 |
+
@click.option(
|
49 |
+
"-ae",
|
50 |
+
"--alpha-matting-erode-size",
|
51 |
+
default=10,
|
52 |
+
type=int,
|
53 |
+
show_default=True,
|
54 |
+
help="erode size",
|
55 |
+
)
|
56 |
+
@click.option(
|
57 |
+
"-om",
|
58 |
+
"--only-mask",
|
59 |
+
is_flag=True,
|
60 |
+
show_default=True,
|
61 |
+
help="output only the mask",
|
62 |
+
)
|
63 |
+
@click.option(
|
64 |
+
"-ppm",
|
65 |
+
"--post-process-mask",
|
66 |
+
is_flag=True,
|
67 |
+
show_default=True,
|
68 |
+
help="post process the mask",
|
69 |
+
)
|
70 |
+
@click.option(
|
71 |
+
"-bgc",
|
72 |
+
"--bgcolor",
|
73 |
+
default=None,
|
74 |
+
type=(int, int, int, int),
|
75 |
+
nargs=4,
|
76 |
+
help="Background color (R G B A) to replace the removed background with",
|
77 |
+
)
|
78 |
+
@click.option("-x", "--extras", type=str)
|
79 |
+
@click.argument(
|
80 |
+
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
81 |
+
)
|
82 |
+
@click.argument(
|
83 |
+
"output",
|
84 |
+
default=(None if sys.stdin.isatty() else "-"),
|
85 |
+
type=click.File("wb", lazy=True),
|
86 |
+
)
|
87 |
+
def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
|
88 |
+
try:
|
89 |
+
kwargs.update(json.loads(extras))
|
90 |
+
except Exception:
|
91 |
+
pass
|
92 |
+
|
93 |
+
output.write(remove(input.read(), session=new_session(model), **kwargs))
|
rembg/commands/p_command.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import pathlib
|
3 |
+
import time
|
4 |
+
from typing import cast
|
5 |
+
|
6 |
+
import click
|
7 |
+
import filetype
|
8 |
+
from tqdm import tqdm
|
9 |
+
from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
10 |
+
from watchdog.observers import Observer
|
11 |
+
|
12 |
+
from ..bg import remove
|
13 |
+
from ..session_factory import new_session
|
14 |
+
from ..sessions import sessions_names
|
15 |
+
|
16 |
+
|
17 |
+
@click.command(
|
18 |
+
name="p",
|
19 |
+
help="for a folder as input",
|
20 |
+
)
|
21 |
+
@click.option(
|
22 |
+
"-m",
|
23 |
+
"--model",
|
24 |
+
default="u2net",
|
25 |
+
type=click.Choice(sessions_names),
|
26 |
+
show_default=True,
|
27 |
+
show_choices=True,
|
28 |
+
help="model name",
|
29 |
+
)
|
30 |
+
@click.option(
|
31 |
+
"-a",
|
32 |
+
"--alpha-matting",
|
33 |
+
is_flag=True,
|
34 |
+
show_default=True,
|
35 |
+
help="use alpha matting",
|
36 |
+
)
|
37 |
+
@click.option(
|
38 |
+
"-af",
|
39 |
+
"--alpha-matting-foreground-threshold",
|
40 |
+
default=240,
|
41 |
+
type=int,
|
42 |
+
show_default=True,
|
43 |
+
help="trimap fg threshold",
|
44 |
+
)
|
45 |
+
@click.option(
|
46 |
+
"-ab",
|
47 |
+
"--alpha-matting-background-threshold",
|
48 |
+
default=10,
|
49 |
+
type=int,
|
50 |
+
show_default=True,
|
51 |
+
help="trimap bg threshold",
|
52 |
+
)
|
53 |
+
@click.option(
|
54 |
+
"-ae",
|
55 |
+
"--alpha-matting-erode-size",
|
56 |
+
default=10,
|
57 |
+
type=int,
|
58 |
+
show_default=True,
|
59 |
+
help="erode size",
|
60 |
+
)
|
61 |
+
@click.option(
|
62 |
+
"-om",
|
63 |
+
"--only-mask",
|
64 |
+
is_flag=True,
|
65 |
+
show_default=True,
|
66 |
+
help="output only the mask",
|
67 |
+
)
|
68 |
+
@click.option(
|
69 |
+
"-ppm",
|
70 |
+
"--post-process-mask",
|
71 |
+
is_flag=True,
|
72 |
+
show_default=True,
|
73 |
+
help="post process the mask",
|
74 |
+
)
|
75 |
+
@click.option(
|
76 |
+
"-w",
|
77 |
+
"--watch",
|
78 |
+
default=False,
|
79 |
+
is_flag=True,
|
80 |
+
show_default=True,
|
81 |
+
help="watches a folder for changes",
|
82 |
+
)
|
83 |
+
@click.option(
|
84 |
+
"-bgc",
|
85 |
+
"--bgcolor",
|
86 |
+
default=None,
|
87 |
+
type=(int, int, int, int),
|
88 |
+
nargs=4,
|
89 |
+
help="Background color (R G B A) to replace the removed background with",
|
90 |
+
)
|
91 |
+
@click.option("-x", "--extras", type=str)
|
92 |
+
@click.argument(
|
93 |
+
"input",
|
94 |
+
type=click.Path(
|
95 |
+
exists=True,
|
96 |
+
path_type=pathlib.Path,
|
97 |
+
file_okay=False,
|
98 |
+
dir_okay=True,
|
99 |
+
readable=True,
|
100 |
+
),
|
101 |
+
)
|
102 |
+
@click.argument(
|
103 |
+
"output",
|
104 |
+
type=click.Path(
|
105 |
+
exists=False,
|
106 |
+
path_type=pathlib.Path,
|
107 |
+
file_okay=False,
|
108 |
+
dir_okay=True,
|
109 |
+
writable=True,
|
110 |
+
),
|
111 |
+
)
|
112 |
+
def p_command(
|
113 |
+
model: str,
|
114 |
+
extras: str,
|
115 |
+
input: pathlib.Path,
|
116 |
+
output: pathlib.Path,
|
117 |
+
watch: bool,
|
118 |
+
**kwargs,
|
119 |
+
) -> None:
|
120 |
+
try:
|
121 |
+
kwargs.update(json.loads(extras))
|
122 |
+
except Exception:
|
123 |
+
pass
|
124 |
+
|
125 |
+
session = new_session(model)
|
126 |
+
|
127 |
+
def process(each_input: pathlib.Path) -> None:
|
128 |
+
try:
|
129 |
+
mimetype = filetype.guess(each_input)
|
130 |
+
if mimetype is None:
|
131 |
+
return
|
132 |
+
if mimetype.mime.find("image") < 0:
|
133 |
+
return
|
134 |
+
|
135 |
+
each_output = (output / each_input.name).with_suffix(".png")
|
136 |
+
each_output.parents[0].mkdir(parents=True, exist_ok=True)
|
137 |
+
|
138 |
+
if not each_output.exists():
|
139 |
+
each_output.write_bytes(
|
140 |
+
cast(
|
141 |
+
bytes,
|
142 |
+
remove(each_input.read_bytes(), session=session, **kwargs),
|
143 |
+
)
|
144 |
+
)
|
145 |
+
|
146 |
+
if watch:
|
147 |
+
print(
|
148 |
+
f"processed: {each_input.absolute()} -> {each_output.absolute()}"
|
149 |
+
)
|
150 |
+
except Exception as e:
|
151 |
+
print(e)
|
152 |
+
|
153 |
+
inputs = list(input.glob("**/*"))
|
154 |
+
if not watch:
|
155 |
+
inputs = tqdm(inputs)
|
156 |
+
|
157 |
+
for each_input in inputs:
|
158 |
+
if not each_input.is_dir():
|
159 |
+
process(each_input)
|
160 |
+
|
161 |
+
if watch:
|
162 |
+
observer = Observer()
|
163 |
+
|
164 |
+
class EventHandler(FileSystemEventHandler):
|
165 |
+
def on_any_event(self, event: FileSystemEvent) -> None:
|
166 |
+
if not (
|
167 |
+
event.is_directory or event.event_type in ["deleted", "closed"]
|
168 |
+
):
|
169 |
+
process(pathlib.Path(event.src_path))
|
170 |
+
|
171 |
+
event_handler = EventHandler()
|
172 |
+
observer.schedule(event_handler, input, recursive=False)
|
173 |
+
observer.start()
|
174 |
+
|
175 |
+
try:
|
176 |
+
while True:
|
177 |
+
time.sleep(1)
|
178 |
+
|
179 |
+
finally:
|
180 |
+
observer.stop()
|
181 |
+
observer.join()
|
rembg/commands/s_command.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Annotated, Optional, Tuple, cast
|
3 |
+
|
4 |
+
import aiohttp
|
5 |
+
import click
|
6 |
+
import uvicorn
|
7 |
+
from asyncer import asyncify
|
8 |
+
from fastapi import Depends, FastAPI, File, Form, Query
|
9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
10 |
+
from starlette.responses import Response
|
11 |
+
|
12 |
+
from .._version import get_versions
|
13 |
+
from ..bg import remove
|
14 |
+
from ..session_factory import new_session
|
15 |
+
from ..sessions import sessions_names
|
16 |
+
from ..sessions.base import BaseSession
|
17 |
+
|
18 |
+
|
19 |
+
@click.command(
|
20 |
+
name="s",
|
21 |
+
help="for a http server",
|
22 |
+
)
|
23 |
+
@click.option(
|
24 |
+
"-p",
|
25 |
+
"--port",
|
26 |
+
default=5000,
|
27 |
+
type=int,
|
28 |
+
show_default=True,
|
29 |
+
help="port",
|
30 |
+
)
|
31 |
+
@click.option(
|
32 |
+
"-l",
|
33 |
+
"--log_level",
|
34 |
+
default="info",
|
35 |
+
type=str,
|
36 |
+
show_default=True,
|
37 |
+
help="log level",
|
38 |
+
)
|
39 |
+
@click.option(
|
40 |
+
"-t",
|
41 |
+
"--threads",
|
42 |
+
default=None,
|
43 |
+
type=int,
|
44 |
+
show_default=True,
|
45 |
+
help="number of worker threads",
|
46 |
+
)
|
47 |
+
def s_command(port: int, log_level: str, threads: int) -> None:
|
48 |
+
sessions: dict[str, BaseSession] = {}
|
49 |
+
tags_metadata = [
|
50 |
+
{
|
51 |
+
"name": "Background Removal",
|
52 |
+
"description": "Endpoints that perform background removal with different image sources.",
|
53 |
+
"externalDocs": {
|
54 |
+
"description": "GitHub Source",
|
55 |
+
"url": "https://github.com/danielgatis/rembg",
|
56 |
+
},
|
57 |
+
},
|
58 |
+
]
|
59 |
+
app = FastAPI(
|
60 |
+
title="Rembg",
|
61 |
+
description="Rembg is a tool to remove images background. That is it.",
|
62 |
+
version=get_versions()["version"],
|
63 |
+
contact={
|
64 |
+
"name": "Daniel Gatis",
|
65 |
+
"url": "https://github.com/danielgatis",
|
66 |
+
"email": "[email protected]",
|
67 |
+
},
|
68 |
+
license_info={
|
69 |
+
"name": "MIT License",
|
70 |
+
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
|
71 |
+
},
|
72 |
+
openapi_tags=tags_metadata,
|
73 |
+
)
|
74 |
+
|
75 |
+
app.add_middleware(
|
76 |
+
CORSMiddleware,
|
77 |
+
allow_credentials=True,
|
78 |
+
allow_origins=["*"],
|
79 |
+
allow_methods=["*"],
|
80 |
+
allow_headers=["*"],
|
81 |
+
)
|
82 |
+
|
83 |
+
class CommonQueryParams:
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
model: Annotated[
|
87 |
+
str, Query(regex=r"(" + "|".join(sessions_names) + ")")
|
88 |
+
] = Query(
|
89 |
+
description="Model to use when processing image",
|
90 |
+
),
|
91 |
+
a: bool = Query(default=False, description="Enable Alpha Matting"),
|
92 |
+
af: int = Query(
|
93 |
+
default=240,
|
94 |
+
ge=0,
|
95 |
+
le=255,
|
96 |
+
description="Alpha Matting (Foreground Threshold)",
|
97 |
+
),
|
98 |
+
ab: int = Query(
|
99 |
+
default=10,
|
100 |
+
ge=0,
|
101 |
+
le=255,
|
102 |
+
description="Alpha Matting (Background Threshold)",
|
103 |
+
),
|
104 |
+
ae: int = Query(
|
105 |
+
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
106 |
+
),
|
107 |
+
om: bool = Query(default=False, description="Only Mask"),
|
108 |
+
ppm: bool = Query(default=False, description="Post Process Mask"),
|
109 |
+
bgc: Optional[str] = Query(default=None, description="Background Color"),
|
110 |
+
extras: Optional[str] = Query(
|
111 |
+
default=None, description="Extra parameters as JSON"
|
112 |
+
),
|
113 |
+
):
|
114 |
+
self.model = model
|
115 |
+
self.a = a
|
116 |
+
self.af = af
|
117 |
+
self.ab = ab
|
118 |
+
self.ae = ae
|
119 |
+
self.om = om
|
120 |
+
self.ppm = ppm
|
121 |
+
self.extras = extras
|
122 |
+
self.bgc = (
|
123 |
+
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
|
124 |
+
if bgc
|
125 |
+
else None
|
126 |
+
)
|
127 |
+
|
128 |
+
class CommonQueryPostParams:
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
model: Annotated[
|
132 |
+
str, Form(regex=r"(" + "|".join(sessions_names) + ")")
|
133 |
+
] = Form(
|
134 |
+
description="Model to use when processing image",
|
135 |
+
),
|
136 |
+
a: bool = Form(default=False, description="Enable Alpha Matting"),
|
137 |
+
af: int = Form(
|
138 |
+
default=240,
|
139 |
+
ge=0,
|
140 |
+
le=255,
|
141 |
+
description="Alpha Matting (Foreground Threshold)",
|
142 |
+
),
|
143 |
+
ab: int = Form(
|
144 |
+
default=10,
|
145 |
+
ge=0,
|
146 |
+
le=255,
|
147 |
+
description="Alpha Matting (Background Threshold)",
|
148 |
+
),
|
149 |
+
ae: int = Form(
|
150 |
+
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
151 |
+
),
|
152 |
+
om: bool = Form(default=False, description="Only Mask"),
|
153 |
+
ppm: bool = Form(default=False, description="Post Process Mask"),
|
154 |
+
bgc: Optional[str] = Query(default=None, description="Background Color"),
|
155 |
+
extras: Optional[str] = Query(
|
156 |
+
default=None, description="Extra parameters as JSON"
|
157 |
+
),
|
158 |
+
):
|
159 |
+
self.model = model
|
160 |
+
self.a = a
|
161 |
+
self.af = af
|
162 |
+
self.ab = ab
|
163 |
+
self.ae = ae
|
164 |
+
self.om = om
|
165 |
+
self.ppm = ppm
|
166 |
+
self.extras = extras
|
167 |
+
self.bgc = (
|
168 |
+
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
|
169 |
+
if bgc
|
170 |
+
else None
|
171 |
+
)
|
172 |
+
|
173 |
+
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
174 |
+
kwargs = {}
|
175 |
+
|
176 |
+
if commons.extras:
|
177 |
+
try:
|
178 |
+
kwargs.update(json.loads(commons.extras))
|
179 |
+
except Exception:
|
180 |
+
pass
|
181 |
+
|
182 |
+
return Response(
|
183 |
+
remove(
|
184 |
+
content,
|
185 |
+
session=sessions.setdefault(commons.model, new_session(commons.model)),
|
186 |
+
alpha_matting=commons.a,
|
187 |
+
alpha_matting_foreground_threshold=commons.af,
|
188 |
+
alpha_matting_background_threshold=commons.ab,
|
189 |
+
alpha_matting_erode_size=commons.ae,
|
190 |
+
only_mask=commons.om,
|
191 |
+
post_process_mask=commons.ppm,
|
192 |
+
bgcolor=commons.bgc,
|
193 |
+
**kwargs
|
194 |
+
),
|
195 |
+
media_type="image/png",
|
196 |
+
)
|
197 |
+
|
198 |
+
@app.on_event("startup")
|
199 |
+
def startup():
|
200 |
+
if threads is not None:
|
201 |
+
from anyio import CapacityLimiter
|
202 |
+
from anyio.lowlevel import RunVar
|
203 |
+
|
204 |
+
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
|
205 |
+
|
206 |
+
@app.get(
|
207 |
+
path="/",
|
208 |
+
tags=["Background Removal"],
|
209 |
+
summary="Remove from URL",
|
210 |
+
description="Removes the background from an image obtained by retrieving an URL.",
|
211 |
+
)
|
212 |
+
async def get_index(
|
213 |
+
url: str = Query(
|
214 |
+
default=..., description="URL of the image that has to be processed."
|
215 |
+
),
|
216 |
+
commons: CommonQueryParams = Depends(),
|
217 |
+
):
|
218 |
+
async with aiohttp.ClientSession() as session:
|
219 |
+
async with session.get(url) as response:
|
220 |
+
file = await response.read()
|
221 |
+
return await asyncify(im_without_bg)(file, commons)
|
222 |
+
|
223 |
+
@app.post(
|
224 |
+
path="/",
|
225 |
+
tags=["Background Removal"],
|
226 |
+
summary="Remove from Stream",
|
227 |
+
description="Removes the background from an image sent within the request itself.",
|
228 |
+
)
|
229 |
+
async def post_index(
|
230 |
+
file: bytes = File(
|
231 |
+
default=...,
|
232 |
+
description="Image file (byte stream) that has to be processed.",
|
233 |
+
),
|
234 |
+
commons: CommonQueryPostParams = Depends(),
|
235 |
+
):
|
236 |
+
return await asyncify(im_without_bg)(file, commons) # type: ignore
|
237 |
+
|
238 |
+
uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
|
rembg/session_factory.py
CHANGED
@@ -1,71 +1,24 @@
|
|
1 |
-
import hashlib
|
2 |
import os
|
3 |
-
import sys
|
4 |
-
from contextlib import redirect_stdout
|
5 |
-
from pathlib import Path
|
6 |
from typing import Type
|
7 |
|
8 |
import onnxruntime as ort
|
9 |
-
import pooch
|
10 |
|
11 |
-
from .
|
12 |
-
from .
|
13 |
-
from .
|
14 |
|
15 |
|
16 |
-
def new_session(model_name: str = "u2net") -> BaseSession:
|
17 |
-
session_class: Type[BaseSession]
|
18 |
-
md5 = "60024c5c889badc19c04ad937298a77b"
|
19 |
-
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
|
20 |
-
session_class = SimpleSession
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
)
|
27 |
-
session_class = SimpleSession
|
28 |
-
elif model_name == "u2net_human_seg":
|
29 |
-
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
|
30 |
-
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx"
|
31 |
-
session_class = SimpleSession
|
32 |
-
elif model_name == "u2net_cloth_seg":
|
33 |
-
md5 = "2434d1f3cb744e0e49386c906e5a08bb"
|
34 |
-
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx"
|
35 |
-
session_class = ClothSession
|
36 |
-
elif model_name == "silueta":
|
37 |
-
md5 = "55e59e0d8062d2f5d013f4725ee84782"
|
38 |
-
url = (
|
39 |
-
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx"
|
40 |
-
)
|
41 |
-
session_class = SimpleSession
|
42 |
-
|
43 |
-
u2net_home = os.getenv(
|
44 |
-
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
|
45 |
-
)
|
46 |
-
|
47 |
-
fname = f"{model_name}.onnx"
|
48 |
-
path = Path(u2net_home).expanduser()
|
49 |
-
full_path = Path(u2net_home).expanduser() / fname
|
50 |
-
|
51 |
-
pooch.retrieve(
|
52 |
-
url,
|
53 |
-
f"md5:{md5}",
|
54 |
-
fname=fname,
|
55 |
-
path=Path(u2net_home).expanduser(),
|
56 |
-
progressbar=True,
|
57 |
-
)
|
58 |
|
59 |
sess_opts = ort.SessionOptions()
|
60 |
|
61 |
if "OMP_NUM_THREADS" in os.environ:
|
62 |
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
63 |
|
64 |
-
return session_class(
|
65 |
-
model_name,
|
66 |
-
ort.InferenceSession(
|
67 |
-
str(full_path),
|
68 |
-
providers=ort.get_available_providers(),
|
69 |
-
sess_options=sess_opts,
|
70 |
-
),
|
71 |
-
)
|
|
|
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
from typing import Type
|
3 |
|
4 |
import onnxruntime as ort
|
|
|
5 |
|
6 |
+
from .sessions import sessions_class
|
7 |
+
from .sessions.base import BaseSession
|
8 |
+
from .sessions.u2net import U2netSession
|
9 |
|
10 |
|
11 |
+
def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
|
12 |
+
session_class: Type[BaseSession] = U2netSession
|
|
|
|
|
|
|
13 |
|
14 |
+
for sc in sessions_class:
|
15 |
+
if sc.name() == model_name:
|
16 |
+
session_class = sc
|
17 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
sess_opts = ort.SessionOptions()
|
20 |
|
21 |
if "OMP_NUM_THREADS" in os.environ:
|
22 |
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
23 |
|
24 |
+
return session_class(model_name, sess_opts, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rembg/sessions/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import import_module
|
2 |
+
from inspect import isclass
|
3 |
+
from pathlib import Path
|
4 |
+
from pkgutil import iter_modules
|
5 |
+
|
6 |
+
from .base import BaseSession
|
7 |
+
|
8 |
+
sessions_class = []
|
9 |
+
sessions_names = []
|
10 |
+
|
11 |
+
package_dir = Path(__file__).resolve().parent
|
12 |
+
for _b, module_name, _p in iter_modules([str(package_dir)]):
|
13 |
+
module = import_module(f"{__name__}.{module_name}")
|
14 |
+
for attribute_name in dir(module):
|
15 |
+
attribute = getattr(module, attribute_name)
|
16 |
+
if (
|
17 |
+
isclass(attribute)
|
18 |
+
and issubclass(attribute, BaseSession)
|
19 |
+
and attribute != BaseSession
|
20 |
+
):
|
21 |
+
sessions_class.append(attribute)
|
22 |
+
sessions_names.append(attribute.name())
|
rembg/sessions/base.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import onnxruntime as ort
|
6 |
+
from PIL import Image
|
7 |
+
from PIL.Image import Image as PILImage
|
8 |
+
|
9 |
+
|
10 |
+
class BaseSession:
|
11 |
+
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
12 |
+
self.model_name = model_name
|
13 |
+
self.inner_session = ort.InferenceSession(
|
14 |
+
str(self.__class__.download_models()),
|
15 |
+
providers=ort.get_available_providers(),
|
16 |
+
sess_options=sess_opts,
|
17 |
+
)
|
18 |
+
|
19 |
+
def normalize(
|
20 |
+
self,
|
21 |
+
img: PILImage,
|
22 |
+
mean: Tuple[float, float, float],
|
23 |
+
std: Tuple[float, float, float],
|
24 |
+
size: Tuple[int, int],
|
25 |
+
*args,
|
26 |
+
**kwargs
|
27 |
+
) -> Dict[str, np.ndarray]:
|
28 |
+
im = img.convert("RGB").resize(size, Image.LANCZOS)
|
29 |
+
|
30 |
+
im_ary = np.array(im)
|
31 |
+
im_ary = im_ary / np.max(im_ary)
|
32 |
+
|
33 |
+
tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
|
34 |
+
tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
|
35 |
+
tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
|
36 |
+
tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
|
37 |
+
|
38 |
+
tmpImg = tmpImg.transpose((2, 0, 1))
|
39 |
+
|
40 |
+
return {
|
41 |
+
self.inner_session.get_inputs()[0]
|
42 |
+
.name: np.expand_dims(tmpImg, 0)
|
43 |
+
.astype(np.float32)
|
44 |
+
}
|
45 |
+
|
46 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
47 |
+
raise NotImplementedError
|
48 |
+
|
49 |
+
@classmethod
|
50 |
+
def u2net_home(cls, *args, **kwargs):
|
51 |
+
return os.path.expanduser(
|
52 |
+
os.getenv(
|
53 |
+
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
|
54 |
+
)
|
55 |
+
)
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def download_models(cls, *args, **kwargs):
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def name(cls, *args, **kwargs):
|
63 |
+
raise NotImplementedError
|
rembg/sessions/dis.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pooch
|
6 |
+
from PIL import Image
|
7 |
+
from PIL.Image import Image as PILImage
|
8 |
+
|
9 |
+
from .base import BaseSession
|
10 |
+
|
11 |
+
|
12 |
+
class DisSession(BaseSession):
|
13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
14 |
+
ort_outs = self.inner_session.run(
|
15 |
+
None,
|
16 |
+
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
17 |
+
)
|
18 |
+
|
19 |
+
pred = ort_outs[0][:, 0, :, :]
|
20 |
+
|
21 |
+
ma = np.max(pred)
|
22 |
+
mi = np.min(pred)
|
23 |
+
|
24 |
+
pred = (pred - mi) / (ma - mi)
|
25 |
+
pred = np.squeeze(pred)
|
26 |
+
|
27 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
28 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
29 |
+
|
30 |
+
return [mask]
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def download_models(cls, *args, **kwargs):
|
34 |
+
fname = f"{cls.name()}.onnx"
|
35 |
+
pooch.retrieve(
|
36 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
|
37 |
+
"md5:fc16ebd8b0c10d971d3513d564d01e29",
|
38 |
+
fname=fname,
|
39 |
+
path=cls.u2net_home(),
|
40 |
+
progressbar=True,
|
41 |
+
)
|
42 |
+
|
43 |
+
return os.path.join(cls.u2net_home(), fname)
|
44 |
+
|
45 |
+
@classmethod
|
46 |
+
def name(cls, *args, **kwargs):
|
47 |
+
return "isnet-general-use"
|
rembg/sessions/sam.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import onnxruntime as ort
|
6 |
+
import pooch
|
7 |
+
from PIL import Image
|
8 |
+
from PIL.Image import Image as PILImage
|
9 |
+
|
10 |
+
from .base import BaseSession
|
11 |
+
|
12 |
+
|
13 |
+
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
|
14 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
15 |
+
newh, neww = oldh * scale, oldw * scale
|
16 |
+
neww = int(neww + 0.5)
|
17 |
+
newh = int(newh + 0.5)
|
18 |
+
return (newh, neww)
|
19 |
+
|
20 |
+
|
21 |
+
def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
|
22 |
+
old_h, old_w = original_size
|
23 |
+
new_h, new_w = get_preprocess_shape(
|
24 |
+
original_size[0], original_size[1], target_length
|
25 |
+
)
|
26 |
+
coords = coords.copy().astype(float)
|
27 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
28 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
29 |
+
return coords
|
30 |
+
|
31 |
+
|
32 |
+
def resize_longes_side(img: PILImage, size=1024):
|
33 |
+
w, h = img.size
|
34 |
+
if h > w:
|
35 |
+
new_h, new_w = size, int(w * size / h)
|
36 |
+
else:
|
37 |
+
new_h, new_w = int(h * size / w), size
|
38 |
+
|
39 |
+
return img.resize((new_w, new_h))
|
40 |
+
|
41 |
+
|
42 |
+
def pad_to_square(img: np.ndarray, size=1024):
|
43 |
+
h, w = img.shape[:2]
|
44 |
+
padh = size - h
|
45 |
+
padw = size - w
|
46 |
+
img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
|
47 |
+
img = img.astype(np.float32)
|
48 |
+
return img
|
49 |
+
|
50 |
+
|
51 |
+
class SamSession(BaseSession):
|
52 |
+
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
53 |
+
self.model_name = model_name
|
54 |
+
paths = self.__class__.download_models()
|
55 |
+
self.encoder = ort.InferenceSession(
|
56 |
+
str(paths[0]),
|
57 |
+
providers=ort.get_available_providers(),
|
58 |
+
sess_options=sess_opts,
|
59 |
+
)
|
60 |
+
self.decoder = ort.InferenceSession(
|
61 |
+
str(paths[1]),
|
62 |
+
providers=ort.get_available_providers(),
|
63 |
+
sess_options=sess_opts,
|
64 |
+
)
|
65 |
+
|
66 |
+
def normalize(
|
67 |
+
self,
|
68 |
+
img: np.ndarray,
|
69 |
+
mean=(123.675, 116.28, 103.53),
|
70 |
+
std=(58.395, 57.12, 57.375),
|
71 |
+
size=(1024, 1024),
|
72 |
+
*args,
|
73 |
+
**kwargs,
|
74 |
+
):
|
75 |
+
pixel_mean = np.array([*mean]).reshape(1, 1, -1)
|
76 |
+
pixel_std = np.array([*std]).reshape(1, 1, -1)
|
77 |
+
x = (img - pixel_mean) / pixel_std
|
78 |
+
return x
|
79 |
+
|
80 |
+
def predict(
|
81 |
+
self,
|
82 |
+
img: PILImage,
|
83 |
+
*args,
|
84 |
+
**kwargs,
|
85 |
+
) -> List[PILImage]:
|
86 |
+
# Preprocess image
|
87 |
+
image = resize_longes_side(img)
|
88 |
+
image = np.array(image)
|
89 |
+
image = self.normalize(image)
|
90 |
+
image = pad_to_square(image)
|
91 |
+
|
92 |
+
input_labels = kwargs.get("input_labels")
|
93 |
+
input_points = kwargs.get("input_points")
|
94 |
+
|
95 |
+
if input_labels is None:
|
96 |
+
raise ValueError("input_labels is required")
|
97 |
+
if input_points is None:
|
98 |
+
raise ValueError("input_points is required")
|
99 |
+
|
100 |
+
# Transpose
|
101 |
+
image = image.transpose(2, 0, 1)[None, :, :, :]
|
102 |
+
# Run encoder (Image embedding)
|
103 |
+
encoded = self.encoder.run(None, {"x": image})
|
104 |
+
image_embedding = encoded[0]
|
105 |
+
|
106 |
+
# Add a batch index, concatenate a padding point, and transform.
|
107 |
+
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
|
108 |
+
None, :, :
|
109 |
+
]
|
110 |
+
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
|
111 |
+
None, :
|
112 |
+
].astype(np.float32)
|
113 |
+
onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
|
114 |
+
|
115 |
+
# Create an empty mask input and an indicator for no mask.
|
116 |
+
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
|
117 |
+
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
|
118 |
+
|
119 |
+
decoder_inputs = {
|
120 |
+
"image_embeddings": image_embedding,
|
121 |
+
"point_coords": onnx_coord,
|
122 |
+
"point_labels": onnx_label,
|
123 |
+
"mask_input": onnx_mask_input,
|
124 |
+
"has_mask_input": onnx_has_mask_input,
|
125 |
+
"orig_im_size": np.array(img.size[::-1], dtype=np.float32),
|
126 |
+
}
|
127 |
+
|
128 |
+
masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
|
129 |
+
masks = masks > 0.0
|
130 |
+
masks = [
|
131 |
+
Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
|
132 |
+
for i in range(masks.shape[0])
|
133 |
+
]
|
134 |
+
|
135 |
+
return masks
|
136 |
+
|
137 |
+
@classmethod
|
138 |
+
def download_models(cls, *args, **kwargs):
|
139 |
+
fname_encoder = f"{cls.name()}_encoder.onnx"
|
140 |
+
fname_decoder = f"{cls.name()}_decoder.onnx"
|
141 |
+
|
142 |
+
pooch.retrieve(
|
143 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
|
144 |
+
"md5:13d97c5c79ab13ef86d67cbde5f1b250",
|
145 |
+
fname=fname_encoder,
|
146 |
+
path=cls.u2net_home(),
|
147 |
+
progressbar=True,
|
148 |
+
)
|
149 |
+
|
150 |
+
pooch.retrieve(
|
151 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
|
152 |
+
"md5:fa3d1c36a3187d3de1c8deebf33dd127",
|
153 |
+
fname=fname_decoder,
|
154 |
+
path=cls.u2net_home(),
|
155 |
+
progressbar=True,
|
156 |
+
)
|
157 |
+
|
158 |
+
return (
|
159 |
+
os.path.join(cls.u2net_home(), fname_encoder),
|
160 |
+
os.path.join(cls.u2net_home(), fname_decoder),
|
161 |
+
)
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def name(cls, *args, **kwargs):
|
165 |
+
return "sam"
|
rembg/sessions/silueta.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pooch
|
6 |
+
from PIL import Image
|
7 |
+
from PIL.Image import Image as PILImage
|
8 |
+
|
9 |
+
from .base import BaseSession
|
10 |
+
|
11 |
+
|
12 |
+
class SiluetaSession(BaseSession):
|
13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
14 |
+
ort_outs = self.inner_session.run(
|
15 |
+
None,
|
16 |
+
self.normalize(
|
17 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
|
18 |
+
),
|
19 |
+
)
|
20 |
+
|
21 |
+
pred = ort_outs[0][:, 0, :, :]
|
22 |
+
|
23 |
+
ma = np.max(pred)
|
24 |
+
mi = np.min(pred)
|
25 |
+
|
26 |
+
pred = (pred - mi) / (ma - mi)
|
27 |
+
pred = np.squeeze(pred)
|
28 |
+
|
29 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
30 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
31 |
+
|
32 |
+
return [mask]
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def download_models(cls, *args, **kwargs):
|
36 |
+
fname = f"{cls.name()}.onnx"
|
37 |
+
pooch.retrieve(
|
38 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
|
39 |
+
"md5:55e59e0d8062d2f5d013f4725ee84782",
|
40 |
+
fname=fname,
|
41 |
+
path=cls.u2net_home(),
|
42 |
+
progressbar=True,
|
43 |
+
)
|
44 |
+
|
45 |
+
return os.path.join(cls.u2net_home(), fname)
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def name(cls, *args, **kwargs):
|
49 |
+
return "silueta"
|
rembg/sessions/u2net.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pooch
|
6 |
+
from PIL import Image
|
7 |
+
from PIL.Image import Image as PILImage
|
8 |
+
|
9 |
+
from .base import BaseSession
|
10 |
+
|
11 |
+
|
12 |
+
class U2netSession(BaseSession):
|
13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
14 |
+
ort_outs = self.inner_session.run(
|
15 |
+
None,
|
16 |
+
self.normalize(
|
17 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
|
18 |
+
),
|
19 |
+
)
|
20 |
+
|
21 |
+
pred = ort_outs[0][:, 0, :, :]
|
22 |
+
|
23 |
+
ma = np.max(pred)
|
24 |
+
mi = np.min(pred)
|
25 |
+
|
26 |
+
pred = (pred - mi) / (ma - mi)
|
27 |
+
pred = np.squeeze(pred)
|
28 |
+
|
29 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
30 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
31 |
+
|
32 |
+
return [mask]
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def download_models(cls, *args, **kwargs):
|
36 |
+
fname = f"{cls.name()}.onnx"
|
37 |
+
pooch.retrieve(
|
38 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
39 |
+
"md5:60024c5c889badc19c04ad937298a77b",
|
40 |
+
fname=fname,
|
41 |
+
path=cls.u2net_home(),
|
42 |
+
progressbar=True,
|
43 |
+
)
|
44 |
+
|
45 |
+
return os.path.join(cls.u2net_home(), fname)
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def name(cls, *args, **kwargs):
|
49 |
+
return "u2net"
|
rembg/sessions/u2net_cloth_seg.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pooch
|
6 |
+
from PIL import Image
|
7 |
+
from PIL.Image import Image as PILImage
|
8 |
+
from scipy.special import log_softmax
|
9 |
+
|
10 |
+
from .base import BaseSession
|
11 |
+
|
12 |
+
pallete1 = [
|
13 |
+
0,
|
14 |
+
0,
|
15 |
+
0,
|
16 |
+
255,
|
17 |
+
255,
|
18 |
+
255,
|
19 |
+
0,
|
20 |
+
0,
|
21 |
+
0,
|
22 |
+
0,
|
23 |
+
0,
|
24 |
+
0,
|
25 |
+
]
|
26 |
+
|
27 |
+
pallete2 = [
|
28 |
+
0,
|
29 |
+
0,
|
30 |
+
0,
|
31 |
+
0,
|
32 |
+
0,
|
33 |
+
0,
|
34 |
+
255,
|
35 |
+
255,
|
36 |
+
255,
|
37 |
+
0,
|
38 |
+
0,
|
39 |
+
0,
|
40 |
+
]
|
41 |
+
|
42 |
+
pallete3 = [
|
43 |
+
0,
|
44 |
+
0,
|
45 |
+
0,
|
46 |
+
0,
|
47 |
+
0,
|
48 |
+
0,
|
49 |
+
0,
|
50 |
+
0,
|
51 |
+
0,
|
52 |
+
255,
|
53 |
+
255,
|
54 |
+
255,
|
55 |
+
]
|
56 |
+
|
57 |
+
|
58 |
+
class Unet2ClothSession(BaseSession):
|
59 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
60 |
+
ort_outs = self.inner_session.run(
|
61 |
+
None,
|
62 |
+
self.normalize(
|
63 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (768, 768)
|
64 |
+
),
|
65 |
+
)
|
66 |
+
|
67 |
+
pred = ort_outs
|
68 |
+
pred = log_softmax(pred[0], 1)
|
69 |
+
pred = np.argmax(pred, axis=1, keepdims=True)
|
70 |
+
pred = np.squeeze(pred, 0)
|
71 |
+
pred = np.squeeze(pred, 0)
|
72 |
+
|
73 |
+
mask = Image.fromarray(pred.astype("uint8"), mode="L")
|
74 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
75 |
+
|
76 |
+
masks = []
|
77 |
+
|
78 |
+
mask1 = mask.copy()
|
79 |
+
mask1.putpalette(pallete1)
|
80 |
+
mask1 = mask1.convert("RGB").convert("L")
|
81 |
+
masks.append(mask1)
|
82 |
+
|
83 |
+
mask2 = mask.copy()
|
84 |
+
mask2.putpalette(pallete2)
|
85 |
+
mask2 = mask2.convert("RGB").convert("L")
|
86 |
+
masks.append(mask2)
|
87 |
+
|
88 |
+
mask3 = mask.copy()
|
89 |
+
mask3.putpalette(pallete3)
|
90 |
+
mask3 = mask3.convert("RGB").convert("L")
|
91 |
+
masks.append(mask3)
|
92 |
+
|
93 |
+
return masks
|
94 |
+
|
95 |
+
@classmethod
|
96 |
+
def download_models(cls, *args, **kwargs):
|
97 |
+
fname = f"{cls.name()}.onnx"
|
98 |
+
pooch.retrieve(
|
99 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
|
100 |
+
"md5:2434d1f3cb744e0e49386c906e5a08bb",
|
101 |
+
fname=fname,
|
102 |
+
path=cls.u2net_home(),
|
103 |
+
progressbar=True,
|
104 |
+
)
|
105 |
+
|
106 |
+
return os.path.join(cls.u2net_home(), fname)
|
107 |
+
|
108 |
+
@classmethod
|
109 |
+
def name(cls, *args, **kwargs):
|
110 |
+
return "u2net_cloth_seg"
|
rembg/sessions/u2net_human_seg.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pooch
|
6 |
+
from PIL import Image
|
7 |
+
from PIL.Image import Image as PILImage
|
8 |
+
|
9 |
+
from .base import BaseSession
|
10 |
+
|
11 |
+
|
12 |
+
class U2netHumanSegSession(BaseSession):
|
13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
14 |
+
ort_outs = self.inner_session.run(
|
15 |
+
None,
|
16 |
+
self.normalize(
|
17 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
|
18 |
+
),
|
19 |
+
)
|
20 |
+
|
21 |
+
pred = ort_outs[0][:, 0, :, :]
|
22 |
+
|
23 |
+
ma = np.max(pred)
|
24 |
+
mi = np.min(pred)
|
25 |
+
|
26 |
+
pred = (pred - mi) / (ma - mi)
|
27 |
+
pred = np.squeeze(pred)
|
28 |
+
|
29 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
30 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
31 |
+
|
32 |
+
return [mask]
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def download_models(cls, *args, **kwargs):
|
36 |
+
fname = f"{cls.name()}.onnx"
|
37 |
+
pooch.retrieve(
|
38 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
39 |
+
"md5:c09ddc2e0104f800e3e1bb4652583d1f",
|
40 |
+
fname=fname,
|
41 |
+
path=cls.u2net_home(),
|
42 |
+
progressbar=True,
|
43 |
+
)
|
44 |
+
|
45 |
+
return os.path.join(cls.u2net_home(), fname)
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def name(cls, *args, **kwargs):
|
49 |
+
return "u2net_human_seg"
|
rembg/sessions/u2netp.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pooch
|
6 |
+
from PIL import Image
|
7 |
+
from PIL.Image import Image as PILImage
|
8 |
+
|
9 |
+
from .base import BaseSession
|
10 |
+
|
11 |
+
|
12 |
+
class U2netpSession(BaseSession):
|
13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
14 |
+
ort_outs = self.inner_session.run(
|
15 |
+
None,
|
16 |
+
self.normalize(
|
17 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
|
18 |
+
),
|
19 |
+
)
|
20 |
+
|
21 |
+
pred = ort_outs[0][:, 0, :, :]
|
22 |
+
|
23 |
+
ma = np.max(pred)
|
24 |
+
mi = np.min(pred)
|
25 |
+
|
26 |
+
pred = (pred - mi) / (ma - mi)
|
27 |
+
pred = np.squeeze(pred)
|
28 |
+
|
29 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
30 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
31 |
+
|
32 |
+
return [mask]
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def download_models(cls, *args, **kwargs):
|
36 |
+
fname = f"{cls.name()}.onnx"
|
37 |
+
pooch.retrieve(
|
38 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
39 |
+
"md5:8e83ca70e441ab06c318d82300c84806",
|
40 |
+
fname=fname,
|
41 |
+
path=cls.u2net_home(),
|
42 |
+
progressbar=True,
|
43 |
+
)
|
44 |
+
|
45 |
+
return os.path.join(cls.u2net_home(), fname)
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def name(cls, *args, **kwargs):
|
49 |
+
return "u2netp"
|