NegiTurkey commited on
Commit
abf8cc6
·
verified ·
1 Parent(s): 174d21b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -50
app.py CHANGED
@@ -6,15 +6,23 @@ import torch
6
  from torchvision import transforms
7
  import os
8
  import zipfile
9
- import numpy as np
10
  from PIL import Image
11
 
 
 
 
 
12
  torch.set_float32_matmul_precision(["high", "highest"][0])
13
 
14
- birefnet = AutoModelForImageSegmentation.from_pretrained(
15
- "ZhengPeng7/BiRefNet", trust_remote_code=True
16
- )
17
- birefnet.to("cpu")
 
 
 
 
 
18
  transform_image = transforms.Compose(
19
  [
20
  transforms.Resize((1024, 1024)),
@@ -23,30 +31,14 @@ transform_image = transforms.Compose(
23
  ]
24
  )
25
 
26
- def fn(image):
27
- im = load_img(image, output_type="pil")
28
- im = im.convert("RGB")
29
- image_size = im.size
30
- origin = im.copy()
31
- input_images = transform_image(im).unsqueeze(0).to("cpu")
32
 
33
- with torch.no_grad():
34
- preds = birefnet(input_images)[-1].sigmoid().cpu()
35
- pred = preds[0].squeeze()
36
- pred_pil = transforms.ToPILImage()(pred)
37
- mask = pred_pil.resize(image_size)
38
-
39
- im.putalpha(mask)
40
- output_file_path = os.path.join("output_images", "output_image_single.png")
41
- im.save(output_file_path)
42
-
43
- output_path = os.path.join("output_images", "output_image_processed.png")
44
- im.save(output_path, "PNG")
45
-
46
- return [im, mask], output_path
47
 
48
- def fn_url(url):
49
- im = load_img(url, output_type="pil")
50
  im = im.convert("RGB")
51
  image_size = im.size
52
  origin = im.copy()
@@ -58,19 +50,58 @@ def fn_url(url):
58
  pred_pil = transforms.ToPILImage()(pred)
59
  mask = pred_pil.resize(image_size)
60
 
61
- im.putalpha(mask)
62
- output_file_path = os.path.join("output_images", "output_image_url.png")
63
- im.save(output_file_path)
64
-
65
- output_path = os.path.join("output_images", "output_image_url_processed.png")
66
- im.save(output_path, "PNG")
 
 
 
 
 
 
 
67
 
68
- return [im, mask], output_path
 
 
 
69
 
70
- def batch_fn(images):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  output_paths = []
 
 
 
72
  for idx, image_path in enumerate(images):
73
  im = load_img(image_path, output_type="pil")
 
 
 
74
  im = im.convert("RGB")
75
  image_size = im.size
76
  input_images = transform_image(im).unsqueeze(0).to("cpu")
@@ -81,44 +112,52 @@ def batch_fn(images):
81
  pred_pil = transforms.ToPILImage()(pred)
82
  mask = pred_pil.resize(image_size)
83
 
84
- im.putalpha(mask)
85
-
86
- output_file_path = os.path.join("output_images", f"output_image_batch_{idx + 1}.png")
87
  im.save(output_file_path)
88
  output_paths.append(output_file_path)
89
 
90
- zip_file_path = os.path.join("output_images", "processed_images.zip")
91
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
92
  for file in output_paths:
93
  zipf.write(file, os.path.basename(file))
94
 
95
- return zip_file_path
96
 
 
 
97
  batch_image = gr.File(label="Upload multiple images", type="filepath", file_count="multiple")
98
 
99
  slider1 = ImageSlider(label="Processed Image", type="pil")
100
  slider2 = ImageSlider(label="Processed Image from URL", type="pil")
101
- image = gr.Image(label="Upload an image")
102
- text = gr.Textbox(label="Paste an image URL")
103
-
104
- chameleon = load_img("chameleon.jpg", output_type="pil")
105
- url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
106
 
107
  tab1 = gr.Interface(
108
- fn, inputs=image, outputs=[slider1, gr.File(label="PNG Output")], examples=[chameleon], api_name="image"
 
 
 
 
109
  )
110
 
111
- tab2 = gr.Interface(fn_url, inputs=text, outputs=[slider2, gr.File(label="PNG Output")], examples=[url], api_name="text")
 
 
 
 
 
 
112
 
113
  tab3 = gr.Interface(
114
- batch_fn,
115
  inputs=batch_image,
116
- outputs=gr.File(label="Download Processed Files"),
117
  api_name="batch"
118
  )
119
 
120
  demo = gr.TabbedInterface(
121
- [tab1, tab2, tab3], ["image", "text", "batch"], title="Multi Birefnet for Background Removal"
 
 
122
  )
123
 
124
  if __name__ == "__main__":
 
6
  from torchvision import transforms
7
  import os
8
  import zipfile
 
9
  from PIL import Image
10
 
11
+ output_folder = 'output_images'
12
+ if not os.path.exists(output_folder):
13
+ os.makedirs(output_folder)
14
+
15
  torch.set_float32_matmul_precision(["high", "highest"][0])
16
 
17
+ try:
18
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
19
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
20
+ )
21
+ birefnet.to("cpu")
22
+ except Exception as e:
23
+ print(f"Error loading model: {e}")
24
+ raise
25
+
26
  transform_image = transforms.Compose(
27
  [
28
  transforms.Resize((1024, 1024)),
 
31
  ]
32
  )
33
 
34
+ def process_single_image(image, output_type="mask"):
35
+ if image is None:
36
+ return [None, None], None
 
 
 
37
 
38
+ im = load_img(image, output_type="pil")
39
+ if im is None:
40
+ return [None, None], None
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
42
  im = im.convert("RGB")
43
  image_size = im.size
44
  origin = im.copy()
 
50
  pred_pil = transforms.ToPILImage()(pred)
51
  mask = pred_pil.resize(image_size)
52
 
53
+ processed_im = im.copy()
54
+ processed_im.putalpha(mask)
55
+ output_file_path = os.path.join(output_folder, "output_image_i2i.png")
56
+ processed_im.save(output_file_path)
57
+
58
+ if output_type == "origin":
59
+ return [processed_im, origin], output_file_path
60
+ else:
61
+ return [processed_im, mask], output_file_path
62
+
63
+ def process_image_from_url(url, output_type="mask"):
64
+ if url is None or url.strip() == "":
65
+ return [None, None], None
66
 
67
+ try:
68
+ im = load_img(url, output_type="pil")
69
+ if im is None:
70
+ return [None, None], None
71
 
72
+ im = im.convert("RGB")
73
+ image_size = im.size
74
+ origin = im.copy()
75
+ input_images = transform_image(im).unsqueeze(0).to("cpu")
76
+
77
+ with torch.no_grad():
78
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
79
+ pred = preds[0].squeeze()
80
+ pred_pil = transforms.ToPILImage()(pred)
81
+ mask = pred_pil.resize(image_size)
82
+
83
+ processed_im = im.copy()
84
+ processed_im.putalpha(mask)
85
+ output_file_path = os.path.join(output_folder, "output_image_url.png")
86
+ processed_im.save(output_file_path)
87
+
88
+ if output_type == "origin":
89
+ return [processed_im, origin], output_file_path
90
+ else:
91
+ return [processed_im, mask], output_file_path
92
+ except Exception as e:
93
+ return [None, None], str(e)
94
+
95
+ def process_batch_images(images):
96
  output_paths = []
97
+ if not images:
98
+ return [], None
99
+
100
  for idx, image_path in enumerate(images):
101
  im = load_img(image_path, output_type="pil")
102
+ if im is None:
103
+ continue
104
+
105
  im = im.convert("RGB")
106
  image_size = im.size
107
  input_images = transform_image(im).unsqueeze(0).to("cpu")
 
112
  pred_pil = transforms.ToPILImage()(pred)
113
  mask = pred_pil.resize(image_size)
114
 
115
+ im.putalpha(mask)
116
+ output_file_path = os.path.join(output_folder, f"output_image_batch_{idx + 1}.png")
 
117
  im.save(output_file_path)
118
  output_paths.append(output_file_path)
119
 
120
+ zip_file_path = os.path.join(output_folder, "processed_images.zip")
121
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
122
  for file in output_paths:
123
  zipf.write(file, os.path.basename(file))
124
 
125
+ return output_paths, zip_file_path
126
 
127
+ image = gr.Image(label="Upload an image")
128
+ text = gr.Textbox(label="Paste an image URL")
129
  batch_image = gr.File(label="Upload multiple images", type="filepath", file_count="multiple")
130
 
131
  slider1 = ImageSlider(label="Processed Image", type="pil")
132
  slider2 = ImageSlider(label="Processed Image from URL", type="pil")
 
 
 
 
 
133
 
134
  tab1 = gr.Interface(
135
+ fn=process_single_image,
136
+ inputs=[image, gr.Radio(choices=["mask", "origin"], value="mask", label="Select Output Type")],
137
+ outputs=[slider1, gr.File(label="PNG Output")],
138
+ examples=[["chameleon.jpg"]],
139
+ api_name="image"
140
  )
141
 
142
+ tab2 = gr.Interface(
143
+ fn=process_image_from_url,
144
+ inputs=[text, gr.Radio(choices=["mask", "origin"], value="mask", label="Select Output Type")],
145
+ outputs=[slider2, gr.File(label="PNG Output")],
146
+ examples=[["https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"]],
147
+ api_name="text"
148
+ )
149
 
150
  tab3 = gr.Interface(
151
+ fn=process_batch_images,
152
  inputs=batch_image,
153
+ outputs=[gr.Gallery(label="Processed Images"), gr.File(label="Download Processed Files")],
154
  api_name="batch"
155
  )
156
 
157
  demo = gr.TabbedInterface(
158
+ [tab1, tab2, tab3],
159
+ ["image", "text", "batch"],
160
+ title="Multi Birefnet for Background Removal"
161
  )
162
 
163
  if __name__ == "__main__":