user-agent commited on
Commit
bdf6b32
·
verified ·
1 Parent(s): 0623b90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -184
app.py CHANGED
@@ -1,207 +1,121 @@
1
- # import gradio as gr
2
- # import torch
3
- # import uuid
4
- # from PIL import Image
5
- # from torchvision import transforms
6
- # from transformers import AutoModelForImageSegmentation
7
- # from typing import Union, List
8
- # from loadimg import load_img # Your helper to load from URL or file
9
-
10
- # torch.set_float32_matmul_precision("high")
11
-
12
- # # Load BiRefNet model
13
- # birefnet = AutoModelForImageSegmentation.from_pretrained(
14
- # "ZhengPeng7/BiRefNet", trust_remote_code=True
15
- # )
16
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- # birefnet.to(device)
18
-
19
- # # Image transformation
20
- # transform_image = transforms.Compose([
21
- # transforms.Resize((1024, 1024)),
22
- # transforms.ToTensor(),
23
- # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
24
- # ])
25
-
26
- # def process(image: Image.Image) -> Image.Image:
27
- # image_size = image.size
28
- # input_tensor = transform_image(image).unsqueeze(0).to(device)
29
-
30
- # with torch.no_grad():
31
- # preds = birefnet(input_tensor)[-1].sigmoid().cpu()
32
-
33
- # pred = preds[0].squeeze()
34
- # mask = transforms.ToPILImage()(pred).resize(image_size).convert("L")
35
- # binary_mask = mask.point(lambda p: 255 if p > 127 else 0)
36
-
37
- # white_bg = Image.new("RGB", image_size, (255, 255, 255))
38
- # result = Image.composite(image, white_bg, binary_mask)
39
- # return result
40
-
41
- # def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
42
- # results = []
43
-
44
- # try:
45
- # # Single image upload
46
- # if image is not None:
47
- # image = image.convert("RGB")
48
- # processed = process(image)
49
- # filename = f"output_{uuid.uuid4().hex[:8]}.png"
50
- # processed.save(filename)
51
- # return filename
52
-
53
- # # Single image from URL
54
- # if image_url:
55
- # im = load_img(image_url, output_type="pil").convert("RGB")
56
- # processed = process(im)
57
- # filename = f"output_{uuid.uuid4().hex[:8]}.png"
58
- # processed.save(filename)
59
- # return filename
60
-
61
- # # Batch of URLs
62
- # if batch_urls:
63
- # urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
64
- # for url in urls:
65
- # try:
66
- # im = load_img(url, output_type="pil").convert("RGB")
67
- # processed = process(im)
68
- # filename = f"output_{uuid.uuid4().hex[:8]}.png"
69
- # processed.save(filename)
70
- # results.append(filename)
71
- # except Exception as e:
72
- # print(f"Error with {url}: {e}")
73
- # return results if results else None
74
-
75
- # except Exception as e:
76
- # print("General error:", e)
77
-
78
- # return None
79
-
80
- # # Interface
81
- # demo = gr.Interface(
82
- # fn=handler,
83
- # inputs=[
84
- # gr.Image(label="Upload Image", type="pil"),
85
- # gr.Textbox(label="Paste Image URL"),
86
- # gr.Textbox(label="Comma-separated Image URLs (Batch)"),
87
- # ],
88
- # outputs=gr.File(label="Output File(s)", file_count="multiple"),
89
- # title="Background Remover (White Fill)",
90
- # description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
91
- # )
92
-
93
- # if __name__ == "__main__":
94
- # demo.launch(show_error=True, mcp_server=True)
95
-
96
-
97
-
98
  import gradio as gr
99
  import torch
100
  import uuid
101
  import base64
 
 
 
102
  from PIL import Image
103
- from torchvision import transforms
104
- from transformers import AutoModelForImageSegmentation
105
  from typing import Union, List
106
- from loadimg import load_img # Your helper to load from URL or file
107
  from io import BytesIO
108
-
109
- torch.set_float32_matmul_precision("high")
110
-
111
- # Load BiRefNet model
112
- birefnet = AutoModelForImageSegmentation.from_pretrained(
113
- "ZhengPeng7/BiRefNet", trust_remote_code=True
114
- )
115
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
- birefnet.to(device)
117
-
118
- # Image transformation
119
- transform_image = transforms.Compose([
120
- transforms.Resize((1024, 1024)),
121
- transforms.ToTensor(),
122
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
123
- ])
124
-
125
- def load_image_from_data_url(data_url: str) -> Image.Image:
126
- """Load image from base64 data URL"""
127
- if data_url.startswith("data:image/"):
128
- # Extract base64 data after the comma
129
- if "," in data_url:
130
- header, encoded = data_url.split(",", 1)
131
- image_data = base64.b64decode(encoded)
132
- return Image.open(BytesIO(image_data))
133
- else:
134
- raise ValueError(f"Invalid data URL format: {data_url[:50]}...")
135
- else:
136
- # Regular URL, use existing load_img function
137
- return load_img(data_url, output_type="pil")
138
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def process(image: Image.Image) -> Image.Image:
140
  image_size = image.size
141
- input_tensor = transform_image(image).unsqueeze(0).to(device)
142
 
143
- with torch.no_grad():
144
- preds = birefnet(input_tensor)[-1].sigmoid().cpu()
145
 
146
- pred = preds[0].squeeze()
147
- mask = transforms.ToPILImage()(pred).resize(image_size).convert("L")
148
- binary_mask = mask.point(lambda p: 255 if p > 127 else 0)
149
 
 
 
 
 
 
 
150
  white_bg = Image.new("RGB", image_size, (255, 255, 255))
151
- result = Image.composite(image, white_bg, binary_mask)
152
  return result
153
 
154
- def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
155
- results = []
156
-
157
- try:
158
- # Single image upload
159
- if image is not None:
160
- image = image.convert("RGB")
161
- processed = process(image)
162
- filename = f"output_{uuid.uuid4().hex[:8]}.png"
163
- processed.save(filename)
164
- return filename
165
-
166
- # Single image from URL (supports both regular URLs and base64 data URLs)
167
- if image_url:
168
- im = load_image_from_data_url(image_url).convert("RGB")
169
- processed = process(im)
170
- filename = f"output_{uuid.uuid4().hex[:8]}.png"
171
- processed.save(filename)
172
- return filename
173
-
174
- # Batch of URLs (supports both regular URLs and base64 data URLs)
175
- if batch_urls:
176
- urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
177
- for url in urls:
178
- try:
179
- im = load_image_from_data_url(url).convert("RGB")
180
- processed = process(im)
181
- filename = f"output_{uuid.uuid4().hex[:8]}.png"
182
- processed.save(filename)
183
- results.append(filename)
184
- except Exception as e:
185
- print(f"Error with {url}: {e}")
186
- return results if results else None
187
-
188
- except Exception as e:
189
- print("General error:", e)
190
 
 
 
 
 
 
 
 
191
  return None
192
 
193
- # Interface
 
194
  demo = gr.Interface(
195
  fn=handler,
196
- inputs=[
197
- gr.Image(label="Upload Image", type="pil"),
198
- gr.Textbox(label="Paste Image URL"),
199
- gr.Textbox(label="Comma-separated Image URLs (Batch)"),
200
- ],
201
- outputs=gr.File(label="Output File(s)", file_count="multiple"),
202
- title="Background Remover (White Fill)",
203
- description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
204
  )
205
 
206
  if __name__ == "__main__":
207
- demo.launch(show_error=True, mcp_server=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import uuid
4
  import base64
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+ import cv2
8
  from PIL import Image
9
+ from torchvision.transforms.functional import normalize
10
+ import torch.nn.functional as F
11
  from typing import Union, List
 
12
  from io import BytesIO
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ # ---- Config ----
16
+ INPUT_SIZE = [1200, 1800] # (H, W)
17
+
18
+ # ---- Load ONNX model ----
19
+ model_path = hf_hub_download(repo_id="Trendyol/background-removal", filename="model.onnx")
20
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
21
+ try:
22
+ ort_sess = ort.InferenceSession(model_path, providers=providers)
23
+ except Exception:
24
+ ort_sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
25
+
26
+
27
+ # ---- Utils from Trendyol ----
28
+ def keep_large_components(a: np.ndarray) -> np.ndarray:
29
+ dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
30
+ a_mask = (a > 25).astype(np.uint8) * 255
31
+ analysis = cv2.connectedComponentsWithStats(a_mask, 4, cv2.CV_32S)
32
+ (totalLabels, label_ids, values, _) = analysis
33
+
34
+ h, w = a.shape[:2]
35
+ area_limit = 50000 * (h * w) / (INPUT_SIZE[1] * INPUT_SIZE[0])
36
+ i_to_keep = []
37
+ for i in range(1, totalLabels):
38
+ area = values[i, cv2.CC_STAT_AREA]
39
+ if area > area_limit:
40
+ i_to_keep.append(i)
41
+
42
+ if len(i_to_keep) > 0:
43
+ final_mask = np.zeros_like(a, dtype=np.uint8)
44
+ for i in i_to_keep:
45
+ componentMask = (label_ids == i).astype("uint8") * 255
46
+ final_mask = cv2.bitwise_or(final_mask, componentMask)
47
+ final_mask = cv2.dilate(final_mask, dilate_kernel, iterations=2)
48
+ a = cv2.bitwise_and(a, final_mask)
49
+
50
+ a = a.reshape((a.shape[0], a.shape[1], 1))
51
+ return a
52
+
53
+
54
+ def preprocess_input(im: np.ndarray) -> torch.Tensor:
55
+ if len(im.shape) < 3:
56
+ im = im[:, :, np.newaxis]
57
+ if im.shape[2] == 4:
58
+ im = im[:, :, :3]
59
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
60
+ im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), INPUT_SIZE, mode="bilinear").type(torch.uint8)
61
+ image = torch.divide(im_tensor, 255.0)
62
+ image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
63
+ return image
64
+
65
+
66
+ def postprocess_output(result: np.ndarray, orig_im_shape) -> np.ndarray:
67
+ result = torch.squeeze(
68
+ F.upsample(torch.from_numpy(result).unsqueeze(0), (orig_im_shape), mode="bilinear"), 0
69
+ )
70
+ ma = torch.max(result)
71
+ mi = torch.min(result)
72
+ result = (result - mi) / (ma - mi + 1e-8)
73
+ a = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
74
+ a = keep_large_components(a)
75
+ return a
76
+
77
+
78
+ # ---- Core processing ----
79
  def process(image: Image.Image) -> Image.Image:
80
  image_size = image.size
81
+ np_img = np.array(image.convert("RGB"))
82
 
83
+ # Preprocess
84
+ img_tensor = preprocess_input(np_img)
85
 
86
+ # Inference
87
+ inputs = {ort_sess.get_inputs()[0].name: img_tensor.numpy()}
88
+ result = ort_sess.run(None, inputs)[0][0] # (1,1,H,W)
89
 
90
+ # Postprocess to mask
91
+ alpha = postprocess_output(result, (np_img.shape[0], np_img.shape[1])) # (H,W,1)
92
+
93
+ # White background composite
94
+ mask = Image.fromarray(alpha.squeeze(-1)).convert("L")
95
+ binary_mask = mask.point(lambda p: 255 if p > 25 else 0)
96
  white_bg = Image.new("RGB", image_size, (255, 255, 255))
97
+ result = Image.composite(image.convert("RGB"), white_bg, binary_mask)
98
  return result
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ # ---- Gradio handler ----
102
+ def handler(image=None) -> Union[str, None]:
103
+ if image is not None:
104
+ processed = process(image)
105
+ filename = f"output_{uuid.uuid4().hex[:8]}.png"
106
+ processed.save(filename)
107
+ return filename
108
  return None
109
 
110
+
111
+ # ---- Gradio UI ----
112
  demo = gr.Interface(
113
  fn=handler,
114
+ inputs=gr.Image(label="Upload Image", type="pil"),
115
+ outputs=gr.File(label="Output File"),
116
+ title="Background Remover (Trendyol)",
117
+ description="Upload an image to remove the background with the Trendyol ONNX model. Background is replaced with white.",
 
 
 
 
118
  )
119
 
120
  if __name__ == "__main__":
121
+ demo.launch(show_error=True)