File size: 1,889 Bytes
e546fea
 
 
f9858d8
e546fea
 
 
 
958511f
e546fea
958511f
 
 
f9858d8
 
958511f
 
 
 
 
 
 
e546fea
 
f9858d8
e546fea
958511f
 
3e75999
 
e546fea
4c33e65
e546fea
 
 
 
 
3e75999
 
43b320b
2e133a7
 
958511f
e546fea
43b320b
e546fea
2e133a7
e546fea
20a2fe0
 
958511f
521737f
3231e76
2e133a7
958511f
 
3231e76
958511f
e546fea
3231e76
 
 
e546fea
 
958511f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
# import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
birefnet.to(device)
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


# @spaces.GPU
def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    image_size = im.size
    origin = im.copy()
    image = load_img(im)
    input_images = transform_image(image).unsqueeze(0).to(device)
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    # return (image, origin)
    image.save("img.png","PNG")
    return (image , "img.png")


img1 = gr.Image(type= "pil", image_mode="RGBA")
image = gr.Image(label="Upload an image")
file = gr.File()


chameleon = load_img("chameleon.jpg", output_type="pil")

url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
demo = gr.Interface(
    fn, inputs=image, outputs=[img1,file], examples=[chameleon], api_name="image"
)

# tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text")


# demo = gr.TabbedInterface(
#     [tab1, tab2], ["image", "text"], title="birefnet for background removal (WIP 🛠️, works for linux)"
# )

if __name__ == "__main__":
    demo.launch()