|
import math |
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
def infer_params(state_dict): |
|
|
|
scale2x = 0 |
|
scalemin = 6 |
|
n_uplayer = 0 |
|
plus = False |
|
|
|
for block in list(state_dict): |
|
parts = block.split(".") |
|
n_parts = len(parts) |
|
if n_parts == 5 and parts[2] == "sub": |
|
nb = int(parts[3]) |
|
elif n_parts == 3: |
|
part_num = int(parts[1]) |
|
if part_num > scalemin and parts[0] == "model" and parts[2] == "weight": |
|
scale2x += 1 |
|
if part_num > n_uplayer: |
|
n_uplayer = part_num |
|
out_nc = state_dict[block].shape[0] |
|
if not plus and "conv1x1" in block: |
|
plus = True |
|
|
|
nf = state_dict["model.0.weight"].shape[0] |
|
in_nc = state_dict["model.0.weight"].shape[1] |
|
out_nc = out_nc |
|
scale = 2**scale2x |
|
|
|
return in_nc, out_nc, nf, nb, plus, scale |
|
|
|
|
|
def tile_process(model, img, tile_pad, tile_size, scale=4): |
|
"""It will first crop input images to tiles, and then process each tile. |
|
Finally, all the processed tiles are merged into one images. |
|
|
|
Modified from: https://github.com/ata4/esrgan-launcher |
|
""" |
|
batch, channel, height, width = img.shape |
|
output_height = height * scale |
|
output_width = width * scale |
|
output_shape = (batch, channel, output_height, output_width) |
|
|
|
|
|
output = img.new_zeros(output_shape) |
|
tiles_x = math.ceil(width / tile_size) |
|
tiles_y = math.ceil(height / tile_size) |
|
|
|
|
|
for y in range(tiles_y): |
|
for x in range(tiles_x): |
|
|
|
ofs_x = x * tile_size |
|
ofs_y = y * tile_size |
|
|
|
input_start_x = ofs_x |
|
input_end_x = min(ofs_x + tile_size, width) |
|
input_start_y = ofs_y |
|
input_end_y = min(ofs_y + tile_size, height) |
|
|
|
|
|
input_start_x_pad = max(input_start_x - tile_pad, 0) |
|
input_end_x_pad = min(input_end_x + tile_pad, width) |
|
input_start_y_pad = max(input_start_y - tile_pad, 0) |
|
input_end_y_pad = min(input_end_y + tile_pad, height) |
|
|
|
|
|
input_tile_width = input_end_x - input_start_x |
|
input_tile_height = input_end_y - input_start_y |
|
tile_idx = y * tiles_x + x + 1 |
|
input_tile = img[ |
|
:, |
|
:, |
|
input_start_y_pad:input_end_y_pad, |
|
input_start_x_pad:input_end_x_pad, |
|
] |
|
|
|
|
|
try: |
|
with torch.no_grad(): |
|
output_tile = model(input_tile) |
|
except RuntimeError as error: |
|
print("Error", error) |
|
print(f"\tTile {tile_idx}/{tiles_x * tiles_y}") |
|
|
|
|
|
output_start_x = input_start_x * scale |
|
output_end_x = input_end_x * scale |
|
output_start_y = input_start_y * scale |
|
output_end_y = input_end_y * scale |
|
|
|
|
|
output_start_x_tile = (input_start_x - input_start_x_pad) * scale |
|
output_end_x_tile = output_start_x_tile + input_tile_width * scale |
|
output_start_y_tile = (input_start_y - input_start_y_pad) * scale |
|
output_end_y_tile = output_start_y_tile + input_tile_height * scale |
|
|
|
|
|
output[ |
|
:, :, output_start_y:output_end_y, output_start_x:output_end_x |
|
] = output_tile[ |
|
:, |
|
:, |
|
output_start_y_tile:output_end_y_tile, |
|
output_start_x_tile:output_end_x_tile, |
|
] |
|
|
|
return output |
|
|
|
|
|
def upscale(model, img, tile_pad, tile_size): |
|
img = np.array(img) |
|
img = img[:, :, ::-1] |
|
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 |
|
img = torch.from_numpy(img).float() |
|
img = img.unsqueeze(0).to("cuda") |
|
|
|
output = tile_process(model, img, tile_pad, tile_size, scale=4) |
|
|
|
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() |
|
output = 255.0 * np.moveaxis(output, 0, 2) |
|
output = output.astype(np.uint8) |
|
output = output[:, :, ::-1] |
|
return output |
|
|