from typing import List | |
import torch | |
import models.ultrasharp.arch as arch | |
from models.ultrasharp.util import infer_params, upscale | |
class Ultrasharp: | |
def __init__(self, model_path, tile_pad=0, tile=0): | |
self.filename = model_path | |
self.tile_pad = tile_pad | |
self.tile = tile | |
def enhance(self, img, outscale=4): | |
state_dict = torch.load(self.filename, map_location="cpu") | |
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) | |
model = arch.RRDBNet( | |
in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus | |
) | |
model.load_state_dict(state_dict) | |
model.eval() | |
model.to("cuda") | |
img = upscale(model, img, self.tile_pad, self.tile) | |
return img, None | |