RemBG_super / rembg /commands /b_command.py
KenjieDec's picture
Update
5f57808
import asyncio
import io
import json
import os
import sys
from typing import IO
import click
from PIL import Image
from ..bg import remove
from ..session_factory import new_session
from ..sessions import sessions_names
@click.command(
name="b",
help="for a byte stream as input",
)
@click.option(
"-m",
"--model",
default="u2net",
type=click.Choice(sessions_names),
show_default=True,
show_choices=True,
help="model name",
)
@click.option(
"-a",
"--alpha-matting",
is_flag=True,
show_default=True,
help="use alpha matting",
)
@click.option(
"-af",
"--alpha-matting-foreground-threshold",
default=240,
type=int,
show_default=True,
help="trimap fg threshold",
)
@click.option(
"-ab",
"--alpha-matting-background-threshold",
default=10,
type=int,
show_default=True,
help="trimap bg threshold",
)
@click.option(
"-ae",
"--alpha-matting-erode-size",
default=10,
type=int,
show_default=True,
help="erode size",
)
@click.option(
"-om",
"--only-mask",
is_flag=True,
show_default=True,
help="output only the mask",
)
@click.option(
"-ppm",
"--post-process-mask",
is_flag=True,
show_default=True,
help="post process the mask",
)
@click.option(
"-bgc",
"--bgcolor",
default=None,
type=(int, int, int, int),
nargs=4,
help="Background color (R G B A) to replace the removed background with",
)
@click.option("-x", "--extras", type=str)
@click.option(
"-o",
"--output_specifier",
type=str,
help="printf-style specifier for output filenames (e.g. 'output-%d.png'))",
)
@click.argument(
"image_width",
type=int,
)
@click.argument(
"image_height",
type=int,
)
def rs_command(
model: str,
extras: str,
image_width: int,
image_height: int,
output_specifier: str,
**kwargs
) -> None:
try:
kwargs.update(json.loads(extras))
except Exception:
pass
session = new_session(model)
bytes_per_img = image_width * image_height * 3
if output_specifier:
output_dir = os.path.dirname(
os.path.abspath(os.path.expanduser(output_specifier))
)
if not os.path.isdir(output_dir):
os.makedirs(output_dir, exist_ok=True)
def img_to_byte_array(img: Image) -> bytes:
buff = io.BytesIO()
img.save(buff, format="PNG")
return buff.getvalue()
async def connect_stdin_stdout():
loop = asyncio.get_event_loop()
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
w_transport, w_protocol = await loop.connect_write_pipe(
asyncio.streams.FlowControlMixin, sys.stdout
)
writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
return reader, writer
async def main():
reader, writer = await connect_stdin_stdout()
idx = 0
while True:
try:
img_bytes = await reader.readexactly(bytes_per_img)
if not img_bytes:
break
img = Image.frombytes("RGB", (image_width, image_height), img_bytes)
output = remove(img, session=session, **kwargs)
if output_specifier:
output.save((output_specifier % idx), format="PNG")
else:
writer.write(img_to_byte_array(output))
idx += 1
except asyncio.IncompleteReadError:
break
asyncio.run(main())