3587jjh commited on
Commit
2349b6e
·
verified ·
1 Parent(s): 5bf8222

Update demo.py

Browse files
Files changed (1) hide show
  1. demo.py +10 -6
demo.py CHANGED
@@ -1,14 +1,18 @@
1
  import torch
2
  from models.pcsr import PCSR
 
3
 
4
  from torchvision import transforms
5
  from utils import *
6
  from PIL import Image
7
  import numpy as np
 
8
 
9
- ##############################################
10
- img_path = 'sample/comic.png' # your LR input image path, only support .png
11
- ##############################################
 
 
12
 
13
  scale = 4 # only support x4
14
  model = PCSR.from_pretrained("3587jjh/pcsr_carn").cuda()
@@ -19,7 +23,7 @@ rgb_std = torch.tensor([1.0, 1.0, 1.0], device='cuda').view(1,3,1,1)
19
 
20
  with torch.no_grad():
21
  # prepare inputs
22
- lr = transforms.ToTensor()(Image.open(img_path)).unsqueeze(0).cuda() # (1,3,h,w), range=[0,1]
23
  h,w = lr.shape[-2:]
24
  H,W = h*scale, w*scale
25
  coord = make_coord((H,W), flatten=True, device='cuda').unsqueeze(0)
@@ -45,7 +49,7 @@ with torch.no_grad():
45
  pred = pred.transpose(1,2).view(-1,3,H,W)
46
  pred = pred * rgb_std + rgb_mean
47
  pred = tensor2numpy(pred)
48
- Image.fromarray(pred).save(f'sample/output.png')
49
 
50
  flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
51
  H,W = pred.shape[:2]
@@ -53,4 +57,4 @@ vis_img = np.zeros_like(pred)
53
  vis_img[flag[0] == 0] = np.array([0,255,0])
54
  vis_img[flag[0] == 1] = np.array([255,0,0])
55
  vis_img = vis_img*0.35 + pred*0.65
56
- Image.fromarray(vis_img.astype('uint8')).save('sample/output_vis.png')
 
1
  import torch
2
  from models.pcsr import PCSR
3
+ import argparse
4
 
5
  from torchvision import transforms
6
  from utils import *
7
  from PIL import Image
8
  import numpy as np
9
+ import os
10
 
11
+ parser = argparse.ArgumentParser(description="PCSR Super-Resolution with Input and Output Paths")
12
+ parser.add_argument('--lr_path', type=str, default='comic.png', help='Path to the input LR image (.png format only)')
13
+ parser.add_argument('--output_path', type=str, default='results', help='Path to save the outputs')
14
+ args = parser.parse_args()
15
+ os.makedirs(args.output_path, exist_ok=True)
16
 
17
  scale = 4 # only support x4
18
  model = PCSR.from_pretrained("3587jjh/pcsr_carn").cuda()
 
23
 
24
  with torch.no_grad():
25
  # prepare inputs
26
+ lr = transforms.ToTensor()(Image.open(args.lr_path)).unsqueeze(0).cuda() # (1,3,h,w), range=[0,1]
27
  h,w = lr.shape[-2:]
28
  H,W = h*scale, w*scale
29
  coord = make_coord((H,W), flatten=True, device='cuda').unsqueeze(0)
 
49
  pred = pred.transpose(1,2).view(-1,3,H,W)
50
  pred = pred * rgb_std + rgb_mean
51
  pred = tensor2numpy(pred)
52
+ Image.fromarray(pred).save(os.path.join(args.output_path, 'output.png'))
53
 
54
  flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
55
  H,W = pred.shape[:2]
 
57
  vis_img[flag[0] == 0] = np.array([0,255,0])
58
  vis_img[flag[0] == 1] = np.array([255,0,0])
59
  vis_img = vis_img*0.35 + pred*0.65
60
+ Image.fromarray(vis_img.astype('uint8')).save(os.path.join(args.output_path, 'output_vis.png'))