Update demo.py
Browse files
demo.py
CHANGED
@@ -1,24 +1,17 @@
|
|
1 |
import torch
|
2 |
-
import
|
|
|
3 |
from torchvision import transforms
|
4 |
from utils import *
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
|
8 |
-
|
9 |
-
|
|
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
adaptive: whether to use automatic decision of k
|
14 |
-
no_refinement: whether not to use pixel-wise refinement (postprocessing for reducing artifacts)
|
15 |
-
parser.add_argument('--opacity', type=float, default=0.65, help='opacity for colored visualization')
|
16 |
-
parser.add_argument('--pixel_batch_size', type=int, default=300000)
|
17 |
-
'''
|
18 |
-
|
19 |
-
resume_path = 'carn-pcsr-phase1.pth'
|
20 |
-
sv_file = torch.load(resume_path)
|
21 |
-
model = models.make(sv_file['model'], load_sd=True).cuda()
|
22 |
model.eval()
|
23 |
|
24 |
rgb_mean = torch.tensor([0.4488, 0.4371, 0.4040], device='cuda').view(1,3,1,1)
|
@@ -35,6 +28,11 @@ with torch.no_grad():
|
|
35 |
cell[:,:,1] *= 2/W
|
36 |
inp_lr = (lr - rgb_mean) / rgb_std
|
37 |
|
|
|
|
|
|
|
|
|
|
|
38 |
pred, flag = model(inp_lr, coord=coord, cell=cell, scale=scale, k=0,
|
39 |
pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
|
40 |
flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=0,
|
@@ -47,7 +45,7 @@ with torch.no_grad():
|
|
47 |
pred = pred.transpose(1,2).view(-1,3,H,W)
|
48 |
pred = pred * rgb_std + rgb_mean
|
49 |
pred = tensor2numpy(pred)
|
50 |
-
Image.fromarray(pred).save(f'output.png')
|
51 |
|
52 |
flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
|
53 |
H,W = pred.shape[:2]
|
@@ -55,4 +53,4 @@ vis_img = np.zeros_like(pred)
|
|
55 |
vis_img[flag[0] == 0] = np.array([0,255,0])
|
56 |
vis_img[flag[0] == 1] = np.array([255,0,0])
|
57 |
vis_img = vis_img*0.35 + pred*0.65
|
58 |
-
Image.fromarray(vis_img.astype('uint8')).save('output_vis.png')
|
|
|
1 |
import torch
|
2 |
+
from models.pcsr import PCSR
|
3 |
+
|
4 |
from torchvision import transforms
|
5 |
from utils import *
|
6 |
from PIL import Image
|
7 |
import numpy as np
|
8 |
|
9 |
+
##############################################
|
10 |
+
img_path = 'sample/comic.png' # your LR input image path, only support .png
|
11 |
+
##############################################
|
12 |
|
13 |
+
scale = 4 # only support x4
|
14 |
+
model = PCSR.from_pretrained("3587jjh/pcsr_carn").cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
model.eval()
|
16 |
|
17 |
rgb_mean = torch.tensor([0.4488, 0.4371, 0.4040], device='cuda').view(1,3,1,1)
|
|
|
28 |
cell[:,:,1] *= 2/W
|
29 |
inp_lr = (lr - rgb_mean) / rgb_std
|
30 |
|
31 |
+
'''
|
32 |
+
k: hyperparameter to traverse PSNR-FLOPs trade-off. smaller k → larger FLOPs & PSNR. range is about [-1,2].
|
33 |
+
adaptive: whether to use automatic decision of k
|
34 |
+
refinement: whether to use pixel-wise refinement (postprocessing for reducing artifacts)
|
35 |
+
'''
|
36 |
pred, flag = model(inp_lr, coord=coord, cell=cell, scale=scale, k=0,
|
37 |
pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
|
38 |
flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=0,
|
|
|
45 |
pred = pred.transpose(1,2).view(-1,3,H,W)
|
46 |
pred = pred * rgb_std + rgb_mean
|
47 |
pred = tensor2numpy(pred)
|
48 |
+
Image.fromarray(pred).save(f'sample/output.png')
|
49 |
|
50 |
flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
|
51 |
H,W = pred.shape[:2]
|
|
|
53 |
vis_img[flag[0] == 0] = np.array([0,255,0])
|
54 |
vis_img[flag[0] == 1] = np.array([255,0,0])
|
55 |
vis_img = vis_img*0.35 + pred*0.65
|
56 |
+
Image.fromarray(vis_img.astype('uint8')).save('sample/output_vis.png')
|