|
import torch |
|
from models.pcsr import PCSR |
|
import argparse |
|
|
|
from torchvision import transforms |
|
from utils import * |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
|
|
parser = argparse.ArgumentParser(description="PCSR Super-Resolution with Input and Output Paths") |
|
parser.add_argument('--lr_path', type=str, default='comic.png', help='Path to the input LR image (.png format only)') |
|
parser.add_argument('--output_path', type=str, default='results', help='Path to save the outputs') |
|
args = parser.parse_args() |
|
os.makedirs(args.output_path, exist_ok=True) |
|
|
|
scale = 4 |
|
model = PCSR.from_pretrained("3587jjh/pcsr_carn").cuda() |
|
model.eval() |
|
|
|
rgb_mean = torch.tensor([0.4488, 0.4371, 0.4040], device='cuda').view(1,3,1,1) |
|
rgb_std = torch.tensor([1.0, 1.0, 1.0], device='cuda').view(1,3,1,1) |
|
|
|
with torch.no_grad(): |
|
|
|
lr = transforms.ToTensor()(Image.open(args.lr_path)).unsqueeze(0).cuda() |
|
h,w = lr.shape[-2:] |
|
H,W = h*scale, w*scale |
|
coord = make_coord((H,W), flatten=True, device='cuda').unsqueeze(0) |
|
cell = torch.ones_like(coord) |
|
cell[:,:,0] *= 2/H |
|
cell[:,:,1] *= 2/W |
|
inp_lr = (lr - rgb_mean) / rgb_std |
|
|
|
''' |
|
k: hyperparameter to traverse PSNR-FLOPs trade-off. smaller k → larger FLOPs & PSNR. range is about [-1,2]. |
|
adaptive: whether to use automatic decision of k |
|
refinement: whether to use pixel-wise refinement (postprocessing for reducing artifacts) |
|
''' |
|
pred, flag = model(inp_lr, coord=coord, cell=cell, scale=scale, k=0, |
|
pixel_batch_size=300000, adaptive_cluster=True, refinement=True) |
|
flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=0, |
|
pixel_batch_size=300000, adaptive_cluster=True, refinement=True) |
|
max_flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=-25, |
|
pixel_batch_size=300000, adaptive_cluster=False, refinement=True) |
|
print('flops: {:.1f}G ({:.1f} %) | max_flops: {:.1f}G (100 %)'.format(flops/1e9, |
|
(flops / max_flops)*100, max_flops/1e9)) |
|
|
|
pred = pred.transpose(1,2).view(-1,3,H,W) |
|
pred = pred * rgb_std + rgb_mean |
|
pred = tensor2numpy(pred) |
|
Image.fromarray(pred).save(os.path.join(args.output_path, 'output.png')) |
|
|
|
flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu() |
|
H,W = pred.shape[:2] |
|
vis_img = np.zeros_like(pred) |
|
vis_img[flag[0] == 0] = np.array([0,255,0]) |
|
vis_img[flag[0] == 1] = np.array([255,0,0]) |
|
vis_img = vis_img*0.35 + pred*0.65 |
|
Image.fromarray(vis_img.astype('uint8')).save(os.path.join(args.output_path, 'output_vis.png')) |