import os import torch from gfpgan import GFPGANer from tqdm import tqdm import cv2 from realesrgan import RealESRGANer from basicsr.archs.rrdbnet_arch import RRDBNet import warnings from enum import Enum class EnhancementMethod(str, Enum): gfpgan = "gfpgan" RestoreFormer = "RestoreFormer" codeformer = "codeformer" realesrgan = "realesrgan" class Enhancer: def __init__(self, method: EnhancementMethod, background_enhancement=True, upscale=2): self.method = method self.background_enhancement = background_enhancement self.upscale = upscale self.bg_upsampler = None self.realesrgan_enhancer = None if self.method != EnhancementMethod.realesrgan: self.setup_face_enhancer() if self.background_enhancement: self.setup_background_enhancer() else: self.setup_realesrgan_enhancer() def setup_background_enhancer(self): if not torch.cuda.is_available(): warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it.') return model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale) model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x{self.upscale}plus.pth' self.bg_upsampler = RealESRGANer( scale=self.upscale, model_path=model_path, model=model, tile=400, tile_pad=10, pre_pad=0, half=True) def setup_realesrgan_enhancer(self): if not torch.cuda.is_available(): raise ValueError('CUDA is not available for RealESRGAN') model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale) model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x{self.upscale}plus.pth' self.realesrgan_enhancer = RealESRGANer( scale=self.upscale, model_path=model_path, model=model, tile=400, tile_pad=10, pre_pad=0, half=True) def setup_face_enhancer(self): model_configs = { EnhancementMethod.gfpgan: { 'arch': 'clean', 'channel_multiplier': 2, 'model_name': 'GFPGANv1.4', 'url': 'https://huggingface.co/gmk123/GFPGAN/resolve/main/GFPGANv1.4.pth' }, EnhancementMethod.RestoreFormer: { 'arch': 'RestoreFormer', 'channel_multiplier': 2, 'model_name': 'RestoreFormer', 'url': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' }, EnhancementMethod.codeformer: { 'arch': 'CodeFormer', 'channel_multiplier': 2, 'model_name': 'CodeFormer', 'url': 'https://huggingface.co/sinadi/aar/resolve/main/codeformer.pth' } } config = model_configs.get(self.method) if not config: raise ValueError(f'Wrong model version {self.method}') model_path = os.path.join('gfpgan/weights', config['model_name'] + '.pth') if not os.path.isfile(model_path): model_path = os.path.join('checkpoints', config['model_name'] + '.pth') if not os.path.isfile(model_path): model_path = config['url'] self.face_enhancer = GFPGANer( model_path=model_path, upscale=self.upscale, arch=config['arch'], channel_multiplier=config['channel_multiplier'], bg_upsampler=self.bg_upsampler) def check_image_resolution(self, image): height, width, _ = image.shape return width, height async def enhance(self, image): img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) width, height = self.check_image_resolution(img) if self.method == EnhancementMethod.realesrgan: enhanced_img, _ = await asyncio.to_thread(self.realesrgan_enhancer.enhance, img, outscale=self.upscale) else: _, _, enhanced_img = await asyncio.to_thread(self.face_enhancer.enhance, img, has_aligned=False, only_center_face=False, paste_back=True) enhanced_img = cv2.cvtColor(enhanced_img, cv2.COLOR_BGR2RGB) enhanced_width, enhanced_height = self.check_image_resolution(enhanced_img) return enhanced_img, (width, height), (enhanced_width, enhanced_height)