Time-TravelRephotography / projector.py
feng2022's picture
time
f9827f9
raw
history blame
6.73 kB
import argparse
import math
import os
import torch
from torch import optim
from torch.nn import functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import lpips
from model import Generator
def noise_regularize(noises):
loss = 0
for noise in noises:
size = noise.shape[2]
while True:
loss = (
loss
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
)
if size <= 8:
break
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
noise = noise.mean([3, 5])
size //= 2
return loss
def noise_normalize_(noises):
for noise in noises:
mean = noise.mean()
std = noise.std()
noise.data.add_(-mean).div_(std)
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
def latent_noise(latent, strength):
noise = torch.randn_like(latent) * strength
return latent + noise
def make_image(tensor):
return (
tensor.detach()
.clamp_(min=-1, max=1)
.add(1)
.div_(2)
.mul(255)
.type(torch.uint8)
.permute(0, 2, 3, 1)
.to("cpu")
.numpy()
)
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(
description="Image projector to the generator latent spaces"
)
parser.add_argument(
"--ckpt", type=str, required=True, help="path to the model checkpoint"
)
parser.add_argument(
"--size", type=int, default=256, help="output image sizes of the generator"
)
parser.add_argument(
"--lr_rampup",
type=float,
default=0.05,
help="duration of the learning rate warmup",
)
parser.add_argument(
"--lr_rampdown",
type=float,
default=0.25,
help="duration of the learning rate decay",
)
parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
parser.add_argument(
"--noise", type=float, default=0.05, help="strength of the noise level"
)
parser.add_argument(
"--noise_ramp",
type=float,
default=0.75,
help="duration of the noise level decay",
)
parser.add_argument("--step", type=int, default=1000, help="optimize iterations")
parser.add_argument(
"--noise_regularize",
type=float,
default=1e5,
help="weight of the noise regularization",
)
parser.add_argument("--mse", type=float, default=0, help="weight of the mse loss")
parser.add_argument(
"--w_plus",
action="store_true",
help="allow to use distinct latent codes to each layers",
)
parser.add_argument(
"files", metavar="FILES", nargs="+", help="path to image files to be projected"
)
args = parser.parse_args()
n_mean_latent = 10000
resize = min(args.size, 256)
transform = transforms.Compose(
[
transforms.Resize(resize),
transforms.CenterCrop(resize),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
imgs = []
for imgfile in args.files:
img = transform(Image.open(imgfile).convert("RGB"))
imgs.append(img)
imgs = torch.stack(imgs, 0).to(device)
g_ema = Generator(args.size, 512, 8)
g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
g_ema.eval()
g_ema = g_ema.to(device)
with torch.no_grad():
noise_sample = torch.randn(n_mean_latent, 512, device=device)
latent_out = g_ema.style(noise_sample)
latent_mean = latent_out.mean(0)
latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
percept = lpips.PerceptualLoss(
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
)
noises_single = g_ema.make_noise()
noises = []
for noise in noises_single:
noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
if args.w_plus:
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
latent_in.requires_grad = True
for noise in noises:
noise.requires_grad = True
optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
pbar = tqdm(range(args.step))
latent_path = []
for i in pbar:
t = i / args.step
lr = get_lr(t, args.lr)
optimizer.param_groups[0]["lr"] = lr
noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2
latent_n = latent_noise(latent_in, noise_strength.item())
img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises)
batch, channel, height, width = img_gen.shape
if height > 256:
factor = height // 256
img_gen = img_gen.reshape(
batch, channel, height // factor, factor, width // factor, factor
)
img_gen = img_gen.mean([3, 5])
p_loss = percept(img_gen, imgs).sum()
n_loss = noise_regularize(noises)
mse_loss = F.mse_loss(img_gen, imgs)
loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
noise_normalize_(noises)
if (i + 1) % 100 == 0:
latent_path.append(latent_in.detach().clone())
pbar.set_description(
(
f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
)
)
img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises)
filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt"
img_ar = make_image(img_gen)
result_file = {}
for i, input_name in enumerate(args.files):
noise_single = []
for noise in noises:
noise_single.append(noise[i : i + 1])
result_file[input_name] = {
"img": img_gen[i],
"latent": latent_in[i],
"noise": noise_single,
}
img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png"
pil_img = Image.fromarray(img_ar[i])
pil_img.save(img_name)
torch.save(result_file, filename)