|
from torch import nn |
|
from k_diffusion import utils as k_utils |
|
import torch |
|
from k_diffusion.external import CompVisDenoiser |
|
from torchvision.utils import make_grid |
|
from IPython import display |
|
from torchvision.transforms.functional import to_pil_image |
|
|
|
class CFGDenoiser(nn.Module): |
|
def __init__(self, model): |
|
super().__init__() |
|
self.inner_model = model |
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale): |
|
x_in = torch.cat([x] * 2) |
|
sigma_in = torch.cat([sigma] * 2) |
|
cond_in = torch.cat([uncond, cond]) |
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) |
|
return uncond + (cond - uncond) * cond_scale |
|
|
|
class CFGDenoiserWithGrad(CompVisDenoiser): |
|
def __init__(self, model, |
|
loss_fns_scales, |
|
clamp_func=None, |
|
gradient_wrt=None, |
|
gradient_add_to=None, |
|
cond_uncond_sync=True, |
|
decode_method=None, |
|
grad_inject_timing_fn=None, |
|
grad_consolidate_fn=None, |
|
verbose=False): |
|
super().__init__(model.inner_model) |
|
self.inner_model = model |
|
self.cond_uncond_sync = cond_uncond_sync |
|
|
|
|
|
self.clamp_func = clamp_func |
|
self.gradient_add_to = gradient_add_to |
|
if gradient_wrt is None: |
|
self.gradient_wrt = 'x' |
|
self.gradient_wrt = gradient_wrt |
|
if decode_method is None: |
|
decode_fn = lambda x: x |
|
elif decode_method == "autoencoder": |
|
decode_fn = model.inner_model.differentiable_decode_first_stage |
|
elif decode_method == "linear": |
|
decode_fn = model.inner_model.linear_decode |
|
self.decode_fn = decode_fn |
|
|
|
|
|
cond_fns = [] |
|
for loss_fn,scale in loss_fns_scales: |
|
if scale != 0: |
|
cond_fn = self.make_cond_fn(loss_fn, scale) |
|
else: |
|
cond_fn = None |
|
cond_fns += [cond_fn] |
|
self.cond_fns = cond_fns |
|
|
|
if grad_inject_timing_fn is None: |
|
self.grad_inject_timing_fn = lambda sigma: True |
|
else: |
|
self.grad_inject_timing_fn = grad_inject_timing_fn |
|
if grad_consolidate_fn is None: |
|
self.grad_consolidate_fn = lambda img, grad, sigma: img + grad * sigma |
|
else: |
|
self.grad_consolidate_fn = grad_consolidate_fn |
|
|
|
self.verbose = verbose |
|
self.verbose_print = print if self.verbose else lambda *args, **kwargs: None |
|
|
|
|
|
|
|
def cond_model_fn_(self, x, sigma, inner_model=None, **kwargs): |
|
|
|
|
|
if inner_model is None: |
|
inner_model = self.inner_model |
|
|
|
total_cond_grad = torch.zeros_like(x) |
|
for cond_fn in self.cond_fns: |
|
if cond_fn is None: continue |
|
|
|
|
|
if self.gradient_wrt == 'x': |
|
with torch.enable_grad(): |
|
x = x.detach().requires_grad_() |
|
denoised = inner_model(x, sigma, **kwargs) |
|
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() |
|
|
|
|
|
elif self.gradient_wrt == 'x0_pred': |
|
with torch.no_grad(): |
|
denoised = inner_model(x, sigma, **kwargs) |
|
with torch.enable_grad(): |
|
cond_grad = cond_fn(x, sigma, denoised=denoised.detach().requires_grad_(), **kwargs).detach() |
|
total_cond_grad += cond_grad |
|
|
|
total_cond_grad = torch.nan_to_num(total_cond_grad, nan=0.0, posinf=float('inf'), neginf=-float('inf')) |
|
|
|
|
|
total_cond_grad = self.clamp_grad_verbose(total_cond_grad, sigma) |
|
|
|
|
|
if self.gradient_wrt == 'x': |
|
x.copy_(self.grad_consolidate_fn(x.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim))) |
|
cond_denoised = inner_model(x, sigma, **kwargs) |
|
elif self.gradient_wrt == 'x0_pred': |
|
x.copy_(self.grad_consolidate_fn(x.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim))) |
|
cond_denoised = self.grad_consolidate_fn(denoised.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim)) |
|
|
|
return cond_denoised |
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale): |
|
|
|
def _cfg_model(x, sigma, cond, **kwargs): |
|
|
|
|
|
x_in = torch.cat([x] * 2) |
|
sigma_in = torch.cat([sigma] * 2) |
|
|
|
denoised = self.inner_model(x_in, sigma_in, cond=cond, **kwargs) |
|
uncond_x0, cond_x0 = denoised.chunk(2) |
|
x0_pred = uncond_x0 + (cond_x0 - uncond_x0) * cond_scale |
|
return x0_pred |
|
|
|
|
|
if self.check_conditioning_schedule(sigma): |
|
|
|
if self.cond_uncond_sync: |
|
|
|
cond_in = torch.cat([uncond, cond]) |
|
x0 = self.cond_model_fn_(x, sigma, cond=cond_in, inner_model=_cfg_model) |
|
|
|
|
|
else: |
|
if self.gradient_add_to == "uncond": |
|
uncond = self.cond_model_fn_(x, sigma, cond=uncond) |
|
cond = self.inner_model(x, sigma, cond=cond) |
|
x0 = uncond + (cond - uncond) * cond_scale |
|
elif self.gradient_add_to == "cond": |
|
uncond = self.inner_model(x, sigma, cond=uncond) |
|
cond = self.cond_model_fn_(x, sigma, cond=cond) |
|
x0 = uncond + (cond - uncond) * cond_scale |
|
elif self.gradient_add_to == "both": |
|
uncond = self.cond_model_fn_(x, sigma, cond=uncond) |
|
cond = self.cond_model_fn_(x, sigma, cond=cond) |
|
x0 = uncond + (cond - uncond) * cond_scale |
|
else: |
|
raise Exception(f"Unrecognised option for gradient_add_to: {self.gradient_add_to}") |
|
|
|
|
|
else: |
|
|
|
if self.cond_uncond_sync: |
|
cond_in = torch.cat([uncond, cond]) |
|
x0 = _cfg_model(x, sigma, cond=cond_in) |
|
else: |
|
uncond = self.inner_model(x, sigma, cond=uncond) |
|
cond = self.inner_model(x, sigma, cond=cond) |
|
x0 = uncond + (cond - uncond) * cond_scale |
|
|
|
return x0 |
|
|
|
def make_cond_fn(self, loss_fn, scale): |
|
|
|
|
|
|
|
|
|
|
|
def cond_fn(x, sigma, denoised, **kwargs): |
|
with torch.enable_grad(): |
|
denoised_sample = self.decode_fn(denoised).requires_grad_() |
|
loss = loss_fn(denoised_sample, sigma, **kwargs) * scale |
|
grad = -torch.autograd.grad(loss, x)[0] |
|
self.verbose_print('Loss:', loss.item()) |
|
return grad |
|
|
|
|
|
def cond_fn_pred(x, sigma, denoised, **kwargs): |
|
with torch.enable_grad(): |
|
denoised_sample = self.decode_fn(denoised).requires_grad_() |
|
loss = loss_fn(denoised_sample, sigma, **kwargs) * scale |
|
grad = -torch.autograd.grad(loss, denoised)[0] |
|
self.verbose_print('Loss:', loss.item()) |
|
return grad |
|
|
|
if self.gradient_wrt == 'x': |
|
return cond_fn |
|
elif self.gradient_wrt == 'x0_pred': |
|
return cond_fn_pred |
|
else: |
|
raise Exception(f"Variable gradient_wrt == {self.gradient_wrt} not recognised.") |
|
|
|
def clamp_grad_verbose(self, grad, sigma): |
|
if self.clamp_func is not None: |
|
if self.verbose: |
|
print("Grad before clamping:") |
|
self.display_samples(torch.abs(grad*2.0) - 1.0) |
|
grad = self.clamp_func(grad, sigma) |
|
if self.verbose: |
|
print("Conditioning gradient") |
|
self.display_samples(torch.abs(grad*2.0) - 1.0) |
|
return grad |
|
|
|
def check_conditioning_schedule(self, sigma): |
|
is_conditioning_step = False |
|
|
|
if (self.cond_fns is not None and |
|
any(cond_fn is not None for cond_fn in self.cond_fns)): |
|
|
|
|
|
if self.grad_inject_timing_fn(sigma): |
|
is_conditioning_step = True |
|
|
|
if self.verbose: |
|
print(f"Conditioning step for sigma={sigma}") |
|
|
|
return is_conditioning_step |
|
|
|
def display_samples(self, images): |
|
images = images.double().cpu().add(1).div(2).clamp(0, 1) |
|
images = torch.tensor(images.numpy()) |
|
grid = make_grid(images, 4).cpu() |
|
display.display(to_pil_image(grid)) |
|
return |
|
|