File size: 3,472 Bytes
003d203
 
 
0e6c023
 
 
aa16383
dcfda89
0e6c023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
003d203
aa16383
 
 
 
 
 
 
dcfda89
 
 
 
aa16383
dcfda89
 
 
 
 
 
aa16383
dcfda89
5fb3d3c
 
 
 
 
aa16383
 
 
 
 
dcfda89
 
 
 
 
30b17a2
 
aa16383
 
dcfda89
aa16383
 
 
dcfda89
 
 
 
aa16383
30b17a2
 
aa16383
 
 
e406805
9a3b780
aa16383
 
 
 
e7bef73
 
0e6c023
 
 
 
 
 
 
 
 
 
 
 
dcfda89
e406805
2ca2e3d
 
 
991bda2
2ca2e3d
e406805
aa16383
 
 
2ca2e3d
 
 
 
1167bff
aa16383
 
 
dcfda89
 
2ca2e3d
dcfda89
9a38c25
 
 
 
6ba9354
30b17a2
 
dcfda89
 
 
aa16383
0e6c023
 
aa16383
 
94386db
0e6c023
 
003d203
0e6c023
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
import spaces
import torch
from loadimg import load_img
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from diffusers import FluxFillPipeline
from PIL import Image, ImageOps

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

pipe = FluxFillPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to("cuda")


def prepare_image_and_mask(
    image,
    padding_top=0,
    padding_bottom=0,
    padding_left=0,
    padding_right=0,
):
    image = load_img(image).convert("RGB")
    # expand image (left,top,right,bottom)
    background = ImageOps.expand(
        image,
        border=(padding_left, padding_top, padding_right, padding_bottom),
        fill="white",
    )
    mask = Image.new("RGB", image.size, "black")
    mask = ImageOps.expand(
        mask,
        border=(padding_left, padding_top, padding_right, padding_bottom),
        fill="white",
    )
    return background, mask


def inpaint(
    image,
    padding_top=0,
    padding_bottom=0,
    padding_left=0,
    padding_right=0,
    prompt="",
    num_inference_steps=28,
    guidance_scale=50,
):
    background, mask = prepare_image_and_mask(
        image, padding_top, padding_bottom, padding_left, padding_right
    )

    result = pipe(
        prompt=prompt,
        height=background.height,
        width=background.width,
        image=background,
        mask_image=mask,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
    ).images[0]

    result = result.convert("RGBA")

    return result


def rmbg(image, url):
    if image is None:
        image = url
    image = load_img(image).convert("RGB")
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image


@spaces.GPU
def main(*args, progress=gr.Progress(track_tqdm=True)):
    api_num = args[0]
    args = args[1:]
    if api_num == 1:
        return rmbg(*args)
    elif api_num == 2:
        return inpaint(*args)


rmbg_tab = gr.Interface(
    fn=main,
    inputs=[gr.Number(1, visible=False), "image", "text"],
    outputs=["image"],
    api_name="rmbg",
    examples=[["./assets/Inpainting mask.png", None]],
)

outpaint_tab = gr.Interface(
    fn=main,
    inputs=[
        gr.Number(2, visible=False),
        "image",
        gr.Number(label="padding top"),
        gr.Number(label="padding bottom"),
        gr.Number(label="padding left"),
        gr.Number(label="padding right"),
        gr.Text(label="prompt"),
        gr.Number(value=50, label="num_inference_steps"),
        gr.Number(value=28, label="guidance_scale"),
    ],
    outputs=["image"],
    api_name="outpainting",
)

demo = gr.TabbedInterface(
    [rmbg_tab, outpaint_tab],
    ["remove background", "outpainting"],
    title="Utilities that require GPU",
)


demo.launch()