Update demo.py
Browse files
demo.py
CHANGED
@@ -1,14 +1,18 @@
|
|
1 |
import torch
|
2 |
from models.pcsr import PCSR
|
|
|
3 |
|
4 |
from torchvision import transforms
|
5 |
from utils import *
|
6 |
from PIL import Image
|
7 |
import numpy as np
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
12 |
|
13 |
scale = 4 # only support x4
|
14 |
model = PCSR.from_pretrained("3587jjh/pcsr_carn").cuda()
|
@@ -19,7 +23,7 @@ rgb_std = torch.tensor([1.0, 1.0, 1.0], device='cuda').view(1,3,1,1)
|
|
19 |
|
20 |
with torch.no_grad():
|
21 |
# prepare inputs
|
22 |
-
lr = transforms.ToTensor()(Image.open(
|
23 |
h,w = lr.shape[-2:]
|
24 |
H,W = h*scale, w*scale
|
25 |
coord = make_coord((H,W), flatten=True, device='cuda').unsqueeze(0)
|
@@ -45,7 +49,7 @@ with torch.no_grad():
|
|
45 |
pred = pred.transpose(1,2).view(-1,3,H,W)
|
46 |
pred = pred * rgb_std + rgb_mean
|
47 |
pred = tensor2numpy(pred)
|
48 |
-
Image.fromarray(pred).save(
|
49 |
|
50 |
flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
|
51 |
H,W = pred.shape[:2]
|
@@ -53,4 +57,4 @@ vis_img = np.zeros_like(pred)
|
|
53 |
vis_img[flag[0] == 0] = np.array([0,255,0])
|
54 |
vis_img[flag[0] == 1] = np.array([255,0,0])
|
55 |
vis_img = vis_img*0.35 + pred*0.65
|
56 |
-
Image.fromarray(vis_img.astype('uint8')).save('
|
|
|
1 |
import torch
|
2 |
from models.pcsr import PCSR
|
3 |
+
import argparse
|
4 |
|
5 |
from torchvision import transforms
|
6 |
from utils import *
|
7 |
from PIL import Image
|
8 |
import numpy as np
|
9 |
+
import os
|
10 |
|
11 |
+
parser = argparse.ArgumentParser(description="PCSR Super-Resolution with Input and Output Paths")
|
12 |
+
parser.add_argument('--lr_path', type=str, default='comic.png', help='Path to the input LR image (.png format only)')
|
13 |
+
parser.add_argument('--output_path', type=str, default='results', help='Path to save the outputs')
|
14 |
+
args = parser.parse_args()
|
15 |
+
os.makedirs(args.output_path, exist_ok=True)
|
16 |
|
17 |
scale = 4 # only support x4
|
18 |
model = PCSR.from_pretrained("3587jjh/pcsr_carn").cuda()
|
|
|
23 |
|
24 |
with torch.no_grad():
|
25 |
# prepare inputs
|
26 |
+
lr = transforms.ToTensor()(Image.open(args.lr_path)).unsqueeze(0).cuda() # (1,3,h,w), range=[0,1]
|
27 |
h,w = lr.shape[-2:]
|
28 |
H,W = h*scale, w*scale
|
29 |
coord = make_coord((H,W), flatten=True, device='cuda').unsqueeze(0)
|
|
|
49 |
pred = pred.transpose(1,2).view(-1,3,H,W)
|
50 |
pred = pred * rgb_std + rgb_mean
|
51 |
pred = tensor2numpy(pred)
|
52 |
+
Image.fromarray(pred).save(os.path.join(args.output_path, 'output.png'))
|
53 |
|
54 |
flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
|
55 |
H,W = pred.shape[:2]
|
|
|
57 |
vis_img[flag[0] == 0] = np.array([0,255,0])
|
58 |
vis_img[flag[0] == 1] = np.array([255,0,0])
|
59 |
vis_img = vis_img*0.35 + pred*0.65
|
60 |
+
Image.fromarray(vis_img.astype('uint8')).save(os.path.join(args.output_path, 'output_vis.png'))
|