feizhengcong's picture
Upload 198 files
history blame
10.2 kB
from torch import nn
from k_diffusion import utils as k_utils
import torch
from k_diffusion.external import CompVisDenoiser
from torchvision.utils import make_grid
from IPython import display
from torchvision.transforms.functional import to_pil_image
class CFGDenoiser(nn.Module):
def __init__(self, model):
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class CFGDenoiserWithGrad(CompVisDenoiser):
def __init__(self, model,
loss_fns_scales, # List of [cond_function, scale] pairs
clamp_func=None, # Gradient clamping function, clamp_func(grad, sigma)
gradient_wrt=None, # Calculate gradient with respect to ["x", "x0_pred", "both"]
gradient_add_to=None, # Add gradient to ["cond", "uncond", "both"]
cond_uncond_sync=True, # Calculates the cond and uncond simultaneously
decode_method=None, # Function used to decode the latent during gradient calculation
grad_inject_timing_fn=None, # Option to use grad in only a few of the steps
grad_consolidate_fn=None, # Function to add grad to image fn(img, grad, sigma)
self.inner_model = model
self.cond_uncond_sync = cond_uncond_sync
# Initialize gradient calculation variables
self.clamp_func = clamp_func
self.gradient_add_to = gradient_add_to
if gradient_wrt is None:
self.gradient_wrt = 'x'
self.gradient_wrt = gradient_wrt
if decode_method is None:
decode_fn = lambda x: x
elif decode_method == "autoencoder":
decode_fn = model.inner_model.differentiable_decode_first_stage
elif decode_method == "linear":
decode_fn = model.inner_model.linear_decode
self.decode_fn = decode_fn
# Parse loss function-scale pairs
cond_fns = []
for loss_fn,scale in loss_fns_scales:
if scale != 0:
cond_fn = self.make_cond_fn(loss_fn, scale)
cond_fn = None
cond_fns += [cond_fn]
self.cond_fns = cond_fns
if grad_inject_timing_fn is None:
self.grad_inject_timing_fn = lambda sigma: True
self.grad_inject_timing_fn = grad_inject_timing_fn
if grad_consolidate_fn is None:
self.grad_consolidate_fn = lambda img, grad, sigma: img + grad * sigma
self.grad_consolidate_fn = grad_consolidate_fn
self.verbose = verbose
self.verbose_print = print if self.verbose else lambda *args, **kwargs: None
# General denoising model with gradient conditioning
def cond_model_fn_(self, x, sigma, inner_model=None, **kwargs):
# inner_model: optionally use a different inner_model function or a wrapper function around inner_model, see self.forward._cfg_model
if inner_model is None:
inner_model = self.inner_model
total_cond_grad = torch.zeros_like(x)
for cond_fn in self.cond_fns:
if cond_fn is None: continue
# Gradient with respect to x
if self.gradient_wrt == 'x':
with torch.enable_grad():
x = x.detach().requires_grad_()
denoised = inner_model(x, sigma, **kwargs)
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
# Gradient wrt x0_pred, so save some compute: don't record grad until after denoised is calculated
elif self.gradient_wrt == 'x0_pred':
with torch.no_grad():
denoised = inner_model(x, sigma, **kwargs)
with torch.enable_grad():
cond_grad = cond_fn(x, sigma, denoised=denoised.detach().requires_grad_(), **kwargs).detach()
total_cond_grad += cond_grad
total_cond_grad = torch.nan_to_num(total_cond_grad, nan=0.0, posinf=float('inf'), neginf=-float('inf'))
# Clamp the gradient
total_cond_grad = self.clamp_grad_verbose(total_cond_grad, sigma)
# Add gradient to the image
if self.gradient_wrt == 'x':
x.copy_(self.grad_consolidate_fn(x.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim)))
cond_denoised = inner_model(x, sigma, **kwargs)
elif self.gradient_wrt == 'x0_pred':
x.copy_(self.grad_consolidate_fn(x.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim)))
cond_denoised = self.grad_consolidate_fn(denoised.detach(), total_cond_grad, k_utils.append_dims(sigma, x.ndim))
return cond_denoised
def forward(self, x, sigma, uncond, cond, cond_scale):
def _cfg_model(x, sigma, cond, **kwargs):
# Wrapper to add denoised cond and uncond as in a cfg model
# input "cond" is both cond and uncond weights: torch.cat([uncond, cond])
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
denoised = self.inner_model(x_in, sigma_in, cond=cond, **kwargs)
uncond_x0, cond_x0 = denoised.chunk(2)
x0_pred = uncond_x0 + (cond_x0 - uncond_x0) * cond_scale
return x0_pred
# Conditioning
if self.check_conditioning_schedule(sigma):
# Apply the conditioning gradient to the completed denoised (after both cond and uncond are combined into the diffused image)
if self.cond_uncond_sync:
# x0 = self.cfg_cond_model_fn_(x, sigma, uncond=uncond, cond=cond, cond_scale=cond_scale)
cond_in = torch.cat([uncond, cond])
x0 = self.cond_model_fn_(x, sigma, cond=cond_in, inner_model=_cfg_model)
# Calculate cond and uncond separately
if self.gradient_add_to == "uncond":
uncond = self.cond_model_fn_(x, sigma, cond=uncond)
cond = self.inner_model(x, sigma, cond=cond)
x0 = uncond + (cond - uncond) * cond_scale
elif self.gradient_add_to == "cond":
uncond = self.inner_model(x, sigma, cond=uncond)
cond = self.cond_model_fn_(x, sigma, cond=cond)
x0 = uncond + (cond - uncond) * cond_scale
elif self.gradient_add_to == "both":
uncond = self.cond_model_fn_(x, sigma, cond=uncond)
cond = self.cond_model_fn_(x, sigma, cond=cond)
x0 = uncond + (cond - uncond) * cond_scale
raise Exception(f"Unrecognised option for gradient_add_to: {self.gradient_add_to}")
# No conditioning
# calculate cond and uncond simultaneously
if self.cond_uncond_sync:
cond_in = torch.cat([uncond, cond])
x0 = _cfg_model(x, sigma, cond=cond_in)
uncond = self.inner_model(x, sigma, cond=uncond)
cond = self.inner_model(x, sigma, cond=cond)
x0 = uncond + (cond - uncond) * cond_scale
return x0
def make_cond_fn(self, loss_fn, scale):
# Turns a loss function into a cond function that is applied to the decoded RGB sample
# loss_fn (function): func(x, sigma, denoised) -> number
# scale (number): how much this loss is applied to the image
# Cond function with respect to x
def cond_fn(x, sigma, denoised, **kwargs):
with torch.enable_grad():
denoised_sample = self.decode_fn(denoised).requires_grad_()
loss = loss_fn(denoised_sample, sigma, **kwargs) * scale
grad = -torch.autograd.grad(loss, x)[0]
self.verbose_print('Loss:', loss.item())
return grad
# Cond function with respect to x0_pred
def cond_fn_pred(x, sigma, denoised, **kwargs):
with torch.enable_grad():
denoised_sample = self.decode_fn(denoised).requires_grad_()
loss = loss_fn(denoised_sample, sigma, **kwargs) * scale
grad = -torch.autograd.grad(loss, denoised)[0]
self.verbose_print('Loss:', loss.item())
return grad
if self.gradient_wrt == 'x':
return cond_fn
elif self.gradient_wrt == 'x0_pred':
return cond_fn_pred
raise Exception(f"Variable gradient_wrt == {self.gradient_wrt} not recognised.")
def clamp_grad_verbose(self, grad, sigma):
if self.clamp_func is not None:
if self.verbose:
print("Grad before clamping:")
self.display_samples(torch.abs(grad*2.0) - 1.0)
grad = self.clamp_func(grad, sigma)
if self.verbose:
print("Conditioning gradient")
self.display_samples(torch.abs(grad*2.0) - 1.0)
return grad
def check_conditioning_schedule(self, sigma):
is_conditioning_step = False
if (self.cond_fns is not None and
any(cond_fn is not None for cond_fn in self.cond_fns)):
# Conditioning strength != 0
# Check if this is a conditioning step
if self.grad_inject_timing_fn(sigma):
is_conditioning_step = True
if self.verbose:
print(f"Conditioning step for sigma={sigma}")
return is_conditioning_step
def display_samples(self, images):
images = images.double().cpu().add(1).div(2).clamp(0, 1)
images = torch.tensor(images.numpy())
grid = make_grid(images, 4).cpu()