Spaces:
Running
Running
File size: 2,285 Bytes
5b83793 95110bc 405a22d 5b83793 473a424 5b83793 95110bc 5b83793 95110bc 5b83793 d89aac0 95110bc 96e29c0 d89aac0 5b83793 d89aac0 5b83793 4b80f4f 95110bc 5b83793 95110bc 473a424 12f4dcf 405a22d 12f4dcf 405a22d 12f4dcf 405a22d 12f4dcf 405a22d 12f4dcf 405a22d 473a424 405a22d dc49371 405a22d 12f4dcf 405a22d dc49371 bc0241b dc49371 405a22d 12f4dcf 95110bc 405a22d 45b47cc 405a22d 45b47cc 405a22d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import numpy as np
import torch
import sys
sys.path.append('models')
from SRFlow.code import imread, impad, load_model, t, rgb
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import PILToTensor, ToPILImage
def return_SRFlow_result(lr, divide, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml'):
"""
Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
Args:
- lr: PIL Image
- conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
- heat (float): Heat parameter for the SRFlow model. Default is 0.6.
Returns:
- sr: PIL Image
"""
model, opt = load_model(conf_path)
lr = PILToTensor()(lr).permute(1, 2, 0).numpy()
if not divide:
lr *= 255
scale = opt['scale']
pad_factor = 2
h, w, c = lr.shape
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
right=int(np.ceil(w / pad_factor) * pad_factor - w))
lr_t = t(lr)
heat = opt['heat']
sr_t = model.get_sr(lq=lr_t, heat=heat)
sr = rgb(torch.clamp(sr_t, 0, 1))
sr = sr[:h * scale, :w * scale]
sr = Image.fromarray((sr).astype('uint8'))
return sr
def return_SRFlow_result_from_tensor(lr_tensor, divide=True):
"""
Apply Super-Resolution using SRFlow model to the input batched BCHW tensor.
Args:
- lr_tensor: Batched BCHW tensor
Returns:
- sr_tensor: Processed batched BCHW tensor
"""
batch_size = lr_tensor.shape[0]
sr_list = []
for b in range(batch_size):
lr_image = ToPILImage()(lr_tensor[b])
sr_image = return_SRFlow_result(lr_image, divide)
sr_tensor = PILToTensor()(sr_image).unsqueeze(0)
sr_list.append(sr_tensor)
sr_tensor = torch.cat(sr_list, dim=0)
if not divide:
sr_tensor /= 255.0
return sr_tensor
if __name__ == '__main__':
lr = Image.open('images/demo.png')
lr_tensor = PILToTensor()(lr).unsqueeze(0)
sr = return_SRFlow_result_from_tensor(lr_tensor)
print(sr.shape)
# Show SR image of the first one in the batch
plt.imshow(np.transpose(sr[0].cpu().detach().numpy(), (1, 2, 0)))
# plt.axis('off')
plt.show() |