|
import math |
|
import os |
|
from pathlib import Path |
|
from typing import Union |
|
|
|
import cv2 |
|
import numpy as np |
|
from basicsr.archs.rrdbnet_arch import RRDBNet |
|
from basicsr.utils.download_util import load_file_from_url |
|
from PIL import Image |
|
from realesrgan import RealESRGANer |
|
|
|
import internals.util.image as ImageUtil |
|
from internals.util.commons import download_image |
|
|
|
|
|
class Upscaler: |
|
__model_esrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" |
|
__model_esrgan_anime_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" |
|
|
|
def load(self): |
|
download_dir = Path(Path.home() / ".cache" / "realesrgan") |
|
download_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
self.__model_path = self.__preload_model(self.__model_esrgan_url, download_dir) |
|
self.__model_path_anime = self.__preload_model( |
|
self.__model_esrgan_anime_url, download_dir |
|
) |
|
|
|
def upscale(self, image: Union[str, Image.Image], resize_dimension: int) -> bytes: |
|
model = RRDBNet( |
|
num_in_ch=3, |
|
num_out_ch=3, |
|
num_feat=64, |
|
num_block=23, |
|
num_grow_ch=32, |
|
scale=4, |
|
) |
|
return self.__internal_upscale( |
|
image, resize_dimension, self.__model_path, model |
|
) |
|
|
|
def upscale_anime( |
|
self, image: Union[str, Image.Image], resize_dimension: int |
|
) -> bytes: |
|
model = RRDBNet( |
|
num_in_ch=3, |
|
num_out_ch=3, |
|
num_feat=64, |
|
num_block=23, |
|
num_grow_ch=32, |
|
scale=4, |
|
) |
|
return self.__internal_upscale( |
|
image, resize_dimension, self.__model_path_anime, model |
|
) |
|
|
|
def __preload_model(self, url: str, download_dir: Path): |
|
name = url.split("/")[-1] |
|
if not os.path.exists(str(download_dir / name)): |
|
return load_file_from_url( |
|
url=url, |
|
model_dir=str(download_dir), |
|
progress=True, |
|
file_name=None, |
|
) |
|
else: |
|
return str(download_dir / name) |
|
|
|
def __internal_upscale( |
|
self, |
|
image, |
|
resize_dimension: int, |
|
model_path: str, |
|
rrbdnet: RRDBNet, |
|
) -> bytes: |
|
if type(image) is str: |
|
image = download_image(image) |
|
image = ImageUtil.resize_image_to512(image) |
|
image = ImageUtil.to_bytes(image) |
|
|
|
upsampler = RealESRGANer( |
|
scale=4, model_path=model_path, model=rrbdnet, half="fp16", gpu_id="0" |
|
) |
|
image_array = np.frombuffer(image, dtype=np.uint8) |
|
input_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) |
|
dimension = min(input_image.shape[0], input_image.shape[1]) |
|
scale = max(math.floor(resize_dimension / dimension), 2) |
|
output, _ = upsampler.enhance(input_image, outscale=scale) |
|
out_bytes = cv2.imencode(".png", output)[1].tobytes() |
|
return out_bytes |
|
|