hail75 commited on
Commit
473a424
·
1 Parent(s): 45b47cc

add divide param

Browse files
models/SRFlow/code/__init__.py CHANGED
@@ -35,7 +35,12 @@ def predict(model, lr):
35
  return visuals.get('rlt', visuals.get("SR"))
36
 
37
 
38
- def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255
 
 
 
 
 
39
 
40
 
41
  def rgb(t): return (
 
35
  return visuals.get('rlt', visuals.get("SR"))
36
 
37
 
38
+ def t(array, divide):
39
+ output = torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32))
40
+ if divide:
41
+ return output / 255
42
+ else:
43
+ return output
44
 
45
 
46
  def rgb(t): return (
models/SRFlow/srflow.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image
7
  import matplotlib.pyplot as plt
8
  from torchvision.transforms import PILToTensor, ToPILImage
9
 
10
- def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml'):
11
  """
12
  Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
13
 
@@ -29,7 +29,7 @@ def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.
29
  lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
30
  right=int(np.ceil(w / pad_factor) * pad_factor - w))
31
 
32
- lr_t = t(lr)
33
  heat = opt['heat']
34
 
35
  sr_t = model.get_sr(lq=lr_t, heat=heat)
@@ -40,7 +40,7 @@ def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.
40
  sr = Image.fromarray((sr).astype('uint8'))
41
  return sr
42
 
43
- def return_SRFlow_result_from_tensor(lr_tensor):
44
  """
45
  Apply Super-Resolution using SRFlow model to the input batched BCHW tensor.
46
 
@@ -55,7 +55,7 @@ def return_SRFlow_result_from_tensor(lr_tensor):
55
 
56
  for b in range(batch_size):
57
  lr_image = ToPILImage()(lr_tensor[b])
58
- sr_image = return_SRFlow_result(lr_image)
59
  sr_tensor = PILToTensor()(sr_image).unsqueeze(0)
60
  sr_list.append(sr_tensor)
61
 
 
7
  import matplotlib.pyplot as plt
8
  from torchvision.transforms import PILToTensor, ToPILImage
9
 
10
+ def return_SRFlow_result(lr, divide, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml'):
11
  """
12
  Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
13
 
 
29
  lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
30
  right=int(np.ceil(w / pad_factor) * pad_factor - w))
31
 
32
+ lr_t = t(lr, divide)
33
  heat = opt['heat']
34
 
35
  sr_t = model.get_sr(lq=lr_t, heat=heat)
 
40
  sr = Image.fromarray((sr).astype('uint8'))
41
  return sr
42
 
43
+ def return_SRFlow_result_from_tensor(lr_tensor, divide=True):
44
  """
45
  Apply Super-Resolution using SRFlow model to the input batched BCHW tensor.
46
 
 
55
 
56
  for b in range(batch_size):
57
  lr_image = ToPILImage()(lr_tensor[b])
58
+ sr_image = return_SRFlow_result(lr_image, divide)
59
  sr_tensor = PILToTensor()(sr_image).unsqueeze(0)
60
  sr_list.append(sr_tensor)
61