Spaces:
Runtime error
Runtime error
import argparse | |
import math | |
import os | |
import torch | |
from torch import optim | |
from torch.nn import functional as F | |
from torchvision import transforms | |
from PIL import Image | |
from tqdm import tqdm | |
import lpips | |
from model import Generator | |
def noise_regularize(noises): | |
loss = 0 | |
for noise in noises: | |
size = noise.shape[2] | |
while True: | |
loss = ( | |
loss | |
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) | |
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) | |
) | |
if size <= 8: | |
break | |
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) | |
noise = noise.mean([3, 5]) | |
size //= 2 | |
return loss | |
def noise_normalize_(noises): | |
for noise in noises: | |
mean = noise.mean() | |
std = noise.std() | |
noise.data.add_(-mean).div_(std) | |
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): | |
lr_ramp = min(1, (1 - t) / rampdown) | |
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) | |
lr_ramp = lr_ramp * min(1, t / rampup) | |
return initial_lr * lr_ramp | |
def latent_noise(latent, strength): | |
noise = torch.randn_like(latent) * strength | |
return latent + noise | |
def make_image(tensor): | |
return ( | |
tensor.detach() | |
.clamp_(min=-1, max=1) | |
.add(1) | |
.div_(2) | |
.mul(255) | |
.type(torch.uint8) | |
.permute(0, 2, 3, 1) | |
.to("cpu") | |
.numpy() | |
) | |
if __name__ == "__main__": | |
device = "cuda" | |
parser = argparse.ArgumentParser( | |
description="Image projector to the generator latent spaces" | |
) | |
parser.add_argument( | |
"--ckpt", type=str, required=True, help="path to the model checkpoint" | |
) | |
parser.add_argument( | |
"--size", type=int, default=256, help="output image sizes of the generator" | |
) | |
parser.add_argument( | |
"--lr_rampup", | |
type=float, | |
default=0.05, | |
help="duration of the learning rate warmup", | |
) | |
parser.add_argument( | |
"--lr_rampdown", | |
type=float, | |
default=0.25, | |
help="duration of the learning rate decay", | |
) | |
parser.add_argument("--lr", type=float, default=0.1, help="learning rate") | |
parser.add_argument( | |
"--noise", type=float, default=0.05, help="strength of the noise level" | |
) | |
parser.add_argument( | |
"--noise_ramp", | |
type=float, | |
default=0.75, | |
help="duration of the noise level decay", | |
) | |
parser.add_argument("--step", type=int, default=1000, help="optimize iterations") | |
parser.add_argument( | |
"--noise_regularize", | |
type=float, | |
default=1e5, | |
help="weight of the noise regularization", | |
) | |
parser.add_argument("--mse", type=float, default=0, help="weight of the mse loss") | |
parser.add_argument( | |
"--w_plus", | |
action="store_true", | |
help="allow to use distinct latent codes to each layers", | |
) | |
parser.add_argument( | |
"files", metavar="FILES", nargs="+", help="path to image files to be projected" | |
) | |
args = parser.parse_args() | |
n_mean_latent = 10000 | |
resize = min(args.size, 256) | |
transform = transforms.Compose( | |
[ | |
transforms.Resize(resize), | |
transforms.CenterCrop(resize), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
imgs = [] | |
for imgfile in args.files: | |
img = transform(Image.open(imgfile).convert("RGB")) | |
imgs.append(img) | |
imgs = torch.stack(imgs, 0).to(device) | |
g_ema = Generator(args.size, 512, 8) | |
g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) | |
g_ema.eval() | |
g_ema = g_ema.to(device) | |
with torch.no_grad(): | |
noise_sample = torch.randn(n_mean_latent, 512, device=device) | |
latent_out = g_ema.style(noise_sample) | |
latent_mean = latent_out.mean(0) | |
latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 | |
percept = lpips.PerceptualLoss( | |
model="net-lin", net="vgg", use_gpu=device.startswith("cuda") | |
) | |
noises_single = g_ema.make_noise() | |
noises = [] | |
for noise in noises_single: | |
noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_()) | |
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1) | |
if args.w_plus: | |
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) | |
latent_in.requires_grad = True | |
for noise in noises: | |
noise.requires_grad = True | |
optimizer = optim.Adam([latent_in] + noises, lr=args.lr) | |
pbar = tqdm(range(args.step)) | |
latent_path = [] | |
for i in pbar: | |
t = i / args.step | |
lr = get_lr(t, args.lr) | |
optimizer.param_groups[0]["lr"] = lr | |
noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2 | |
latent_n = latent_noise(latent_in, noise_strength.item()) | |
img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises) | |
batch, channel, height, width = img_gen.shape | |
if height > 256: | |
factor = height // 256 | |
img_gen = img_gen.reshape( | |
batch, channel, height // factor, factor, width // factor, factor | |
) | |
img_gen = img_gen.mean([3, 5]) | |
p_loss = percept(img_gen, imgs).sum() | |
n_loss = noise_regularize(noises) | |
mse_loss = F.mse_loss(img_gen, imgs) | |
loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
noise_normalize_(noises) | |
if (i + 1) % 100 == 0: | |
latent_path.append(latent_in.detach().clone()) | |
pbar.set_description( | |
( | |
f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};" | |
f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" | |
) | |
) | |
img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises) | |
filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt" | |
img_ar = make_image(img_gen) | |
result_file = {} | |
for i, input_name in enumerate(args.files): | |
noise_single = [] | |
for noise in noises: | |
noise_single.append(noise[i : i + 1]) | |
result_file[input_name] = { | |
"img": img_gen[i], | |
"latent": latent_in[i], | |
"noise": noise_single, | |
} | |
img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png" | |
pil_img = Image.fromarray(img_ar[i]) | |
pil_img.save(img_name) | |
torch.save(result_file, filename) | |