Spaces:
Runtime error
Runtime error
File size: 5,720 Bytes
f9827f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from skimage.measure import compare_ssim
import torch
from torch.autograd import Variable
from lpips import dist_model
class PerceptualLoss(torch.nn.Module):
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
super(PerceptualLoss, self).__init__()
print('Setting up Perceptual loss...')
self.use_gpu = use_gpu
self.spatial = spatial
self.gpu_ids = gpu_ids
self.model = dist_model.DistModel()
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
print('...[%s] initialized'%self.model.name())
print('...Done')
def forward(self, pred, target, normalize=False):
"""
Pred and target are Variables.
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
If normalize is False, assumes the images are already between [-1,+1]
Inputs pred and target are Nx3xHxW
Output pytorch Variable N long
"""
if normalize:
target = 2 * target - 1
pred = 2 * pred - 1
return self.model.forward(target, pred)
def normalize_tensor(in_feat,eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
return in_feat/(norm_factor+eps)
def l2(p0, p1, range=255.):
return .5*np.mean((p0 / range - p1 / range)**2)
def psnr(p0, p1, peak=255.):
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
def dssim(p0, p1, range=255.):
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
def rgb2lab(in_img,mean_cent=False):
from skimage import color
img_lab = color.rgb2lab(in_img)
if(mean_cent):
img_lab[:,:,0] = img_lab[:,:,0]-50
return img_lab
def tensor2np(tensor_obj):
# change dimension of a tensor object into a numpy array
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
def np2tensor(np_obj):
# change dimenion of np array into tensor array
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
# image tensor to lab tensor
from skimage import color
img = tensor2im(image_tensor)
img_lab = color.rgb2lab(img)
if(mc_only):
img_lab[:,:,0] = img_lab[:,:,0]-50
if(to_norm and not mc_only):
img_lab[:,:,0] = img_lab[:,:,0]-50
img_lab = img_lab/100.
return np2tensor(img_lab)
def tensorlab2tensor(lab_tensor,return_inbnd=False):
from skimage import color
import warnings
warnings.filterwarnings("ignore")
lab = tensor2np(lab_tensor)*100.
lab[:,:,0] = lab[:,:,0]+50
rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
if(return_inbnd):
# convert back to lab, see if we match
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
mask = 1.*np.isclose(lab_back,lab,atol=2.)
mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
return (im2tensor(rgb_back),mask)
else:
return im2tensor(rgb_back)
def rgb2lab(input):
from skimage import color
return color.rgb2lab(input / 255.)
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
def tensor2vec(vector_tensor):
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
def voc_ap(rec, prec, use_07_metric=False):
""" ap = voc_ap(rec, prec, [use_07_metric])
Compute VOC AP given precision and recall.
If use_07_metric is true, uses the
VOC 07 11 point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|