from typing import List import torch import models.ultrasharp.arch as arch from models.ultrasharp.util import infer_params, upscale class Ultrasharp: def __init__(self, model_path, tile_pad=0, tile=0): self.filename = model_path self.tile_pad = tile_pad self.tile = tile def enhance(self, img, outscale=4): state_dict = torch.load(self.filename, map_location="cpu") in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) model = arch.RRDBNet( in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus ) model.load_state_dict(state_dict) model.eval() model.to("cuda") img = upscale(model, img, self.tile_pad, self.tile) return img, None