not-lain commited on
Commit
c368dca
Β·
1 Parent(s): e546fea

🌘wπŸŒ–

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -8,8 +8,8 @@ from torchvision import transforms
8
 
9
  # torch.set_float32_matmul_precision(['high', 'highest'][0])
10
 
11
- birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True,device="auto",torch_dtype=torch.float16)
12
-
13
  transform_image = transforms.Compose([
14
  transforms.Resize((1024, 1024)),
15
  transforms.ToTensor(),
@@ -33,7 +33,7 @@ def fn(image):
33
  return out
34
 
35
  slider1 = ImageSlider(label="birefnet", type="pil")
36
- slider2 = ImageSlider(label="RMBG", type="pil")
37
  image = gr.Image(label="Upload an image")
38
  text = gr.Textbox(label="Paste an image URL")
39
 
@@ -41,7 +41,7 @@ text = gr.Textbox(label="Paste an image URL")
41
  tab1 = gr.Interface(fn,inputs= image, outputs= slider1, api_name="image")
42
  tab2 = gr.Interface(fn,inputs= text, outputs= slider2, api_name="text")
43
 
44
- demo = gr.TabbedInterface([tab1,tab2],["image","text"],title="RMBG with image slider")
45
 
46
  if __name__ == "__main__":
47
  demo.launch()
 
8
 
9
  # torch.set_float32_matmul_precision(['high', 'highest'][0])
10
 
11
+ birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True,torch_dtype=torch.float16)
12
+ birefnet.to("cuda")
13
  transform_image = transforms.Compose([
14
  transforms.Resize((1024, 1024)),
15
  transforms.ToTensor(),
 
33
  return out
34
 
35
  slider1 = ImageSlider(label="birefnet", type="pil")
36
+ slider2 = ImageSlider(label="birefnet", type="pil")
37
  image = gr.Image(label="Upload an image")
38
  text = gr.Textbox(label="Paste an image URL")
39
 
 
41
  tab1 = gr.Interface(fn,inputs= image, outputs= slider1, api_name="image")
42
  tab2 = gr.Interface(fn,inputs= text, outputs= slider2, api_name="text")
43
 
44
+ demo = gr.TabbedInterface([tab1,tab2],["image","text"],title="birefnet with image slider")
45
 
46
  if __name__ == "__main__":
47
  demo.launch()