|
from types import MethodType |
|
from typing import Optional |
|
|
|
from diffusers.models.attention_processor import Attention |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from .feature import * |
|
from .utils import * |
|
|
|
|
|
def get_control_config(structure_schedule, appearance_schedule): |
|
s = structure_schedule |
|
a = appearance_schedule |
|
|
|
control_config =\ |
|
f"""control_schedule: |
|
# structure_conv structure_attn appearance_attn conv/attn |
|
encoder: # (num layers) |
|
0: [[ ], [ ], [ ]] # 2/0 |
|
1: [[ ], [ ], [{a}, {a} ]] # 2/2 |
|
2: [[ ], [ ], [{a}, {a} ]] # 2/2 |
|
middle: [[ ], [ ], [ ]] # 2/1 |
|
decoder: |
|
0: [[{s} ], [{s}, {s}, {s}], [0.0, {a}, {a}]] # 3/3 |
|
1: [[ ], [ ], [{a}, {a} ]] # 3/3 |
|
2: [[ ], [ ], [ ]] # 3/0 |
|
|
|
control_target: |
|
- [output_tensor] # structure_conv choices: {{hidden_states, output_tensor}} |
|
- [query, key] # structure_attn choices: {{query, key, value}} |
|
- [before] # appearance_attn choices: {{before, value, after}} |
|
|
|
self_recurrence_schedule: |
|
- [0.1, 0.5, 2] # format: [start, end, num_recurrence]""" |
|
|
|
return control_config |
|
|
|
|
|
def convolution_forward( |
|
self, |
|
input_tensor: torch.Tensor, |
|
temb: torch.Tensor, |
|
*args, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
do_structure_control = self.do_control and self.t in self.structure_schedule |
|
|
|
hidden_states = input_tensor |
|
|
|
hidden_states = self.norm1(hidden_states) |
|
hidden_states = self.nonlinearity(hidden_states) |
|
|
|
if self.upsample is not None: |
|
|
|
if hidden_states.shape[0] >= 64: |
|
input_tensor = input_tensor.contiguous() |
|
hidden_states = hidden_states.contiguous() |
|
input_tensor = self.upsample(input_tensor) |
|
hidden_states = self.upsample(hidden_states) |
|
elif self.downsample is not None: |
|
input_tensor = self.downsample(input_tensor) |
|
hidden_states = self.downsample(hidden_states) |
|
|
|
hidden_states = self.conv1(hidden_states) |
|
|
|
if self.time_emb_proj is not None: |
|
if not self.skip_time_act: |
|
temb = self.nonlinearity(temb) |
|
temb = self.time_emb_proj(temb)[:, :, None, None] |
|
|
|
if self.time_embedding_norm == "default": |
|
if temb is not None: |
|
hidden_states = hidden_states + temb |
|
hidden_states = self.norm2(hidden_states) |
|
elif self.time_embedding_norm == "scale_shift": |
|
if temb is None: |
|
raise ValueError( |
|
f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" |
|
) |
|
time_scale, time_shift = torch.chunk(temb, 2, dim=1) |
|
hidden_states = self.norm2(hidden_states) |
|
hidden_states = hidden_states * (1 + time_scale) + time_shift |
|
else: |
|
hidden_states = self.norm2(hidden_states) |
|
|
|
hidden_states = self.nonlinearity(hidden_states) |
|
|
|
hidden_states = self.dropout(hidden_states) |
|
hidden_states = self.conv2(hidden_states) |
|
|
|
|
|
if do_structure_control and "hidden_states" in self.structure_target: |
|
hidden_states = feature_injection(hidden_states, batch_order=self.batch_order) |
|
|
|
if self.conv_shortcut is not None: |
|
input_tensor = self.conv_shortcut(input_tensor) |
|
|
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor |
|
|
|
|
|
if do_structure_control and "output_tensor" in self.structure_target: |
|
output_tensor = feature_injection(output_tensor, batch_order=self.batch_order) |
|
|
|
return output_tensor |
|
|
|
|
|
class AttnProcessor2_0: |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
temb: Optional[torch.FloatTensor] = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.FloatTensor: |
|
do_structure_control = attn.do_control and attn.t in attn.structure_schedule |
|
do_appearance_control = attn.do_control and attn.t in attn.appearance_schedule |
|
|
|
residual = hidden_states |
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
no_encoder_hidden_states = encoder_hidden_states is None |
|
if no_encoder_hidden_states: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
if do_appearance_control: |
|
hidden_states_normed = normalize(hidden_states, dim=-2) |
|
encoder_hidden_states_normed = normalize(encoder_hidden_states, dim=-2) |
|
|
|
query_normed = attn.to_q(hidden_states_normed) |
|
key_normed = attn.to_k(encoder_hidden_states_normed) |
|
|
|
inner_dim = key_normed.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
query_normed = query_normed.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key_normed = key_normed.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
if do_structure_control: |
|
if "query" in attn.structure_target: |
|
query_normed = feature_injection(query_normed, batch_order=attn.batch_order) |
|
if "key" in attn.structure_target: |
|
key_normed = feature_injection(key_normed, batch_order=attn.batch_order) |
|
|
|
|
|
if do_appearance_control and "before" in attn.appearance_target: |
|
hidden_states = hidden_states.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
hidden_states = appearance_transfer(hidden_states, query_normed, key_normed, batch_order=attn.batch_order) |
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
|
|
if no_encoder_hidden_states: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
if do_structure_control: |
|
if "query" in attn.structure_target: |
|
query = feature_injection(query, batch_order=attn.batch_order) |
|
if "key" in attn.structure_target: |
|
key = feature_injection(key, batch_order=attn.batch_order) |
|
if "value" in attn.structure_target: |
|
value = feature_injection(value, batch_order=attn.batch_order) |
|
|
|
|
|
if do_appearance_control and "value" in attn.appearance_target: |
|
value = appearance_transfer(value, query_normed, key_normed, batch_order=attn.batch_order) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
|
|
if do_appearance_control and "after" in attn.appearance_target: |
|
hidden_states = appearance_transfer(hidden_states, query_normed, key_normed, batch_order=attn.batch_order) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states, *args) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
def register_control( |
|
model, |
|
timesteps, |
|
control_schedule, |
|
control_target = [["output_tensor"], ["query", "key"], ["before"]], |
|
): |
|
|
|
for block_type in ["encoder", "decoder", "middle"]: |
|
blocks = { |
|
"encoder": model.unet.down_blocks, |
|
"decoder": model.unet.up_blocks, |
|
"middle": [model.unet.mid_block], |
|
}[block_type] |
|
|
|
control_schedule_block = control_schedule[block_type] |
|
if block_type == "middle": |
|
control_schedule_block = [control_schedule_block] |
|
|
|
for layer in range(len(control_schedule_block)): |
|
|
|
num_blocks = len(blocks[layer].resnets) if hasattr(blocks[layer], "resnets") else 0 |
|
for block in range(num_blocks): |
|
convolution = blocks[layer].resnets[block] |
|
convolution.structure_target = control_target[0] |
|
convolution.structure_schedule = get_schedule( |
|
timesteps, get_elem(control_schedule_block[layer][0], block) |
|
) |
|
convolution.forward = MethodType(convolution_forward, convolution) |
|
|
|
|
|
num_blocks = len(blocks[layer].attentions) if hasattr(blocks[layer], "attentions") else 0 |
|
for block in range(num_blocks): |
|
for transformer_block in blocks[layer].attentions[block].transformer_blocks: |
|
attention = transformer_block.attn1 |
|
attention.structure_target = control_target[1] |
|
attention.structure_schedule = get_schedule( |
|
timesteps, get_elem(control_schedule_block[layer][1], block) |
|
) |
|
attention.appearance_target = control_target[2] |
|
attention.appearance_schedule = get_schedule( |
|
timesteps, get_elem(control_schedule_block[layer][2], block) |
|
) |
|
attention.processor = AttnProcessor2_0() |
|
|
|
|
|
def register_attr(model, t, do_control, batch_order): |
|
for layer_type in ["encoder", "decoder", "middle"]: |
|
blocks = {"encoder": model.unet.down_blocks, "decoder": model.unet.up_blocks, |
|
"middle": [model.unet.mid_block]}[layer_type] |
|
for layer in blocks: |
|
|
|
for module in layer.resnets: |
|
module.t = t |
|
module.do_control = do_control |
|
module.batch_order = batch_order |
|
|
|
if hasattr(layer, "attentions"): |
|
for block in layer.attentions: |
|
for module in block.transformer_blocks: |
|
module.attn1.t = t |
|
module.attn1.do_control = do_control |
|
module.attn1.batch_order = batch_order |
|
|