user-agent's picture
Update app.py
bdf6b32 verified
import gradio as gr
import torch
import uuid
import base64
import numpy as np
import onnxruntime as ort
import cv2
from PIL import Image
from torchvision.transforms.functional import normalize
import torch.nn.functional as F
from typing import Union, List
from io import BytesIO
from huggingface_hub import hf_hub_download
# ---- Config ----
INPUT_SIZE = [1200, 1800] # (H, W)
# ---- Load ONNX model ----
model_path = hf_hub_download(repo_id="Trendyol/background-removal", filename="model.onnx")
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
try:
ort_sess = ort.InferenceSession(model_path, providers=providers)
except Exception:
ort_sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
# ---- Utils from Trendyol ----
def keep_large_components(a: np.ndarray) -> np.ndarray:
dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
a_mask = (a > 25).astype(np.uint8) * 255
analysis = cv2.connectedComponentsWithStats(a_mask, 4, cv2.CV_32S)
(totalLabels, label_ids, values, _) = analysis
h, w = a.shape[:2]
area_limit = 50000 * (h * w) / (INPUT_SIZE[1] * INPUT_SIZE[0])
i_to_keep = []
for i in range(1, totalLabels):
area = values[i, cv2.CC_STAT_AREA]
if area > area_limit:
i_to_keep.append(i)
if len(i_to_keep) > 0:
final_mask = np.zeros_like(a, dtype=np.uint8)
for i in i_to_keep:
componentMask = (label_ids == i).astype("uint8") * 255
final_mask = cv2.bitwise_or(final_mask, componentMask)
final_mask = cv2.dilate(final_mask, dilate_kernel, iterations=2)
a = cv2.bitwise_and(a, final_mask)
a = a.reshape((a.shape[0], a.shape[1], 1))
return a
def preprocess_input(im: np.ndarray) -> torch.Tensor:
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
if im.shape[2] == 4:
im = im[:, :, :3]
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), INPUT_SIZE, mode="bilinear").type(torch.uint8)
image = torch.divide(im_tensor, 255.0)
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
return image
def postprocess_output(result: np.ndarray, orig_im_shape) -> np.ndarray:
result = torch.squeeze(
F.upsample(torch.from_numpy(result).unsqueeze(0), (orig_im_shape), mode="bilinear"), 0
)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi + 1e-8)
a = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
a = keep_large_components(a)
return a
# ---- Core processing ----
def process(image: Image.Image) -> Image.Image:
image_size = image.size
np_img = np.array(image.convert("RGB"))
# Preprocess
img_tensor = preprocess_input(np_img)
# Inference
inputs = {ort_sess.get_inputs()[0].name: img_tensor.numpy()}
result = ort_sess.run(None, inputs)[0][0] # (1,1,H,W)
# Postprocess to mask
alpha = postprocess_output(result, (np_img.shape[0], np_img.shape[1])) # (H,W,1)
# White background composite
mask = Image.fromarray(alpha.squeeze(-1)).convert("L")
binary_mask = mask.point(lambda p: 255 if p > 25 else 0)
white_bg = Image.new("RGB", image_size, (255, 255, 255))
result = Image.composite(image.convert("RGB"), white_bg, binary_mask)
return result
# ---- Gradio handler ----
def handler(image=None) -> Union[str, None]:
if image is not None:
processed = process(image)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
return None
# ---- Gradio UI ----
demo = gr.Interface(
fn=handler,
inputs=gr.Image(label="Upload Image", type="pil"),
outputs=gr.File(label="Output File"),
title="Background Remover (Trendyol)",
description="Upload an image to remove the background with the Trendyol ONNX model. Background is replaced with white.",
)
if __name__ == "__main__":
demo.launch(show_error=True)