Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
from torchvision import utils | |
from model import Generator | |
if __name__ == "__main__": | |
torch.set_grad_enabled(False) | |
parser = argparse.ArgumentParser(description="Apply closed form factorization") | |
parser.add_argument( | |
"-i", "--index", type=int, default=0, help="index of eigenvector" | |
) | |
parser.add_argument( | |
"-d", | |
"--degree", | |
type=float, | |
default=5, | |
help="scalar factors for moving latent vectors along eigenvector", | |
) | |
parser.add_argument( | |
"--channel_multiplier", | |
type=int, | |
default=2, | |
help='channel multiplier factor. config-f = 2, else = 1', | |
) | |
parser.add_argument("--ckpt", type=str, required=True, help="stylegan2 checkpoints") | |
parser.add_argument( | |
"--size", type=int, default=256, help="output image size of the generator" | |
) | |
parser.add_argument( | |
"-n", "--n_sample", type=int, default=7, help="number of samples created" | |
) | |
parser.add_argument( | |
"--truncation", type=float, default=0.7, help="truncation factor" | |
) | |
parser.add_argument( | |
"--device", type=str, default="cuda", help="device to run the model" | |
) | |
parser.add_argument( | |
"--out_prefix", | |
type=str, | |
default="factor", | |
help="filename prefix to result samples", | |
) | |
parser.add_argument( | |
"factor", | |
type=str, | |
help="name of the closed form factorization result factor file", | |
) | |
args = parser.parse_args() | |
eigvec = torch.load(args.factor)["eigvec"].to(args.device) | |
ckpt = torch.load(args.ckpt) | |
g = Generator(args.size, 512, 8, channel_multiplier=args.channel_multiplier).to(args.device) | |
g.load_state_dict(ckpt["g_ema"], strict=False) | |
trunc = g.mean_latent(4096) | |
latent = torch.randn(args.n_sample, 512, device=args.device) | |
latent = g.get_latent(latent) | |
direction = args.degree * eigvec[:, args.index].unsqueeze(0) | |
img, _ = g( | |
[latent], | |
truncation=args.truncation, | |
truncation_latent=trunc, | |
input_is_latent=True, | |
) | |
img1, _ = g( | |
[latent + direction], | |
truncation=args.truncation, | |
truncation_latent=trunc, | |
input_is_latent=True, | |
) | |
img2, _ = g( | |
[latent - direction], | |
truncation=args.truncation, | |
truncation_latent=trunc, | |
input_is_latent=True, | |
) | |
grid = utils.save_image( | |
torch.cat([img1, img, img2], 0), | |
f"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png", | |
normalize=True, | |
range=(-1, 1), | |
nrow=args.n_sample, | |
) | |