import gradio as gr import torch import uuid import base64 import numpy as np import onnxruntime as ort import cv2 from PIL import Image from torchvision.transforms.functional import normalize import torch.nn.functional as F from typing import Union, List from io import BytesIO from huggingface_hub import hf_hub_download # ---- Config ---- INPUT_SIZE = [1200, 1800] # (H, W) # ---- Load ONNX model ---- model_path = hf_hub_download(repo_id="Trendyol/background-removal", filename="model.onnx") providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] try: ort_sess = ort.InferenceSession(model_path, providers=providers) except Exception: ort_sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) # ---- Utils from Trendyol ---- def keep_large_components(a: np.ndarray) -> np.ndarray: dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9)) a_mask = (a > 25).astype(np.uint8) * 255 analysis = cv2.connectedComponentsWithStats(a_mask, 4, cv2.CV_32S) (totalLabels, label_ids, values, _) = analysis h, w = a.shape[:2] area_limit = 50000 * (h * w) / (INPUT_SIZE[1] * INPUT_SIZE[0]) i_to_keep = [] for i in range(1, totalLabels): area = values[i, cv2.CC_STAT_AREA] if area > area_limit: i_to_keep.append(i) if len(i_to_keep) > 0: final_mask = np.zeros_like(a, dtype=np.uint8) for i in i_to_keep: componentMask = (label_ids == i).astype("uint8") * 255 final_mask = cv2.bitwise_or(final_mask, componentMask) final_mask = cv2.dilate(final_mask, dilate_kernel, iterations=2) a = cv2.bitwise_and(a, final_mask) a = a.reshape((a.shape[0], a.shape[1], 1)) return a def preprocess_input(im: np.ndarray) -> torch.Tensor: if len(im.shape) < 3: im = im[:, :, np.newaxis] if im.shape[2] == 4: im = im[:, :, :3] im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1) im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), INPUT_SIZE, mode="bilinear").type(torch.uint8) image = torch.divide(im_tensor, 255.0) image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) return image def postprocess_output(result: np.ndarray, orig_im_shape) -> np.ndarray: result = torch.squeeze( F.upsample(torch.from_numpy(result).unsqueeze(0), (orig_im_shape), mode="bilinear"), 0 ) ma = torch.max(result) mi = torch.min(result) result = (result - mi) / (ma - mi + 1e-8) a = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8) a = keep_large_components(a) return a # ---- Core processing ---- def process(image: Image.Image) -> Image.Image: image_size = image.size np_img = np.array(image.convert("RGB")) # Preprocess img_tensor = preprocess_input(np_img) # Inference inputs = {ort_sess.get_inputs()[0].name: img_tensor.numpy()} result = ort_sess.run(None, inputs)[0][0] # (1,1,H,W) # Postprocess to mask alpha = postprocess_output(result, (np_img.shape[0], np_img.shape[1])) # (H,W,1) # White background composite mask = Image.fromarray(alpha.squeeze(-1)).convert("L") binary_mask = mask.point(lambda p: 255 if p > 25 else 0) white_bg = Image.new("RGB", image_size, (255, 255, 255)) result = Image.composite(image.convert("RGB"), white_bg, binary_mask) return result # ---- Gradio handler ---- def handler(image=None) -> Union[str, None]: if image is not None: processed = process(image) filename = f"output_{uuid.uuid4().hex[:8]}.png" processed.save(filename) return filename return None # ---- Gradio UI ---- demo = gr.Interface( fn=handler, inputs=gr.Image(label="Upload Image", type="pil"), outputs=gr.File(label="Output File"), title="Background Remover (Trendyol)", description="Upload an image to remove the background with the Trendyol ONNX model. Background is replaced with white.", ) if __name__ == "__main__": demo.launch(show_error=True)