jayparmr's picture
Upload folder using huggingface_hub
22df957 verified
raw
history blame
779 Bytes
from typing import List
import torch
import models.ultrasharp.arch as arch
from models.ultrasharp.util import infer_params, upscale
class Ultrasharp:
def __init__(self, model_path, tile_pad=0, tile=0):
self.filename = model_path
self.tile_pad = tile_pad
self.tile = tile
def enhance(self, img, outscale=4):
state_dict = torch.load(self.filename, map_location="cpu")
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
model = arch.RRDBNet(
in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus
)
model.load_state_dict(state_dict)
model.eval()
model.to("cuda")
img = upscale(model, img, self.tile_pad, self.tile)
return img, None