AnyControl / models /anycontrol.py
nowsyn's picture
upload codes
54a7220
raw
history blame
No virus
12.4 kB
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import log_txt_as_img, instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from models.q_formers import load_qformer_model
class AnyControlNet(LatentDiffusion):
def __init__(self, mode, qformer_config=None, local_control_config=None, global_control_config=None, *args, **kwargs):
super().__init__(*args, **kwargs)
assert mode in ['local', 'uni']
self.mode = mode
self.qformer_config = qformer_config
self.local_control_config = local_control_config
self.global_control_config = global_control_config
self.model.diffusion_model.requires_grad_(False)
self.model.diffusion_model.requires_grad_(False)
self.model.diffusion_model.requires_grad_(False)
q_former, (vis_processor, txt_processor) = load_qformer_model(qformer_config)
self.q_former = q_former
self.qformer_vis_processor = vis_processor
self.qformer_txt_processor = txt_processor
self.local_adapter = instantiate_from_config(local_control_config)
self.local_control_scales = [1.0] * 13
self.global_adapter = instantiate_from_config(global_control_config) if self.mode == 'uni' else None
self.clip_embeddings_dim = global_control_config.params.clip_embeddings_dim
self.color_in_dim = global_control_config.params.color_in_dim
@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):
# latent and text
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
bs = bs or x.size(0)
shape = self.get_shape(batch, bs)
local_control = self.get_local_conditions_for_vision_encoder(batch, bs)
local_control = local_control.to(memory_format=torch.contiguous_format).float()
global_control = {}
global_conditions = batch['global_conditions'][:bs]
for key in batch['global_conditions'][0].data.keys():
global_cond = torch.stack([torch.Tensor(dc.data[key]) for dc in global_conditions])
global_cond = global_cond.to(self.device).to(memory_format=torch.contiguous_format).float()
global_control[key] = global_cond
conditions = dict(
text=[batch['txt']],
c_crossattn=[c],
local_control=[local_control],
global_control=[global_control],
)
return x, conditions
def apply_model(self, x_noisy, t, cond, local_strength=1.0, content_strength=1.0, color_strength=1.0, *args, **kwargs):
assert isinstance(cond, dict)
diffusion_model = self.model.diffusion_model
cond_txt = torch.cat(cond['c_crossattn'], 1)
text = cond['text'][0]
bs = x_noisy.shape[0]
# extract global control
if self.mode in ['uni']:
content_control, color_control = self.global_adapter(
cond['global_control'][0]['clipembedding'], cond['global_control'][0]['color'])
else:
content_control = torch.zeros(bs, self.clip_embeddings_dim).to(self.device).to(memory_format=torch.contiguous_format).float()
color_control = torch.zeros(bs, self.color_in_dim).to(self.device).to(memory_format=torch.contiguous_format).float()
# extract local control
if self.mode in ['local', 'uni']:
local_features = self.local_adapter.extract_local_features(self.q_former, text, cond['local_control'][0])
local_control = self.local_adapter(x=x_noisy, timesteps=t, context=cond_txt, local_features=local_features)
local_control = [c * scale for c, scale in zip(local_control, self.local_control_scales)]
eps = diffusion_model(
x=x_noisy, timesteps=t, context=cond_txt,
local_control=local_control, local_w=local_strength,
content_control=content_control, extra_w=content_strength,
color_control=color_control, color_w=color_strength)
return eps
@torch.no_grad()
def get_unconditional_conditioning(self, N):
return self.get_learned_conditioning([""] * N)
@torch.no_grad()
def get_unconditional_global_conditioning(self, c):
if isinstance(c, dict):
return {k:torch.zeros_like(v) for k,v in c.items()}
elif isinstance(c, list):
return [torch.zeros_like(v) for v in c]
else:
return torch.zeros_like(c)
@torch.no_grad()
def get_shape(self, batch, N):
return [dc.data[0].shape[:2] for dc in batch['local_conditions'][:N]]
@torch.no_grad()
def get_local_conditions_for_vision_encoder(self, batch, N):
# return: local_conditions, (bs, num_conds * 3, h, w)
local_conditions = []
max_len = max([len(dc.data) for dc in batch['local_conditions'][:N]])
for dc in batch['local_conditions'][:N]:
conds = torch.cat([self.qformer_vis_processor['eval'](Image.fromarray(img)).unsqueeze(0) for img in dc.data], dim=1)
local_conditions.append(conds)
local_conditions = [F.pad(cond, (0,0,0,0,0,max_len*3-cond.shape[1],0,0)) for cond in local_conditions]
local_conditions = torch.cat(local_conditions, dim=0).to(self.device)
return local_conditions
@torch.no_grad()
def get_local_conditions_for_logging(self, batch, N):
local_conditions = []
max_len = max([len(dc.data) for dc in batch['local_conditions'][:N]])
for dc in batch['local_conditions'][:N]:
conds = torch.stack([torch.Tensor(img).permute(2,0,1) for img in dc.data], dim=0) # (n, c, h, w)
conds = conds.float() / 255.
conds = conds * 2.0 - 1.0
local_conditions.append(conds)
local_conditions = [F.pad(cond, (0,0,0,0,0,0,0,max_len-cond.shape[0])) for cond in local_conditions]
local_conditions = torch.stack(local_conditions, dim=0).to(self.device) # (bs, n, c, h, w)
local_conditions = local_conditions.flatten(1,2)
return local_conditions
def clip_batch(self, batch, key, N, flag=True):
if isinstance(batch, torch.Tensor):
return batch[:N]
elif isinstance(batch, list):
return batch[:N]
batch = batch[key][0] if flag else batch[key]
if isinstance(batch, torch.Tensor):
return batch[:N]
elif isinstance(batch, list):
return batch[:N]
elif isinstance(batch, dict):
return {k:self.clip_batch(v,'',N,flag=False) for k,v in batch.items()}
else:
raise ValueError(f'Unsupported type {type(batch)}')
@torch.no_grad()
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0,
plot_denoise_rows=False, plot_diffusion_rows=False, unconditional_guidance_scale=9.0, **kwargs):
use_ddim = ddim_steps is not None
log = dict()
z, c = self.get_input(batch, self.first_stage_key, bs=N)
shape = self.get_shape(batch, N)
c_local = self.clip_batch(c, "local_control", N)
c_global = self.clip_batch(c, "global_control", N)
c_context = self.clip_batch(c, "c_crossattn", N)
c_text = self.clip_batch(batch, self.cond_stage_key, N, False)
N = min(z.shape[0], N)
n_row = min(z.shape[0], n_row)
log["reconstruction"] = self.decode_first_stage(z)
log["conditioning"] = log_txt_as_img((512, 512), c_text, size=16)
log["local_control"] = self.get_local_conditions_for_logging(batch, N)
if plot_diffusion_rows:
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row)
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
cond_dict = dict(
local_control=[c_local],
global_control=[c_global],
c_crossattn=[c_context],
text=[c_text],
shape=[shape],
)
if sample:
samples, z_denoise_row = self.sample_log(cond=cond_dict,
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
log_every_t=self.log_every_t * 0.05)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
if isinstance(z_denoise_row, dict):
for key in ['pred_x0', 'x_inter']:
z_denoise_row_key = z_denoise_row[key]
denoise_grid = self._get_denoise_row_from_list(z_denoise_row_key)
log[f"denoise_row_{key}"] = denoise_grid
else:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_context = self.get_unconditional_conditioning(N)
uc_global = self.get_unconditional_global_conditioning(c_global)
uc_local = c_local
uc_text = c_text
uncond_dict = dict(
local_control=[uc_local],
global_control=[uc_global],
c_crossattn=[uc_context],
text=[uc_text],
shape=[shape]
)
samples_cfg, _ = self.sample_log(cond=cond_dict,
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uncond_dict,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
ddim_sampler = DDIMSampler(self)
if cond['shape'] is None:
h, w = 512, 512
else:
h, w = cond["shape"][0][0]
shape = (self.channels, h // 8, w // 8)
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
return samples, intermediates
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.q_former.parameters()) + list(self.local_adapter.parameters())
if not self.sd_locked:
params += list(self.model.diffusion_model.output_blocks.parameters())
params += list(self.model.diffusion_model.out.parameters())
opt = torch.optim.AdamW(params, lr=lr)
return opt
def low_vram_shift(self, is_diffusing):
if is_diffusing:
self.model = self.model.cuda()
self.local_adapter = self.local_adapter.cuda()
self.first_stage_model = self.first_stage_model.cpu()
self.cond_stage_model = self.cond_stage_model.cpu()
else:
self.model = self.model.cpu()
self.local_adapter = self.local_adapter.cpu()
self.first_stage_model = self.first_stage_model.cuda()
self.cond_stage_model = self.cond_stage_model.cuda()