File size: 4,019 Bytes
9639dd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e84a93e
9639dd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
981b92f
 
 
 
9639dd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import subprocess
import spaces
import torch
import cv2
import uuid
import gradio as gr
import numpy as np

from PIL import Image
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan.utils import GFPGANer
from realesrgan.utils import RealESRGANer

def runcmd(cmd, verbose = False):

    process = subprocess.Popen(
        cmd,
        stdout = subprocess.PIPE,
        stderr = subprocess.PIPE,
        text = True,
        shell = True
    )
    std_out, std_err = process.communicate()
    if verbose:
        print(std_out.strip(), std_err)
    pass


if not os.path.exists('GFPGANv1.4.pth'):
    runcmd("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
if not os.path.exists('realesr-general-x4v3.pth'):
    runcmd("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")



model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'realesr-general-x4v3.pth'
half = True if torch.cuda.is_available() else False
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)


@spaces.GPU(duration=15)
def enhance_image(
    input_image: Image,
    scale: int,
    enhance_mode: str,
):
    only_face = enhance_mode == "Only Face Enhance"
    if enhance_mode == "Only Face Enhance":
        face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2)
    elif enhance_mode == "Only Image Enhance":
        face_enhancer = None
    else:
        face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
    
    img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)

    h, w = img.shape[0:2]
    if h < 300:
        img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
    
    if face_enhancer is not None:
        _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=only_face, paste_back=True)
    else:
        output, _ = upsampler.enhance(img, outscale=scale)

    # if scale != 2:
    #     interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
    #     h, w = img.shape[0:2]
    #     output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)

    enhanced_image = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
    tmpPrefix = "/tmp/gradio/"

    extension = 'png'

    targetDir = f"{tmpPrefix}output/"
    if not os.path.exists(targetDir):
        os.makedirs(targetDir)

    enhanced_path = f"{targetDir}{uuid.uuid4()}.{extension}"
    enhanced_image.save(enhanced_path, quality=100)
        
    return enhanced_image, enhanced_path


def create_demo() -> gr.Blocks:

    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                scale = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Scale")
            with gr.Column():
                enhance_mode = gr.Dropdown(
                    label="Enhance Mode",
                    choices=[
                        "Only Face Enhance",
                        "Only Image Enhance",
                        "Face Enhance + Image Enhance",
                    ],
                    value="Face Enhance + Image Enhance",
                )
                g_btn = gr.Button("Enhance Image")
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Input Image", type="pil")
            with gr.Column():
                output_image = gr.Image(label="Enhanced Image", type="pil", interactive=False)
                enhance_image_path = gr.File(label="Download the Enhanced Image", interactive=False)
                
        
        g_btn.click(
            fn=enhance_image,
            inputs=[input_image, scale, enhance_mode],
            outputs=[output_image, enhance_image_path],
        )

    return demo