ZhengPeng7 commited on
Commit
741bf59
1 Parent(s): 9e40d13

Add tab of batch inference with saving function.

Browse files
Files changed (1) hide show
  1. app.py +100 -41
app.py CHANGED
@@ -57,15 +57,37 @@ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7',
57
  birefnet.to(device)
58
  birefnet.eval()
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  @spaces.GPU
62
- def predict(image, resolution, weights_file):
63
- assert (image is not None), 'AssertionError: image cannot be None.'
64
-
65
- if isinstance(image, str):
66
- response = requests.get(image)
67
- image_data = BytesIO(response.content)
68
- image = np.array(Image.open(image_data))
69
  global birefnet
70
  # Load BiRefNet with chosen weights
71
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
@@ -74,33 +96,63 @@ def predict(image, resolution, weights_file):
74
  birefnet.to(device)
75
  birefnet.eval()
76
 
77
- resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
78
- resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- image_shape = image.shape[:2]
81
- image_pil = array_to_pil_image(image, tuple(resolution))
82
-
83
- # Preprocess the image
84
- image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
85
- image_proc = image_preprocessor.proc(image_pil)
86
- image_proc = image_proc.unsqueeze(0)
87
-
88
- # Perform the prediction
89
- with torch.no_grad():
90
- scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
91
-
92
- if device == 'cuda':
93
- scaled_pred_tensor = scaled_pred_tensor.cpu()
94
 
95
- # Resize the prediction to match the original image shape
96
- pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
97
-
98
- # Apply the prediction mask to the original image
99
- image_pil = image_pil.resize(pred.shape[::-1])
100
- pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
101
- image_pred = (pred * np.array(image_pil)).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- torch.cuda.empty_cache()
 
 
 
 
104
 
105
  return image, image_pred
106
 
@@ -118,6 +170,11 @@ examples_url = [
118
  for idx_example_url, example_url in enumerate(examples_url):
119
  examples_url[idx_example_url].append('1024x1024')
120
 
 
 
 
 
 
121
  tab_image = gr.Interface(
122
  fn=predict,
123
  inputs=[
@@ -128,10 +185,7 @@ tab_image = gr.Interface(
128
  outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
129
  examples=examples,
130
  api_name="image",
131
- description=('Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n)'
132
- ' The resolution used in our training was `1024x1024`, thus the suggested resolution to obtain good results!\n'
133
- ' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
134
- ' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.'),
135
  )
136
 
137
  tab_text = gr.Interface(
@@ -144,15 +198,20 @@ tab_text = gr.Interface(
144
  outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
145
  examples=examples_url,
146
  api_name="text",
147
- description=('Upload a URL, our model will extract a highly accurate segmentation of the subject in it.\n)'
148
- ' The resolution used in our training was `1024x1024`, thus the suggested resolution to obtain good results!\n'
149
- ' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
150
- ' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.'),
 
 
 
 
 
151
  )
152
 
153
  demo = gr.TabbedInterface(
154
- [tab_image, tab_text],
155
- ["image", "text"],
156
  title="BiRefNet demo for subject extraction (general / salient / camouflaged / portrait).",
157
  )
158
 
 
57
  birefnet.to(device)
58
  birefnet.eval()
59
 
60
+ # for idx, image_path in enumerate(images):
61
+ # im = load_img(image_path, output_type="pil")
62
+ # if im is None:
63
+ # continue
64
+
65
+ # im = im.convert("RGB")
66
+ # image_size = im.size
67
+ # input_images = transform_image(im).unsqueeze(0).to("cpu")
68
+
69
+ # with torch.no_grad():
70
+ # preds = birefnet(input_images)[-1].sigmoid().cpu()
71
+ # pred = preds[0].squeeze()
72
+ # pred_pil = transforms.ToPILImage()(pred)
73
+ # mask = pred_pil.resize(image_size)
74
+
75
+ # im.putalpha(mask)
76
+ # output_file_path = os.path.join(save_dir, f"output_image_batch_{idx + 1}.png")
77
+ # im.save(output_file_path)
78
+ # output_paths.append(output_file_path)
79
+
80
+ # zip_file_path = os.path.join(save_dir, "processed_images.zip")
81
+ # with zipfile.ZipFile(zip_file_path, 'w') as zipf:
82
+ # for file in output_paths:
83
+ # zipf.write(file, os.path.basename(file))
84
+
85
+ # return output_paths, zip_file_path
86
 
87
  @spaces.GPU
88
+ def predict(images, resolution, weights_file):
89
+ assert (images is not None), 'AssertionError: images cannot be None.'
90
+
 
 
 
 
91
  global birefnet
92
  # Load BiRefNet with chosen weights
93
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
 
96
  birefnet.to(device)
97
  birefnet.eval()
98
 
99
+ try:
100
+ resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
101
+ except:
102
+ resolution = [1024, 1024]
103
+ print('Invalid resolution input. Automatically changed to 1024x1024.')
104
+
105
+ if isinstance(images, list):
106
+ save_dir = 'preds-BiRefNet'
107
+ if not os.path.exists(save_dir):
108
+ os.makedirs(save_dir)
109
+ else:
110
+ # For tab_batch
111
+ save_paths = []
112
+ images = [images]
113
+
114
+ for idx_image, image_src in enumerate(images):
115
+ if isinstance(image_src, str):
116
+ response = requests.get(image_src)
117
+ image_data = BytesIO(response.content)
118
+ image = np.array(Image.open(image_data))
119
+ else:
120
+ image = image_src
121
 
122
+ image_shape = image.shape[:2]
123
+ image_pil = array_to_pil_image(image, tuple(resolution))
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ # Preprocess the image
126
+ image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
127
+ image_proc = image_preprocessor.proc(image_pil)
128
+ image_proc = image_proc.unsqueeze(0)
129
+
130
+ # Perform the prediction
131
+ with torch.no_grad():
132
+ scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
133
+
134
+ if device == 'cuda':
135
+ scaled_pred_tensor = scaled_pred_tensor.cpu()
136
+
137
+ # Resize the prediction to match the original image shape
138
+ pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
139
+
140
+ # Apply the prediction mask to the original image
141
+ image_pil = image_pil.resize(pred.shape[::-1])
142
+ pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
143
+ image_pred = (pred * np.array(image_pil)).astype(np.uint8)
144
+
145
+ torch.cuda.empty_cache()
146
+
147
+ save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
148
+ cv2.imwrite(save_file_path)
149
+ save_paths.append(save_file_path)
150
 
151
+ if len(images) > 1:
152
+ zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
153
+ with zipfile.ZipFile(zip_file_path, 'w') as zipf:
154
+ for file in save_paths:
155
+ zipf.write(file, os.path.basename(file))
156
 
157
  return image, image_pred
158
 
 
170
  for idx_example_url, example_url in enumerate(examples_url):
171
  examples_url[idx_example_url].append('1024x1024')
172
 
173
+ descriptions = ('Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n)'
174
+ ' The resolution used in our training was `1024x1024`, thus the suggested resolution to obtain good results!\n'
175
+ ' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
176
+ ' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.')
177
+
178
  tab_image = gr.Interface(
179
  fn=predict,
180
  inputs=[
 
185
  outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
186
  examples=examples,
187
  api_name="image",
188
+ description=descriptions,
 
 
 
189
  )
190
 
191
  tab_text = gr.Interface(
 
198
  outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
199
  examples=examples_url,
200
  api_name="text",
201
+ description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
202
+ )
203
+
204
+ tab_batch = gr.Interface(
205
+ fn=predict,
206
+ inputs=gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
207
+ outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
208
+ api_name="batch",
209
+ description=descriptions+'\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
210
  )
211
 
212
  demo = gr.TabbedInterface(
213
+ [tab_image, tab_text, tab_batch],
214
+ ['image', 'text', 'batch'],
215
  title="BiRefNet demo for subject extraction (general / salient / camouflaged / portrait).",
216
  )
217