File size: 3,767 Bytes
542c815
3f8e328
542c815
 
 
a888400
d6e753e
 
 
8a357d1
542c815
4f91b95
 
542c815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db969ee
542c815
 
1605763
542c815
 
 
 
 
 
 
 
 
70974c3
542c815
 
 
 
 
 
 
70974c3
542c815
 
 
 
 
 
70974c3
542c815
 
d909bca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542c815
d909bca
 
 
 
542c815
d909bca
 
c530952
d909bca
c530952
 
 
 
 
 
 
 
 
d909bca
c530952
 
d909bca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple

net=BriaRMBG()
model_path = "./model.pth"
if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_path))
    net=net.cuda()
else:
    net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval() 

def image_size_by_min_resolution(
    image: Image.Image,
    resolution: Tuple,
    resample=None,
):
    w, h = image.size  

    image_min = min(w, h)
    resolution_min = min(resolution)

    scale_factor = image_min / resolution_min

    resize_to: Tuple[int, int] = (
        int(w // scale_factor),
        int(h // scale_factor),
    )
    return resize_to
    

def resize_image(image):
    image = image.convert('RGB')
    new_image_size = image_size_by_min_resolution(image=image,resolution=(1024, 1024))
    image = image.resize(new_image_size, Image.BILINEAR)
    return image


def process(Image):

    # prepare input
    orig_image = Image.fromarray(image)
    w,h = orig_im_size = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
    im_tensor = torch.unsqueeze(im_tensor,0)
    im_tensor = torch.divide(im_tensor,255.0)
    im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
    if torch.cuda.is_available():
        im_tensor=im_tensor.cuda()

    #inference
    result=net(im_tensor)
    # post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result-mi)/(ma-mi)    
    # image to pil
    im_array = (result*255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0,0,0))
    new_im.paste(orig_image, mask=pil_im)

    return [orig_image, new_im]


# block = gr.Blocks().queue()

# with block:
#     gr.Markdown("## BRIA RMBG 1.4")
#     gr.HTML('''
#       <p style="margin-bottom: 10px; font-size: 94%">
#         This is a demo for BRIA RMBG 1.4 that using
#         <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone. 
#       </p>
#     ''')
#     with gr.Row():
#         with gr.Column():
#             input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
#             # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
#             run_button = gr.Button(value="Run")
            
#         with gr.Column():
#             result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
#     ips = [input_image]
#     run_button.click(fn=process, inputs=ips, outputs=[result_gallery])

# block.launch(debug = True)

# block = gr.Blocks().queue()

gr.Markdown("## BRIA RMBG 1.4")
gr.HTML('''
  <p style="margin-bottom: 10px; font-size: 94%">
    This is a demo for BRIA RMBG 1.4 that using
    <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone. 
  </p>
''')
title = "Background Removal"
description = "Remove Image Background"
examples = [['./input.jpg'],]
output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
demo = gr.Interface(fn=process,inputs="Image", outputs=output, examples=examples, title=title, description=description)

if __name__ == "__main__":
    demo.launch(share=False)