import argparse import torch from torchvision import utils from model import Generator from tqdm import tqdm def generate(args, g_ema, device, mean_latent): with torch.no_grad(): g_ema.eval() for i in tqdm(range(args.pics)): sample_z = torch.randn(args.sample, args.latent, device=device) sample, _ = g_ema( [sample_z], truncation=args.truncation, truncation_latent=mean_latent ) utils.save_image( sample, f"sample/{str(i).zfill(6)}.png", nrow=1, normalize=True, range=(-1, 1), ) if __name__ == "__main__": device = "cuda" parser = argparse.ArgumentParser(description="Generate samples from the generator") parser.add_argument( "--size", type=int, default=1024, help="output image size of the generator" ) parser.add_argument( "--sample", type=int, default=1, help="number of samples to be generated for each image", ) parser.add_argument( "--pics", type=int, default=20, help="number of images to be generated" ) parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") parser.add_argument( "--truncation_mean", type=int, default=4096, help="number of vectors to calculate mean for the truncation", ) parser.add_argument( "--ckpt", type=str, default="stylegan2-ffhq-config-f.pt", help="path to the model checkpoint", ) parser.add_argument( "--channel_multiplier", type=int, default=2, help="channel multiplier of the generator. config-f = 2, else = 1", ) args = parser.parse_args() args.latent = 512 args.n_mlp = 8 g_ema = Generator( args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier ).to(device) checkpoint = torch.load(args.ckpt) g_ema.load_state_dict(checkpoint["g_ema"]) if args.truncation < 1: with torch.no_grad(): mean_latent = g_ema.mean_latent(args.truncation_mean) else: mean_latent = None generate(args, g_ema, device, mean_latent)