File size: 4,309 Bytes
9639dd1 e84a93e 9639dd1 981b92f 55c1cea f590c37 34afbba a362aff 34afbba 9639dd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import subprocess
import spaces
import torch
import cv2
import uuid
import gradio as gr
import numpy as np
from PIL import Image
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan.utils import GFPGANer
from realesrgan.utils import RealESRGANer
def runcmd(cmd, verbose = False):
process = subprocess.Popen(
cmd,
stdout = subprocess.PIPE,
stderr = subprocess.PIPE,
text = True,
shell = True
)
std_out, std_err = process.communicate()
if verbose:
print(std_out.strip(), std_err)
pass
if not os.path.exists('GFPGANv1.4.pth'):
runcmd("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
if not os.path.exists('realesr-general-x4v3.pth'):
runcmd("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'realesr-general-x4v3.pth'
half = True if torch.cuda.is_available() else False
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
@spaces.GPU(duration=15)
def enhance_image(
input_image: Image,
scale: int,
enhance_mode: str,
):
only_face = enhance_mode == "Only Face Enhance"
if enhance_mode == "Only Face Enhance":
face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2)
elif enhance_mode == "Only Image Enhance":
face_enhancer = None
else:
face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
h, w = img.shape[0:2]
if h < 300:
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
if face_enhancer is not None:
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=only_face, paste_back=True)
else:
output, _ = upsampler.enhance(img, outscale=scale)
# if scale != 2:
# interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
# h, w = img.shape[0:2]
# output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
h, w = output.shape[0:2]
max_size = 3480
if h > max_size:
w = int(w * max_size / h)
h = max_size
if w > max_size:
h = int(h * max_size / w)
w = max_size
output = cv2.resize(output, (w, h), interpolation=cv2.INTER_LANCZOS4)
enhanced_image = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
tmpPrefix = "/tmp/gradio/"
extension = 'png'
targetDir = f"{tmpPrefix}output/"
if not os.path.exists(targetDir):
os.makedirs(targetDir)
enhanced_path = f"{targetDir}{uuid.uuid4()}.{extension}"
enhanced_image.save(enhanced_path, quality=100)
return enhanced_image, enhanced_path
def create_demo() -> gr.Blocks:
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
scale = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Scale")
with gr.Column():
enhance_mode = gr.Dropdown(
label="Enhance Mode",
choices=[
"Only Face Enhance",
"Only Image Enhance",
"Face Enhance + Image Enhance",
],
value="Face Enhance + Image Enhance",
)
g_btn = gr.Button("Enhance Image")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
with gr.Column():
output_image = gr.Image(label="Enhanced Image", type="pil", interactive=False)
enhance_image_path = gr.File(label="Download the Enhanced Image", interactive=False)
g_btn.click(
fn=enhance_image,
inputs=[input_image, scale, enhance_mode],
outputs=[output_image, enhance_image_path],
)
return demo |