guetLzy's picture
Update app.py
976a1fb verified
import os
import cv2
import subprocess
import subprocess
import subprocess
# 修复导入语句
subprocess.run([
"sed", "-i",
"8s/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/",
"/usr/local/lib/python3.10/site-packages/basicsr/data/degradations.py"
], check=True)
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
import gradio as gr
# 模型配置
MODEL_OPTIONS = {
"RealESRGAN_x4plus": {
"model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
"netscale": 4,
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
},
"RealESRNet_x4plus": {
"model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
"netscale": 4,
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
},
"RealESRGAN_x4plus_anime_6B": {
"model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4),
"netscale": 4,
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
},
"RealESRGAN_x2plus": {
"model": lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2),
"netscale": 2,
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
},
"realesr-animevideov3": {
"model": lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'),
"netscale": 4,
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"
},
"realesr-general-x4v3": {
"model": lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'),
"netscale": 4,
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
}
}
def load_model(model_name):
"""加载并初始化 Real-ESRGAN 模型"""
model_config = MODEL_OPTIONS[model_name]
model = model_config["model"]()
netscale = model_config["netscale"]
file_url = model_config["file_url"]
# 下载模型权重
model_path = os.path.join("weights", f"{model_name}.pth")
if not os.path.isfile(model_path):
os.makedirs("weights", exist_ok=True)
model_path = load_file_from_url(url=file_url, model_dir="weights", progress=True, file_name=None)
# 初始化 RealESRGANer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
model=model,
tile=0, # 默认无分块
tile_pad=10,
pre_pad=0,
half=True # 默认使用 fp16
)
return upsampler
def enhance_image(input_image, model_name, outscale, face_enhance):
"""执行图像超分辨率增强"""
# 将 Gradio 上传的图像转换为 OpenCV 格式
img = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
# 确定图像模式
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
else:
img_mode = None
# 加载模型
upsampler = load_model(model_name)
# 是否使用人脸增强
if face_enhance:
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler
)
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = upsampler.enhance(img, outscale=outscale)
# 将结果转换回 RGB 格式以供 Gradio 显示
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
return output
# Gradio 界面
with gr.Blocks(title="Real-ESRGAN 图像超分辨率",theme="NoCrypt/miku") as app:
gr.Markdown("## Real-ESRGAN 图像超分辨率系统")
gr.Markdown("上传图像,选择模型和参数,生成高清图像!")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="输入图像", type="numpy")
model_dropdown = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
label="选择模型",
value="RealESRGAN_x4plus"
)
outscale = gr.Slider(minimum=1, maximum=4, step=0.5, value=4, label="放大倍数")
face_enhance = gr.Checkbox(label="启用人脸增强", value=False)
enhance_btn = gr.Button("开始增强", variant="primary")
with gr.Column():
output_image = gr.Image(label="增强结果", type="numpy")
enhance_btn.click(
fn=enhance_image,
inputs=[input_image, model_dropdown, outscale, face_enhance],
outputs=[output_image],
api_name="enhance"
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)