File size: 2,587 Bytes
61522a1
c3c4871
2349b6e
c3c4871
61522a1
 
 
 
2349b6e
61522a1
2349b6e
 
 
 
 
61522a1
c3c4871
 
61522a1
 
 
 
 
 
 
2349b6e
61522a1
 
 
 
 
 
 
 
c3c4871
 
 
 
 
61522a1
 
 
 
 
 
 
 
 
 
 
 
2349b6e
61522a1
 
 
 
 
 
 
2349b6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from models.pcsr import PCSR
import argparse

from torchvision import transforms
from utils import *
from PIL import Image
import numpy as np
import os

parser = argparse.ArgumentParser(description="PCSR Super-Resolution with Input and Output Paths")
parser.add_argument('--lr_path', type=str, default='comic.png', help='Path to the input LR image (.png format only)')
parser.add_argument('--output_path', type=str, default='results', help='Path to save the outputs')
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)

scale = 4 # only support x4
model = PCSR.from_pretrained("3587jjh/pcsr_carn").cuda()
model.eval()

rgb_mean = torch.tensor([0.4488, 0.4371, 0.4040], device='cuda').view(1,3,1,1)
rgb_std = torch.tensor([1.0, 1.0, 1.0], device='cuda').view(1,3,1,1)

with torch.no_grad():
    # prepare inputs
    lr = transforms.ToTensor()(Image.open(args.lr_path)).unsqueeze(0).cuda() # (1,3,h,w), range=[0,1]
    h,w = lr.shape[-2:]
    H,W = h*scale, w*scale
    coord = make_coord((H,W), flatten=True, device='cuda').unsqueeze(0)
    cell = torch.ones_like(coord)
    cell[:,:,0] *= 2/H
    cell[:,:,1] *= 2/W
    inp_lr = (lr - rgb_mean) / rgb_std

    '''
    k: hyperparameter to traverse PSNR-FLOPs trade-off. smaller k → larger FLOPs & PSNR. range is about [-1,2].
    adaptive: whether to use automatic decision of k
    refinement: whether to use pixel-wise refinement (postprocessing for reducing artifacts)
    '''
    pred, flag = model(inp_lr, coord=coord, cell=cell, scale=scale, k=0, 
        pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
    flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=0, 
        pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
    max_flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=-25, 
        pixel_batch_size=300000, adaptive_cluster=False, refinement=True)
    print('flops: {:.1f}G ({:.1f} %) | max_flops: {:.1f}G (100 %)'.format(flops/1e9, 
        (flops / max_flops)*100, max_flops/1e9))

pred = pred.transpose(1,2).view(-1,3,H,W)
pred = pred * rgb_std + rgb_mean
pred = tensor2numpy(pred)
Image.fromarray(pred).save(os.path.join(args.output_path, 'output.png'))

flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
H,W = pred.shape[:2]
vis_img = np.zeros_like(pred)
vis_img[flag[0] == 0] = np.array([0,255,0])
vis_img[flag[0] == 1] = np.array([255,0,0])
vis_img = vis_img*0.35 + pred*0.65
Image.fromarray(vis_img.astype('uint8')).save(os.path.join(args.output_path, 'output_vis.png'))