Spaces:
Runtime error
Runtime error
import argparse | |
import pickle | |
import torch | |
from torch import nn | |
import numpy as np | |
from scipy import linalg | |
from tqdm import tqdm | |
from model import Generator | |
from calc_inception import load_patched_inception_v3 | |
def extract_feature_from_samples( | |
generator, inception, truncation, truncation_latent, batch_size, n_sample, device | |
): | |
n_batch = n_sample // batch_size | |
resid = n_sample - (n_batch * batch_size) | |
batch_sizes = [batch_size] * n_batch + [resid] | |
features = [] | |
for batch in tqdm(batch_sizes): | |
latent = torch.randn(batch, 512, device=device) | |
img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent) | |
feat = inception(img)[0].view(img.shape[0], -1) | |
features.append(feat.to("cpu")) | |
features = torch.cat(features, 0) | |
return features | |
def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): | |
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) | |
if not np.isfinite(cov_sqrt).all(): | |
print("product of cov matrices is singular") | |
offset = np.eye(sample_cov.shape[0]) * eps | |
cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) | |
if np.iscomplexobj(cov_sqrt): | |
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): | |
m = np.max(np.abs(cov_sqrt.imag)) | |
raise ValueError(f"Imaginary component {m}") | |
cov_sqrt = cov_sqrt.real | |
mean_diff = sample_mean - real_mean | |
mean_norm = mean_diff @ mean_diff | |
trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) | |
fid = mean_norm + trace | |
return fid | |
if __name__ == "__main__": | |
device = "cuda" | |
parser = argparse.ArgumentParser(description="Calculate FID scores") | |
parser.add_argument("--truncation", type=float, default=1, help="truncation factor") | |
parser.add_argument( | |
"--truncation_mean", | |
type=int, | |
default=4096, | |
help="number of samples to calculate mean for truncation", | |
) | |
parser.add_argument( | |
"--batch", type=int, default=64, help="batch size for the generator" | |
) | |
parser.add_argument( | |
"--n_sample", | |
type=int, | |
default=50000, | |
help="number of the samples for calculating FID", | |
) | |
parser.add_argument( | |
"--size", type=int, default=256, help="image sizes for generator" | |
) | |
parser.add_argument( | |
"--inception", | |
type=str, | |
default=None, | |
required=True, | |
help="path to precomputed inception embedding", | |
) | |
parser.add_argument( | |
"ckpt", metavar="CHECKPOINT", help="path to generator checkpoint" | |
) | |
args = parser.parse_args() | |
ckpt = torch.load(args.ckpt) | |
g = Generator(args.size, 512, 8).to(device) | |
g.load_state_dict(ckpt["g_ema"]) | |
g = nn.DataParallel(g) | |
g.eval() | |
if args.truncation < 1: | |
with torch.no_grad(): | |
mean_latent = g.mean_latent(args.truncation_mean) | |
else: | |
mean_latent = None | |
inception = nn.DataParallel(load_patched_inception_v3()).to(device) | |
inception.eval() | |
features = extract_feature_from_samples( | |
g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device | |
).numpy() | |
print(f"extracted {features.shape[0]} features") | |
sample_mean = np.mean(features, 0) | |
sample_cov = np.cov(features, rowvar=False) | |
with open(args.inception, "rb") as f: | |
embeds = pickle.load(f) | |
real_mean = embeds["mean"] | |
real_cov = embeds["cov"] | |
fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) | |
print("fid:", fid) | |