import argparse import torch from torch.nn import functional as F import numpy as np from tqdm import tqdm import lpips from model import Generator def normalize(x): return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True)) def slerp(a, b, t): a = normalize(a) b = normalize(b) d = (a * b).sum(-1, keepdim=True) p = t * torch.acos(d) c = normalize(b - d * a) d = a * torch.cos(p) + c * torch.sin(p) return normalize(d) def lerp(a, b, t): return a + (b - a) * t if __name__ == "__main__": device = "cuda" parser = argparse.ArgumentParser(description="Perceptual Path Length calculator") parser.add_argument( "--space", choices=["z", "w"], help="space that PPL calculated with" ) parser.add_argument( "--batch", type=int, default=64, help="batch size for the models" ) parser.add_argument( "--n_sample", type=int, default=5000, help="number of the samples for calculating PPL", ) parser.add_argument( "--size", type=int, default=256, help="output image sizes of the generator" ) parser.add_argument( "--eps", type=float, default=1e-4, help="epsilon for numerical stability" ) parser.add_argument( "--crop", action="store_true", help="apply center crop to the images" ) parser.add_argument( "--sampling", default="end", choices=["end", "full"], help="set endpoint sampling method", ) parser.add_argument( "ckpt", metavar="CHECKPOINT", help="path to the model checkpoints" ) args = parser.parse_args() latent_dim = 512 ckpt = torch.load(args.ckpt) g = Generator(args.size, latent_dim, 8).to(device) g.load_state_dict(ckpt["g_ema"]) g.eval() percept = lpips.PerceptualLoss( model="net-lin", net="vgg", use_gpu=device.startswith("cuda") ) distances = [] n_batch = args.n_sample // args.batch resid = args.n_sample - (n_batch * args.batch) batch_sizes = [args.batch] * n_batch + [resid] with torch.no_grad(): for batch in tqdm(batch_sizes): noise = g.make_noise() inputs = torch.randn([batch * 2, latent_dim], device=device) if args.sampling == "full": lerp_t = torch.rand(batch, device=device) else: lerp_t = torch.zeros(batch, device=device) if args.space == "w": latent = g.get_latent(inputs) latent_t0, latent_t1 = latent[::2], latent[1::2] latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None]) latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps) latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape) image, _ = g([latent_e], input_is_latent=True, noise=noise) if args.crop: c = image.shape[2] // 8 image = image[:, :, c * 3 : c * 7, c * 2 : c * 6] factor = image.shape[2] // 256 if factor > 1: image = F.interpolate( image, size=(256, 256), mode="bilinear", align_corners=False ) dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / ( args.eps ** 2 ) distances.append(dist.to("cpu").numpy()) distances = np.concatenate(distances, 0) lo = np.percentile(distances, 1, interpolation="lower") hi = np.percentile(distances, 99, interpolation="higher") filtered_dist = np.extract( np.logical_and(lo <= distances, distances <= hi), distances ) print("ppl:", filtered_dist.mean())