Genesis3D / infer /removebg.py
Huiwenshi's picture
Upload folder using huggingface_hub
e3e5f9e verified
import os, sys
sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
import numpy as np
from PIL import Image
from rembg import remove, new_session
from infer.utils import timing_decorator
class Removebg():
def __init__(self, name="u2net"):
self.session = new_session(name)
@timing_decorator("remove background")
def __call__(self, rgb_maybe, force=True):
'''
args:
rgb_maybe: PIL.Image, with RGB mode or RGBA mode
force: bool, if input is RGBA mode, covert to RGB then remove bg
return:
rgba_img: PIL.Image, with RGBA mode
'''
if rgb_maybe.mode == "RGBA":
if force:
rgb_maybe = rgb_maybe.convert("RGB")
rgba_img = remove(rgb_maybe, session=self.session)
else:
rgba_img = rgb_maybe
else:
rgba_img = remove(rgb_maybe, session=self.session)
rgba_img = white_out_background(rgba_img)
rgba_img = preprocess(rgba_img)
return rgba_img
def white_out_background(pil_img):
data = pil_img.getdata()
new_data = []
for r, g, b, a in data:
if a < 16: # background
new_data.append((255, 255, 255, 0)) # full white color
else:
is_white = (r>235) and (g>235) and (b>235)
new_r = 235 if is_white else r
new_g = 235 if is_white else g
new_b = 235 if is_white else b
new_data.append((new_r, new_g, new_b, a))
pil_img.putdata(new_data)
return pil_img
def preprocess(rgba_img, size=(512,512), ratio=1.15):
image = np.asarray(rgba_img)
rgb, alpha = image[:,:,:3] / 255., image[:,:,3:] / 255.
# crop
coords = np.nonzero(alpha > 0.1)
x_min, x_max = coords[0].min(), coords[0].max()
y_min, y_max = coords[1].min(), coords[1].max()
rgb = (rgb[x_min:x_max, y_min:y_max, :] * 255).astype("uint8")
alpha = (alpha[x_min:x_max, y_min:y_max, 0] * 255).astype("uint8")
# padding
h, w = rgb.shape[:2]
resize_side = int(max(h, w) * ratio)
pad_h, pad_w = resize_side - h, resize_side - w
start_h, start_w = pad_h // 2, pad_w // 2
new_rgb = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255
new_alpha = np.zeros((resize_side, resize_side), dtype=np.uint8)
new_rgb[start_h:start_h + h, start_w:start_w + w] = rgb
new_alpha[start_h:start_h + h, start_w:start_w + w] = alpha
rgba_array = np.concatenate((new_rgb, new_alpha[:,:,None]), axis=-1)
rgba_image = Image.fromarray(rgba_array, 'RGBA')
rgba_image = rgba_image.resize(size)
return rgba_image
if __name__ == "__main__":
import argparse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--rgb_path", type=str, required=True)
parser.add_argument("--output_rgba_path", type=str, required=True)
parser.add_argument("--force", default=False, action="store_true")
return parser.parse_args()
args = get_args()
rgb_maybe = Image.open(args.rgb_path)
model = Removebg()
rgba_pil = model(rgb_maybe, args.force)
rgba_pil.save(args.output_rgba_path)