File size: 4,055 Bytes
3fcc660
 
 
3d3a8e1
bdf6b32
 
 
3fcc660
bdf6b32
 
3fcc660
3d3a8e1
bdf6b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fcc660
 
bdf6b32
3fcc660
bdf6b32
 
3fcc660
bdf6b32
 
 
3fcc660
bdf6b32
 
 
 
 
 
3fcc660
bdf6b32
3fcc660
 
 
bdf6b32
 
 
 
 
 
 
3fcc660
 
bdf6b32
 
3fcc660
 
bdf6b32
 
 
 
3fcc660
 
 
bdf6b32
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import torch
import uuid
import base64
import numpy as np
import onnxruntime as ort
import cv2
from PIL import Image
from torchvision.transforms.functional import normalize
import torch.nn.functional as F
from typing import Union, List
from io import BytesIO
from huggingface_hub import hf_hub_download

# ---- Config ----
INPUT_SIZE = [1200, 1800]  # (H, W)

# ---- Load ONNX model ----
model_path = hf_hub_download(repo_id="Trendyol/background-removal", filename="model.onnx")
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
try:
    ort_sess = ort.InferenceSession(model_path, providers=providers)
except Exception:
    ort_sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])


# ---- Utils from Trendyol ----
def keep_large_components(a: np.ndarray) -> np.ndarray:
    dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
    a_mask = (a > 25).astype(np.uint8) * 255
    analysis = cv2.connectedComponentsWithStats(a_mask, 4, cv2.CV_32S)
    (totalLabels, label_ids, values, _) = analysis

    h, w = a.shape[:2]
    area_limit = 50000 * (h * w) / (INPUT_SIZE[1] * INPUT_SIZE[0])
    i_to_keep = []
    for i in range(1, totalLabels):
        area = values[i, cv2.CC_STAT_AREA]
        if area > area_limit:
            i_to_keep.append(i)

    if len(i_to_keep) > 0:
        final_mask = np.zeros_like(a, dtype=np.uint8)
        for i in i_to_keep:
            componentMask = (label_ids == i).astype("uint8") * 255
            final_mask = cv2.bitwise_or(final_mask, componentMask)
        final_mask = cv2.dilate(final_mask, dilate_kernel, iterations=2)
        a = cv2.bitwise_and(a, final_mask)

    a = a.reshape((a.shape[0], a.shape[1], 1))
    return a


def preprocess_input(im: np.ndarray) -> torch.Tensor:
    if len(im.shape) < 3:
        im = im[:, :, np.newaxis]
    if im.shape[2] == 4:
        im = im[:, :, :3]
    im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), INPUT_SIZE, mode="bilinear").type(torch.uint8)
    image = torch.divide(im_tensor, 255.0)
    image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
    return image


def postprocess_output(result: np.ndarray, orig_im_shape) -> np.ndarray:
    result = torch.squeeze(
        F.upsample(torch.from_numpy(result).unsqueeze(0), (orig_im_shape), mode="bilinear"), 0
    )
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi + 1e-8)
    a = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
    a = keep_large_components(a)
    return a


# ---- Core processing ----
def process(image: Image.Image) -> Image.Image:
    image_size = image.size
    np_img = np.array(image.convert("RGB"))

    # Preprocess
    img_tensor = preprocess_input(np_img)

    # Inference
    inputs = {ort_sess.get_inputs()[0].name: img_tensor.numpy()}
    result = ort_sess.run(None, inputs)[0][0]  # (1,1,H,W)

    # Postprocess to mask
    alpha = postprocess_output(result, (np_img.shape[0], np_img.shape[1]))  # (H,W,1)

    # White background composite
    mask = Image.fromarray(alpha.squeeze(-1)).convert("L")
    binary_mask = mask.point(lambda p: 255 if p > 25 else 0)
    white_bg = Image.new("RGB", image_size, (255, 255, 255))
    result = Image.composite(image.convert("RGB"), white_bg, binary_mask)
    return result


# ---- Gradio handler ----
def handler(image=None) -> Union[str, None]:
    if image is not None:
        processed = process(image)
        filename = f"output_{uuid.uuid4().hex[:8]}.png"
        processed.save(filename)
        return filename
    return None


# ---- Gradio UI ----
demo = gr.Interface(
    fn=handler,
    inputs=gr.Image(label="Upload Image", type="pil"),
    outputs=gr.File(label="Output File"),
    title="Background Remover (Trendyol)",
    description="Upload an image to remove the background with the Trendyol ONNX model. Background is replaced with white.",
)

if __name__ == "__main__":
    demo.launch(show_error=True)