|
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
def _expand(img):
|
|
if img.ndim < 4:
|
|
img = img.expand([1] * (4-img.ndim) + list(img.shape))
|
|
|
|
return img
|
|
|
|
class PSNR(nn.Module):
|
|
def __init__(self):
|
|
super(PSNR, self).__init__()
|
|
|
|
def forward(self, im1, im2, data_range=None):
|
|
|
|
|
|
if data_range is None:
|
|
data_range = 255 if im1.max() > 1 else 1
|
|
|
|
se = (im1-im2)**2
|
|
se = _expand(se)
|
|
|
|
mse = se.mean(dim=list(range(1, se.ndim)))
|
|
psnr = 10 * (data_range**2/mse).log10().mean()
|
|
|
|
return psnr
|
|
|
|
class SSIM(nn.Module):
|
|
def __init__(self, device_type='cpu', dtype=torch.float32):
|
|
super(SSIM, self).__init__()
|
|
|
|
self.device_type = device_type
|
|
self.dtype = dtype
|
|
|
|
def _get_ssim_weight():
|
|
truncate = 3.5
|
|
sigma = 1.5
|
|
r = int(truncate * sigma + 0.5)
|
|
win_size = 2 * r + 1
|
|
nch = 3
|
|
|
|
weight = torch.Tensor([-(x - win_size//2)**2/float(2*sigma**2) for x in range(win_size)]).exp().unsqueeze(1)
|
|
weight = weight.mm(weight.t())
|
|
weight /= weight.sum()
|
|
weight = weight.repeat(nch, 1, 1, 1)
|
|
|
|
return weight
|
|
|
|
self.weight = _get_ssim_weight().to(self.device_type, dtype=self.dtype, non_blocking=True)
|
|
|
|
def forward(self, im1, im2, data_range=None):
|
|
"""Implementation adopted from skimage.metrics.structural_similarity
|
|
Default arguments set to multichannel=True, gaussian_weight=True, use_sample_covariance=False
|
|
"""
|
|
|
|
im1 = im1.to(self.device_type, dtype=self.dtype, non_blocking=True)
|
|
im2 = im2.to(self.device_type, dtype=self.dtype, non_blocking=True)
|
|
|
|
K1 = 0.01
|
|
K2 = 0.03
|
|
sigma = 1.5
|
|
|
|
truncate = 3.5
|
|
r = int(truncate * sigma + 0.5)
|
|
win_size = 2 * r + 1
|
|
|
|
im1 = _expand(im1)
|
|
im2 = _expand(im2)
|
|
|
|
nch = im1.shape[1]
|
|
|
|
if im1.shape[2] < win_size or im1.shape[3] < win_size:
|
|
raise ValueError(
|
|
"win_size exceeds image extent. If the input is a multichannel "
|
|
"(color) image, set multichannel=True.")
|
|
|
|
if data_range is None:
|
|
data_range = 255 if im1.max() > 1 else 1
|
|
|
|
def filter_func(img):
|
|
return nn.functional.conv2d(img, self.weight, groups=nch).to(self.dtype)
|
|
|
|
|
|
|
|
ux = filter_func(im1)
|
|
uy = filter_func(im2)
|
|
|
|
|
|
uxx = filter_func(im1 * im1)
|
|
uyy = filter_func(im2 * im2)
|
|
uxy = filter_func(im1 * im2)
|
|
vx = (uxx - ux * ux)
|
|
vy = (uyy - uy * uy)
|
|
vxy = (uxy - ux * uy)
|
|
|
|
R = data_range
|
|
C1 = (K1 * R) ** 2
|
|
C2 = (K2 * R) ** 2
|
|
|
|
A1, A2, B1, B2 = ((2 * ux * uy + C1,
|
|
2 * vxy + C2,
|
|
ux ** 2 + uy ** 2 + C1,
|
|
vx + vy + C2))
|
|
D = B1 * B2
|
|
S = (A1 * A2) / D
|
|
|
|
|
|
mssim = S.mean()
|
|
|
|
return mssim
|
|
|