jayparmr's picture
Upload folder using huggingface_hub
22df957 verified
raw
history blame
4.43 kB
import math
import numpy as np
import torch
def infer_params(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
scale2x = 0
scalemin = 6
n_uplayer = 0
plus = False
for block in list(state_dict):
parts = block.split(".")
n_parts = len(parts)
if n_parts == 5 and parts[2] == "sub":
nb = int(parts[3])
elif n_parts == 3:
part_num = int(parts[1])
if part_num > scalemin and parts[0] == "model" and parts[2] == "weight":
scale2x += 1
if part_num > n_uplayer:
n_uplayer = part_num
out_nc = state_dict[block].shape[0]
if not plus and "conv1x1" in block:
plus = True
nf = state_dict["model.0.weight"].shape[0]
in_nc = state_dict["model.0.weight"].shape[1]
out_nc = out_nc
scale = 2**scale2x
return in_nc, out_nc, nf, nb, plus, scale
def tile_process(model, img, tile_pad, tile_size, scale=4):
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = img.shape
output_height = height * scale
output_width = width * scale
output_shape = (batch, channel, output_height, output_width)
# start with black image
output = img.new_zeros(output_shape)
tiles_x = math.ceil(width / tile_size)
tiles_y = math.ceil(height / tile_size)
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * tile_size
ofs_y = y * tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - tile_pad, 0)
input_end_x_pad = min(input_end_x + tile_pad, width)
input_start_y_pad = max(input_start_y - tile_pad, 0)
input_end_y_pad = min(input_end_y + tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = img[
:,
:,
input_start_y_pad:input_end_y_pad,
input_start_x_pad:input_end_x_pad,
]
# upscale tile
try:
with torch.no_grad():
output_tile = model(input_tile)
except RuntimeError as error:
print("Error", error)
print(f"\tTile {tile_idx}/{tiles_x * tiles_y}")
# output tile area on total image
output_start_x = input_start_x * scale
output_end_x = input_end_x * scale
output_start_y = input_start_y * scale
output_end_y = input_end_y * scale
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * scale
output_end_x_tile = output_start_x_tile + input_tile_width * scale
output_start_y_tile = (input_start_y - input_start_y_pad) * scale
output_end_y_tile = output_start_y_tile + input_tile_height * scale
# put tile into output image
output[
:, :, output_start_y:output_end_y, output_start_x:output_end_x
] = output_tile[
:,
:,
output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile,
]
return output
def upscale(model, img, tile_pad, tile_size):
img = np.array(img)
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to("cuda")
output = tile_process(model, img, tile_pad, tile_size, scale=4)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255.0 * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
return output