from typing import List import cv2 import numpy as np import torch import models.ultrasharp.arch as arch from models.ultrasharp.util import infer_params, upscale class Ultrasharp: def __init__(self, model_path, tile_pad=0, tile=0): self.filename = model_path self.tile_pad = tile_pad self.tile = tile def enhance(self, img, outscale=4): state_dict = torch.load(self.filename, map_location="cpu") in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) model = arch.RRDBNet( in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus ) model.load_state_dict(state_dict) model.eval() model.to("cuda") if img.shape[2] == 4: # RGBA image with alpha channel img_mode = "RGBA" alpha = img[:, :, 3] img = img[:, :, 0:3] alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) else: img_mode = "RGB" img = upscale(model, img, self.tile_pad, self.tile) # process alpha channel if necessary if img_mode == "RGBA": output_alpha = upscale(model, alpha, self.tile_pad, self.tile) output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) # output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) # merge the alpha channel img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA) img[:, :, 3] = output_alpha return img, None