KenjieDec commited on
Commit
3a6f1f2
·
1 Parent(s): 158fac3
LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Daniel Gatis
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
- title: RemBG
3
- emoji: 🐠
4
  colorFrom: pink
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.0.20
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Rembg
3
+ emoji: 👀
4
  colorFrom: pink
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.0.20
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ '''
4
+ @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
5
+ @author: yangxy ([email protected])
6
+ '''
7
+ import os
8
+ import cv2
9
+
10
+
11
+ def inference(file):
12
+ im = cv2.imread(file, cv2.IMREAD_COLOR)
13
+ cv2.imwrite(os.path.join("input.png"), im)
14
+
15
+ from rembg import remove
16
+
17
+ input_path = 'input.png'
18
+ output_path = 'output.png'
19
+
20
+ with open(input_path, 'rb') as i:
21
+ with open(output_path, 'wb') as o:
22
+ input = i.read()
23
+ output = remove(input)
24
+ o.write(output)
25
+ return os.path.join("output.png")
26
+
27
+ title = "RemBG"
28
+ description = "Gradio demo for RemBG. To use it, simply upload your image and wait. Read more at the link below."
29
+
30
+ article = "<p style='text-align: center;'><a href='https://github.com/danielgatis/rembg' target='_blank'>Github Repo</a></p>"
31
+
32
+
33
+ gr.Interface(
34
+ inference,
35
+ [gr.inputs.Image(type="filepath", label="Input")],
36
+ gr.outputs.Image(type="file", label="Output"),
37
+ title=title,
38
+ description=description,
39
+ article=article,
40
+ examples=[],
41
+ enable_queue=True
42
+ ).launch()
rembg/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import warnings
3
+
4
+ if not (sys.version_info.major == 3 and sys.version_info.minor == 9):
5
+ warnings.warn("This library is only for Python 3.9", RuntimeWarning)
6
+
7
+ from . import _version
8
+
9
+ __version__ = _version.get_versions()["version"]
10
+
11
+ from .bg import remove
rembg/_version.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file helps to compute a version number in source trees obtained from
2
+ # git-archive tarball (such as those provided by githubs download-from-tag
3
+ # feature). Distribution tarballs (built by setup.py sdist) and build
4
+ # directories (produced by setup.py build) will contain a much shorter file
5
+ # that just contains the computed version number.
6
+
7
+ # This file is released into the public domain. Generated by
8
+ # versioneer-0.21 (https://github.com/python-versioneer/python-versioneer)
9
+
10
+ """Git implementation of _version.py."""
11
+
12
+ import errno
13
+ import os
14
+ import re
15
+ import subprocess
16
+ import sys
17
+ from typing import Callable, Dict
18
+
19
+
20
+ def get_keywords():
21
+ """Get the keywords needed to look up the version information."""
22
+ # these strings will be replaced by git during git-archive.
23
+ # setup.py/versioneer.py will grep for the variable names, so they must
24
+ # each be defined on a line of their own. _version.py will just call
25
+ # get_keywords().
26
+ git_refnames = " (HEAD -> main)"
27
+ git_full = "3bc1c1af99ebd47dd08d02763fc754d70d42afea"
28
+ git_date = "2022-06-16 23:00:14 -0300"
29
+ keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
+ return keywords
31
+
32
+
33
+ class VersioneerConfig:
34
+ """Container for Versioneer configuration parameters."""
35
+
36
+
37
+ def get_config():
38
+ """Create, populate and return the VersioneerConfig() object."""
39
+ # these strings are filled in when 'setup.py versioneer' creates
40
+ # _version.py
41
+ cfg = VersioneerConfig()
42
+ cfg.VCS = "git"
43
+ cfg.style = "pep440"
44
+ cfg.tag_prefix = "v"
45
+ cfg.parentdir_prefix = "rembg-"
46
+ cfg.versionfile_source = "rembg/_version.py"
47
+ cfg.verbose = False
48
+ return cfg
49
+
50
+
51
+ class NotThisMethod(Exception):
52
+ """Exception raised if a method is not valid for the current scenario."""
53
+
54
+
55
+ LONG_VERSION_PY: Dict[str, str] = {}
56
+ HANDLERS: Dict[str, Dict[str, Callable]] = {}
57
+
58
+
59
+ def register_vcs_handler(vcs, method): # decorator
60
+ """Create decorator to mark a method as the handler of a VCS."""
61
+
62
+ def decorate(f):
63
+ """Store f in HANDLERS[vcs][method]."""
64
+ if vcs not in HANDLERS:
65
+ HANDLERS[vcs] = {}
66
+ HANDLERS[vcs][method] = f
67
+ return f
68
+
69
+ return decorate
70
+
71
+
72
+ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
73
+ """Call the given command(s)."""
74
+ assert isinstance(commands, list)
75
+ process = None
76
+ for command in commands:
77
+ try:
78
+ dispcmd = str([command] + args)
79
+ # remember shell=False, so use git.cmd on windows, not just git
80
+ process = subprocess.Popen(
81
+ [command] + args,
82
+ cwd=cwd,
83
+ env=env,
84
+ stdout=subprocess.PIPE,
85
+ stderr=(subprocess.PIPE if hide_stderr else None),
86
+ )
87
+ break
88
+ except OSError:
89
+ e = sys.exc_info()[1]
90
+ if e.errno == errno.ENOENT:
91
+ continue
92
+ if verbose:
93
+ print("unable to run %s" % dispcmd)
94
+ print(e)
95
+ return None, None
96
+ else:
97
+ if verbose:
98
+ print("unable to find command, tried %s" % (commands,))
99
+ return None, None
100
+ stdout = process.communicate()[0].strip().decode()
101
+ if process.returncode != 0:
102
+ if verbose:
103
+ print("unable to run %s (error)" % dispcmd)
104
+ print("stdout was %s" % stdout)
105
+ return None, process.returncode
106
+ return stdout, process.returncode
107
+
108
+
109
+ def versions_from_parentdir(parentdir_prefix, root, verbose):
110
+ """Try to determine the version from the parent directory name.
111
+
112
+ Source tarballs conventionally unpack into a directory that includes both
113
+ the project name and a version string. We will also support searching up
114
+ two directory levels for an appropriately named parent directory
115
+ """
116
+ rootdirs = []
117
+
118
+ for _ in range(3):
119
+ dirname = os.path.basename(root)
120
+ if dirname.startswith(parentdir_prefix):
121
+ return {
122
+ "version": dirname[len(parentdir_prefix) :],
123
+ "full-revisionid": None,
124
+ "dirty": False,
125
+ "error": None,
126
+ "date": None,
127
+ }
128
+ rootdirs.append(root)
129
+ root = os.path.dirname(root) # up a level
130
+
131
+ if verbose:
132
+ print(
133
+ "Tried directories %s but none started with prefix %s"
134
+ % (str(rootdirs), parentdir_prefix)
135
+ )
136
+ raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
137
+
138
+
139
+ @register_vcs_handler("git", "get_keywords")
140
+ def git_get_keywords(versionfile_abs):
141
+ """Extract version information from the given file."""
142
+ # the code embedded in _version.py can just fetch the value of these
143
+ # keywords. When used from setup.py, we don't want to import _version.py,
144
+ # so we do it with a regexp instead. This function is not used from
145
+ # _version.py.
146
+ keywords = {}
147
+ try:
148
+ with open(versionfile_abs, "r") as fobj:
149
+ for line in fobj:
150
+ if line.strip().startswith("git_refnames ="):
151
+ mo = re.search(r'=\s*"(.*)"', line)
152
+ if mo:
153
+ keywords["refnames"] = mo.group(1)
154
+ if line.strip().startswith("git_full ="):
155
+ mo = re.search(r'=\s*"(.*)"', line)
156
+ if mo:
157
+ keywords["full"] = mo.group(1)
158
+ if line.strip().startswith("git_date ="):
159
+ mo = re.search(r'=\s*"(.*)"', line)
160
+ if mo:
161
+ keywords["date"] = mo.group(1)
162
+ except OSError:
163
+ pass
164
+ return keywords
165
+
166
+
167
+ @register_vcs_handler("git", "keywords")
168
+ def git_versions_from_keywords(keywords, tag_prefix, verbose):
169
+ """Get version information from git keywords."""
170
+ if "refnames" not in keywords:
171
+ raise NotThisMethod("Short version file found")
172
+ date = keywords.get("date")
173
+ if date is not None:
174
+ # Use only the last line. Previous lines may contain GPG signature
175
+ # information.
176
+ date = date.splitlines()[-1]
177
+
178
+ # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
179
+ # datestamp. However we prefer "%ci" (which expands to an "ISO-8601
180
+ # -like" string, which we must then edit to make compliant), because
181
+ # it's been around since git-1.5.3, and it's too difficult to
182
+ # discover which version we're using, or to work around using an
183
+ # older one.
184
+ date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
185
+ refnames = keywords["refnames"].strip()
186
+ if refnames.startswith("$Format"):
187
+ if verbose:
188
+ print("keywords are unexpanded, not using")
189
+ raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
190
+ refs = {r.strip() for r in refnames.strip("()").split(",")}
191
+ # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
192
+ # just "foo-1.0". If we see a "tag: " prefix, prefer those.
193
+ TAG = "tag: "
194
+ tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}
195
+ if not tags:
196
+ # Either we're using git < 1.8.3, or there really are no tags. We use
197
+ # a heuristic: assume all version tags have a digit. The old git %d
198
+ # expansion behaves like git log --decorate=short and strips out the
199
+ # refs/heads/ and refs/tags/ prefixes that would let us distinguish
200
+ # between branches and tags. By ignoring refnames without digits, we
201
+ # filter out many common branch names like "release" and
202
+ # "stabilization", as well as "HEAD" and "master".
203
+ tags = {r for r in refs if re.search(r"\d", r)}
204
+ if verbose:
205
+ print("discarding '%s', no digits" % ",".join(refs - tags))
206
+ if verbose:
207
+ print("likely tags: %s" % ",".join(sorted(tags)))
208
+ for ref in sorted(tags):
209
+ # sorting will prefer e.g. "2.0" over "2.0rc1"
210
+ if ref.startswith(tag_prefix):
211
+ r = ref[len(tag_prefix) :]
212
+ # Filter out refs that exactly match prefix or that don't start
213
+ # with a number once the prefix is stripped (mostly a concern
214
+ # when prefix is '')
215
+ if not re.match(r"\d", r):
216
+ continue
217
+ if verbose:
218
+ print("picking %s" % r)
219
+ return {
220
+ "version": r,
221
+ "full-revisionid": keywords["full"].strip(),
222
+ "dirty": False,
223
+ "error": None,
224
+ "date": date,
225
+ }
226
+ # no suitable tags, so version is "0+unknown", but full hex is still there
227
+ if verbose:
228
+ print("no suitable tags, using unknown + full revision id")
229
+ return {
230
+ "version": "0+unknown",
231
+ "full-revisionid": keywords["full"].strip(),
232
+ "dirty": False,
233
+ "error": "no suitable tags",
234
+ "date": None,
235
+ }
236
+
237
+
238
+ @register_vcs_handler("git", "pieces_from_vcs")
239
+ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
240
+ """Get version from 'git describe' in the root of the source tree.
241
+
242
+ This only gets called if the git-archive 'subst' keywords were *not*
243
+ expanded, and _version.py hasn't already been rewritten with a short
244
+ version string, meaning we're inside a checked out source tree.
245
+ """
246
+ GITS = ["git"]
247
+ TAG_PREFIX_REGEX = "*"
248
+ if sys.platform == "win32":
249
+ GITS = ["git.cmd", "git.exe"]
250
+ TAG_PREFIX_REGEX = r"\*"
251
+
252
+ _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
253
+ if rc != 0:
254
+ if verbose:
255
+ print("Directory %s not under git control" % root)
256
+ raise NotThisMethod("'git rev-parse --git-dir' returned error")
257
+
258
+ # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
259
+ # if there isn't one, this yields HEX[-dirty] (no NUM)
260
+ describe_out, rc = runner(
261
+ GITS,
262
+ [
263
+ "describe",
264
+ "--tags",
265
+ "--dirty",
266
+ "--always",
267
+ "--long",
268
+ "--match",
269
+ "%s%s" % (tag_prefix, TAG_PREFIX_REGEX),
270
+ ],
271
+ cwd=root,
272
+ )
273
+ # --long was added in git-1.5.5
274
+ if describe_out is None:
275
+ raise NotThisMethod("'git describe' failed")
276
+ describe_out = describe_out.strip()
277
+ full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
278
+ if full_out is None:
279
+ raise NotThisMethod("'git rev-parse' failed")
280
+ full_out = full_out.strip()
281
+
282
+ pieces = {}
283
+ pieces["long"] = full_out
284
+ pieces["short"] = full_out[:7] # maybe improved later
285
+ pieces["error"] = None
286
+
287
+ branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root)
288
+ # --abbrev-ref was added in git-1.6.3
289
+ if rc != 0 or branch_name is None:
290
+ raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
291
+ branch_name = branch_name.strip()
292
+
293
+ if branch_name == "HEAD":
294
+ # If we aren't exactly on a branch, pick a branch which represents
295
+ # the current commit. If all else fails, we are on a branchless
296
+ # commit.
297
+ branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
298
+ # --contains was added in git-1.5.4
299
+ if rc != 0 or branches is None:
300
+ raise NotThisMethod("'git branch --contains' returned error")
301
+ branches = branches.split("\n")
302
+
303
+ # Remove the first line if we're running detached
304
+ if "(" in branches[0]:
305
+ branches.pop(0)
306
+
307
+ # Strip off the leading "* " from the list of branches.
308
+ branches = [branch[2:] for branch in branches]
309
+ if "master" in branches:
310
+ branch_name = "master"
311
+ elif not branches:
312
+ branch_name = None
313
+ else:
314
+ # Pick the first branch that is returned. Good or bad.
315
+ branch_name = branches[0]
316
+
317
+ pieces["branch"] = branch_name
318
+
319
+ # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
320
+ # TAG might have hyphens.
321
+ git_describe = describe_out
322
+
323
+ # look for -dirty suffix
324
+ dirty = git_describe.endswith("-dirty")
325
+ pieces["dirty"] = dirty
326
+ if dirty:
327
+ git_describe = git_describe[: git_describe.rindex("-dirty")]
328
+
329
+ # now we have TAG-NUM-gHEX or HEX
330
+
331
+ if "-" in git_describe:
332
+ # TAG-NUM-gHEX
333
+ mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
334
+ if not mo:
335
+ # unparsable. Maybe git-describe is misbehaving?
336
+ pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
337
+ return pieces
338
+
339
+ # tag
340
+ full_tag = mo.group(1)
341
+ if not full_tag.startswith(tag_prefix):
342
+ if verbose:
343
+ fmt = "tag '%s' doesn't start with prefix '%s'"
344
+ print(fmt % (full_tag, tag_prefix))
345
+ pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
346
+ full_tag,
347
+ tag_prefix,
348
+ )
349
+ return pieces
350
+ pieces["closest-tag"] = full_tag[len(tag_prefix) :]
351
+
352
+ # distance: number of commits since tag
353
+ pieces["distance"] = int(mo.group(2))
354
+
355
+ # commit: short hex revision ID
356
+ pieces["short"] = mo.group(3)
357
+
358
+ else:
359
+ # HEX: no tags
360
+ pieces["closest-tag"] = None
361
+ count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
362
+ pieces["distance"] = int(count_out) # total number of commits
363
+
364
+ # commit date: see ISO-8601 comment in git_versions_from_keywords()
365
+ date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
366
+ # Use only the last line. Previous lines may contain GPG signature
367
+ # information.
368
+ date = date.splitlines()[-1]
369
+ pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
370
+
371
+ return pieces
372
+
373
+
374
+ def plus_or_dot(pieces):
375
+ """Return a + if we don't already have one, else return a ."""
376
+ if "+" in pieces.get("closest-tag", ""):
377
+ return "."
378
+ return "+"
379
+
380
+
381
+ def render_pep440(pieces):
382
+ """Build up version string, with post-release "local version identifier".
383
+
384
+ Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
385
+ get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
386
+
387
+ Exceptions:
388
+ 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
389
+ """
390
+ if pieces["closest-tag"]:
391
+ rendered = pieces["closest-tag"]
392
+ if pieces["distance"] or pieces["dirty"]:
393
+ rendered += plus_or_dot(pieces)
394
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
395
+ if pieces["dirty"]:
396
+ rendered += ".dirty"
397
+ else:
398
+ # exception #1
399
+ rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
400
+ if pieces["dirty"]:
401
+ rendered += ".dirty"
402
+ return rendered
403
+
404
+
405
+ def render_pep440_branch(pieces):
406
+ """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
407
+
408
+ The ".dev0" means not master branch. Note that .dev0 sorts backwards
409
+ (a feature branch will appear "older" than the master branch).
410
+
411
+ Exceptions:
412
+ 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
413
+ """
414
+ if pieces["closest-tag"]:
415
+ rendered = pieces["closest-tag"]
416
+ if pieces["distance"] or pieces["dirty"]:
417
+ if pieces["branch"] != "master":
418
+ rendered += ".dev0"
419
+ rendered += plus_or_dot(pieces)
420
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
421
+ if pieces["dirty"]:
422
+ rendered += ".dirty"
423
+ else:
424
+ # exception #1
425
+ rendered = "0"
426
+ if pieces["branch"] != "master":
427
+ rendered += ".dev0"
428
+ rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
429
+ if pieces["dirty"]:
430
+ rendered += ".dirty"
431
+ return rendered
432
+
433
+
434
+ def pep440_split_post(ver):
435
+ """Split pep440 version string at the post-release segment.
436
+
437
+ Returns the release segments before the post-release and the
438
+ post-release version number (or -1 if no post-release segment is present).
439
+ """
440
+ vc = str.split(ver, ".post")
441
+ return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
442
+
443
+
444
+ def render_pep440_pre(pieces):
445
+ """TAG[.postN.devDISTANCE] -- No -dirty.
446
+
447
+ Exceptions:
448
+ 1: no tags. 0.post0.devDISTANCE
449
+ """
450
+ if pieces["closest-tag"]:
451
+ if pieces["distance"]:
452
+ # update the post release segment
453
+ tag_version, post_version = pep440_split_post(pieces["closest-tag"])
454
+ rendered = tag_version
455
+ if post_version is not None:
456
+ rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
457
+ else:
458
+ rendered += ".post0.dev%d" % (pieces["distance"])
459
+ else:
460
+ # no commits, use the tag as the version
461
+ rendered = pieces["closest-tag"]
462
+ else:
463
+ # exception #1
464
+ rendered = "0.post0.dev%d" % pieces["distance"]
465
+ return rendered
466
+
467
+
468
+ def render_pep440_post(pieces):
469
+ """TAG[.postDISTANCE[.dev0]+gHEX] .
470
+
471
+ The ".dev0" means dirty. Note that .dev0 sorts backwards
472
+ (a dirty tree will appear "older" than the corresponding clean one),
473
+ but you shouldn't be releasing software with -dirty anyways.
474
+
475
+ Exceptions:
476
+ 1: no tags. 0.postDISTANCE[.dev0]
477
+ """
478
+ if pieces["closest-tag"]:
479
+ rendered = pieces["closest-tag"]
480
+ if pieces["distance"] or pieces["dirty"]:
481
+ rendered += ".post%d" % pieces["distance"]
482
+ if pieces["dirty"]:
483
+ rendered += ".dev0"
484
+ rendered += plus_or_dot(pieces)
485
+ rendered += "g%s" % pieces["short"]
486
+ else:
487
+ # exception #1
488
+ rendered = "0.post%d" % pieces["distance"]
489
+ if pieces["dirty"]:
490
+ rendered += ".dev0"
491
+ rendered += "+g%s" % pieces["short"]
492
+ return rendered
493
+
494
+
495
+ def render_pep440_post_branch(pieces):
496
+ """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
497
+
498
+ The ".dev0" means not master branch.
499
+
500
+ Exceptions:
501
+ 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
502
+ """
503
+ if pieces["closest-tag"]:
504
+ rendered = pieces["closest-tag"]
505
+ if pieces["distance"] or pieces["dirty"]:
506
+ rendered += ".post%d" % pieces["distance"]
507
+ if pieces["branch"] != "master":
508
+ rendered += ".dev0"
509
+ rendered += plus_or_dot(pieces)
510
+ rendered += "g%s" % pieces["short"]
511
+ if pieces["dirty"]:
512
+ rendered += ".dirty"
513
+ else:
514
+ # exception #1
515
+ rendered = "0.post%d" % pieces["distance"]
516
+ if pieces["branch"] != "master":
517
+ rendered += ".dev0"
518
+ rendered += "+g%s" % pieces["short"]
519
+ if pieces["dirty"]:
520
+ rendered += ".dirty"
521
+ return rendered
522
+
523
+
524
+ def render_pep440_old(pieces):
525
+ """TAG[.postDISTANCE[.dev0]] .
526
+
527
+ The ".dev0" means dirty.
528
+
529
+ Exceptions:
530
+ 1: no tags. 0.postDISTANCE[.dev0]
531
+ """
532
+ if pieces["closest-tag"]:
533
+ rendered = pieces["closest-tag"]
534
+ if pieces["distance"] or pieces["dirty"]:
535
+ rendered += ".post%d" % pieces["distance"]
536
+ if pieces["dirty"]:
537
+ rendered += ".dev0"
538
+ else:
539
+ # exception #1
540
+ rendered = "0.post%d" % pieces["distance"]
541
+ if pieces["dirty"]:
542
+ rendered += ".dev0"
543
+ return rendered
544
+
545
+
546
+ def render_git_describe(pieces):
547
+ """TAG[-DISTANCE-gHEX][-dirty].
548
+
549
+ Like 'git describe --tags --dirty --always'.
550
+
551
+ Exceptions:
552
+ 1: no tags. HEX[-dirty] (note: no 'g' prefix)
553
+ """
554
+ if pieces["closest-tag"]:
555
+ rendered = pieces["closest-tag"]
556
+ if pieces["distance"]:
557
+ rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
558
+ else:
559
+ # exception #1
560
+ rendered = pieces["short"]
561
+ if pieces["dirty"]:
562
+ rendered += "-dirty"
563
+ return rendered
564
+
565
+
566
+ def render_git_describe_long(pieces):
567
+ """TAG-DISTANCE-gHEX[-dirty].
568
+
569
+ Like 'git describe --tags --dirty --always -long'.
570
+ The distance/hash is unconditional.
571
+
572
+ Exceptions:
573
+ 1: no tags. HEX[-dirty] (note: no 'g' prefix)
574
+ """
575
+ if pieces["closest-tag"]:
576
+ rendered = pieces["closest-tag"]
577
+ rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
578
+ else:
579
+ # exception #1
580
+ rendered = pieces["short"]
581
+ if pieces["dirty"]:
582
+ rendered += "-dirty"
583
+ return rendered
584
+
585
+
586
+ def render(pieces, style):
587
+ """Render the given version pieces into the requested style."""
588
+ if pieces["error"]:
589
+ return {
590
+ "version": "unknown",
591
+ "full-revisionid": pieces.get("long"),
592
+ "dirty": None,
593
+ "error": pieces["error"],
594
+ "date": None,
595
+ }
596
+
597
+ if not style or style == "default":
598
+ style = "pep440" # the default
599
+
600
+ if style == "pep440":
601
+ rendered = render_pep440(pieces)
602
+ elif style == "pep440-branch":
603
+ rendered = render_pep440_branch(pieces)
604
+ elif style == "pep440-pre":
605
+ rendered = render_pep440_pre(pieces)
606
+ elif style == "pep440-post":
607
+ rendered = render_pep440_post(pieces)
608
+ elif style == "pep440-post-branch":
609
+ rendered = render_pep440_post_branch(pieces)
610
+ elif style == "pep440-old":
611
+ rendered = render_pep440_old(pieces)
612
+ elif style == "git-describe":
613
+ rendered = render_git_describe(pieces)
614
+ elif style == "git-describe-long":
615
+ rendered = render_git_describe_long(pieces)
616
+ else:
617
+ raise ValueError("unknown style '%s'" % style)
618
+
619
+ return {
620
+ "version": rendered,
621
+ "full-revisionid": pieces["long"],
622
+ "dirty": pieces["dirty"],
623
+ "error": None,
624
+ "date": pieces.get("date"),
625
+ }
626
+
627
+
628
+ def get_versions():
629
+ """Get version information or return default if unable to do so."""
630
+ # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
631
+ # __file__, we can work backwards from there to the root. Some
632
+ # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
633
+ # case we can only use expanded keywords.
634
+
635
+ cfg = get_config()
636
+ verbose = cfg.verbose
637
+
638
+ try:
639
+ return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
640
+ except NotThisMethod:
641
+ pass
642
+
643
+ try:
644
+ root = os.path.realpath(__file__)
645
+ # versionfile_source is the relative path from the top of the source
646
+ # tree (where the .git directory might live) to this file. Invert
647
+ # this to find the root from __file__.
648
+ for _ in cfg.versionfile_source.split("/"):
649
+ root = os.path.dirname(root)
650
+ except NameError:
651
+ return {
652
+ "version": "0+unknown",
653
+ "full-revisionid": None,
654
+ "dirty": None,
655
+ "error": "unable to find root of source tree",
656
+ "date": None,
657
+ }
658
+
659
+ try:
660
+ pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
661
+ return render(pieces, cfg.style)
662
+ except NotThisMethod:
663
+ pass
664
+
665
+ try:
666
+ if cfg.parentdir_prefix:
667
+ return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
668
+ except NotThisMethod:
669
+ pass
670
+
671
+ return {
672
+ "version": "0+unknown",
673
+ "full-revisionid": None,
674
+ "dirty": None,
675
+ "error": "unable to compute version",
676
+ "date": None,
677
+ }
rembg/bg.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from enum import Enum
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+ from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
9
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
10
+ from pymatting.util.util import stack_images
11
+ from scipy.ndimage.morphology import binary_erosion
12
+
13
+ from .session_base import BaseSession
14
+ from .session_factory import new_session
15
+
16
+
17
+ class ReturnType(Enum):
18
+ BYTES = 0
19
+ PILLOW = 1
20
+ NDARRAY = 2
21
+
22
+
23
+ def alpha_matting_cutout(
24
+ img: PILImage,
25
+ mask: PILImage,
26
+ foreground_threshold: int,
27
+ background_threshold: int,
28
+ erode_structure_size: int,
29
+ ) -> PILImage:
30
+ img = np.asarray(img)
31
+ mask = np.asarray(mask)
32
+
33
+ is_foreground = mask > foreground_threshold
34
+ is_background = mask < background_threshold
35
+
36
+ structure = None
37
+ if erode_structure_size > 0:
38
+ structure = np.ones(
39
+ (erode_structure_size, erode_structure_size), dtype=np.uint8
40
+ )
41
+
42
+ is_foreground = binary_erosion(is_foreground, structure=structure)
43
+ is_background = binary_erosion(is_background, structure=structure, border_value=1)
44
+
45
+ trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
46
+ trimap[is_foreground] = 255
47
+ trimap[is_background] = 0
48
+
49
+ img_normalized = img / 255.0
50
+ trimap_normalized = trimap / 255.0
51
+
52
+ alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
53
+ foreground = estimate_foreground_ml(img_normalized, alpha)
54
+ cutout = stack_images(foreground, alpha)
55
+
56
+ cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
57
+ cutout = Image.fromarray(cutout)
58
+
59
+ return cutout
60
+
61
+
62
+ def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
63
+ empty = Image.new("RGBA", (img.size), 0)
64
+ cutout = Image.composite(img, empty, mask)
65
+ return cutout
66
+
67
+
68
+ def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
69
+ pivot = imgs.pop(0)
70
+ for im in imgs:
71
+ pivot = get_concat_v(pivot, im)
72
+ return pivot
73
+
74
+
75
+ def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
76
+ dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
77
+ dst.paste(img1, (0, 0))
78
+ dst.paste(img2, (0, img1.height))
79
+ return dst
80
+
81
+
82
+ def remove(
83
+ data: Union[bytes, PILImage, np.ndarray],
84
+ alpha_matting: bool = False,
85
+ alpha_matting_foreground_threshold: int = 240,
86
+ alpha_matting_background_threshold: int = 10,
87
+ alpha_matting_erode_size: int = 10,
88
+ session: Optional[BaseSession] = None,
89
+ only_mask: bool = False,
90
+ ) -> Union[bytes, PILImage, np.ndarray]:
91
+
92
+ if isinstance(data, PILImage):
93
+ return_type = ReturnType.PILLOW
94
+ img = data
95
+ elif isinstance(data, bytes):
96
+ return_type = ReturnType.BYTES
97
+ img = Image.open(io.BytesIO(data))
98
+ elif isinstance(data, np.ndarray):
99
+ return_type = ReturnType.NDARRAY
100
+ img = Image.fromarray(data)
101
+ else:
102
+ raise ValueError("Input type {} is not supported.".format(type(data)))
103
+
104
+ if session is None:
105
+ session = new_session("u2net")
106
+
107
+ masks = session.predict(img)
108
+ cutouts = []
109
+
110
+ for mask in masks:
111
+ if only_mask:
112
+ cutout = mask
113
+
114
+ elif alpha_matting:
115
+ try:
116
+ cutout = alpha_matting_cutout(
117
+ img,
118
+ mask,
119
+ alpha_matting_foreground_threshold,
120
+ alpha_matting_background_threshold,
121
+ alpha_matting_erode_size,
122
+ )
123
+ except ValueError:
124
+ cutout = naive_cutout(img, mask)
125
+
126
+ else:
127
+ cutout = naive_cutout(img, mask)
128
+
129
+ cutouts.append(cutout)
130
+
131
+ cutout = img
132
+ if len(cutouts) > 0:
133
+ cutout = get_concat_v_multi(cutouts)
134
+
135
+ if ReturnType.PILLOW == return_type:
136
+ return cutout
137
+
138
+ if ReturnType.NDARRAY == return_type:
139
+ return np.asarray(cutout)
140
+
141
+ bio = io.BytesIO()
142
+ cutout.save(bio, "PNG")
143
+ bio.seek(0)
144
+
145
+ return bio.read()
rembg/cli.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import sys
3
+ import time
4
+ from enum import Enum
5
+ from typing import IO, cast
6
+
7
+ import aiohttp
8
+ import click
9
+ import filetype
10
+ import uvicorn
11
+ from asyncer import asyncify
12
+ from fastapi import Depends, FastAPI, File, Form, Query
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from starlette.responses import Response
15
+ from tqdm import tqdm
16
+ from watchdog.events import FileSystemEvent, FileSystemEventHandler
17
+ from watchdog.observers import Observer
18
+
19
+ from . import _version
20
+ from .bg import remove
21
+ from .session_base import BaseSession
22
+ from .session_factory import new_session
23
+
24
+
25
+ @click.group()
26
+ @click.version_option(version=_version.get_versions()["version"])
27
+ def main() -> None:
28
+ pass
29
+
30
+
31
+ @main.command(help="for a file as input")
32
+ @click.option(
33
+ "-m",
34
+ "--model",
35
+ default="u2net",
36
+ type=click.Choice(["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg"]),
37
+ show_default=True,
38
+ show_choices=True,
39
+ help="model name",
40
+ )
41
+ @click.option(
42
+ "-a",
43
+ "--alpha-matting",
44
+ is_flag=True,
45
+ show_default=True,
46
+ help="use alpha matting",
47
+ )
48
+ @click.option(
49
+ "-af",
50
+ "--alpha-matting-foreground-threshold",
51
+ default=240,
52
+ type=int,
53
+ show_default=True,
54
+ help="trimap fg threshold",
55
+ )
56
+ @click.option(
57
+ "-ab",
58
+ "--alpha-matting-background-threshold",
59
+ default=10,
60
+ type=int,
61
+ show_default=True,
62
+ help="trimap bg threshold",
63
+ )
64
+ @click.option(
65
+ "-ae",
66
+ "--alpha-matting-erode-size",
67
+ default=10,
68
+ type=int,
69
+ show_default=True,
70
+ help="erode size",
71
+ )
72
+ @click.option(
73
+ "-om",
74
+ "--only-mask",
75
+ is_flag=True,
76
+ show_default=True,
77
+ help="output only the mask",
78
+ )
79
+ @click.argument(
80
+ "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
81
+ )
82
+ @click.argument(
83
+ "output",
84
+ default=(None if sys.stdin.isatty() else "-"),
85
+ type=click.File("wb", lazy=True),
86
+ )
87
+ def i(model: str, input: IO, output: IO, **kwargs) -> None:
88
+ output.write(remove(input.read(), session=new_session(model), **kwargs))
89
+
90
+
91
+ @main.command(help="for a folder as input")
92
+ @click.option(
93
+ "-m",
94
+ "--model",
95
+ default="u2net",
96
+ type=click.Choice(["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg"]),
97
+ show_default=True,
98
+ show_choices=True,
99
+ help="model name",
100
+ )
101
+ @click.option(
102
+ "-a",
103
+ "--alpha-matting",
104
+ is_flag=True,
105
+ show_default=True,
106
+ help="use alpha matting",
107
+ )
108
+ @click.option(
109
+ "-af",
110
+ "--alpha-matting-foreground-threshold",
111
+ default=240,
112
+ type=int,
113
+ show_default=True,
114
+ help="trimap fg threshold",
115
+ )
116
+ @click.option(
117
+ "-ab",
118
+ "--alpha-matting-background-threshold",
119
+ default=10,
120
+ type=int,
121
+ show_default=True,
122
+ help="trimap bg threshold",
123
+ )
124
+ @click.option(
125
+ "-ae",
126
+ "--alpha-matting-erode-size",
127
+ default=10,
128
+ type=int,
129
+ show_default=True,
130
+ help="erode size",
131
+ )
132
+ @click.option(
133
+ "-om",
134
+ "--only-mask",
135
+ is_flag=True,
136
+ show_default=True,
137
+ help="output only the mask",
138
+ )
139
+ @click.option(
140
+ "-w",
141
+ "--watch",
142
+ default=False,
143
+ is_flag=True,
144
+ show_default=True,
145
+ help="watches a folder for changes",
146
+ )
147
+ @click.argument(
148
+ "input",
149
+ type=click.Path(
150
+ exists=True,
151
+ path_type=pathlib.Path,
152
+ file_okay=False,
153
+ dir_okay=True,
154
+ readable=True,
155
+ ),
156
+ )
157
+ @click.argument(
158
+ "output",
159
+ type=click.Path(
160
+ exists=False,
161
+ path_type=pathlib.Path,
162
+ file_okay=False,
163
+ dir_okay=True,
164
+ writable=True,
165
+ ),
166
+ )
167
+ def p(
168
+ model: str, input: pathlib.Path, output: pathlib.Path, watch: bool, **kwargs
169
+ ) -> None:
170
+ session = new_session(model)
171
+
172
+ def process(each_input: pathlib.Path) -> None:
173
+ try:
174
+ mimetype = filetype.guess(each_input)
175
+ if mimetype is None:
176
+ return
177
+ if mimetype.mime.find("image") < 0:
178
+ return
179
+
180
+ each_output = (output / each_input.name).with_suffix(".png")
181
+ each_output.parents[0].mkdir(parents=True, exist_ok=True)
182
+
183
+ if not each_output.exists():
184
+ each_output.write_bytes(
185
+ cast(
186
+ bytes,
187
+ remove(each_input.read_bytes(), session=session, **kwargs),
188
+ )
189
+ )
190
+
191
+ if watch:
192
+ print(
193
+ f"processed: {each_input.absolute()} -> {each_output.absolute()}"
194
+ )
195
+ except Exception as e:
196
+ print(e)
197
+
198
+ inputs = list(input.glob("**/*"))
199
+ if not watch:
200
+ inputs = tqdm(inputs)
201
+
202
+ for each_input in inputs:
203
+ if not each_input.is_dir():
204
+ process(each_input)
205
+
206
+ if watch:
207
+ observer = Observer()
208
+
209
+ class EventHandler(FileSystemEventHandler):
210
+ def on_any_event(self, event: FileSystemEvent) -> None:
211
+ if not (
212
+ event.is_directory or event.event_type in ["deleted", "closed"]
213
+ ):
214
+ process(pathlib.Path(event.src_path))
215
+
216
+ event_handler = EventHandler()
217
+ observer.schedule(event_handler, input, recursive=False)
218
+ observer.start()
219
+
220
+ try:
221
+ while True:
222
+ time.sleep(1)
223
+
224
+ finally:
225
+ observer.stop()
226
+ observer.join()
227
+
228
+
229
+ @main.command(help="for a http server")
230
+ @click.option(
231
+ "-p",
232
+ "--port",
233
+ default=5000,
234
+ type=int,
235
+ show_default=True,
236
+ help="port",
237
+ )
238
+ @click.option(
239
+ "-l",
240
+ "--log_level",
241
+ default="info",
242
+ type=str,
243
+ show_default=True,
244
+ help="log level",
245
+ )
246
+ def s(port: int, log_level: str) -> None:
247
+ sessions: dict[str, BaseSession] = {}
248
+ tags_metadata = [
249
+ {
250
+ "name": "Background Removal",
251
+ "description": "Endpoints that perform background removal with different image sources.",
252
+ "externalDocs": {
253
+ "description": "GitHub Source",
254
+ "url": "https://github.com/danielgatis/rembg",
255
+ },
256
+ },
257
+ ]
258
+ app = FastAPI(
259
+ title="Rembg",
260
+ description="Rembg is a tool to remove images background. That is it.",
261
+ version=_version.get_versions()["version"],
262
+ contact={
263
+ "name": "Daniel Gatis",
264
+ "url": "https://github.com/danielgatis",
265
+ "email": "[email protected]",
266
+ },
267
+ license_info={
268
+ "name": "MIT License",
269
+ "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
270
+ },
271
+ openapi_tags=tags_metadata,
272
+ )
273
+
274
+ app.add_middleware(
275
+ CORSMiddleware,
276
+ allow_credentials=True,
277
+ allow_origins=["*"],
278
+ allow_methods=["*"],
279
+ allow_headers=["*"],
280
+ )
281
+
282
+ class ModelType(str, Enum):
283
+ u2net = "u2net"
284
+ u2netp = "u2netp"
285
+ u2net_human_seg = "u2net_human_seg"
286
+ u2net_cloth_seg = "u2net_cloth_seg"
287
+
288
+ class CommonQueryParams:
289
+ def __init__(
290
+ self,
291
+ model: ModelType = Query(
292
+ default=ModelType.u2net,
293
+ description="Model to use when processing image",
294
+ ),
295
+ a: bool = Query(default=False, description="Enable Alpha Matting"),
296
+ af: int = Query(
297
+ default=240,
298
+ ge=0,
299
+ le=255,
300
+ description="Alpha Matting (Foreground Threshold)",
301
+ ),
302
+ ab: int = Query(
303
+ default=10,
304
+ ge=0,
305
+ le=255,
306
+ description="Alpha Matting (Background Threshold)",
307
+ ),
308
+ ae: int = Query(
309
+ default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
310
+ ),
311
+ om: bool = Query(default=False, description="Only Mask"),
312
+ ):
313
+ self.model = model
314
+ self.a = a
315
+ self.af = af
316
+ self.ab = ab
317
+ self.ae = ae
318
+ self.om = om
319
+
320
+ class CommonQueryPostParams:
321
+ def __init__(
322
+ self,
323
+ model: ModelType = Form(
324
+ default=ModelType.u2net,
325
+ description="Model to use when processing image",
326
+ ),
327
+ a: bool = Form(default=False, description="Enable Alpha Matting"),
328
+ af: int = Form(
329
+ default=240,
330
+ ge=0,
331
+ le=255,
332
+ description="Alpha Matting (Foreground Threshold)",
333
+ ),
334
+ ab: int = Form(
335
+ default=10,
336
+ ge=0,
337
+ le=255,
338
+ description="Alpha Matting (Background Threshold)",
339
+ ),
340
+ ae: int = Form(
341
+ default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
342
+ ),
343
+ om: bool = Form(default=False, description="Only Mask"),
344
+ ):
345
+ self.model = model
346
+ self.a = a
347
+ self.af = af
348
+ self.ab = ab
349
+ self.ae = ae
350
+ self.om = om
351
+
352
+ def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
353
+ return Response(
354
+ remove(
355
+ content,
356
+ session=sessions.setdefault(
357
+ commons.model.value, new_session(commons.model.value)
358
+ ),
359
+ alpha_matting=commons.a,
360
+ alpha_matting_foreground_threshold=commons.af,
361
+ alpha_matting_background_threshold=commons.ab,
362
+ alpha_matting_erode_size=commons.ae,
363
+ only_mask=commons.om,
364
+ ),
365
+ media_type="image/png",
366
+ )
367
+
368
+ @app.get(
369
+ path="/",
370
+ tags=["Background Removal"],
371
+ summary="Remove from URL",
372
+ description="Removes the background from an image obtained by retrieving an URL.",
373
+ )
374
+ async def get_index(
375
+ url: str = Query(
376
+ default=..., description="URL of the image that has to be processed."
377
+ ),
378
+ commons: CommonQueryParams = Depends(),
379
+ ):
380
+ async with aiohttp.ClientSession() as session:
381
+ async with session.get(url) as response:
382
+ file = await response.read()
383
+ return await asyncify(im_without_bg)(file, commons)
384
+
385
+ @app.post(
386
+ path="/",
387
+ tags=["Background Removal"],
388
+ summary="Remove from Stream",
389
+ description="Removes the background from an image sent within the request itself.",
390
+ )
391
+ async def post_index(
392
+ file: bytes = File(
393
+ default=...,
394
+ description="Image file (byte stream) that has to be processed.",
395
+ ),
396
+ commons: CommonQueryPostParams = Depends(),
397
+ ):
398
+ return await asyncify(im_without_bg)(file, commons)
399
+
400
+ uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
rembg/session_base.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple
2
+
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ from PIL import Image
6
+ from PIL.Image import Image as PILImage
7
+
8
+
9
+ class BaseSession:
10
+ def __init__(self, model_name: str, inner_session: ort.InferenceSession):
11
+ self.model_name = model_name
12
+ self.inner_session = inner_session
13
+
14
+ def normalize(
15
+ self,
16
+ img: PILImage,
17
+ mean: Tuple[float, float, float],
18
+ std: Tuple[float, float, float],
19
+ size: Tuple[int, int],
20
+ ) -> Dict[str, np.ndarray]:
21
+ im = img.convert("RGB").resize(size, Image.LANCZOS)
22
+
23
+ im_ary = np.array(im)
24
+ im_ary = im_ary / np.max(im_ary)
25
+
26
+ tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
27
+ tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
28
+ tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
29
+ tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
30
+
31
+ tmpImg = tmpImg.transpose((2, 0, 1))
32
+
33
+ return {
34
+ self.inner_session.get_inputs()[0]
35
+ .name: np.expand_dims(tmpImg, 0)
36
+ .astype(np.float32)
37
+ }
38
+
39
+ def predict(self, img: PILImage) -> List[PILImage]:
40
+ raise NotImplementedError
rembg/session_cloth.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from PIL.Image import Image as PILImage
6
+ from scipy.special import log_softmax
7
+
8
+ from .session_base import BaseSession
9
+
10
+ pallete1 = [
11
+ 0,
12
+ 0,
13
+ 0,
14
+ 255,
15
+ 255,
16
+ 255,
17
+ 0,
18
+ 0,
19
+ 0,
20
+ 0,
21
+ 0,
22
+ 0,
23
+ ]
24
+
25
+ pallete2 = [
26
+ 0,
27
+ 0,
28
+ 0,
29
+ 0,
30
+ 0,
31
+ 0,
32
+ 255,
33
+ 255,
34
+ 255,
35
+ 0,
36
+ 0,
37
+ 0,
38
+ ]
39
+
40
+ pallete3 = [
41
+ 0,
42
+ 0,
43
+ 0,
44
+ 0,
45
+ 0,
46
+ 0,
47
+ 0,
48
+ 0,
49
+ 0,
50
+ 255,
51
+ 255,
52
+ 255,
53
+ ]
54
+
55
+
56
+ class ClothSession(BaseSession):
57
+ def predict(self, img: PILImage) -> List[PILImage]:
58
+ ort_outs = self.inner_session.run(
59
+ None, self.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (768, 768))
60
+ )
61
+
62
+ pred = ort_outs
63
+ pred = log_softmax(pred[0], 1)
64
+ pred = np.argmax(pred, axis=1, keepdims=True)
65
+ pred = np.squeeze(pred, 0)
66
+ pred = np.squeeze(pred, 0)
67
+
68
+ mask = Image.fromarray(pred.astype("uint8"), mode="L")
69
+ mask = mask.resize(img.size, Image.LANCZOS)
70
+
71
+ masks = []
72
+
73
+ mask1 = mask.copy()
74
+ mask1.putpalette(pallete1)
75
+ mask1 = mask1.convert("RGB").convert("L")
76
+ masks.append(mask1)
77
+
78
+ mask2 = mask.copy()
79
+ mask2.putpalette(pallete2)
80
+ mask2 = mask2.convert("RGB").convert("L")
81
+ masks.append(mask2)
82
+
83
+ mask3 = mask.copy()
84
+ mask3.putpalette(pallete3)
85
+ mask3 = mask3.convert("RGB").convert("L")
86
+ masks.append(mask3)
87
+
88
+ return masks
rembg/session_factory.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import sys
4
+ from contextlib import redirect_stdout
5
+ from pathlib import Path
6
+ from typing import Type
7
+
8
+ import gdown
9
+ import onnxruntime as ort
10
+
11
+ from .session_base import BaseSession
12
+ from .session_cloth import ClothSession
13
+ from .session_simple import SimpleSession
14
+
15
+
16
+ def new_session(model_name: str) -> BaseSession:
17
+ session_class: Type[BaseSession]
18
+
19
+ if model_name == "u2netp":
20
+ md5 = "8e83ca70e441ab06c318d82300c84806"
21
+ url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
22
+ session_class = SimpleSession
23
+ elif model_name == "u2net":
24
+ md5 = "60024c5c889badc19c04ad937298a77b"
25
+ url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
26
+ session_class = SimpleSession
27
+ elif model_name == "u2net_human_seg":
28
+ md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
29
+ url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j"
30
+ session_class = SimpleSession
31
+ elif model_name == "u2net_cloth_seg":
32
+ md5 = "2434d1f3cb744e0e49386c906e5a08bb"
33
+ url = "https://drive.google.com/uc?id=15rKbQSXQzrKCQurUjZFg8HqzZad8bcyz"
34
+ session_class = ClothSession
35
+ else:
36
+ assert AssertionError(
37
+ "Choose between u2net, u2netp, u2net_human_seg or u2net_cloth_seg"
38
+ )
39
+
40
+ home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net"))
41
+ path = Path(home).expanduser() / f"{model_name}.onnx"
42
+ path.parents[0].mkdir(parents=True, exist_ok=True)
43
+
44
+ if not path.exists():
45
+ with redirect_stdout(sys.stderr):
46
+ gdown.download(url, str(path), use_cookies=False)
47
+ else:
48
+ hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False)
49
+ if hashing.hexdigest() != md5:
50
+ with redirect_stdout(sys.stderr):
51
+ gdown.download(url, str(path), use_cookies=False)
52
+
53
+ sess_opts = ort.SessionOptions()
54
+
55
+ if "OMP_NUM_THREADS" in os.environ:
56
+ sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
57
+
58
+ return session_class(
59
+ model_name,
60
+ ort.InferenceSession(
61
+ str(path), providers=ort.get_available_providers(), sess_options=sess_opts
62
+ ),
63
+ )
rembg/session_simple.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from PIL.Image import Image as PILImage
6
+
7
+ from .session_base import BaseSession
8
+
9
+
10
+ class SimpleSession(BaseSession):
11
+ def predict(self, img: PILImage) -> List[PILImage]:
12
+ ort_outs = self.inner_session.run(
13
+ None,
14
+ self.normalize(
15
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
16
+ ),
17
+ )
18
+
19
+ pred = ort_outs[0][:, 0, :, :]
20
+
21
+ ma = np.max(pred)
22
+ mi = np.min(pred)
23
+
24
+ pred = (pred - mi) / (ma - mi)
25
+ pred = np.squeeze(pred)
26
+
27
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
+ mask = mask.resize(img.size, Image.LANCZOS)
29
+
30
+ return [mask]
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.1
2
+ asyncer==0.0.1
3
+ click==8.0.3
4
+ fastapi==0.72.0
5
+ filetype==1.0.9
6
+ gdown==4.4.0
7
+ numpy==1.22.3
8
+ onnxruntime==1.10.0
9
+ pillow==9.0.1
10
+ pymatting==1.1.5
11
+ python-multipart==0.0.5
12
+ scikit-image==0.19.1
13
+ scipy==1.8.0
14
+ tqdm==4.62.3
15
+ uvicorn==0.17.0
16
+ watchdog==2.1.7
17
+ opencv-python