sammichang commited on
Commit
c859cbd
Β·
verified Β·
1 Parent(s): 248ef24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -78
app.py CHANGED
@@ -1,78 +1,71 @@
1
- import gradio as gr
2
- import cv2
3
- import torch
4
- import numpy as np
5
- from torchvision import transforms
6
-
7
- title = "η§»ι™€θƒŒζ™― Demo"
8
- description = "θ‡ͺε‹•εŽ»ι™€θƒŒζ™―εœ–η‰‡."
9
- article = "<p style='text-align: center'><a href='https://news.machinelearning.sg/posts/beautiful_profile_pics_remove_background_image_with_deeplabv3/'>Blog</a> | <a href='https://github.com/eugenesiow/practical-ml'>Github Repo</a></p>"
10
-
11
-
12
- def make_transparent_foreground(pic, mask):
13
- # split the image into channels
14
- b, g, r = cv2.split(np.array(pic).astype('uint8'))
15
- # add an alpha channel with and fill all with transparent pixels (max 255)
16
- a = np.ones(mask.shape, dtype='uint8') * 255
17
- # merge the alpha channel back
18
- alpha_im = cv2.merge([b, g, r, a], 4)
19
- # create a transparent background
20
- bg = np.zeros(alpha_im.shape)
21
- # setup the new mask
22
- new_mask = np.stack([mask, mask, mask, mask], axis=2)
23
- # copy only the foreground color pixels from the original image where mask is set
24
- foreground = np.where(new_mask, alpha_im, bg).astype(np.uint8)
25
-
26
- return foreground
27
-
28
-
29
- def remove_background(input_image):
30
- preprocess = transforms.Compose([
31
- transforms.ToTensor(),
32
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
- ])
34
-
35
- input_tensor = preprocess(input_image)
36
- input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
37
-
38
- # move the input and model to GPU for speed if available
39
- if torch.cuda.is_available():
40
- input_batch = input_batch.to('cuda')
41
- model.to('cuda')
42
-
43
- with torch.no_grad():
44
- output = model(input_batch)['out'][0]
45
- output_predictions = output.argmax(0)
46
-
47
- # create a binary (black and white) mask of the profile foreground
48
- mask = output_predictions.byte().cpu().numpy()
49
- background = np.zeros(mask.shape)
50
- bin_mask = np.where(mask, 255, background).astype(np.uint8)
51
-
52
- foreground = make_transparent_foreground(input_image, bin_mask)
53
-
54
- return foreground, bin_mask
55
-
56
-
57
- def inference(img):
58
- foreground, _ = remove_background(img)
59
- return foreground
60
-
61
-
62
- torch.hub.download_url_to_file('https://pbs.twimg.com/profile_images/691700243809718272/z7XZUARB_400x400.jpg',
63
- 'capture.png')
64
- torch.hub.download_url_to_file('https://hai.stanford.edu/sites/default/files/styles/person_medium/public/2020-03/hai_1512feifei.png?itok=INFuLABp',
65
- 'girl1.png')
66
- model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
67
- model.eval()
68
-
69
- gr.Interface(
70
- inference,
71
- gr.Image(type="pil", label="Input"),
72
- gr.Image(type="pil", label="Output"),
73
- title=title,
74
- description=description,
75
- article=article,
76
- examples=[['removebg/girl1.png'],['removebg/girl2.png'],['removebg/girl3.png'],['removebg/gonfu1.jpg'],['removebg/angel.png']],
77
- #enable_queue=True,
78
- ).launch(debug=False)
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from transformers import AutoModelForImageSegmentation
4
+ from PIL import Image
5
+ import requests
6
+ from io import BytesIO
7
+ import gradio as gr
8
+
9
+ title = "η§»ι™€θƒŒζ™― Demo"
10
+ description = "δΈŠε‚³εœ–η‰‡ ,θ‡ͺε‹•εŽ»ι™€θƒŒζ™―."
11
+
12
+ # Set up CUDA if available
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ torch.set_float32_matmul_precision("high")
15
+
16
+ # Load the model
17
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
18
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
19
+ )
20
+ birefnet.to(device)
21
+
22
+ # Define image transformations
23
+ transform_image = transforms.Compose([
24
+ transforms.Resize((256, 256)),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
27
+ ])
28
+
29
+
30
+ def load_img(image_path_or_url):
31
+ if image_path_or_url.startswith('http'):
32
+ response = requests.get(image_path_or_url)
33
+ img = Image.open(BytesIO(response.content))
34
+ else:
35
+ img = Image.open(image_path_or_url)
36
+ return img.convert("RGB")
37
+
38
+ def process(image):
39
+ image_size = image.size
40
+ input_images = transform_image(image).unsqueeze(0).to(device)
41
+
42
+ with torch.no_grad():
43
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
44
+
45
+ pred = preds[0].squeeze()
46
+ pred_pil = transforms.ToPILImage()(pred)
47
+ mask = pred_pil.resize(image_size)
48
+
49
+ # Create a new image with transparency
50
+ transparent_image = Image.new("RGBA", image.size)
51
+ transparent_image.paste(image, (0, 0))
52
+ transparent_image.putalpha(mask) # Apply mask to the new image
53
+
54
+ return transparent_image # Return the new transparent image
55
+
56
+ def remove_background_gradio(image):
57
+ processed_img = process(image)
58
+ return processed_img
59
+
60
+
61
+ # Create the Gradio interface with drag-and-drop and paste functionality
62
+ demo = gr.Interface(
63
+ fn=remove_background_gradio,
64
+ inputs = gr.Image(type="pil"), # Remove 'source' argument
65
+ outputs = gr.Image(type="pil"),
66
+ title = title,
67
+ description = description,
68
+ # examples=[['girl1.png'],['girl2.png'],['girl3.png'],['gonfu1.jpg'],['removebg/angel.png']],
69
+ )
70
+
71
+ demo.launch(share=True) # Launch the interface and get a shareable link