|
import os, sys |
|
sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}") |
|
|
|
import numpy as np |
|
from PIL import Image |
|
from rembg import remove, new_session |
|
from infer.utils import timing_decorator |
|
|
|
class Removebg(): |
|
def __init__(self, name="u2net"): |
|
self.session = new_session(name) |
|
|
|
@timing_decorator("remove background") |
|
def __call__(self, rgb_maybe, force=True): |
|
''' |
|
args: |
|
rgb_maybe: PIL.Image, with RGB mode or RGBA mode |
|
force: bool, if input is RGBA mode, covert to RGB then remove bg |
|
return: |
|
rgba_img: PIL.Image, with RGBA mode |
|
''' |
|
if rgb_maybe.mode == "RGBA": |
|
if force: |
|
rgb_maybe = rgb_maybe.convert("RGB") |
|
rgba_img = remove(rgb_maybe, session=self.session) |
|
else: |
|
rgba_img = rgb_maybe |
|
else: |
|
rgba_img = remove(rgb_maybe, session=self.session) |
|
|
|
rgba_img = white_out_background(rgba_img) |
|
|
|
rgba_img = preprocess(rgba_img) |
|
|
|
return rgba_img |
|
|
|
|
|
def white_out_background(pil_img): |
|
data = pil_img.getdata() |
|
new_data = [] |
|
for r, g, b, a in data: |
|
if a < 16: |
|
new_data.append((255, 255, 255, 0)) |
|
else: |
|
is_white = (r>235) and (g>235) and (b>235) |
|
new_r = 235 if is_white else r |
|
new_g = 235 if is_white else g |
|
new_b = 235 if is_white else b |
|
new_data.append((new_r, new_g, new_b, a)) |
|
pil_img.putdata(new_data) |
|
return pil_img |
|
|
|
def preprocess(rgba_img, size=(512,512), ratio=1.15): |
|
image = np.asarray(rgba_img) |
|
rgb, alpha = image[:,:,:3] / 255., image[:,:,3:] / 255. |
|
|
|
|
|
coords = np.nonzero(alpha > 0.1) |
|
x_min, x_max = coords[0].min(), coords[0].max() |
|
y_min, y_max = coords[1].min(), coords[1].max() |
|
rgb = (rgb[x_min:x_max, y_min:y_max, :] * 255).astype("uint8") |
|
alpha = (alpha[x_min:x_max, y_min:y_max, 0] * 255).astype("uint8") |
|
|
|
|
|
h, w = rgb.shape[:2] |
|
resize_side = int(max(h, w) * ratio) |
|
pad_h, pad_w = resize_side - h, resize_side - w |
|
start_h, start_w = pad_h // 2, pad_w // 2 |
|
new_rgb = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255 |
|
new_alpha = np.zeros((resize_side, resize_side), dtype=np.uint8) |
|
new_rgb[start_h:start_h + h, start_w:start_w + w] = rgb |
|
new_alpha[start_h:start_h + h, start_w:start_w + w] = alpha |
|
rgba_array = np.concatenate((new_rgb, new_alpha[:,:,None]), axis=-1) |
|
|
|
rgba_image = Image.fromarray(rgba_array, 'RGBA') |
|
rgba_image = rgba_image.resize(size) |
|
return rgba_image |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import argparse |
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--rgb_path", type=str, required=True) |
|
parser.add_argument("--output_rgba_path", type=str, required=True) |
|
parser.add_argument("--force", default=False, action="store_true") |
|
return parser.parse_args() |
|
|
|
args = get_args() |
|
|
|
rgb_maybe = Image.open(args.rgb_path) |
|
|
|
model = Removebg() |
|
|
|
rgba_pil = model(rgb_maybe, args.force) |
|
|
|
rgba_pil.save(args.output_rgba_path) |
|
|
|
|