File size: 779 Bytes
1bc457e
 
 
 
 
22df957
1bc457e
 
 
22df957
 
 
 
1bc457e
 
 
 
 
 
 
 
 
 
 
 
 
 
22df957
1bc457e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import List

import torch

import models.ultrasharp.arch as arch
from models.ultrasharp.util import infer_params, upscale


class Ultrasharp:
    def __init__(self, model_path, tile_pad=0, tile=0):
        self.filename = model_path
        self.tile_pad = tile_pad
        self.tile = tile

    def enhance(self, img, outscale=4):
        state_dict = torch.load(self.filename, map_location="cpu")

        in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)

        model = arch.RRDBNet(
            in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus
        )
        model.load_state_dict(state_dict)
        model.eval()

        model.to("cuda")

        img = upscale(model, img, self.tile_pad, self.tile)
        return img, None