Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -25,30 +25,49 @@ from transformers import (
|
|
25 |
AutoConfig,
|
26 |
AutoModelForImageSegmentation,
|
27 |
)
|
|
|
|
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
30 |
config = AutoConfig.from_pretrained(
|
31 |
-
"zhengpeng7/BiRefNet",
|
32 |
trust_remote_code=True
|
33 |
)
|
34 |
|
35 |
-
# 2) config.get_text_config
|
36 |
def dummy_get_text_config(decoder=True):
|
37 |
return type("DummyTextConfig", (), {"tie_word_embeddings": False})()
|
38 |
|
39 |
config.get_text_config = dummy_get_text_config
|
40 |
|
41 |
-
# 3) ๋ชจ๋ธ ๊ตฌ์กฐ๋ง ๋ง๋ค๊ธฐ
|
42 |
birefnet = AutoModelForImageSegmentation.from_config(config, trust_remote_code=True)
|
43 |
birefnet.eval()
|
44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
45 |
birefnet.to(device)
|
46 |
birefnet.half()
|
47 |
|
48 |
-
|
49 |
-
#
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
missing, unexpected = birefnet.load_state_dict(state_dict, strict=False)
|
53 |
print("[Info] Missing keys:", missing)
|
54 |
print("[Info] Unexpected keys:", unexpected)
|
@@ -56,7 +75,7 @@ torch.cuda.empty_cache()
|
|
56 |
|
57 |
|
58 |
##########################################################
|
59 |
-
#
|
60 |
##########################################################
|
61 |
|
62 |
def refine_foreground(image, mask, r=90):
|
@@ -85,7 +104,6 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
|
85 |
F = np.clip(F, 0, 1)
|
86 |
return F, blurred_B
|
87 |
|
88 |
-
|
89 |
class ImagePreprocessor():
|
90 |
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
91 |
self.transform_image = transforms.Compose([
|
@@ -99,7 +117,7 @@ class ImagePreprocessor():
|
|
99 |
|
100 |
|
101 |
##########################################################
|
102 |
-
#
|
103 |
##########################################################
|
104 |
|
105 |
usage_to_weights_file = {
|
@@ -130,30 +148,24 @@ descriptions = (
|
|
130 |
"We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access."
|
131 |
)
|
132 |
|
133 |
-
|
134 |
##########################################################
|
135 |
-
#
|
136 |
##########################################################
|
137 |
|
138 |
@spaces.GPU
|
139 |
def predict(images, resolution, weights_file):
|
140 |
-
|
141 |
-
์ฌ๊ธฐ์๋, ๋จ์ผ birefnet ๋ชจ๋ธ๋ง ์ ์งํ๊ณ ์์ผ๋ฉฐ,
|
142 |
-
weight_file์ ๋ฐ๊พธ๋๋ผ๋ ์ค์ ๋ก๋ ์ด๋ฏธ ๋ก๋๋ 'birefnet' ๋ชจ๋ธ๋ง ์ฌ์ฉ.
|
143 |
-
(๋ง์ฝ ๋ค๋ฅธ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๊ณ ์ถ๋ค๋ฉด, ์๋์ฒ๋ผ ๋ก์ปฌ state_dict ๊ต์ฒด ๋ฐฉ์ ์ถ๊ฐ ๊ฐ๋ฅ.)
|
144 |
-
"""
|
145 |
assert images is not None, 'Images cannot be None.'
|
146 |
|
147 |
-
#
|
148 |
try:
|
149 |
-
w, h = resolution.strip().split('x')
|
150 |
-
w, h = int(
|
151 |
-
resolution_list = (w, h)
|
152 |
except:
|
153 |
-
|
154 |
-
|
155 |
|
156 |
-
#
|
157 |
if isinstance(images, list):
|
158 |
is_batch = True
|
159 |
outputs, save_paths = [], []
|
@@ -164,65 +176,57 @@ def predict(images, resolution, weights_file):
|
|
164 |
is_batch = False
|
165 |
|
166 |
for idx, image_src in enumerate(images):
|
167 |
-
#
|
168 |
if isinstance(image_src, str):
|
169 |
if os.path.isfile(image_src):
|
170 |
image_ori = Image.open(image_src)
|
171 |
else:
|
172 |
resp = requests.get(image_src)
|
173 |
image_ori = Image.open(BytesIO(resp.content))
|
174 |
-
# numpy
|
175 |
elif isinstance(image_src, np.ndarray):
|
176 |
image_ori = Image.fromarray(image_src)
|
177 |
else:
|
178 |
image_ori = image_src.convert('RGB')
|
179 |
|
180 |
-
|
181 |
-
preproc = ImagePreprocessor(
|
182 |
-
image_proc = preproc.proc(
|
183 |
|
184 |
-
#
|
185 |
with torch.inference_mode():
|
186 |
-
# ๊ฒฐ๊ณผ ๋งจ ๋ง์ง๋ง ๋ ์ด์ด preds
|
187 |
preds = birefnet(image_proc)[-1].sigmoid().cpu()
|
188 |
pred_mask = preds[0].squeeze()
|
189 |
|
190 |
# ํ์ฒ๋ฆฌ
|
191 |
pred_pil = transforms.ToPILImage()(pred_mask)
|
192 |
-
image_masked = refine_foreground(
|
193 |
-
image_masked.putalpha(pred_pil.resize(
|
194 |
|
195 |
if is_batch:
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
)
|
201 |
-
out_path = os.path.join(save_dir, f"{file_name}.png")
|
202 |
-
image_masked.save(out_path)
|
203 |
-
save_paths.append(out_path)
|
204 |
outputs.append(image_masked)
|
205 |
else:
|
206 |
outputs = [image_masked, image_ori]
|
207 |
|
208 |
torch.cuda.empty_cache()
|
209 |
|
210 |
-
# ๋ฐฐ์น๋ผ๋ฉด ๊ฐค๋ฌ๋ฆฌ + ZIP ๋ฐํ
|
211 |
if is_batch:
|
212 |
-
|
213 |
-
with zipfile.ZipFile(
|
214 |
for fpath in save_paths:
|
215 |
zipf.write(fpath, os.path.basename(fpath))
|
216 |
-
return
|
217 |
else:
|
218 |
return outputs
|
219 |
|
220 |
-
|
221 |
##########################################################
|
222 |
-
#
|
223 |
##########################################################
|
224 |
|
225 |
-
# ์ปค์คํ
CSS
|
226 |
css = """
|
227 |
body {
|
228 |
background: linear-gradient(135deg, #667eea, #764ba2);
|
@@ -280,14 +284,13 @@ button:hover, .btn:hover {
|
|
280 |
title_html = """
|
281 |
<h1 align="center" style="margin-bottom: 0.2em;">BiRefNet Demo (No Tie-Weights Crash)</h1>
|
282 |
<p align="center" style="font-size:1.1em; color:#555;">
|
283 |
-
Using <code>from_config()</code> + local <code>state_dict</code> to bypass tie_weights issues
|
284 |
</p>
|
285 |
"""
|
286 |
|
287 |
with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
288 |
gr.Markdown(title_html)
|
289 |
with gr.Tabs():
|
290 |
-
# ํญ 1: Image
|
291 |
with gr.Tab("Image"):
|
292 |
with gr.Row():
|
293 |
with gr.Column(scale=1):
|
@@ -297,13 +300,8 @@ with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
|
297 |
predict_btn = gr.Button("Predict")
|
298 |
with gr.Column(scale=2):
|
299 |
output_slider = ImageSlider(label="Result", type="pil")
|
300 |
-
gr.Examples(
|
301 |
-
examples=examples_image,
|
302 |
-
inputs=[image_input, resolution_input, weights_radio],
|
303 |
-
label="Examples"
|
304 |
-
)
|
305 |
|
306 |
-
# ํญ 2: Text(URL)
|
307 |
with gr.Tab("Text"):
|
308 |
with gr.Row():
|
309 |
with gr.Column(scale=1):
|
@@ -313,36 +311,23 @@ with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
|
313 |
predict_btn_text = gr.Button("Predict")
|
314 |
with gr.Column(scale=2):
|
315 |
output_slider_text = ImageSlider(label="Result", type="pil")
|
316 |
-
gr.Examples(
|
317 |
-
examples=examples_text,
|
318 |
-
inputs=[image_url, resolution_input_text, weights_radio_text],
|
319 |
-
label="Examples"
|
320 |
-
)
|
321 |
|
322 |
-
# ํญ 3: Batch
|
323 |
with gr.Tab("Batch"):
|
324 |
with gr.Row():
|
325 |
with gr.Column(scale=1):
|
326 |
-
file_input = gr.File(
|
327 |
-
label="Upload Multiple Images",
|
328 |
-
type="filepath",
|
329 |
-
file_count="multiple"
|
330 |
-
)
|
331 |
resolution_input_batch = gr.Textbox(lines=1, placeholder="e.g., 1024x1024", label="Resolution")
|
332 |
weights_radio_batch = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
333 |
predict_btn_batch = gr.Button("Predict")
|
334 |
with gr.Column(scale=2):
|
335 |
output_gallery = gr.Gallery(label="Results", scale=1)
|
336 |
zip_output = gr.File(label="Zip Download")
|
337 |
-
gr.Examples(
|
338 |
-
examples=examples_batch,
|
339 |
-
inputs=[file_input, resolution_input_batch, weights_radio_batch],
|
340 |
-
label="Examples"
|
341 |
-
)
|
342 |
|
343 |
gr.Markdown("<p align='center'>Model by <a href='https://huggingface.co/ZhengPeng7/BiRefNet'>ZhengPeng7/BiRefNet</a></p>")
|
344 |
|
345 |
-
#
|
346 |
predict_btn.click(
|
347 |
fn=predict,
|
348 |
inputs=[image_input, resolution_input, weights_radio],
|
|
|
25 |
AutoConfig,
|
26 |
AutoModelForImageSegmentation,
|
27 |
)
|
28 |
+
# Hugging Face Hub
|
29 |
+
from huggingface_hub import hf_hub_download
|
30 |
|
31 |
+
|
32 |
+
##########################################################
|
33 |
+
# 1. Config ๋ฐ from_config() ์ด๊ธฐํ
|
34 |
+
##########################################################
|
35 |
+
|
36 |
+
# 1) Config๋ง ๋จผ์ ๋ก๋
|
37 |
config = AutoConfig.from_pretrained(
|
38 |
+
"zhengpeng7/BiRefNet", # ์์
|
39 |
trust_remote_code=True
|
40 |
)
|
41 |
|
42 |
+
# 2) config.get_text_config์ ๋๋ฏธ ๋ฉ์๋ ๋ถ์ฌ (tie_word_embeddings=False)
|
43 |
def dummy_get_text_config(decoder=True):
|
44 |
return type("DummyTextConfig", (), {"tie_word_embeddings": False})()
|
45 |
|
46 |
config.get_text_config = dummy_get_text_config
|
47 |
|
48 |
+
# 3) ๋ชจ๋ธ ๊ตฌ์กฐ๋ง ๋ง๋ค๊ธฐ
|
49 |
birefnet = AutoModelForImageSegmentation.from_config(config, trust_remote_code=True)
|
50 |
birefnet.eval()
|
51 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
52 |
birefnet.to(device)
|
53 |
birefnet.half()
|
54 |
|
55 |
+
##########################################################
|
56 |
+
# 2. ๋ชจ๋ธ ๊ฐ์ค์น ๋ค์ด๋ก๋ & ๋ก๋
|
57 |
+
##########################################################
|
58 |
+
|
59 |
+
# huggingface_hub์์ safetensors ๋๋ bin ํ์ผ ๋ค์ด๋ก๋
|
60 |
+
# (repo_id, filename ๋ฑ์ ์ค์ ์ฌ์ฉ ํ๊ฒฝ์ ๋ง๊ฒ ๋ณ๊ฒฝ)
|
61 |
+
weights_path = hf_hub_download(
|
62 |
+
repo_id="zhengpeng7/BiRefNet", # ์์
|
63 |
+
filename="model.safetensors", # ๋๋ "pytorch_model.bin"
|
64 |
+
trust_remote_code=True
|
65 |
+
)
|
66 |
+
print("Downloaded weights to:", weights_path)
|
67 |
+
|
68 |
+
# state_dict ๋ก๋
|
69 |
+
print("Loading BiRefNet weights from HF Hub file:", weights_path)
|
70 |
+
state_dict = torch.load(weights_path, map_location="cpu")
|
71 |
missing, unexpected = birefnet.load_state_dict(state_dict, strict=False)
|
72 |
print("[Info] Missing keys:", missing)
|
73 |
print("[Info] Unexpected keys:", unexpected)
|
|
|
75 |
|
76 |
|
77 |
##########################################################
|
78 |
+
# 3. ์ด๋ฏธ์ง ํ์ฒ๋ฆฌ ํจ์๋ค
|
79 |
##########################################################
|
80 |
|
81 |
def refine_foreground(image, mask, r=90):
|
|
|
104 |
F = np.clip(F, 0, 1)
|
105 |
return F, blurred_B
|
106 |
|
|
|
107 |
class ImagePreprocessor():
|
108 |
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
109 |
self.transform_image = transforms.Compose([
|
|
|
117 |
|
118 |
|
119 |
##########################################################
|
120 |
+
# 4. ์์ ์ค์ ๋ฐ ๊ธฐํ
|
121 |
##########################################################
|
122 |
|
123 |
usage_to_weights_file = {
|
|
|
148 |
"We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access."
|
149 |
)
|
150 |
|
|
|
151 |
##########################################################
|
152 |
+
# 5. ์ถ๋ก ํจ์ (์ด๋ฏธ ๋ก๋๋ birefnet ๋ชจ๋ธ ์ฌ์ฉ)
|
153 |
##########################################################
|
154 |
|
155 |
@spaces.GPU
|
156 |
def predict(images, resolution, weights_file):
|
157 |
+
# weights_file์ ์ฌ๊ธฐ์๋ ๋ฌด์ํ๊ณ , ์ด๋ฏธ ๋ก๋๋ birefnet ์ฌ์ฉ
|
|
|
|
|
|
|
|
|
158 |
assert images is not None, 'Images cannot be None.'
|
159 |
|
160 |
+
# Parse resolution
|
161 |
try:
|
162 |
+
w, h = map(int, resolution.strip().split('x'))
|
163 |
+
w, h = int(w//32*32), int(h//32*32)
|
|
|
164 |
except:
|
165 |
+
w, h = 1024, 1024
|
166 |
+
resolution_tuple = (w, h)
|
167 |
|
168 |
+
# ๋ฆฌ์คํธ์ธ์ง ํ์ธ
|
169 |
if isinstance(images, list):
|
170 |
is_batch = True
|
171 |
outputs, save_paths = [], []
|
|
|
176 |
is_batch = False
|
177 |
|
178 |
for idx, image_src in enumerate(images):
|
179 |
+
# ํ์ผ ๊ฒฝ๋ก ํน์ URL
|
180 |
if isinstance(image_src, str):
|
181 |
if os.path.isfile(image_src):
|
182 |
image_ori = Image.open(image_src)
|
183 |
else:
|
184 |
resp = requests.get(image_src)
|
185 |
image_ori = Image.open(BytesIO(resp.content))
|
186 |
+
# numpy array โ PIL
|
187 |
elif isinstance(image_src, np.ndarray):
|
188 |
image_ori = Image.fromarray(image_src)
|
189 |
else:
|
190 |
image_ori = image_src.convert('RGB')
|
191 |
|
192 |
+
# ์ ์ฒ๋ฆฌ
|
193 |
+
preproc = ImagePreprocessor(resolution_tuple)
|
194 |
+
image_proc = preproc.proc(image_ori.convert('RGB')).unsqueeze(0).to(device).half()
|
195 |
|
196 |
+
# ์ถ๋ก
|
197 |
with torch.inference_mode():
|
|
|
198 |
preds = birefnet(image_proc)[-1].sigmoid().cpu()
|
199 |
pred_mask = preds[0].squeeze()
|
200 |
|
201 |
# ํ์ฒ๋ฆฌ
|
202 |
pred_pil = transforms.ToPILImage()(pred_mask)
|
203 |
+
image_masked = refine_foreground(image_ori, pred_pil)
|
204 |
+
image_masked.putalpha(pred_pil.resize(image_ori.size))
|
205 |
|
206 |
if is_batch:
|
207 |
+
fbase = (os.path.splitext(os.path.basename(image_src))[0] if isinstance(image_src, str) else f"img_{idx}")
|
208 |
+
outpath = os.path.join(save_dir, f"{fbase}.png")
|
209 |
+
image_masked.save(outpath)
|
210 |
+
save_paths.append(outpath)
|
|
|
|
|
|
|
|
|
211 |
outputs.append(image_masked)
|
212 |
else:
|
213 |
outputs = [image_masked, image_ori]
|
214 |
|
215 |
torch.cuda.empty_cache()
|
216 |
|
|
|
217 |
if is_batch:
|
218 |
+
zippath = os.path.join(save_dir, f"{save_dir}.zip")
|
219 |
+
with zipfile.ZipFile(zippath, 'w') as zipf:
|
220 |
for fpath in save_paths:
|
221 |
zipf.write(fpath, os.path.basename(fpath))
|
222 |
+
return outputs, zippath
|
223 |
else:
|
224 |
return outputs
|
225 |
|
|
|
226 |
##########################################################
|
227 |
+
# 6. Gradio UI
|
228 |
##########################################################
|
229 |
|
|
|
230 |
css = """
|
231 |
body {
|
232 |
background: linear-gradient(135deg, #667eea, #764ba2);
|
|
|
284 |
title_html = """
|
285 |
<h1 align="center" style="margin-bottom: 0.2em;">BiRefNet Demo (No Tie-Weights Crash)</h1>
|
286 |
<p align="center" style="font-size:1.1em; color:#555;">
|
287 |
+
Using <code>from_config()</code> + local <code>state_dict</code> or <code>hf_hub_download</code> to bypass tie_weights issues
|
288 |
</p>
|
289 |
"""
|
290 |
|
291 |
with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
292 |
gr.Markdown(title_html)
|
293 |
with gr.Tabs():
|
|
|
294 |
with gr.Tab("Image"):
|
295 |
with gr.Row():
|
296 |
with gr.Column(scale=1):
|
|
|
300 |
predict_btn = gr.Button("Predict")
|
301 |
with gr.Column(scale=2):
|
302 |
output_slider = ImageSlider(label="Result", type="pil")
|
303 |
+
gr.Examples(examples=examples_image, inputs=[image_input, resolution_input, weights_radio], label="Examples")
|
|
|
|
|
|
|
|
|
304 |
|
|
|
305 |
with gr.Tab("Text"):
|
306 |
with gr.Row():
|
307 |
with gr.Column(scale=1):
|
|
|
311 |
predict_btn_text = gr.Button("Predict")
|
312 |
with gr.Column(scale=2):
|
313 |
output_slider_text = ImageSlider(label="Result", type="pil")
|
314 |
+
gr.Examples(examples=examples_text, inputs=[image_url, resolution_input_text, weights_radio_text], label="Examples")
|
|
|
|
|
|
|
|
|
315 |
|
|
|
316 |
with gr.Tab("Batch"):
|
317 |
with gr.Row():
|
318 |
with gr.Column(scale=1):
|
319 |
+
file_input = gr.File(label="Upload Multiple Images", type="filepath", file_count="multiple")
|
|
|
|
|
|
|
|
|
320 |
resolution_input_batch = gr.Textbox(lines=1, placeholder="e.g., 1024x1024", label="Resolution")
|
321 |
weights_radio_batch = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
322 |
predict_btn_batch = gr.Button("Predict")
|
323 |
with gr.Column(scale=2):
|
324 |
output_gallery = gr.Gallery(label="Results", scale=1)
|
325 |
zip_output = gr.File(label="Zip Download")
|
326 |
+
gr.Examples(examples=examples_batch, inputs=[file_input, resolution_input_batch, weights_radio_batch], label="Examples")
|
|
|
|
|
|
|
|
|
327 |
|
328 |
gr.Markdown("<p align='center'>Model by <a href='https://huggingface.co/ZhengPeng7/BiRefNet'>ZhengPeng7/BiRefNet</a></p>")
|
329 |
|
330 |
+
# ์ด๋ฒคํธ ์ฐ๊ฒฐ
|
331 |
predict_btn.click(
|
332 |
fn=predict,
|
333 |
inputs=[image_input, resolution_input, weights_radio],
|