Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
from torchvision import utils | |
from model import Generator | |
from tqdm import tqdm | |
def generate(args, g_ema, device, mean_latent): | |
with torch.no_grad(): | |
g_ema.eval() | |
for i in tqdm(range(args.pics)): | |
sample_z = torch.randn(args.sample, args.latent, device=device) | |
sample, _ = g_ema( | |
[sample_z], truncation=args.truncation, truncation_latent=mean_latent | |
) | |
utils.save_image( | |
sample, | |
f"sample/{str(i).zfill(6)}.png", | |
nrow=1, | |
normalize=True, | |
range=(-1, 1), | |
) | |
if __name__ == "__main__": | |
device = "cuda" | |
parser = argparse.ArgumentParser(description="Generate samples from the generator") | |
parser.add_argument( | |
"--size", type=int, default=1024, help="output image size of the generator" | |
) | |
parser.add_argument( | |
"--sample", | |
type=int, | |
default=1, | |
help="number of samples to be generated for each image", | |
) | |
parser.add_argument( | |
"--pics", type=int, default=20, help="number of images to be generated" | |
) | |
parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") | |
parser.add_argument( | |
"--truncation_mean", | |
type=int, | |
default=4096, | |
help="number of vectors to calculate mean for the truncation", | |
) | |
parser.add_argument( | |
"--ckpt", | |
type=str, | |
default="stylegan2-ffhq-config-f.pt", | |
help="path to the model checkpoint", | |
) | |
parser.add_argument( | |
"--channel_multiplier", | |
type=int, | |
default=2, | |
help="channel multiplier of the generator. config-f = 2, else = 1", | |
) | |
args = parser.parse_args() | |
args.latent = 512 | |
args.n_mlp = 8 | |
g_ema = Generator( | |
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier | |
).to(device) | |
checkpoint = torch.load(args.ckpt) | |
g_ema.load_state_dict(checkpoint["g_ema"]) | |
if args.truncation < 1: | |
with torch.no_grad(): | |
mean_latent = g_ema.mean_latent(args.truncation_mean) | |
else: | |
mean_latent = None | |
generate(args, g_ema, device, mean_latent) | |