dominic1021 commited on
Commit
21c88fa
·
verified ·
1 Parent(s): 568661b

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/My_Love_2.jpg filter=lfs diff=lfs merge=lfs -text
37
+ examples/My_Love.jpg filter=lfs diff=lfs merge=lfs -text
38
+ examples/My_MiSheng.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: BiRefNet Demo
3
+ emoji: 🐠
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.38.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ models:
12
+ - ZhengPeng7/BiRefNet
13
+ - ZhengPeng7/BiRefNet-portrait
14
+ preload_from_hub:
15
+ - ZhengPeng7/BiRefNet
16
+ - ZhengPeng7/BiRefNet-portrait
17
+ ---
18
+
19
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+ import spaces
7
+
8
+ from glob import glob
9
+ from typing import Tuple
10
+
11
+ from PIL import Image
12
+ from gradio_imageslider import ImageSlider
13
+ from transformers import AutoModelForImageSegmentation
14
+ from torchvision import transforms
15
+
16
+ import requests
17
+ from io import BytesIO
18
+ import zipfile
19
+
20
+
21
+ torch.set_float32_matmul_precision('high')
22
+ torch.jit.script = lambda f: f
23
+
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ ### image_proc.py
27
+ def refine_foreground(image, mask, r=90):
28
+ if mask.size != image.size:
29
+ mask = mask.resize(image.size)
30
+ image = np.array(image) / 255.0
31
+ mask = np.array(mask) / 255.0
32
+ estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
33
+ image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
34
+ return image_masked
35
+
36
+
37
+ def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
38
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
39
+ alpha = alpha[:, :, None]
40
+ F, blur_B = FB_blur_fusion_foreground_estimator(
41
+ image, image, image, alpha, r)
42
+ return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
43
+
44
+
45
+ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
46
+ if isinstance(image, Image.Image):
47
+ image = np.array(image) / 255.0
48
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
49
+
50
+ blurred_FA = cv2.blur(F * alpha, (r, r))
51
+ blurred_F = blurred_FA / (blurred_alpha + 1e-5)
52
+
53
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
54
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
55
+ F = blurred_F + alpha * \
56
+ (image - alpha * blurred_F - (1 - alpha) * blurred_B)
57
+ F = np.clip(F, 0, 1)
58
+ return F, blurred_B
59
+
60
+
61
+ class ImagePreprocessor():
62
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
63
+ self.transform_image = transforms.Compose([
64
+ transforms.Resize(resolution),
65
+ transforms.ToTensor(),
66
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
67
+ ])
68
+
69
+ def proc(self, image: Image.Image) -> torch.Tensor:
70
+ image = self.transform_image(image)
71
+ return image
72
+
73
+
74
+ usage_to_weights_file = {
75
+ 'General': 'BiRefNet',
76
+ 'General-Lite': 'BiRefNet_lite',
77
+ 'General-Lite-2K': 'BiRefNet_lite-2K',
78
+ 'Matting': 'BiRefNet-matting',
79
+ 'Portrait': 'BiRefNet-portrait',
80
+ 'DIS': 'BiRefNet-DIS5K',
81
+ 'HRSOD': 'BiRefNet-HRSOD',
82
+ 'COD': 'BiRefNet-COD',
83
+ 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
84
+ 'General-legacy': 'BiRefNet-legacy'
85
+ }
86
+
87
+ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
88
+ birefnet.to(device)
89
+ birefnet.eval()
90
+
91
+
92
+ @spaces.GPU
93
+ def predict(images, resolution, weights_file):
94
+ assert (images is not None), 'AssertionError: images cannot be None.'
95
+
96
+ global birefnet
97
+ # Load BiRefNet with chosen weights
98
+ _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
99
+ print('Using weights: {}.'.format(_weights_file))
100
+ birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
101
+ birefnet.to(device)
102
+ birefnet.eval()
103
+
104
+ try:
105
+ resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
106
+ except:
107
+ resolution = (1024, 1024) if weights_file not in ['General-Lite-2K'] else (2560, 1440)
108
+ print('Invalid resolution input. Automatically changed to 1024x1024 or 2K.')
109
+
110
+ if isinstance(images, list):
111
+ # For tab_batch
112
+ save_paths = []
113
+ save_dir = 'preds-BiRefNet'
114
+ if not os.path.exists(save_dir):
115
+ os.makedirs(save_dir)
116
+ tab_is_batch = True
117
+ else:
118
+ images = [images]
119
+ tab_is_batch = False
120
+
121
+ for idx_image, image_src in enumerate(images):
122
+ if isinstance(image_src, str):
123
+ if os.path.isfile(image_src):
124
+ image_ori = Image.open(image_src)
125
+ else:
126
+ response = requests.get(image_src)
127
+ image_data = BytesIO(response.content)
128
+ image_ori = Image.open(image_data)
129
+ else:
130
+ image_ori = Image.fromarray(image_src)
131
+
132
+ image = image_ori.convert('RGB')
133
+ # Preprocess the image
134
+ image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
135
+ image_proc = image_preprocessor.proc(image)
136
+ image_proc = image_proc.unsqueeze(0)
137
+
138
+ # Prediction
139
+ with torch.no_grad():
140
+ preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
141
+ pred = preds[0].squeeze()
142
+
143
+ # Show Results
144
+ pred_pil = transforms.ToPILImage()(pred)
145
+ image_masked = refine_foreground(image, pred_pil)
146
+ image_masked.putalpha(pred_pil.resize(image.size))
147
+
148
+ torch.cuda.empty_cache()
149
+
150
+ if tab_is_batch:
151
+ save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
152
+ image_masked.save(save_file_path)
153
+ save_paths.append(save_file_path)
154
+
155
+ if tab_is_batch:
156
+ zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
157
+ with zipfile.ZipFile(zip_file_path, 'w') as zipf:
158
+ for file in save_paths:
159
+ zipf.write(file, os.path.basename(file))
160
+ return save_paths, zip_file_path
161
+ else:
162
+ return (image_masked, image_ori)
163
+
164
+
165
+ examples = [[_] for _ in glob('examples/*')][:]
166
+ # Add the option of resolution in a text box.
167
+ for idx_example, example in enumerate(examples):
168
+ examples[idx_example].append('1024x1024')
169
+ examples.append(examples[-1].copy())
170
+ examples[-1][1] = '512x512'
171
+
172
+ examples_url = [
173
+ ['https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg'],
174
+ ]
175
+ for idx_example_url, example_url in enumerate(examples_url):
176
+ examples_url[idx_example_url].append('1024x1024')
177
+
178
+ descriptions = ('Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n)'
179
+ ' The resolution used in our training was `1024x1024`, thus the suggested resolution to obtain good results!\n'
180
+ ' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
181
+ ' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.')
182
+
183
+ tab_image = gr.Interface(
184
+ fn=predict,
185
+ inputs=[
186
+ gr.Image(label='Upload an image'),
187
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
188
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
189
+ ],
190
+ outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
191
+ examples=examples,
192
+ api_name="image",
193
+ description=descriptions,
194
+ )
195
+
196
+ tab_text = gr.Interface(
197
+ fn=predict,
198
+ inputs=[
199
+ gr.Textbox(label="Paste an image URL"),
200
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
201
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
202
+ ],
203
+ outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
204
+ examples=examples_url,
205
+ api_name="text",
206
+ description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
207
+ )
208
+
209
+ tab_batch = gr.Interface(
210
+ fn=predict,
211
+ inputs=[
212
+ gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
213
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
214
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
215
+ ],
216
+ outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
217
+ api_name="batch",
218
+ description=descriptions+'\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
219
+ )
220
+
221
+ demo = gr.TabbedInterface(
222
+ [tab_image, tab_text, tab_batch],
223
+ ['image', 'text', 'batch'],
224
+ title="BiRefNet demo for subject extraction (general / matting / salient / camouflaged / portrait).",
225
+ )
226
+
227
+ if __name__ == "__main__":
228
+ demo.launch(debug=True)
app_local.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+ # import spaces
7
+
8
+ from glob import glob
9
+ from typing import Tuple
10
+
11
+ from PIL import Image
12
+ # from gradio_imageslider import ImageSlider
13
+ from transformers import AutoModelForImageSegmentation
14
+ from torchvision import transforms
15
+
16
+ import requests
17
+ from io import BytesIO
18
+ import zipfile
19
+
20
+
21
+ torch.set_float32_matmul_precision('high')
22
+ # torch.jit.script = lambda f: f
23
+
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ ### image_proc.py
27
+ def refine_foreground(image, mask, r=90):
28
+ if mask.size != image.size:
29
+ mask = mask.resize(image.size)
30
+ image = np.array(image) / 255.0
31
+ mask = np.array(mask) / 255.0
32
+ estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
33
+ image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
34
+ return image_masked
35
+
36
+
37
+ def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
38
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
39
+ alpha = alpha[:, :, None]
40
+ F, blur_B = FB_blur_fusion_foreground_estimator(
41
+ image, image, image, alpha, r)
42
+ return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
43
+
44
+
45
+ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
46
+ if isinstance(image, Image.Image):
47
+ image = np.array(image) / 255.0
48
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
49
+
50
+ blurred_FA = cv2.blur(F * alpha, (r, r))
51
+ blurred_F = blurred_FA / (blurred_alpha + 1e-5)
52
+
53
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
54
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
55
+ F = blurred_F + alpha * \
56
+ (image - alpha * blurred_F - (1 - alpha) * blurred_B)
57
+ F = np.clip(F, 0, 1)
58
+ return F, blurred_B
59
+
60
+
61
+ class ImagePreprocessor():
62
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
63
+ self.transform_image = transforms.Compose([
64
+ transforms.Resize(resolution),
65
+ transforms.ToTensor(),
66
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
67
+ ])
68
+
69
+ def proc(self, image: Image.Image) -> torch.Tensor:
70
+ image = self.transform_image(image)
71
+ return image
72
+
73
+
74
+ usage_to_weights_file = {
75
+ 'General': 'BiRefNet',
76
+ 'General-Lite': 'BiRefNet_lite',
77
+ 'General-Lite-2K': 'BiRefNet_lite-2K',
78
+ 'Matting': 'BiRefNet-matting',
79
+ 'Portrait': 'BiRefNet-portrait',
80
+ 'DIS': 'BiRefNet-DIS5K',
81
+ 'HRSOD': 'BiRefNet-HRSOD',
82
+ 'COD': 'BiRefNet-COD',
83
+ 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
84
+ 'General-legacy': 'BiRefNet-legacy'
85
+ }
86
+
87
+ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
88
+ birefnet.to(device)
89
+ birefnet.eval()
90
+
91
+
92
+ # @spaces.GPU
93
+ def predict(images, resolution, weights_file):
94
+ assert (images is not None), 'AssertionError: images cannot be None.'
95
+
96
+ global birefnet
97
+ # Load BiRefNet with chosen weights
98
+ _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
99
+ print('Using weights: {}.'.format(_weights_file))
100
+ birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
101
+ birefnet.to(device)
102
+ birefnet.eval()
103
+
104
+ try:
105
+ resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
106
+ except:
107
+ resolution = (1024, 1024) if weights_file not in ['General-Lite-2K'] else (2560, 1440)
108
+ print('Invalid resolution input. Automatically changed to 1024x1024 or 2K.')
109
+
110
+ if isinstance(images, list):
111
+ # For tab_batch
112
+ save_paths = []
113
+ save_dir = 'preds-BiRefNet'
114
+ if not os.path.exists(save_dir):
115
+ os.makedirs(save_dir)
116
+ tab_is_batch = True
117
+ else:
118
+ images = [images]
119
+ tab_is_batch = False
120
+
121
+ for idx_image, image_src in enumerate(images):
122
+ if isinstance(image_src, str):
123
+ if os.path.isfile(image_src):
124
+ image_ori = Image.open(image_src)
125
+ else:
126
+ response = requests.get(image_src)
127
+ image_data = BytesIO(response.content)
128
+ image_ori = Image.open(image_data)
129
+ else:
130
+ image_ori = Image.fromarray(image_src)
131
+
132
+ image = image_ori.convert('RGB')
133
+ # Preprocess the image
134
+ image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
135
+ image_proc = image_preprocessor.proc(image)
136
+ image_proc = image_proc.unsqueeze(0)
137
+
138
+ # Prediction
139
+ with torch.no_grad():
140
+ preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
141
+ pred = preds[0].squeeze()
142
+
143
+ # Show Results
144
+ pred_pil = transforms.ToPILImage()(pred)
145
+ image_masked = refine_foreground(image, pred_pil)
146
+ image_masked.putalpha(pred_pil.resize(image.size))
147
+
148
+ torch.cuda.empty_cache()
149
+
150
+ if tab_is_batch:
151
+ save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
152
+ image_masked.save(save_file_path)
153
+ save_paths.append(save_file_path)
154
+
155
+ if tab_is_batch:
156
+ zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
157
+ with zipfile.ZipFile(zip_file_path, 'w') as zipf:
158
+ for file in save_paths:
159
+ zipf.write(file, os.path.basename(file))
160
+ return save_paths, zip_file_path
161
+ else:
162
+ return (image_masked, image_ori)[0]
163
+
164
+
165
+ examples = [[_] for _ in glob('examples/*')][:]
166
+ # Add the option of resolution in a text box.
167
+ for idx_example, example in enumerate(examples):
168
+ examples[idx_example].append('1024x1024')
169
+ examples.append(examples[-1].copy())
170
+ examples[-1][1] = '512x512'
171
+
172
+ examples_url = [
173
+ ['https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg'],
174
+ ]
175
+ for idx_example_url, example_url in enumerate(examples_url):
176
+ examples_url[idx_example_url].append('1024x1024')
177
+
178
+ descriptions = ('Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n)'
179
+ ' The resolution used in our training was `1024x1024`, thus the suggested resolution to obtain good results!\n'
180
+ ' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
181
+ ' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.')
182
+
183
+ tab_image = gr.Interface(
184
+ fn=predict,
185
+ inputs=[
186
+ gr.Image(label='Upload an image'),
187
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
188
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
189
+ ],
190
+ outputs=gr.Image(label="BiRefNet's prediction", type="pil", format='png'),
191
+ examples=examples,
192
+ api_name="image",
193
+ description=descriptions,
194
+ )
195
+
196
+ tab_text = gr.Interface(
197
+ fn=predict,
198
+ inputs=[
199
+ gr.Textbox(label="Paste an image URL"),
200
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
201
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
202
+ ],
203
+ outputs=gr.Image(label="BiRefNet's prediction", type="pil", format='png'),
204
+ examples=examples_url,
205
+ api_name="text",
206
+ description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
207
+ )
208
+
209
+ tab_batch = gr.Interface(
210
+ fn=predict,
211
+ inputs=[
212
+ gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
213
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
214
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
215
+ ],
216
+ outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
217
+ api_name="batch",
218
+ description=descriptions+'\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
219
+ )
220
+
221
+ demo = gr.TabbedInterface(
222
+ [tab_image, tab_text, tab_batch],
223
+ ['image', 'text', 'batch'],
224
+ title="BiRefNet demo for subject extraction (general / matting / salient / camouflaged / portrait).",
225
+ )
226
+
227
+ if __name__ == "__main__":
228
+ demo.launch(debug=True)
examples/Helicopter.jpg ADDED
examples/Jewelry.jpg ADDED
examples/My_Love.jpg ADDED

Git LFS Details

  • SHA256: ffca8347b4e2bbc4e064f02b1aaf93cd53a8834e6fd210200172b913206c2aef
  • Pointer size: 132 Bytes
  • Size of remote file: 1.47 MB
examples/My_Love_1.jpg ADDED
examples/My_Love_2.jpg ADDED

Git LFS Details

  • SHA256: 55060d321436a58d17b98f1f10de546ef638bcf2b6774c8d714a3e2d1851cf55
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
examples/My_MiSheng.jpg ADDED

Git LFS Details

  • SHA256: 0eb9960dbeb7e9ace2e1c9a20cdad9405983d182bad4ce337527f13be45af230
  • Pointer size: 132 Bytes
  • Size of remote file: 2.05 MB
examples/Windmill.jpg ADDED
gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ flagged/
2
+
3
+ __pycache__/
4
+
5
+ .DS_Store
requirements (1).txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ opencv-python==4.9.0.80
4
+ tqdm==4.66.2
5
+ timm==0.9.16
6
+ prettytable==3.10.0
7
+ scipy==1.12.0
8
+ scikit-image==0.22.0
9
+ kornia==0.7.1
10
+ gradio_imageslider==0.0.18
11
+ transformers==4.42.4
12
+ huggingface_hub==0.23.4