import os import torch from tqdm import tqdm from src.utils.videoio import load_video_to_cv2 import cv2 class GeneratorWithLen(object): """ From https://stackoverflow.com/a/7460929 """ def __init__(self, gen, length): self.gen = gen self.length = length def __len__(self): return self.length def __iter__(self): return self.gen def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'): gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) return list(gen) def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'): """ Provide a generator with a __len__ method so that it can passed to functions that call len()""" if os.path.isfile(images): # handle video to images # TODO: Create a generator version of load_video_to_cv2 images = load_video_to_cv2(images) gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) gen_with_len = GeneratorWithLen(gen, len(images)) return gen_with_len def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'): """ Provide a generator function so that all of the enhanced images don't need to be stored in memory at the same time. This can save tons of RAM compared to the enhancer function. """ try: from gfpgan import GFPGANer except ImportError: print("GFPGAN library not found. Installing...") try: # Use pip to install the library import subprocess subprocess.check_call(["pip", "install", "gfpgan"]) # Retry the import after installation from gfpgan import GFPGANer print("GFPGAN library installed successfully!") except Exception as e: print(f"Failed to install GFPGAN library. Error: {e}") # Handle the error or raise it again if needed print('face enhancer....') if not isinstance(images, list) and os.path.isfile(images): # handle video to images images = load_video_to_cv2(images) # ------------------------ set up GFPGAN restorer ------------------------ if method == 'gfpgan': arch = 'clean' channel_multiplier = 2 model_name = 'GFPGANv1.4' url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' elif method == 'RestoreFormer': arch = 'RestoreFormer' channel_multiplier = 2 model_name = 'RestoreFormer' url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' elif method == 'codeformer': # TODO: arch = 'CodeFormer' channel_multiplier = 2 model_name = 'CodeFormer' url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' else: raise ValueError(f'Wrong model version {method}.') # ------------------------ set up background upsampler ------------------------ if bg_upsampler == 'realesrgan': if not torch.cuda.is_available(): # CPU import warnings warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' 'If you really want to use it, please modify the corresponding codes.') bg_upsampler = None else: from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) bg_upsampler = RealESRGANer( scale=2, model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', model=model, tile=400, tile_pad=10, pre_pad=0, half=True) # need to set False in CPU mode else: bg_upsampler = None # determine model paths model_path = os.path.join('gfpgan/weights', model_name + '.pth') if not os.path.isfile(model_path): model_path = os.path.join('checkpoints', model_name + '.pth') if not os.path.isfile(model_path): # download pre-trained models from url model_path = url restorer = GFPGANer( model_path=model_path, upscale=2, arch=arch, channel_multiplier=channel_multiplier, bg_upsampler=bg_upsampler) # ------------------------ restore ------------------------ for idx in tqdm(range(len(images)), 'Face Enhancer:'): img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR) # restore faces and background if necessary cropped_faces, restored_faces, r_img = restorer.enhance( img, has_aligned=False, only_center_face=False, paste_back=True) r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) yield r_img