|
import gradio as gr |
|
import torch |
|
import uuid |
|
import base64 |
|
import numpy as np |
|
import onnxruntime as ort |
|
import cv2 |
|
from PIL import Image |
|
from torchvision.transforms.functional import normalize |
|
import torch.nn.functional as F |
|
from typing import Union, List |
|
from io import BytesIO |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
INPUT_SIZE = [1200, 1800] |
|
|
|
|
|
model_path = hf_hub_download(repo_id="Trendyol/background-removal", filename="model.onnx") |
|
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
|
try: |
|
ort_sess = ort.InferenceSession(model_path, providers=providers) |
|
except Exception: |
|
ort_sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) |
|
|
|
|
|
|
|
def keep_large_components(a: np.ndarray) -> np.ndarray: |
|
dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9)) |
|
a_mask = (a > 25).astype(np.uint8) * 255 |
|
analysis = cv2.connectedComponentsWithStats(a_mask, 4, cv2.CV_32S) |
|
(totalLabels, label_ids, values, _) = analysis |
|
|
|
h, w = a.shape[:2] |
|
area_limit = 50000 * (h * w) / (INPUT_SIZE[1] * INPUT_SIZE[0]) |
|
i_to_keep = [] |
|
for i in range(1, totalLabels): |
|
area = values[i, cv2.CC_STAT_AREA] |
|
if area > area_limit: |
|
i_to_keep.append(i) |
|
|
|
if len(i_to_keep) > 0: |
|
final_mask = np.zeros_like(a, dtype=np.uint8) |
|
for i in i_to_keep: |
|
componentMask = (label_ids == i).astype("uint8") * 255 |
|
final_mask = cv2.bitwise_or(final_mask, componentMask) |
|
final_mask = cv2.dilate(final_mask, dilate_kernel, iterations=2) |
|
a = cv2.bitwise_and(a, final_mask) |
|
|
|
a = a.reshape((a.shape[0], a.shape[1], 1)) |
|
return a |
|
|
|
|
|
def preprocess_input(im: np.ndarray) -> torch.Tensor: |
|
if len(im.shape) < 3: |
|
im = im[:, :, np.newaxis] |
|
if im.shape[2] == 4: |
|
im = im[:, :, :3] |
|
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1) |
|
im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), INPUT_SIZE, mode="bilinear").type(torch.uint8) |
|
image = torch.divide(im_tensor, 255.0) |
|
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) |
|
return image |
|
|
|
|
|
def postprocess_output(result: np.ndarray, orig_im_shape) -> np.ndarray: |
|
result = torch.squeeze( |
|
F.upsample(torch.from_numpy(result).unsqueeze(0), (orig_im_shape), mode="bilinear"), 0 |
|
) |
|
ma = torch.max(result) |
|
mi = torch.min(result) |
|
result = (result - mi) / (ma - mi + 1e-8) |
|
a = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8) |
|
a = keep_large_components(a) |
|
return a |
|
|
|
|
|
|
|
def process(image: Image.Image) -> Image.Image: |
|
image_size = image.size |
|
np_img = np.array(image.convert("RGB")) |
|
|
|
|
|
img_tensor = preprocess_input(np_img) |
|
|
|
|
|
inputs = {ort_sess.get_inputs()[0].name: img_tensor.numpy()} |
|
result = ort_sess.run(None, inputs)[0][0] |
|
|
|
|
|
alpha = postprocess_output(result, (np_img.shape[0], np_img.shape[1])) |
|
|
|
|
|
mask = Image.fromarray(alpha.squeeze(-1)).convert("L") |
|
binary_mask = mask.point(lambda p: 255 if p > 25 else 0) |
|
white_bg = Image.new("RGB", image_size, (255, 255, 255)) |
|
result = Image.composite(image.convert("RGB"), white_bg, binary_mask) |
|
return result |
|
|
|
|
|
|
|
def handler(image=None) -> Union[str, None]: |
|
if image is not None: |
|
processed = process(image) |
|
filename = f"output_{uuid.uuid4().hex[:8]}.png" |
|
processed.save(filename) |
|
return filename |
|
return None |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=handler, |
|
inputs=gr.Image(label="Upload Image", type="pil"), |
|
outputs=gr.File(label="Output File"), |
|
title="Background Remover (Trendyol)", |
|
description="Upload an image to remove the background with the Trendyol ONNX model. Background is replaced with white.", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(show_error=True) |
|
|