|
import torch |
|
import numpy as np |
|
from internal import math |
|
from skimage.metrics import structural_similarity, peak_signal_noise_ratio |
|
import cv2 |
|
|
|
|
|
def mse_to_psnr(mse): |
|
"""Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" |
|
return -10. / np.log(10.) * np.log(mse) |
|
|
|
|
|
def psnr_to_mse(psnr): |
|
"""Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" |
|
return np.exp(-0.1 * np.log(10.) * psnr) |
|
|
|
|
|
def ssim_to_dssim(ssim): |
|
"""Compute DSSIM given an SSIM.""" |
|
return (1 - ssim) / 2 |
|
|
|
|
|
def dssim_to_ssim(dssim): |
|
"""Compute DSSIM given an SSIM.""" |
|
return 1 - 2 * dssim |
|
|
|
|
|
def linear_to_srgb(linear, eps=None): |
|
"""Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" |
|
if eps is None: |
|
eps = torch.finfo(linear.dtype).eps |
|
|
|
|
|
srgb0 = 323 / 25 * linear |
|
srgb1 = (211 * linear.clamp_min(eps) ** (5 / 12) - 11) / 200 |
|
return torch.where(linear <= 0.0031308, srgb0, srgb1) |
|
|
|
|
|
def linear_to_srgb_np(linear, eps=None): |
|
"""Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" |
|
if eps is None: |
|
eps = np.finfo(linear.dtype).eps |
|
srgb0 = 323 / 25 * linear |
|
srgb1 = (211 * np.maximum(eps, linear) ** (5 / 12) - 11) / 200 |
|
return np.where(linear <= 0.0031308, srgb0, srgb1) |
|
|
|
|
|
def srgb_to_linear(srgb, eps=None): |
|
"""Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" |
|
if eps is None: |
|
eps = np.finfo(srgb.dtype).eps |
|
linear0 = 25 / 323 * srgb |
|
linear1 = np.maximum(eps, ((200 * srgb + 11) / (211))) ** (12 / 5) |
|
return np.where(srgb <= 0.04045, linear0, linear1) |
|
|
|
|
|
def downsample(img, factor): |
|
"""Area downsample img (factor must evenly divide img height and width).""" |
|
sh = img.shape |
|
if not (sh[0] % factor == 0 and sh[1] % factor == 0): |
|
raise ValueError(f'Downsampling factor {factor} does not ' |
|
f'evenly divide image shape {sh[:2]}') |
|
img = img.reshape((sh[0] // factor, factor, sh[1] // factor, factor) + sh[2:]) |
|
img = img.mean((1, 3)) |
|
return img |
|
|
|
|
|
def color_correct(img, ref, num_iters=5, eps=0.5 / 255): |
|
"""Warp `img` to match the colors in `ref_img`.""" |
|
if img.shape[-1] != ref.shape[-1]: |
|
raise ValueError( |
|
f'img\'s {img.shape[-1]} and ref\'s {ref.shape[-1]} channels must match' |
|
) |
|
num_channels = img.shape[-1] |
|
img_mat = img.reshape([-1, num_channels]) |
|
ref_mat = ref.reshape([-1, num_channels]) |
|
is_unclipped = lambda z: (z >= eps) & (z <= (1 - eps)) |
|
mask0 = is_unclipped(img_mat) |
|
|
|
|
|
|
|
for _ in range(num_iters): |
|
|
|
|
|
a_mat = [] |
|
for c in range(num_channels): |
|
a_mat.append(img_mat[:, c:(c + 1)] * img_mat[:, c:]) |
|
a_mat.append(img_mat) |
|
a_mat.append(torch.ones_like(img_mat[:, :1])) |
|
a_mat = torch.cat(a_mat, dim=-1) |
|
warp = [] |
|
for c in range(num_channels): |
|
|
|
|
|
b = ref_mat[:, c] |
|
|
|
|
|
mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b) |
|
ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat)) |
|
mb = torch.where(mask, b, torch.zeros_like(b)) |
|
w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0] |
|
assert torch.all(torch.isfinite(w)) |
|
warp.append(w) |
|
warp = torch.stack(warp, dim=-1) |
|
|
|
img_mat = torch.clip(math.matmul(a_mat, warp), 0, 1) |
|
corrected_img = torch.reshape(img_mat, img.shape) |
|
return corrected_img |
|
|
|
|
|
class MetricHarness: |
|
"""A helper class for evaluating several error metrics.""" |
|
|
|
def __call__(self, rgb_pred, rgb_gt, name_fn=lambda s: s): |
|
"""Evaluate the error between a predicted rgb image and the true image.""" |
|
rgb_pred = (rgb_pred * 255).astype(np.uint8) |
|
rgb_gt = (rgb_gt * 255).astype(np.uint8) |
|
rgb_pred_gray = cv2.cvtColor(rgb_pred, cv2.COLOR_RGB2GRAY) |
|
rgb_gt_gray = cv2.cvtColor(rgb_gt, cv2.COLOR_RGB2GRAY) |
|
psnr = float(peak_signal_noise_ratio(rgb_pred, rgb_gt, data_range=255)) |
|
ssim = float(structural_similarity(rgb_pred_gray, rgb_gt_gray, data_range=255)) |
|
|
|
return { |
|
name_fn('psnr'): psnr, |
|
name_fn('ssim'): ssim, |
|
} |
|
|