hail75 commited on
Commit
405a22d
·
1 Parent(s): d09509f

add new func

Browse files
Files changed (1) hide show
  1. models/SRFlow/srflow.py +27 -32
models/SRFlow/srflow.py CHANGED
@@ -4,9 +4,10 @@ import sys
4
  sys.path.append('models')
5
  from SRFlow.code import imread, impad, load_model, t, rgb
6
  from PIL import Image
7
- from torchvision.transforms import PILToTensor
 
8
 
9
- def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.7):
10
  """
11
  Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
12
 
@@ -39,43 +40,37 @@ def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.
39
  sr = Image.fromarray((sr).astype('uint8'))
40
  return sr
41
 
42
- def test_srflow(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.7):
43
  """
44
- Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
45
 
46
  Args:
47
- - lr: tensor
48
- - conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
49
- - heat (float): Heat parameter for the SRFlow model. Default is 0.6.
50
 
51
  Returns:
52
- - sr: tensor
53
  """
54
- model, opt = load_model(conf_path)
 
55
 
56
- scale = opt['scale']
57
- pad_factor = 2
58
-
59
- lr = lr.squeeze(0).permute(0, 1, 2).numpy()
60
- h, w, c = lr.shape
61
- lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
62
- right=int(np.ceil(w / pad_factor) * pad_factor - w))
63
-
64
- lr_t = t(lr)
65
- heat = opt['heat']
66
-
67
- sr_t = model.get_sr(lq=lr_t, heat=heat)
68
 
69
- sr = rgb(torch.clamp(sr_t, 0, 1))
70
- sr = sr[:h * scale, :w * scale]
71
- sr = sr.unsqueeze(0).permute(0, 3, 1, 2)
72
-
73
- return sr
74
 
75
  if __name__ == '__main__':
76
- ip = Image.open('images/demo.png')
77
- lr = PILToTensor()(ip).permute(1, 2, 0).numpy()
78
- print(lr.shape)
79
- sr = test_srflow(lr)
80
- # sr = return_SRFlow_result(ip)
81
- print(sr.size)
 
 
 
 
 
 
4
  sys.path.append('models')
5
  from SRFlow.code import imread, impad, load_model, t, rgb
6
  from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ from torchvision.transforms import PILToTensor, ToPILImage
9
 
10
+ def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml'):
11
  """
12
  Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
13
 
 
40
  sr = Image.fromarray((sr).astype('uint8'))
41
  return sr
42
 
43
+ def return_SRFlow_result_from_tensor(lr_tensor):
44
  """
45
+ Apply Super-Resolution using SRFlow model to the input batched BCHW tensor.
46
 
47
  Args:
48
+ - lr_tensor: Batched BCHW tensor
 
 
49
 
50
  Returns:
51
+ - sr_tensor: Processed batched BCHW tensor
52
  """
53
+ batch_size = lr_tensor.shape[0]
54
+ sr_list = []
55
 
56
+ for b in range(batch_size):
57
+ lr_image = ToPILImage()(lr_tensor[b])
58
+ sr_image = return_SRFlow_result(lr_image)
59
+ sr_tensor = PILToTensor()(sr_image).unsqueeze(0)
60
+ sr_list.append(sr_tensor)
 
 
 
 
 
 
 
61
 
62
+ sr_tensor = torch.cat(sr_list, dim=0)
63
+ return sr_tensor
 
 
 
64
 
65
  if __name__ == '__main__':
66
+ lr = Image.open('images/demo.png')
67
+
68
+ lr_tensor = PILToTensor()(lr).unsqueeze(0)
69
+
70
+ sr = return_SRFlow_result_from_tensor(lr_tensor)
71
+ print(sr.shape)
72
+
73
+ plt.imshow(np.transpose(sr[0].cpu().detach().numpy(), (1, 2, 0)))
74
+ plt.axis('off')
75
+ plt.title('Super-Resolved Image')
76
+ plt.show()