File size: 4,055 Bytes
3fcc660 3d3a8e1 bdf6b32 3fcc660 bdf6b32 3fcc660 3d3a8e1 bdf6b32 3fcc660 bdf6b32 3fcc660 bdf6b32 3fcc660 bdf6b32 3fcc660 bdf6b32 3fcc660 bdf6b32 3fcc660 bdf6b32 3fcc660 bdf6b32 3fcc660 bdf6b32 3fcc660 bdf6b32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
import torch
import uuid
import base64
import numpy as np
import onnxruntime as ort
import cv2
from PIL import Image
from torchvision.transforms.functional import normalize
import torch.nn.functional as F
from typing import Union, List
from io import BytesIO
from huggingface_hub import hf_hub_download
# ---- Config ----
INPUT_SIZE = [1200, 1800] # (H, W)
# ---- Load ONNX model ----
model_path = hf_hub_download(repo_id="Trendyol/background-removal", filename="model.onnx")
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
try:
ort_sess = ort.InferenceSession(model_path, providers=providers)
except Exception:
ort_sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
# ---- Utils from Trendyol ----
def keep_large_components(a: np.ndarray) -> np.ndarray:
dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
a_mask = (a > 25).astype(np.uint8) * 255
analysis = cv2.connectedComponentsWithStats(a_mask, 4, cv2.CV_32S)
(totalLabels, label_ids, values, _) = analysis
h, w = a.shape[:2]
area_limit = 50000 * (h * w) / (INPUT_SIZE[1] * INPUT_SIZE[0])
i_to_keep = []
for i in range(1, totalLabels):
area = values[i, cv2.CC_STAT_AREA]
if area > area_limit:
i_to_keep.append(i)
if len(i_to_keep) > 0:
final_mask = np.zeros_like(a, dtype=np.uint8)
for i in i_to_keep:
componentMask = (label_ids == i).astype("uint8") * 255
final_mask = cv2.bitwise_or(final_mask, componentMask)
final_mask = cv2.dilate(final_mask, dilate_kernel, iterations=2)
a = cv2.bitwise_and(a, final_mask)
a = a.reshape((a.shape[0], a.shape[1], 1))
return a
def preprocess_input(im: np.ndarray) -> torch.Tensor:
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
if im.shape[2] == 4:
im = im[:, :, :3]
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), INPUT_SIZE, mode="bilinear").type(torch.uint8)
image = torch.divide(im_tensor, 255.0)
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
return image
def postprocess_output(result: np.ndarray, orig_im_shape) -> np.ndarray:
result = torch.squeeze(
F.upsample(torch.from_numpy(result).unsqueeze(0), (orig_im_shape), mode="bilinear"), 0
)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi + 1e-8)
a = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
a = keep_large_components(a)
return a
# ---- Core processing ----
def process(image: Image.Image) -> Image.Image:
image_size = image.size
np_img = np.array(image.convert("RGB"))
# Preprocess
img_tensor = preprocess_input(np_img)
# Inference
inputs = {ort_sess.get_inputs()[0].name: img_tensor.numpy()}
result = ort_sess.run(None, inputs)[0][0] # (1,1,H,W)
# Postprocess to mask
alpha = postprocess_output(result, (np_img.shape[0], np_img.shape[1])) # (H,W,1)
# White background composite
mask = Image.fromarray(alpha.squeeze(-1)).convert("L")
binary_mask = mask.point(lambda p: 255 if p > 25 else 0)
white_bg = Image.new("RGB", image_size, (255, 255, 255))
result = Image.composite(image.convert("RGB"), white_bg, binary_mask)
return result
# ---- Gradio handler ----
def handler(image=None) -> Union[str, None]:
if image is not None:
processed = process(image)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
return None
# ---- Gradio UI ----
demo = gr.Interface(
fn=handler,
inputs=gr.Image(label="Upload Image", type="pil"),
outputs=gr.File(label="Output File"),
title="Background Remover (Trendyol)",
description="Upload an image to remove the background with the Trendyol ONNX model. Background is replaced with white.",
)
if __name__ == "__main__":
demo.launch(show_error=True)
|