|
import numpy as np |
|
import torch |
|
|
|
|
|
def infer_params(state_dict): |
|
|
|
scale2x = 0 |
|
scalemin = 6 |
|
n_uplayer = 0 |
|
plus = False |
|
|
|
for block in list(state_dict): |
|
parts = block.split(".") |
|
n_parts = len(parts) |
|
if n_parts == 5 and parts[2] == "sub": |
|
nb = int(parts[3]) |
|
elif n_parts == 3: |
|
part_num = int(parts[1]) |
|
if part_num > scalemin and parts[0] == "model" and parts[2] == "weight": |
|
scale2x += 1 |
|
if part_num > n_uplayer: |
|
n_uplayer = part_num |
|
out_nc = state_dict[block].shape[0] |
|
if not plus and "conv1x1" in block: |
|
plus = True |
|
|
|
nf = state_dict["model.0.weight"].shape[0] |
|
in_nc = state_dict["model.0.weight"].shape[1] |
|
out_nc = out_nc |
|
scale = 2**scale2x |
|
|
|
return in_nc, out_nc, nf, nb, plus, scale |
|
|
|
|
|
def upscale_without_tiling(model, img): |
|
img = np.array(img) |
|
img = img[:, :, ::-1] |
|
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 |
|
img = torch.from_numpy(img).float() |
|
img = img.unsqueeze(0).to("cuda") |
|
with torch.no_grad(): |
|
output = model(img) |
|
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() |
|
output = 255.0 * np.moveaxis(output, 0, 2) |
|
output = output.astype(np.uint8) |
|
output = output[:, :, ::-1] |
|
return output |
|
|