zhiweili commited on
Commit
9639dd1
·
1 Parent(s): bc3d042

add app_enhance

Browse files
Files changed (5) hide show
  1. .gitignore +3 -0
  2. app.py +10 -0
  3. app_enhance.py +119 -0
  4. enhance_utils.py +54 -0
  5. requirements.txt +7 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .vscode
2
+ .DS_Store
3
+ __pycache__
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from app_enhance import create_demo as create_demo_enhance
4
+
5
+ with gr.Blocks(css="style.css") as demo:
6
+ with gr.Tabs():
7
+ with gr.Tab(label="Enhance"):
8
+ create_demo_enhance()
9
+
10
+ demo.launch()
app_enhance.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import spaces
4
+ import torch
5
+ import cv2
6
+ import uuid
7
+ import gradio as gr
8
+ import numpy as np
9
+
10
+ from PIL import Image
11
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
12
+ from gfpgan.utils import GFPGANer
13
+ from realesrgan.utils import RealESRGANer
14
+
15
+ def runcmd(cmd, verbose = False):
16
+
17
+ process = subprocess.Popen(
18
+ cmd,
19
+ stdout = subprocess.PIPE,
20
+ stderr = subprocess.PIPE,
21
+ text = True,
22
+ shell = True
23
+ )
24
+ std_out, std_err = process.communicate()
25
+ if verbose:
26
+ print(std_out.strip(), std_err)
27
+ pass
28
+
29
+
30
+ if not os.path.exists('GFPGANv1.4.pth'):
31
+ runcmd("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
32
+ if not os.path.exists('realesr-general-x4v3.pth'):
33
+ runcmd("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
34
+
35
+
36
+
37
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
38
+ model_path = 'realesr-general-x4v3.pth'
39
+ half = True if torch.cuda.is_available() else False
40
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
41
+
42
+
43
+ @spaces.GPU(duration=5)
44
+ def enhance_image(
45
+ input_image: Image,
46
+ scale: int,
47
+ enhance_mode: str,
48
+ ):
49
+ only_face = enhance_mode == "Only Face Enhance"
50
+ if enhance_mode == "Only Face Enhance":
51
+ face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2)
52
+ elif enhance_mode == "Only Image Enhance":
53
+ face_enhancer = None
54
+ else:
55
+ face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
56
+
57
+ img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
58
+
59
+ h, w = img.shape[0:2]
60
+ if h < 300:
61
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
62
+
63
+ if face_enhancer is not None:
64
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=only_face, paste_back=True)
65
+ else:
66
+ output, _ = upsampler.enhance(img, outscale=scale)
67
+
68
+ if scale != 2:
69
+ interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
70
+ h, w = img.shape[0:2]
71
+ output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
72
+
73
+ enhanced_image = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
74
+ tmpPrefix = "/tmp/gradio/"
75
+
76
+ extension = 'png'
77
+
78
+ targetDir = f"{tmpPrefix}output/"
79
+ if not os.path.exists(targetDir):
80
+ os.makedirs(targetDir)
81
+
82
+ enhanced_path = f"{targetDir}{uuid.uuid4()}.{extension}"
83
+ enhanced_image.save(enhanced_path, quality=100)
84
+
85
+ return enhanced_image, enhanced_path
86
+
87
+
88
+ def create_demo() -> gr.Blocks:
89
+
90
+ with gr.Blocks() as demo:
91
+ with gr.Row():
92
+ with gr.Column():
93
+ scale = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Scale")
94
+ with gr.Column():
95
+ enhance_mode = gr.Dropdown(
96
+ label="Enhance Mode",
97
+ choices=[
98
+ "Only Face Enhance",
99
+ "Only Image Enhance",
100
+ "Face Enhance + Image Enhance",
101
+ ],
102
+ value="Face Enhance + Image Enhance",
103
+ )
104
+ g_btn = gr.Button("Enhance Image")
105
+ with gr.Row():
106
+ with gr.Column():
107
+ input_image = gr.Image(label="Input Image", type="pil")
108
+ with gr.Column():
109
+ output_image = gr.Image(label="Enhanced Image", type="pil", interactive=False)
110
+ enhance_image_path = gr.File(label="Download the Enhanced Image", interactive=False)
111
+
112
+
113
+ g_btn.click(
114
+ fn=enhance_image,
115
+ inputs=[input_image, scale, enhance_mode],
116
+ outputs=[output_image, enhance_image_path],
117
+ )
118
+
119
+ return demo
enhance_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import subprocess
6
+
7
+ from PIL import Image
8
+ from gfpgan.utils import GFPGANer
9
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
10
+ from realesrgan.utils import RealESRGANer
11
+
12
+ def runcmd(cmd, verbose = False, *args, **kwargs):
13
+
14
+ process = subprocess.Popen(
15
+ cmd,
16
+ stdout = subprocess.PIPE,
17
+ stderr = subprocess.PIPE,
18
+ text = True,
19
+ shell = True
20
+ )
21
+ std_out, std_err = process.communicate()
22
+ if verbose:
23
+ print(std_out.strip(), std_err)
24
+ pass
25
+
26
+ runcmd("pip freeze")
27
+ if not os.path.exists('GFPGANv1.4.pth'):
28
+ runcmd("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
29
+ if not os.path.exists('realesr-general-x4v3.pth'):
30
+ runcmd("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
31
+
32
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
33
+ model_path = 'realesr-general-x4v3.pth'
34
+ half = True if torch.cuda.is_available() else False
35
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
36
+
37
+ face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2)
38
+
39
+ def enhance_image(
40
+ pil_image: Image,
41
+ enhance_face: bool = True,
42
+ ):
43
+ img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
44
+
45
+ h, w = img.shape[0:2]
46
+ if h < 300:
47
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
48
+ if enhance_face:
49
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=True, paste_back=True)
50
+ else:
51
+ output, _ = upsampler.enhance(img, outscale=2)
52
+ pil_output = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
53
+
54
+ return pil_output
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ spaces
4
+ gfpgan
5
+ git+https://github.com/XPixelGroup/BasicSR@master
6
+ facexlib
7
+ realesrgan