|
import numpy as np |
|
|
|
import torch |
|
|
|
from skimage.metrics import peak_signal_noise_ratio, structural_similarity |
|
|
|
|
|
def calculate_batch_psnr(gt_tensor, output_tensor, mode='avg'): |
|
|
|
|
|
if mode == 'avg': |
|
gt_np = gt_tensor.cpu().numpy().astype(np.float32) |
|
output_np = output_tensor.cpu().numpy().astype(np.float32) |
|
|
|
bs = gt_np.shape[0] |
|
psnr_list = [] |
|
psnr = 0 |
|
for i in range(bs): |
|
gt_im = gt_np[i, :, :, :] |
|
output_im = output_np[i, :, :, :] |
|
|
|
gt_im = gt_im.transpose((1, 2, 0)) |
|
output_im = output_im.transpose((1, 2, 0)) |
|
|
|
psnr_list.append(peak_signal_noise_ratio(gt_im, output_im, data_range=1.)) |
|
psnr += peak_signal_noise_ratio(gt_im, output_im, data_range=1.) |
|
return float(psnr / bs), psnr_list |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def calculate_batch_ssim(gt_tensor, output_tensor, mode='avg'): |
|
if mode == 'avg': |
|
gt_np = gt_tensor.cpu().numpy().astype(np.float32) |
|
output_np = output_tensor.cpu().numpy().astype(np.float32) |
|
|
|
bs = gt_np.shape[0] |
|
ssim = 0 |
|
for i in range(bs): |
|
gt_im = gt_np[i, :, :, :] |
|
output_im = output_np[i, :, :, :] |
|
gt_im = gt_im.transpose((1, 2, 0)) |
|
output_im = output_im.transpose((1, 2, 0)) |
|
|
|
ssim += structural_similarity(gt_im, output_im, data_range=1., multichannel=True, channel_axis=2) |
|
|
|
return float(ssim / bs), bs |
|
else: |
|
raise NotImplementedError |
|
|