Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Extract factor/eigenvectors of latent spaces using closed form factorization" | |
) | |
parser.add_argument( | |
"--out", type=str, default="factor.pt", help="name of the result factor file" | |
) | |
parser.add_argument("ckpt", type=str, help="name of the model checkpoint") | |
args = parser.parse_args() | |
ckpt = torch.load(args.ckpt) | |
modulate = { | |
k: v | |
for k, v in ckpt["g_ema"].items() | |
if "modulation" in k and "to_rgbs" not in k and "weight" in k | |
} | |
weight_mat = [] | |
for k, v in modulate.items(): | |
weight_mat.append(v) | |
W = torch.cat(weight_mat, 0) | |
eigvec = torch.svd(W).V.to("cpu") | |
torch.save({"ckpt": args.ckpt, "eigvec": eigvec}, args.out) | |