jayparmr's picture
update : inference
35575bb verified
from typing import List
import cv2
import numpy as np
import torch
import models.ultrasharp.arch as arch
from models.ultrasharp.util import infer_params, upscale
class Ultrasharp:
def __init__(self, model_path, tile_pad=0, tile=0):
self.filename = model_path
self.tile_pad = tile_pad
self.tile = tile
def enhance(self, img, outscale=4):
state_dict = torch.load(self.filename, map_location="cpu")
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
model = arch.RRDBNet(
in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus
)
model.load_state_dict(state_dict)
model.eval()
model.to("cuda")
if img.shape[2] == 4: # RGBA image with alpha channel
img_mode = "RGBA"
alpha = img[:, :, 3]
img = img[:, :, 0:3]
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = "RGB"
img = upscale(model, img, self.tile_pad, self.tile)
# process alpha channel if necessary
if img_mode == "RGBA":
output_alpha = upscale(model, alpha, self.tile_pad, self.tile)
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
# output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
# merge the alpha channel
img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
img[:, :, 3] = output_alpha
return img, None