KenjieDec commited on
Commit
5f57808
·
1 Parent(s): 3faa99b
anime-girl.jpg ADDED
app.py CHANGED
@@ -5,7 +5,7 @@ import gradio as gr
5
  import os
6
  import cv2
7
 
8
- def inference(file, af, mask, model):
9
  im = cv2.imread(file, cv2.IMREAD_COLOR)
10
  cv2.imwrite(os.path.join("input.png"), im)
11
 
@@ -20,7 +20,6 @@ def inference(file, af, mask, model):
20
  output = remove(
21
  input,
22
  session = new_session(model),
23
- alpha_matting_erode_size = af,
24
  only_mask = (True if mask == "Mask only" else False)
25
  )
26
 
@@ -38,7 +37,6 @@ gr.Interface(
38
  inference,
39
  [
40
  gr.inputs.Image(type="filepath", label="Input"),
41
- gr.inputs.Slider(10, 25, default=10, label="Alpha matting erode size"),
42
  gr.inputs.Radio(
43
  [
44
  "Default",
@@ -55,10 +53,11 @@ gr.Interface(
55
  "u2net_cloth_seg",
56
  "silueta",
57
  "isnet-general-use",
 
58
  "sam",
59
  ],
60
  type="value",
61
- default="u2net",
62
  label="Models"
63
  ),
64
  ],
@@ -66,6 +65,6 @@ gr.Interface(
66
  title=title,
67
  description=description,
68
  article=article,
69
- examples=[["lion.png", 10, "Default", "u2net"], ["girl.jpg", 10, "Default", "u2net"]],
70
  enable_queue=True
71
  ).launch()
 
5
  import os
6
  import cv2
7
 
8
+ def inference(file, mask, model):
9
  im = cv2.imread(file, cv2.IMREAD_COLOR)
10
  cv2.imwrite(os.path.join("input.png"), im)
11
 
 
20
  output = remove(
21
  input,
22
  session = new_session(model),
 
23
  only_mask = (True if mask == "Mask only" else False)
24
  )
25
 
 
37
  inference,
38
  [
39
  gr.inputs.Image(type="filepath", label="Input"),
 
40
  gr.inputs.Radio(
41
  [
42
  "Default",
 
53
  "u2net_cloth_seg",
54
  "silueta",
55
  "isnet-general-use",
56
+ "isnet-anime",
57
  "sam",
58
  ],
59
  type="value",
60
+ default="isnet-general-use",
61
  label="Models"
62
  ),
63
  ],
 
65
  title=title,
66
  description=description,
67
  article=article,
68
+ examples=[["lion.png", "Default", "u2net"], ["girl.jpg", "Default", "u2net"], ["anime-girl.jpg", "Default", "isnet-anime"]],
69
  enable_queue=True
70
  ).launch()
rembg/_version.py CHANGED
@@ -23,9 +23,9 @@ def get_keywords():
23
  # setup.py/versioneer.py will grep for the variable names, so they must
24
  # each be defined on a line of their own. _version.py will just call
25
  # get_keywords().
26
- git_refnames = " (HEAD -> main)"
27
- git_full = "e47b2a0ed405a5a30f42bacb142b107f7a4b6536"
28
- git_date = "2023-04-26 20:40:21 -0300"
29
  keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
  return keywords
31
 
 
23
  # setup.py/versioneer.py will grep for the variable names, so they must
24
  # each be defined on a line of their own. _version.py will just call
25
  # get_keywords().
26
+ git_refnames = " (HEAD -> main, tag: v2.0.43)"
27
+ git_full = "848a38e4cc5cf41522974dea00848596105b1dfa"
28
+ git_date = "2023-06-02 09:20:57 -0300"
29
  keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
  return keywords
31
 
rembg/bg.py CHANGED
@@ -11,7 +11,7 @@ from cv2 import (
11
  getStructuringElement,
12
  morphologyEx,
13
  )
14
- from PIL import Image
15
  from PIL.Image import Image as PILImage
16
  from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
17
  from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
@@ -19,6 +19,7 @@ from pymatting.util.util import stack_images
19
  from scipy.ndimage import binary_erosion
20
 
21
  from .session_factory import new_session
 
22
  from .sessions.base import BaseSession
23
 
24
  kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
@@ -113,6 +114,15 @@ def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> P
113
  return colored_image
114
 
115
 
 
 
 
 
 
 
 
 
 
116
  def remove(
117
  data: Union[bytes, PILImage, np.ndarray],
118
  alpha_matting: bool = False,
@@ -138,6 +148,9 @@ def remove(
138
  else:
139
  raise ValueError("Input type {} is not supported.".format(type(data)))
140
 
 
 
 
141
  if session is None:
142
  session = new_session("u2net", *args, **kwargs)
143
 
 
11
  getStructuringElement,
12
  morphologyEx,
13
  )
14
+ from PIL import Image, ImageOps
15
  from PIL.Image import Image as PILImage
16
  from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
17
  from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
 
19
  from scipy.ndimage import binary_erosion
20
 
21
  from .session_factory import new_session
22
+ from .sessions import sessions_class
23
  from .sessions.base import BaseSession
24
 
25
  kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
 
114
  return colored_image
115
 
116
 
117
+ def fix_image_orientation(img: PILImage) -> PILImage:
118
+ return ImageOps.exif_transpose(img)
119
+
120
+
121
+ def download_models() -> None:
122
+ for session in sessions_class:
123
+ session.download_models()
124
+
125
+
126
  def remove(
127
  data: Union[bytes, PILImage, np.ndarray],
128
  alpha_matting: bool = False,
 
148
  else:
149
  raise ValueError("Input type {} is not supported.".format(type(data)))
150
 
151
+ # Fix image orientation
152
+ img = fix_image_orientation(img)
153
+
154
  if session is None:
155
  session = new_session("u2net", *args, **kwargs)
156
 
rembg/commands/b_command.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import io
3
+ import json
4
+ import os
5
+ import sys
6
+ from typing import IO
7
+
8
+ import click
9
+ from PIL import Image
10
+
11
+ from ..bg import remove
12
+ from ..session_factory import new_session
13
+ from ..sessions import sessions_names
14
+
15
+
16
+ @click.command(
17
+ name="b",
18
+ help="for a byte stream as input",
19
+ )
20
+ @click.option(
21
+ "-m",
22
+ "--model",
23
+ default="u2net",
24
+ type=click.Choice(sessions_names),
25
+ show_default=True,
26
+ show_choices=True,
27
+ help="model name",
28
+ )
29
+ @click.option(
30
+ "-a",
31
+ "--alpha-matting",
32
+ is_flag=True,
33
+ show_default=True,
34
+ help="use alpha matting",
35
+ )
36
+ @click.option(
37
+ "-af",
38
+ "--alpha-matting-foreground-threshold",
39
+ default=240,
40
+ type=int,
41
+ show_default=True,
42
+ help="trimap fg threshold",
43
+ )
44
+ @click.option(
45
+ "-ab",
46
+ "--alpha-matting-background-threshold",
47
+ default=10,
48
+ type=int,
49
+ show_default=True,
50
+ help="trimap bg threshold",
51
+ )
52
+ @click.option(
53
+ "-ae",
54
+ "--alpha-matting-erode-size",
55
+ default=10,
56
+ type=int,
57
+ show_default=True,
58
+ help="erode size",
59
+ )
60
+ @click.option(
61
+ "-om",
62
+ "--only-mask",
63
+ is_flag=True,
64
+ show_default=True,
65
+ help="output only the mask",
66
+ )
67
+ @click.option(
68
+ "-ppm",
69
+ "--post-process-mask",
70
+ is_flag=True,
71
+ show_default=True,
72
+ help="post process the mask",
73
+ )
74
+ @click.option(
75
+ "-bgc",
76
+ "--bgcolor",
77
+ default=None,
78
+ type=(int, int, int, int),
79
+ nargs=4,
80
+ help="Background color (R G B A) to replace the removed background with",
81
+ )
82
+ @click.option("-x", "--extras", type=str)
83
+ @click.option(
84
+ "-o",
85
+ "--output_specifier",
86
+ type=str,
87
+ help="printf-style specifier for output filenames (e.g. 'output-%d.png'))",
88
+ )
89
+ @click.argument(
90
+ "image_width",
91
+ type=int,
92
+ )
93
+ @click.argument(
94
+ "image_height",
95
+ type=int,
96
+ )
97
+ def rs_command(
98
+ model: str,
99
+ extras: str,
100
+ image_width: int,
101
+ image_height: int,
102
+ output_specifier: str,
103
+ **kwargs
104
+ ) -> None:
105
+ try:
106
+ kwargs.update(json.loads(extras))
107
+ except Exception:
108
+ pass
109
+
110
+ session = new_session(model)
111
+ bytes_per_img = image_width * image_height * 3
112
+
113
+ if output_specifier:
114
+ output_dir = os.path.dirname(
115
+ os.path.abspath(os.path.expanduser(output_specifier))
116
+ )
117
+
118
+ if not os.path.isdir(output_dir):
119
+ os.makedirs(output_dir, exist_ok=True)
120
+
121
+ def img_to_byte_array(img: Image) -> bytes:
122
+ buff = io.BytesIO()
123
+ img.save(buff, format="PNG")
124
+ return buff.getvalue()
125
+
126
+ async def connect_stdin_stdout():
127
+ loop = asyncio.get_event_loop()
128
+ reader = asyncio.StreamReader()
129
+ protocol = asyncio.StreamReaderProtocol(reader)
130
+
131
+ await loop.connect_read_pipe(lambda: protocol, sys.stdin)
132
+ w_transport, w_protocol = await loop.connect_write_pipe(
133
+ asyncio.streams.FlowControlMixin, sys.stdout
134
+ )
135
+
136
+ writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
137
+ return reader, writer
138
+
139
+ async def main():
140
+ reader, writer = await connect_stdin_stdout()
141
+
142
+ idx = 0
143
+ while True:
144
+ try:
145
+ img_bytes = await reader.readexactly(bytes_per_img)
146
+ if not img_bytes:
147
+ break
148
+
149
+ img = Image.frombytes("RGB", (image_width, image_height), img_bytes)
150
+ output = remove(img, session=session, **kwargs)
151
+
152
+ if output_specifier:
153
+ output.save((output_specifier % idx), format="PNG")
154
+ else:
155
+ writer.write(img_to_byte_array(output))
156
+
157
+ idx += 1
158
+ except asyncio.IncompleteReadError:
159
+ break
160
+
161
+ asyncio.run(main())
rembg/commands/s_command.py CHANGED
@@ -1,8 +1,11 @@
1
  import json
2
- from typing import Annotated, Optional, Tuple, cast
 
 
3
 
4
  import aiohttp
5
  import click
 
6
  import uvicorn
7
  from asyncer import asyncify
8
  from fastapi import Depends, FastAPI, File, Form, Query
@@ -70,6 +73,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
70
  "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
71
  },
72
  openapi_tags=tags_metadata,
 
73
  )
74
 
75
  app.add_middleware(
@@ -83,10 +87,10 @@ def s_command(port: int, log_level: str, threads: int) -> None:
83
  class CommonQueryParams:
84
  def __init__(
85
  self,
86
- model: Annotated[
87
- str, Query(regex=r"(" + "|".join(sessions_names) + ")")
88
- ] = Query(
89
  description="Model to use when processing image",
 
 
90
  ),
91
  a: bool = Query(default=False, description="Enable Alpha Matting"),
92
  af: int = Query(
@@ -128,10 +132,10 @@ def s_command(port: int, log_level: str, threads: int) -> None:
128
  class CommonQueryPostParams:
129
  def __init__(
130
  self,
131
- model: Annotated[
132
- str, Form(regex=r"(" + "|".join(sessions_names) + ")")
133
- ] = Form(
134
  description="Model to use when processing image",
 
 
135
  ),
136
  a: bool = Form(default=False, description="Enable Alpha Matting"),
137
  af: int = Form(
@@ -190,13 +194,18 @@ def s_command(port: int, log_level: str, threads: int) -> None:
190
  only_mask=commons.om,
191
  post_process_mask=commons.ppm,
192
  bgcolor=commons.bgc,
193
- **kwargs
194
  ),
195
  media_type="image/png",
196
  )
197
 
198
  @app.on_event("startup")
199
  def startup():
 
 
 
 
 
200
  if threads is not None:
201
  from anyio import CapacityLimiter
202
  from anyio.lowlevel import RunVar
@@ -204,7 +213,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
204
  RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
205
 
206
  @app.get(
207
- path="/",
208
  tags=["Background Removal"],
209
  summary="Remove from URL",
210
  description="Removes the background from an image obtained by retrieving an URL.",
@@ -221,7 +230,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
221
  return await asyncify(im_without_bg)(file, commons)
222
 
223
  @app.post(
224
- path="/",
225
  tags=["Background Removal"],
226
  summary="Remove from Stream",
227
  description="Removes the background from an image sent within the request itself.",
@@ -235,4 +244,42 @@ def s_command(port: int, log_level: str, threads: int) -> None:
235
  ):
236
  return await asyncify(im_without_bg)(file, commons) # type: ignore
237
 
238
- uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ import os
3
+ import webbrowser
4
+ from typing import Optional, Tuple, cast
5
 
6
  import aiohttp
7
  import click
8
+ import gradio as gr
9
  import uvicorn
10
  from asyncer import asyncify
11
  from fastapi import Depends, FastAPI, File, Form, Query
 
73
  "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
74
  },
75
  openapi_tags=tags_metadata,
76
+ docs_url="/api",
77
  )
78
 
79
  app.add_middleware(
 
87
  class CommonQueryParams:
88
  def __init__(
89
  self,
90
+ model: str = Query(
 
 
91
  description="Model to use when processing image",
92
+ regex=r"(" + "|".join(sessions_names) + ")",
93
+ default="u2net",
94
  ),
95
  a: bool = Query(default=False, description="Enable Alpha Matting"),
96
  af: int = Query(
 
132
  class CommonQueryPostParams:
133
  def __init__(
134
  self,
135
+ model: str = Form(
 
 
136
  description="Model to use when processing image",
137
+ regex=r"(" + "|".join(sessions_names) + ")",
138
+ default="u2net",
139
  ),
140
  a: bool = Form(default=False, description="Enable Alpha Matting"),
141
  af: int = Form(
 
194
  only_mask=commons.om,
195
  post_process_mask=commons.ppm,
196
  bgcolor=commons.bgc,
197
+ **kwargs,
198
  ),
199
  media_type="image/png",
200
  )
201
 
202
  @app.on_event("startup")
203
  def startup():
204
+ try:
205
+ webbrowser.open(f"http://localhost:{port}")
206
+ except Exception:
207
+ pass
208
+
209
  if threads is not None:
210
  from anyio import CapacityLimiter
211
  from anyio.lowlevel import RunVar
 
213
  RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
214
 
215
  @app.get(
216
+ path="/api/remove",
217
  tags=["Background Removal"],
218
  summary="Remove from URL",
219
  description="Removes the background from an image obtained by retrieving an URL.",
 
230
  return await asyncify(im_without_bg)(file, commons)
231
 
232
  @app.post(
233
+ path="/api/remove",
234
  tags=["Background Removal"],
235
  summary="Remove from Stream",
236
  description="Removes the background from an image sent within the request itself.",
 
244
  ):
245
  return await asyncify(im_without_bg)(file, commons) # type: ignore
246
 
247
+ def gr_app(app):
248
+ def inference(input_path, model):
249
+ output_path = "output.png"
250
+ with open(input_path, "rb") as i:
251
+ with open(output_path, "wb") as o:
252
+ input = i.read()
253
+ output = remove(input, session=new_session(model))
254
+ o.write(output)
255
+ return os.path.join(output_path)
256
+
257
+ interface = gr.Interface(
258
+ inference,
259
+ [
260
+ gr.components.Image(type="filepath", label="Input"),
261
+ gr.components.Dropdown(
262
+ [
263
+ "u2net",
264
+ "u2netp",
265
+ "u2net_human_seg",
266
+ "u2net_cloth_seg",
267
+ "silueta",
268
+ "isnet-general-use",
269
+ "isnet-anime",
270
+ ],
271
+ value="u2net",
272
+ label="Models",
273
+ ),
274
+ ],
275
+ gr.components.Image(type="filepath", label="Output"),
276
+ )
277
+
278
+ interface.queue(concurrency_count=3)
279
+ app = gr.mount_gradio_app(app, interface, path="/")
280
+ return app
281
+
282
+ print(f"To access the API documentation, go to http://localhost:{port}/api")
283
+ print(f"To access the UI, go to http://localhost:{port}")
284
+
285
+ uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level)
rembg/session_factory.py CHANGED
@@ -8,7 +8,9 @@ from .sessions.base import BaseSession
8
  from .sessions.u2net import U2netSession
9
 
10
 
11
- def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
 
 
12
  session_class: Type[BaseSession] = U2netSession
13
 
14
  for sc in sessions_class:
@@ -21,4 +23,4 @@ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
21
  if "OMP_NUM_THREADS" in os.environ:
22
  sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
23
 
24
- return session_class(model_name, sess_opts, *args, **kwargs)
 
8
  from .sessions.u2net import U2netSession
9
 
10
 
11
+ def new_session(
12
+ model_name: str = "u2net", providers=None, *args, **kwargs
13
+ ) -> BaseSession:
14
  session_class: Type[BaseSession] = U2netSession
15
 
16
  for sc in sessions_class:
 
23
  if "OMP_NUM_THREADS" in os.environ:
24
  sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
25
 
26
+ return session_class(model_name, sess_opts, providers, *args, **kwargs)
rembg/sessions/base.py CHANGED
@@ -8,11 +8,29 @@ from PIL.Image import Image as PILImage
8
 
9
 
10
  class BaseSession:
11
- def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
 
 
 
 
 
 
 
12
  self.model_name = model_name
 
 
 
 
 
 
 
 
 
 
 
13
  self.inner_session = ort.InferenceSession(
14
  str(self.__class__.download_models()),
15
- providers=ort.get_available_providers(),
16
  sess_options=sess_opts,
17
  )
18
 
@@ -46,6 +64,10 @@ class BaseSession:
46
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
47
  raise NotImplementedError
48
 
 
 
 
 
49
  @classmethod
50
  def u2net_home(cls, *args, **kwargs):
51
  return os.path.expanduser(
 
8
 
9
 
10
  class BaseSession:
11
+ def __init__(
12
+ self,
13
+ model_name: str,
14
+ sess_opts: ort.SessionOptions,
15
+ providers=None,
16
+ *args,
17
+ **kwargs
18
+ ):
19
  self.model_name = model_name
20
+
21
+ self.providers = []
22
+
23
+ _providers = ort.get_available_providers()
24
+ if providers:
25
+ for provider in providers:
26
+ if provider in _providers:
27
+ self.providers.append(provider)
28
+ else:
29
+ self.providers.extend(_providers)
30
+
31
  self.inner_session = ort.InferenceSession(
32
  str(self.__class__.download_models()),
33
+ providers=self.providers,
34
  sess_options=sess_opts,
35
  )
36
 
 
64
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
65
  raise NotImplementedError
66
 
67
+ @classmethod
68
+ def checksum_disabled(cls, *args, **kwargs):
69
+ return os.getenv("MODEL_CHECKSUM_DISABLED", None) is not None
70
+
71
  @classmethod
72
  def u2net_home(cls, *args, **kwargs):
73
  return os.path.expanduser(
rembg/sessions/dis_anime.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class DisSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
17
+ )
18
+
19
+ pred = ort_outs[0][:, 0, :, :]
20
+
21
+ ma = np.max(pred)
22
+ mi = np.min(pred)
23
+
24
+ pred = (pred - mi) / (ma - mi)
25
+ pred = np.squeeze(pred)
26
+
27
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
+ mask = mask.resize(img.size, Image.LANCZOS)
29
+
30
+ return [mask]
31
+
32
+ @classmethod
33
+ def download_models(cls, *args, **kwargs):
34
+ fname = f"{cls.name()}.onnx"
35
+ pooch.retrieve(
36
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
37
+ None
38
+ if cls.checksum_disabled(*args, **kwargs)
39
+ else "md5:6f184e756bb3bd901c8849220a83e38e",
40
+ fname=fname,
41
+ path=cls.u2net_home(*args, **kwargs),
42
+ progressbar=True,
43
+ )
44
+
45
+ return os.path.join(cls.u2net_home(), fname)
46
+
47
+ @classmethod
48
+ def name(cls, *args, **kwargs):
49
+ return "isnet-anime"
rembg/sessions/dis_general_use.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class DisSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
17
+ )
18
+
19
+ pred = ort_outs[0][:, 0, :, :]
20
+
21
+ ma = np.max(pred)
22
+ mi = np.min(pred)
23
+
24
+ pred = (pred - mi) / (ma - mi)
25
+ pred = np.squeeze(pred)
26
+
27
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
+ mask = mask.resize(img.size, Image.LANCZOS)
29
+
30
+ return [mask]
31
+
32
+ @classmethod
33
+ def download_models(cls, *args, **kwargs):
34
+ fname = f"{cls.name()}.onnx"
35
+ pooch.retrieve(
36
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
37
+ None
38
+ if cls.checksum_disabled(*args, **kwargs)
39
+ else "md5:fc16ebd8b0c10d971d3513d564d01e29",
40
+ fname=fname,
41
+ path=cls.u2net_home(*args, **kwargs),
42
+ progressbar=True,
43
+ )
44
+
45
+ return os.path.join(cls.u2net_home(), fname)
46
+
47
+ @classmethod
48
+ def name(cls, *args, **kwargs):
49
+ return "isnet-general-use"
rembg/sessions/sam.py CHANGED
@@ -141,17 +141,21 @@ class SamSession(BaseSession):
141
 
142
  pooch.retrieve(
143
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
144
- "md5:13d97c5c79ab13ef86d67cbde5f1b250",
 
 
145
  fname=fname_encoder,
146
- path=cls.u2net_home(),
147
  progressbar=True,
148
  )
149
 
150
  pooch.retrieve(
151
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
152
- "md5:fa3d1c36a3187d3de1c8deebf33dd127",
 
 
153
  fname=fname_decoder,
154
- path=cls.u2net_home(),
155
  progressbar=True,
156
  )
157
 
 
141
 
142
  pooch.retrieve(
143
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
144
+ None
145
+ if cls.checksum_disabled(*args, **kwargs)
146
+ else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
147
  fname=fname_encoder,
148
+ path=cls.u2net_home(*args, **kwargs),
149
  progressbar=True,
150
  )
151
 
152
  pooch.retrieve(
153
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
154
+ None
155
+ if cls.checksum_disabled(*args, **kwargs)
156
+ else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
157
  fname=fname_decoder,
158
+ path=cls.u2net_home(*args, **kwargs),
159
  progressbar=True,
160
  )
161
 
rembg/sessions/silueta.py CHANGED
@@ -36,9 +36,11 @@ class SiluetaSession(BaseSession):
36
  fname = f"{cls.name()}.onnx"
37
  pooch.retrieve(
38
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
39
- "md5:55e59e0d8062d2f5d013f4725ee84782",
 
 
40
  fname=fname,
41
- path=cls.u2net_home(),
42
  progressbar=True,
43
  )
44
 
 
36
  fname = f"{cls.name()}.onnx"
37
  pooch.retrieve(
38
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
39
+ None
40
+ if cls.checksum_disabled(*args, **kwargs)
41
+ else "md5:55e59e0d8062d2f5d013f4725ee84782",
42
  fname=fname,
43
+ path=cls.u2net_home(*args, **kwargs),
44
  progressbar=True,
45
  )
46
 
rembg/sessions/u2net.py CHANGED
@@ -36,9 +36,11 @@ class U2netSession(BaseSession):
36
  fname = f"{cls.name()}.onnx"
37
  pooch.retrieve(
38
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
39
- "md5:60024c5c889badc19c04ad937298a77b",
 
 
40
  fname=fname,
41
- path=cls.u2net_home(),
42
  progressbar=True,
43
  )
44
 
 
36
  fname = f"{cls.name()}.onnx"
37
  pooch.retrieve(
38
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
39
+ None
40
+ if cls.checksum_disabled(*args, **kwargs)
41
+ else "md5:60024c5c889badc19c04ad937298a77b",
42
  fname=fname,
43
+ path=cls.u2net_home(*args, **kwargs),
44
  progressbar=True,
45
  )
46
 
rembg/sessions/u2net_cloth_seg.py CHANGED
@@ -97,9 +97,11 @@ class Unet2ClothSession(BaseSession):
97
  fname = f"{cls.name()}.onnx"
98
  pooch.retrieve(
99
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
100
- "md5:2434d1f3cb744e0e49386c906e5a08bb",
 
 
101
  fname=fname,
102
- path=cls.u2net_home(),
103
  progressbar=True,
104
  )
105
 
 
97
  fname = f"{cls.name()}.onnx"
98
  pooch.retrieve(
99
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
100
+ None
101
+ if cls.checksum_disabled(*args, **kwargs)
102
+ else "md5:2434d1f3cb744e0e49386c906e5a08bb",
103
  fname=fname,
104
+ path=cls.u2net_home(*args, **kwargs),
105
  progressbar=True,
106
  )
107
 
rembg/sessions/u2net_human_seg.py CHANGED
@@ -36,9 +36,11 @@ class U2netHumanSegSession(BaseSession):
36
  fname = f"{cls.name()}.onnx"
37
  pooch.retrieve(
38
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
39
- "md5:c09ddc2e0104f800e3e1bb4652583d1f",
 
 
40
  fname=fname,
41
- path=cls.u2net_home(),
42
  progressbar=True,
43
  )
44
 
 
36
  fname = f"{cls.name()}.onnx"
37
  pooch.retrieve(
38
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
39
+ None
40
+ if cls.checksum_disabled(*args, **kwargs)
41
+ else "md5:c09ddc2e0104f800e3e1bb4652583d1f",
42
  fname=fname,
43
+ path=cls.u2net_home(*args, **kwargs),
44
  progressbar=True,
45
  )
46
 
rembg/sessions/u2netp.py CHANGED
@@ -36,9 +36,11 @@ class U2netpSession(BaseSession):
36
  fname = f"{cls.name()}.onnx"
37
  pooch.retrieve(
38
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
39
- "md5:8e83ca70e441ab06c318d82300c84806",
 
 
40
  fname=fname,
41
- path=cls.u2net_home(),
42
  progressbar=True,
43
  )
44
 
 
36
  fname = f"{cls.name()}.onnx"
37
  pooch.retrieve(
38
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
39
+ None
40
+ if cls.checksum_disabled(*args, **kwargs)
41
+ else "md5:8e83ca70e441ab06c318d82300c84806",
42
  fname=fname,
43
+ path=cls.u2net_home(*args, **kwargs),
44
  progressbar=True,
45
  )
46