import numpy as np import torch def infer_params(state_dict): # this code is copied from https://github.com/victorca25/iNNfer scale2x = 0 scalemin = 6 n_uplayer = 0 plus = False for block in list(state_dict): parts = block.split(".") n_parts = len(parts) if n_parts == 5 and parts[2] == "sub": nb = int(parts[3]) elif n_parts == 3: part_num = int(parts[1]) if part_num > scalemin and parts[0] == "model" and parts[2] == "weight": scale2x += 1 if part_num > n_uplayer: n_uplayer = part_num out_nc = state_dict[block].shape[0] if not plus and "conv1x1" in block: plus = True nf = state_dict["model.0.weight"].shape[0] in_nc = state_dict["model.0.weight"].shape[1] out_nc = out_nc scale = 2**scale2x return in_nc, out_nc, nf, nb, plus, scale def upscale_without_tiling(model, img): img = np.array(img) img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() img = img.unsqueeze(0).to("cuda") with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = 255.0 * np.moveaxis(output, 0, 2) output = output.astype(np.uint8) output = output[:, :, ::-1] return output