import math import numpy as np import torch def infer_params(state_dict): # this code is copied from https://github.com/victorca25/iNNfer scale2x = 0 scalemin = 6 n_uplayer = 0 plus = False for block in list(state_dict): parts = block.split(".") n_parts = len(parts) if n_parts == 5 and parts[2] == "sub": nb = int(parts[3]) elif n_parts == 3: part_num = int(parts[1]) if part_num > scalemin and parts[0] == "model" and parts[2] == "weight": scale2x += 1 if part_num > n_uplayer: n_uplayer = part_num out_nc = state_dict[block].shape[0] if not plus and "conv1x1" in block: plus = True nf = state_dict["model.0.weight"].shape[0] in_nc = state_dict["model.0.weight"].shape[1] out_nc = out_nc scale = 2**scale2x return in_nc, out_nc, nf, nb, plus, scale def tile_process(model, img, tile_pad, tile_size, scale=4): """It will first crop input images to tiles, and then process each tile. Finally, all the processed tiles are merged into one images. Modified from: https://github.com/ata4/esrgan-launcher """ batch, channel, height, width = img.shape output_height = height * scale output_width = width * scale output_shape = (batch, channel, output_height, output_width) # start with black image output = img.new_zeros(output_shape) tiles_x = math.ceil(width / tile_size) tiles_y = math.ceil(height / tile_size) # loop over all tiles for y in range(tiles_y): for x in range(tiles_x): # extract tile from input image ofs_x = x * tile_size ofs_y = y * tile_size # input tile area on total image input_start_x = ofs_x input_end_x = min(ofs_x + tile_size, width) input_start_y = ofs_y input_end_y = min(ofs_y + tile_size, height) # input tile area on total image with padding input_start_x_pad = max(input_start_x - tile_pad, 0) input_end_x_pad = min(input_end_x + tile_pad, width) input_start_y_pad = max(input_start_y - tile_pad, 0) input_end_y_pad = min(input_end_y + tile_pad, height) # input tile dimensions input_tile_width = input_end_x - input_start_x input_tile_height = input_end_y - input_start_y tile_idx = y * tiles_x + x + 1 input_tile = img[ :, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad, ] # upscale tile try: with torch.no_grad(): output_tile = model(input_tile) except RuntimeError as error: print("Error", error) print(f"\tTile {tile_idx}/{tiles_x * tiles_y}") # output tile area on total image output_start_x = input_start_x * scale output_end_x = input_end_x * scale output_start_y = input_start_y * scale output_end_y = input_end_y * scale # output tile area without padding output_start_x_tile = (input_start_x - input_start_x_pad) * scale output_end_x_tile = output_start_x_tile + input_tile_width * scale output_start_y_tile = (input_start_y - input_start_y_pad) * scale output_end_y_tile = output_start_y_tile + input_tile_height * scale # put tile into output image output[ :, :, output_start_y:output_end_y, output_start_x:output_end_x ] = output_tile[ :, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile, ] return output def upscale(model, img, tile_pad, tile_size): img = np.array(img) img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() img = img.unsqueeze(0).to("cuda") output = tile_process(model, img, tile_pad, tile_size, scale=4) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = 255.0 * np.moveaxis(output, 0, 2) output = output.astype(np.uint8) output = output[:, :, ::-1] return output