File size: 2,366 Bytes
61522a1 ff270c2 61522a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch
import models
from torchvision import transforms
from utils import *
from PIL import Image
import numpy as np
img_path = 'myimage.png' # only support .png
scale = 4 # only support x4
'''
k: hyperparameter to traverse PSNR-FLOPs trade-off. smaller k → larger FLOPs & PSNR. range is about [-1,2].
adaptive: whether to use automatic decision of k
no_refinement: whether not to use pixel-wise refinement (postprocessing for reducing artifacts)
parser.add_argument('--opacity', type=float, default=0.65, help='opacity for colored visualization')
parser.add_argument('--pixel_batch_size', type=int, default=300000)
'''
resume_path = 'carn-pcsr-phase1.pth'
sv_file = torch.load(resume_path)
model = models.make(sv_file['model'], load_sd=True).cuda()
model.eval()
rgb_mean = torch.tensor([0.4488, 0.4371, 0.4040], device='cuda').view(1,3,1,1)
rgb_std = torch.tensor([1.0, 1.0, 1.0], device='cuda').view(1,3,1,1)
with torch.no_grad():
# prepare inputs
lr = transforms.ToTensor()(Image.open(img_path)).unsqueeze(0).cuda() # (1,3,h,w), range=[0,1]
h,w = lr.shape[-2:]
H,W = h*scale, w*scale
coord = make_coord((H,W), flatten=True, device='cuda').unsqueeze(0)
cell = torch.ones_like(coord)
cell[:,:,0] *= 2/H
cell[:,:,1] *= 2/W
inp_lr = (lr - rgb_mean) / rgb_std
pred, flag = model(inp_lr, coord=coord, cell=cell, scale=scale, k=0,
pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=0,
pixel_batch_size=300000, adaptive_cluster=True, refinement=True)
max_flops = get_model_flops(model, inp_lr, coord=coord, cell=cell, scale=scale, k=-25,
pixel_batch_size=300000, adaptive_cluster=False, refinement=True)
print('flops: {:.1f}G ({:.1f} %) | max_flops: {:.1f}G (100 %)'.format(flops/1e9,
(flops / max_flops)*100, max_flops/1e9))
pred = pred.transpose(1,2).view(-1,3,H,W)
pred = pred * rgb_std + rgb_mean
pred = tensor2numpy(pred)
Image.fromarray(pred).save(f'output.png')
flag = flag.view(-1,1,H,W).repeat(1,3,1,1).squeeze(0).detach().cpu()
H,W = pred.shape[:2]
vis_img = np.zeros_like(pred)
vis_img[flag[0] == 0] = np.array([0,255,0])
vis_img[flag[0] == 1] = np.array([255,0,0])
vis_img = vis_img*0.35 + pred*0.65
Image.fromarray(vis_img.astype('uint8')).save('output_vis.png') |