3587jjh commited on
Commit
c3c4871
1 Parent(s): a7a99f6

Update demo.py

Browse files
Files changed (1) hide show
  1. demo.py +14 -16
demo.py CHANGED
@@ -1,24 +1,17 @@
1
  import torch
2
- import models
 
3
  from torchvision import transforms
4
  from utils import *
5
  from PIL import Image
6
  import numpy as np
7
 
8
- img_path = 'myimage.png' # only support .png
9
- scale = 4 # only support x4
 
10
 
11
- '''
12
- k: hyperparameter to traverse PSNR-FLOPs trade-off. smaller k → larger FLOPs & PSNR. range is about [-1,2].
13
- adaptive: whether to use automatic decision of k
14
- no_refinement: whether not to use pixel-wise refinement (postprocessing for reducing artifacts)
15
- parser.add_argument('--opacity', type=float, default=0.65, help='opacity for colored visualization')
16
- parser.add_argument('--pixel_batch_size', type=int, default=300000)
17
- '''
18
-
19
- resume_path = 'carn-pcsr-phase1.pth'
20
- sv_file = torch.load(resume_path)
21
- model = models.make(sv_file['model'], load_sd=True).cuda()
22
  model.eval()
23
 
24
  rgb_mean = torch.tensor([0.4488, 0.4371, 0.4040], device='cuda').view(1,3,1,1)
@@ -35,6 +28,11 @@ with torch.no_grad():
35
  cell[:,:,1] *= 2/W
36
  inp_lr = (lr - rgb_mean) / rgb_std
37
 
 
 
 
 
 
38
  pred, flag = model(inp_lr, coord=coord, cell=cell, scale=scale, k=0,
39
  pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
40
  flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=0,
@@ -47,7 +45,7 @@ with torch.no_grad():
47
  pred = pred.transpose(1,2).view(-1,3,H,W)
48
  pred = pred * rgb_std + rgb_mean
49
  pred = tensor2numpy(pred)
50
- Image.fromarray(pred).save(f'output.png')
51
 
52
  flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
53
  H,W = pred.shape[:2]
@@ -55,4 +53,4 @@ vis_img = np.zeros_like(pred)
55
  vis_img[flag[0] == 0] = np.array([0,255,0])
56
  vis_img[flag[0] == 1] = np.array([255,0,0])
57
  vis_img = vis_img*0.35 + pred*0.65
58
- Image.fromarray(vis_img.astype('uint8')).save('output_vis.png')
 
1
  import torch
2
+ from models.pcsr import PCSR
3
+
4
  from torchvision import transforms
5
  from utils import *
6
  from PIL import Image
7
  import numpy as np
8
 
9
+ ##############################################
10
+ img_path = 'sample/comic.png' # your LR input image path, only support .png
11
+ ##############################################
12
 
13
+ scale = 4 # only support x4
14
+ model = PCSR.from_pretrained("3587jjh/pcsr_carn").cuda()
 
 
 
 
 
 
 
 
 
15
  model.eval()
16
 
17
  rgb_mean = torch.tensor([0.4488, 0.4371, 0.4040], device='cuda').view(1,3,1,1)
 
28
  cell[:,:,1] *= 2/W
29
  inp_lr = (lr - rgb_mean) / rgb_std
30
 
31
+ '''
32
+ k: hyperparameter to traverse PSNR-FLOPs trade-off. smaller k → larger FLOPs & PSNR. range is about [-1,2].
33
+ adaptive: whether to use automatic decision of k
34
+ refinement: whether to use pixel-wise refinement (postprocessing for reducing artifacts)
35
+ '''
36
  pred, flag = model(inp_lr, coord=coord, cell=cell, scale=scale, k=0,
37
  pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
38
  flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=0,
 
45
  pred = pred.transpose(1,2).view(-1,3,H,W)
46
  pred = pred * rgb_std + rgb_mean
47
  pred = tensor2numpy(pred)
48
+ Image.fromarray(pred).save(f'sample/output.png')
49
 
50
  flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
51
  H,W = pred.shape[:2]
 
53
  vis_img[flag[0] == 0] = np.array([0,255,0])
54
  vis_img[flag[0] == 1] = np.array([255,0,0])
55
  vis_img = vis_img*0.35 + pred*0.65
56
+ Image.fromarray(vis_img.astype('uint8')).save('sample/output_vis.png')