Spaces:
Sleeping
Sleeping
File size: 5,014 Bytes
bc3753a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import torch
from tqdm import tqdm
from src.utils.videoio import load_video_to_cv2
import cv2
class GeneratorWithLen(object):
""" From https://stackoverflow.com/a/7460929 """
def __init__(self, gen, length):
self.gen = gen
self.length = length
def __len__(self):
return self.length
def __iter__(self):
return self.gen
def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):
gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
return list(gen)
def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'):
""" Provide a generator with a __len__ method so that it can passed to functions that
call len()"""
if os.path.isfile(images): # handle video to images
# TODO: Create a generator version of load_video_to_cv2
images = load_video_to_cv2(images)
gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
gen_with_len = GeneratorWithLen(gen, len(images))
return gen_with_len
def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'):
""" Provide a generator function so that all of the enhanced images don't need
to be stored in memory at the same time. This can save tons of RAM compared to
the enhancer function. """
try:
from gfpgan import GFPGANer
except ImportError:
print("GFPGAN library not found. Installing...")
try:
# Use pip to install the library
import subprocess
subprocess.check_call(["pip", "install", "gfpgan"])
# Retry the import after installation
from gfpgan import GFPGANer
print("GFPGAN library installed successfully!")
except Exception as e:
print(f"Failed to install GFPGAN library. Error: {e}")
# Handle the error or raise it again if needed
print('face enhancer....')
if not isinstance(images, list) and os.path.isfile(images): # handle video to images
images = load_video_to_cv2(images)
# ------------------------ set up GFPGAN restorer ------------------------
if method == 'gfpgan':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.4'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
elif method == 'RestoreFormer':
arch = 'RestoreFormer'
channel_multiplier = 2
model_name = 'RestoreFormer'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
elif method == 'codeformer': # TODO:
arch = 'CodeFormer'
channel_multiplier = 2
model_name = 'CodeFormer'
url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
else:
raise ValueError(f'Wrong model version {method}.')
# ------------------------ set up background upsampler ------------------------
if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
import warnings
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.')
bg_upsampler = None
else:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
bg_upsampler = RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=model,
tile=400,
tile_pad=10,
pre_pad=0,
half=True) # need to set False in CPU mode
else:
bg_upsampler = None
# determine model paths
model_path = os.path.join('gfpgan/weights', model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('checkpoints', model_name + '.pth')
if not os.path.isfile(model_path):
# download pre-trained models from url
model_path = url
restorer = GFPGANer(
model_path=model_path,
upscale=2,
arch=arch,
channel_multiplier=channel_multiplier,
bg_upsampler=bg_upsampler)
# ------------------------ restore ------------------------
for idx in tqdm(range(len(images)), 'Face Enhancer:'):
img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
# restore faces and background if necessary
cropped_faces, restored_faces, r_img = restorer.enhance(
img,
has_aligned=False,
only_center_face=False,
paste_back=True)
r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
yield r_img
|