not-lain commited on
Commit
958511f
β€’
1 Parent(s): 521737f

🌘wπŸŒ–

Browse files
Files changed (2) hide show
  1. app.py +29 -16
  2. cool kid.jpg +0 -0
app.py CHANGED
@@ -6,26 +6,30 @@ from transformers import AutoModelForImageSegmentation
6
  import torch
7
  from torchvision import transforms
8
  from PIL import Image
9
- torch.set_float32_matmul_precision(['high', 'highest'][0])
10
 
11
- birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
12
- birefnet.to("cuda")
13
- transform_image = transforms.Compose([
14
- transforms.Resize((1024, 1024)),
15
- transforms.ToTensor(),
16
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
17
- ])
18
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  @spaces.GPU
22
  def fn(image):
23
- im = load_img(image,output_type="pil")
24
- im = im.convert('RGB')
25
  image_size = im.size
26
  origin = im.copy()
27
  image = load_img(im)
28
- input_images = transform_image(image).unsqueeze(0).to('cuda')
29
  # Prediction
30
  with torch.no_grad():
31
  preds = birefnet(input_images)[-1].sigmoid().cpu()
@@ -33,7 +37,8 @@ def fn(image):
33
  pred_pil = transforms.ToPILImage()(pred)
34
  mask = pred_pil.resize(image_size)
35
  image.putalpha(mask)
36
- return (image , origin)
 
37
 
38
  slider1 = ImageSlider(label="birefnet", type="pil")
39
  slider2 = ImageSlider(label="birefnet", type="pil")
@@ -41,11 +46,19 @@ image = gr.Image(label="Upload an image")
41
  text = gr.Textbox(label="Paste an image URL")
42
 
43
  chameleon = Image.open("chameleon.jpg")
 
 
44
  url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
45
- tab1 = gr.Interface(fn,inputs= image, outputs= slider1,examples=[chameleon], api_name="image")
46
- tab2 = gr.Interface(fn,inputs= text, outputs= slider2,examples=[url], api_name="text")
 
 
 
 
47
 
48
- demo = gr.TabbedInterface([tab1,tab2],["image","text"],title="birefnet with image slider")
 
 
49
 
50
  if __name__ == "__main__":
51
- demo.launch()
 
6
  import torch
7
  from torchvision import transforms
8
  from PIL import Image
 
9
 
10
+ torch.set_float32_matmul_precision(["high", "highest"][0])
 
 
 
 
 
 
11
 
12
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
13
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
14
+ )
15
+ birefnet.to("cuda")
16
+ transform_image = transforms.Compose(
17
+ [
18
+ transforms.Resize((1024, 1024)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
21
+ ]
22
+ )
23
 
24
 
25
  @spaces.GPU
26
  def fn(image):
27
+ im = load_img(image, output_type="pil")
28
+ im = im.convert("RGB")
29
  image_size = im.size
30
  origin = im.copy()
31
  image = load_img(im)
32
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
33
  # Prediction
34
  with torch.no_grad():
35
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
37
  pred_pil = transforms.ToPILImage()(pred)
38
  mask = pred_pil.resize(image_size)
39
  image.putalpha(mask)
40
+ return (image, origin)
41
+
42
 
43
  slider1 = ImageSlider(label="birefnet", type="pil")
44
  slider2 = ImageSlider(label="birefnet", type="pil")
 
46
  text = gr.Textbox(label="Paste an image URL")
47
 
48
  chameleon = Image.open("chameleon.jpg")
49
+ cool = Image.open("cool kid.jpg")
50
+
51
  url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
52
+ tab1 = gr.Interface(
53
+ fn, inputs=image, outputs=slider1, examples=[[chameleon], [cool]], api_name="image"
54
+ )
55
+
56
+ tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text")
57
+
58
 
59
+ demo = gr.TabbedInterface(
60
+ [tab1, tab2], ["image", "text"], title="birefnet for background removal"
61
+ )
62
 
63
  if __name__ == "__main__":
64
+ demo.launch()
cool kid.jpg ADDED