import argparse import pickle import torch from torch import nn import numpy as np from scipy import linalg from tqdm import tqdm from model import Generator from calc_inception import load_patched_inception_v3 @torch.no_grad() def extract_feature_from_samples( generator, inception, truncation, truncation_latent, batch_size, n_sample, device ): n_batch = n_sample // batch_size resid = n_sample - (n_batch * batch_size) batch_sizes = [batch_size] * n_batch + [resid] features = [] for batch in tqdm(batch_sizes): latent = torch.randn(batch, 512, device=device) img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent) feat = inception(img)[0].view(img.shape[0], -1) features.append(feat.to("cpu")) features = torch.cat(features, 0) return features def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) if not np.isfinite(cov_sqrt).all(): print("product of cov matrices is singular") offset = np.eye(sample_cov.shape[0]) * eps cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) if np.iscomplexobj(cov_sqrt): if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): m = np.max(np.abs(cov_sqrt.imag)) raise ValueError(f"Imaginary component {m}") cov_sqrt = cov_sqrt.real mean_diff = sample_mean - real_mean mean_norm = mean_diff @ mean_diff trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) fid = mean_norm + trace return fid if __name__ == "__main__": device = "cuda" parser = argparse.ArgumentParser(description="Calculate FID scores") parser.add_argument("--truncation", type=float, default=1, help="truncation factor") parser.add_argument( "--truncation_mean", type=int, default=4096, help="number of samples to calculate mean for truncation", ) parser.add_argument( "--batch", type=int, default=64, help="batch size for the generator" ) parser.add_argument( "--n_sample", type=int, default=50000, help="number of the samples for calculating FID", ) parser.add_argument( "--size", type=int, default=256, help="image sizes for generator" ) parser.add_argument( "--inception", type=str, default=None, required=True, help="path to precomputed inception embedding", ) parser.add_argument( "ckpt", metavar="CHECKPOINT", help="path to generator checkpoint" ) args = parser.parse_args() ckpt = torch.load(args.ckpt) g = Generator(args.size, 512, 8).to(device) g.load_state_dict(ckpt["g_ema"]) g = nn.DataParallel(g) g.eval() if args.truncation < 1: with torch.no_grad(): mean_latent = g.mean_latent(args.truncation_mean) else: mean_latent = None inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() features = extract_feature_from_samples( g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device ).numpy() print(f"extracted {features.shape[0]} features") sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) print("fid:", fid)