ysharma HF staff commited on
Commit
7fe9e0c
·
verified ·
1 Parent(s): 73527b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -40
app.py CHANGED
@@ -84,28 +84,28 @@ usage_to_weights_file = {
84
  'General-legacy': 'BiRefNet-legacy'
85
  }
86
 
87
- birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
88
  birefnet.to(device)
89
  birefnet.eval()
90
 
91
 
92
  @spaces.GPU
93
- def predict(images, resolution, weights_file):
94
  assert (images is not None), 'AssertionError: images cannot be None.'
95
 
96
  global birefnet
97
  # Load BiRefNet with chosen weights
98
- _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
99
  print('Using weights: {}.'.format(_weights_file))
100
  birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
101
  birefnet.to(device)
102
  birefnet.eval()
103
 
104
- try:
105
- resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
106
- except:
107
- resolution = (1024, 1024) if weights_file not in ['General-Lite-2K'] else (2560, 1440)
108
- print('Invalid resolution input. Automatically changed to 1024x1024 or 2K.')
109
 
110
  if isinstance(images, list):
111
  # For tab_batch
@@ -131,7 +131,7 @@ def predict(images, resolution, weights_file):
131
 
132
  image = image_ori.convert('RGB')
133
  # Preprocess the image
134
- image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
135
  image_proc = image_preprocessor.proc(image)
136
  image_proc = image_proc.unsqueeze(0)
137
 
@@ -184,44 +184,19 @@ tab_image = gr.Interface(
184
  fn=predict,
185
  inputs=[
186
  gr.Image(label='Upload an image'),
187
- gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
188
- gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
189
  ],
190
  outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
191
- examples=examples,
192
  api_name="image",
193
  description=descriptions,
194
  )
195
 
196
- tab_text = gr.Interface(
197
- fn=predict,
198
- inputs=[
199
- gr.Textbox(label="Paste an image URL"),
200
- gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
201
- gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
202
- ],
203
- outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
204
- examples=examples_url,
205
- api_name="text",
206
- description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
207
- )
208
-
209
- tab_batch = gr.Interface(
210
- fn=predict,
211
- inputs=[
212
- gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
213
- gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
214
- gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
215
- ],
216
- outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
217
- api_name="batch",
218
- description=descriptions+'\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
219
- )
220
-
221
  demo = gr.TabbedInterface(
222
- [tab_image, tab_text, tab_batch],
223
- ['image', 'text', 'batch'],
224
- title="BiRefNet demo for subject extraction (general / matting / salient / camouflaged / portrait).",
225
  )
226
 
227
  if __name__ == "__main__":
 
84
  'General-legacy': 'BiRefNet-legacy'
85
  }
86
 
87
+ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join('zhengpeng7', 'BiRefNet_lite'), trust_remote_code=True)
88
  birefnet.to(device)
89
  birefnet.eval()
90
 
91
 
92
  @spaces.GPU
93
+ def predict(images):
94
  assert (images is not None), 'AssertionError: images cannot be None.'
95
 
96
  global birefnet
97
  # Load BiRefNet with chosen weights
98
+ _weights_file = '/'.join('zhengpeng7', 'BiRefNet_lite')
99
  print('Using weights: {}.'.format(_weights_file))
100
  birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
101
  birefnet.to(device)
102
  birefnet.eval()
103
 
104
+ #try:
105
+ # resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
106
+ #except:
107
+ # resolution = (1024, 1024) if weights_file not in ['General-Lite-2K'] else (2560, 1440)
108
+ # print('Invalid resolution input. Automatically changed to 1024x1024 or 2K.')
109
 
110
  if isinstance(images, list):
111
  # For tab_batch
 
131
 
132
  image = image_ori.convert('RGB')
133
  # Preprocess the image
134
+ image_preprocessor = ImagePreprocessor() #(resolution=tuple(resolution))
135
  image_proc = image_preprocessor.proc(image)
136
  image_proc = image_proc.unsqueeze(0)
137
 
 
184
  fn=predict,
185
  inputs=[
186
  gr.Image(label='Upload an image'),
187
+ #gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
188
+ #gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
189
  ],
190
  outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
191
+ #examples=examples,
192
  api_name="image",
193
  description=descriptions,
194
  )
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  demo = gr.TabbedInterface(
197
+ [tab_image],
198
+ ['image'],
199
+ title="BiRefNet demo for subject extraction.",
200
  )
201
 
202
  if __name__ == "__main__":