lop / src /utils /face_enhancer.py
thepianist9's picture
Upload folder using huggingface_hub
5a256aa verified
import os
import torch
from tqdm import tqdm
from src.utils.videoio import load_video_to_cv2
import cv2
class GeneratorWithLen(object):
""" From https://stackoverflow.com/a/7460929 """
def __init__(self, gen, length):
self.gen = gen
self.length = length
def __len__(self):
return self.length
def __iter__(self):
return self.gen
def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):
gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
return list(gen)
def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'):
""" Provide a generator with a __len__ method so that it can passed to functions that
call len()"""
if os.path.isfile(images): # handle video to images
# TODO: Create a generator version of load_video_to_cv2
images = load_video_to_cv2(images)
gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
gen_with_len = GeneratorWithLen(gen, len(images))
return gen_with_len
def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'):
""" Provide a generator function so that all of the enhanced images don't need
to be stored in memory at the same time. This can save tons of RAM compared to
the enhancer function. """
try:
from gfpgan import GFPGANer
except ImportError:
print("GFPGAN library not found. Installing...")
try:
# Use pip to install the library
import subprocess
subprocess.check_call(["pip", "install", "gfpgan"])
# Retry the import after installation
from gfpgan import GFPGANer
print("GFPGAN library installed successfully!")
except Exception as e:
print(f"Failed to install GFPGAN library. Error: {e}")
# Handle the error or raise it again if needed
print('face enhancer....')
if not isinstance(images, list) and os.path.isfile(images): # handle video to images
images = load_video_to_cv2(images)
# ------------------------ set up GFPGAN restorer ------------------------
if method == 'gfpgan':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.4'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
elif method == 'RestoreFormer':
arch = 'RestoreFormer'
channel_multiplier = 2
model_name = 'RestoreFormer'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
elif method == 'codeformer': # TODO:
arch = 'CodeFormer'
channel_multiplier = 2
model_name = 'CodeFormer'
url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
else:
raise ValueError(f'Wrong model version {method}.')
# ------------------------ set up background upsampler ------------------------
if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
import warnings
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.')
bg_upsampler = None
else:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
bg_upsampler = RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=model,
tile=400,
tile_pad=10,
pre_pad=0,
half=True) # need to set False in CPU mode
else:
bg_upsampler = None
# determine model paths
model_path = os.path.join('gfpgan/weights', model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('checkpoints', model_name + '.pth')
if not os.path.isfile(model_path):
# download pre-trained models from url
model_path = url
restorer = GFPGANer(
model_path=model_path,
upscale=2,
arch=arch,
channel_multiplier=channel_multiplier,
bg_upsampler=bg_upsampler)
# ------------------------ restore ------------------------
for idx in tqdm(range(len(images)), 'Face Enhancer:'):
img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
# restore faces and background if necessary
cropped_faces, restored_faces, r_img = restorer.enhance(
img,
has_aligned=False,
only_center_face=False,
paste_back=True)
r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
yield r_img