|
from __future__ import annotations |
|
|
|
import logging |
|
import os |
|
|
|
import torch |
|
|
|
from modules import ( |
|
devices, |
|
errors, |
|
face_restoration, |
|
face_restoration_utils, |
|
modelloader, |
|
shared, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" |
|
model_download_name = "GFPGANv1.4.pth" |
|
gfpgan_face_restorer: face_restoration.FaceRestoration | None = None |
|
|
|
|
|
class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): |
|
def name(self): |
|
return "GFPGAN" |
|
|
|
def get_device(self): |
|
return devices.device_gfpgan |
|
|
|
def load_net(self) -> torch.Module: |
|
for model_path in modelloader.load_models( |
|
model_path=self.model_path, |
|
model_url=model_url, |
|
command_path=self.model_path, |
|
download_name=model_download_name, |
|
ext_filter=['.pth'], |
|
): |
|
if 'GFPGAN' in os.path.basename(model_path): |
|
model = modelloader.load_spandrel_model( |
|
model_path, |
|
device=self.get_device(), |
|
expected_architecture='GFPGAN', |
|
).model |
|
model.different_w = True |
|
return model |
|
raise ValueError("No GFPGAN model found") |
|
|
|
def restore(self, np_image): |
|
def restore_face(cropped_face_t): |
|
assert self.net is not None |
|
return self.net(cropped_face_t, return_rgb=False)[0] |
|
|
|
return self.restore_with_helper(np_image, restore_face) |
|
|
|
|
|
def gfpgan_fix_faces(np_image): |
|
if gfpgan_face_restorer: |
|
return gfpgan_face_restorer.restore(np_image) |
|
logger.warning("GFPGAN face restorer not set up") |
|
return np_image |
|
|
|
|
|
def setup_model(dirname: str) -> None: |
|
global gfpgan_face_restorer |
|
|
|
try: |
|
face_restoration_utils.patch_facexlib(dirname) |
|
gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname) |
|
shared.face_restorers.append(gfpgan_face_restorer) |
|
except Exception: |
|
errors.report("Error setting up GFPGAN", exc_info=True) |
|
|