Spaces:
Runtime error
Runtime error
import math | |
from argparse import ( | |
ArgumentParser, | |
Namespace, | |
) | |
from typing import ( | |
Dict, | |
Iterable, | |
Optional, | |
Tuple, | |
) | |
import numpy as np | |
from tqdm import tqdm | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.utils.tensorboard import SummaryWriter | |
from torchvision.utils import make_grid | |
from torchvision.transforms import Resize | |
#from optim import get_optimizer_class, OPTIMIZER_MAP | |
from losses.regularize_noise import NoiseRegularizer | |
from optim import RAdam | |
from utils.misc import ( | |
iterable_to_str, | |
optional_string, | |
) | |
class OptimizerArguments: | |
def add_arguments(parser: ArgumentParser): | |
parser.add_argument('--coarse_min', type=int, default=32) | |
parser.add_argument('--wplus_step', type=int, nargs="+", default=[250, 750], help="#step for optimizing w_plus") | |
#parser.add_argument('--lr_rampup', type=float, default=0.05) | |
#parser.add_argument('--lr_rampdown', type=float, default=0.25) | |
parser.add_argument('--lr', type=float, default=0.1) | |
parser.add_argument('--noise_strength', type=float, default=.0) | |
parser.add_argument('--noise_ramp', type=float, default=0.75) | |
#parser.add_argument('--optimize_noise', action="store_true") | |
parser.add_argument('--camera_lr', type=float, default=0.01) | |
parser.add_argument("--log_dir", default="log/projector", help="tensorboard log directory") | |
parser.add_argument("--log_freq", type=int, default=10, help="log frequency") | |
parser.add_argument("--log_visual_freq", type=int, default=50, help="log frequency") | |
def to_string(args: Namespace) -> str: | |
return ( | |
f"lr{args.lr}_{args.camera_lr}-c{args.coarse_min}" | |
+ f"-wp({iterable_to_str(args.wplus_step)})" | |
+ optional_string(args.noise_strength, f"-n{args.noise_strength}") | |
) | |
class LatentNoiser(nn.Module): | |
def __init__( | |
self, generator: torch.nn, | |
noise_ramp: float = 0.75, noise_strength: float = 0.05, | |
n_mean_latent: int = 10000 | |
): | |
super().__init__() | |
self.noise_ramp = noise_ramp | |
self.noise_strength = noise_strength | |
with torch.no_grad(): | |
# TODO: get 512 from generator | |
noise_sample = torch.randn(n_mean_latent, 512, device=generator.device) | |
latent_out = generator.style(noise_sample) | |
latent_mean = latent_out.mean(0) | |
self.latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 | |
def forward(self, latent: torch.Tensor, t: float) -> torch.Tensor: | |
strength = self.latent_std * self.noise_strength * max(0, 1 - t / self.noise_ramp) ** 2 | |
noise = torch.randn_like(latent) * strength | |
return latent + noise | |
class Optimizer: | |
def optimize( | |
cls, | |
generator: torch.nn, | |
criterion: torch.nn, | |
degrade: torch.nn, | |
target: torch.Tensor, # only used in writer since it's mostly baked in criterion | |
latent_init: torch.Tensor, | |
noise_init: torch.Tensor, | |
args: Namespace, | |
writer: Optional[SummaryWriter] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# do not optimize generator | |
generator = generator.eval() | |
target = target.detach() | |
# prepare parameters | |
noises = [] | |
for n in noise_init: | |
noise = n.detach().clone() | |
noise.requires_grad = True | |
noises.append(noise) | |
def create_parameters(latent_coarse): | |
parameters = [ | |
{'params': [latent_coarse], 'lr': args.lr}, | |
{'params': noises, 'lr': args.lr}, | |
{'params': degrade.parameters(), 'lr': args.camera_lr}, | |
] | |
return parameters | |
device = target.device | |
# start optimize | |
total_steps = np.sum(args.wplus_step) | |
max_coarse_size = (2 ** (len(args.wplus_step) - 1)) * args.coarse_min | |
noiser = LatentNoiser(generator, noise_ramp=args.noise_ramp, noise_strength=args.noise_strength).to(device) | |
latent = latent_init.detach().clone() | |
for coarse_level, steps in enumerate(args.wplus_step): | |
if criterion.weights["contextual"] > 0: | |
with torch.no_grad(): | |
# synthesize new sibling image using the current optimization results | |
# FIXME: update rgbs sibling | |
sibling, _, _ = generator([latent], input_is_latent=True, randomize_noise=True) | |
criterion.update_sibling(sibling) | |
coarse_size = (2 ** coarse_level) * args.coarse_min | |
latent_coarse, latent_fine = cls.split_latent( | |
latent, generator.get_latent_size(coarse_size)) | |
parameters = create_parameters(latent_coarse) | |
optimizer = RAdam(parameters) | |
print(f"Optimizing {coarse_size}x{coarse_size}") | |
pbar = tqdm(range(steps)) | |
for si in pbar: | |
latent = torch.cat((latent_coarse, latent_fine), dim=1) | |
niters = si + np.sum(args.wplus_step[:coarse_level]) | |
latent_noisy = noiser(latent, niters / total_steps) | |
img_gen, _, rgbs = generator([latent_noisy], input_is_latent=True, noise=noises) | |
# TODO: use coarse_size instead of args.coarse_size for rgb_level | |
loss, losses = criterion(img_gen, degrade=degrade, noises=noises, rgbs=rgbs) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
NoiseRegularizer.normalize(noises) | |
# log | |
pbar.set_description("; ".join([f"{k}: {v.item(): .3e}" for k, v in losses.items()])) | |
if writer is not None and niters % args.log_freq == 0: | |
cls.log_losses(writer, niters, loss, losses, criterion.weights) | |
cls.log_parameters(writer, niters, degrade.named_parameters()) | |
if writer is not None and niters % args.log_visual_freq == 0: | |
cls.log_visuals(writer, niters, img_gen, target, degraded=degrade(img_gen), rgbs=rgbs) | |
latent = torch.cat((latent_coarse, latent_fine), dim=1).detach() | |
return latent, noises | |
def split_latent(latent: torch.Tensor, coarse_latent_size: int): | |
latent_coarse = latent[:, :coarse_latent_size] | |
latent_coarse.requires_grad = True | |
latent_fine = latent[:, coarse_latent_size:] | |
latent_fine.requires_grad = False | |
return latent_coarse, latent_fine | |
def log_losses( | |
writer: SummaryWriter, | |
niters: int, | |
loss_total: torch.Tensor, | |
losses: Dict[str, torch.Tensor], | |
weights: Optional[Dict[str, torch.Tensor]] = None | |
): | |
writer.add_scalar("loss", loss_total.item(), niters) | |
for name, loss in losses.items(): | |
writer.add_scalar(name, loss.item(), niters) | |
if weights is not None: | |
writer.add_scalar(f"weighted_{name}", weights[name] * loss.item(), niters) | |
def log_parameters( | |
writer: SummaryWriter, | |
niters: int, | |
named_parameters: Iterable[Tuple[str, torch.nn.Parameter]], | |
): | |
for name, para in named_parameters: | |
writer.add_scalar(name, para.item(), niters) | |
def log_visuals( | |
cls, | |
writer: SummaryWriter, | |
niters: int, | |
img: torch.Tensor, | |
target: torch.Tensor, | |
degraded=None, | |
rgbs=None, | |
): | |
if target.shape[-1] != img.shape[-1]: | |
visual = make_grid(img, nrow=1, normalize=True, range=(-1, 1)) | |
writer.add_image("pred", visual, niters) | |
def resize(img): | |
return F.interpolate(img, size=target.shape[2:], mode="area") | |
vis = resize(img) | |
if degraded is not None: | |
vis = torch.cat((resize(degraded), vis), dim=-1) | |
visual = make_grid(torch.cat((target.repeat(1, vis.shape[1] // target.shape[1], 1, 1), vis), dim=-1), nrow=1, normalize=True, range=(-1, 1)) | |
writer.add_image("gnd[-degraded]-pred", visual, niters) | |
# log to rgbs | |
if rgbs is not None: | |
cls.log_torgbs(writer, niters, rgbs) | |
def log_torgbs(writer: SummaryWriter, niters: int, rgbs: Iterable[torch.Tensor], prefix: str = ""): | |
for ri, rgb in enumerate(rgbs): | |
scale = 2 ** (-(len(rgbs) - ri)) | |
visual = make_grid(torch.cat((rgb, rgb / scale), dim=-1), nrow=1, normalize=True, range=(-1, 1)) | |
writer.add_image(f"{prefix}to_rbg_{2 ** (ri + 2)}", visual, niters) | |