AK391
files
d380b77
raw
history blame
1.2 kB
import logging
import torch
from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs):
logging.info(f'Make evaluator {kind}')
device = "cuda" if torch.cuda.is_available() else "cpu"
metrics = {}
if ssim:
metrics['ssim'] = SSIMScore()
if lpips:
metrics['lpips'] = LPIPSScore()
if fid:
metrics['fid'] = FIDScore().to(device)
if integral_kind is None:
integral_func = None
elif integral_kind == 'ssim_fid100_f1':
integral_func = ssim_fid100_f1
elif integral_kind == 'lpips_fid100_f1':
integral_func = lpips_fid100_f1
else:
raise ValueError(f'Unexpected integral_kind={integral_kind}')
if kind == 'default':
return InpaintingEvaluatorOnline(scores=metrics,
integral_func=integral_func,
integral_title=integral_kind,
**kwargs)