|
|
|
from typing import Any, Dict, Optional |
|
|
|
import torch |
|
from einops import rearrange |
|
|
|
from .attention import BasicTransformerBlock |
|
from .attention import TemporalBasicTransformerBlock |
|
|
|
|
|
def torch_dfs(model: torch.nn.Module): |
|
result = [model] |
|
for child in model.children(): |
|
result += torch_dfs(child) |
|
return result |
|
|
|
|
|
class ReferenceAttentionControl: |
|
def __init__( |
|
self, |
|
unet, |
|
mode="write", |
|
do_classifier_free_guidance=False, |
|
attention_auto_machine_weight=float("inf"), |
|
gn_auto_machine_weight=1.0, |
|
style_fidelity=1.0, |
|
reference_attn=True, |
|
reference_adain=False, |
|
fusion_blocks="midup", |
|
batch_size=1, |
|
reference_attention_weight=1., |
|
audio_attention_weight=1., |
|
) -> None: |
|
|
|
self.unet = unet |
|
assert mode in ["read", "write"] |
|
assert fusion_blocks in ["midup", "full"] |
|
self.reference_attn = reference_attn |
|
self.reference_adain = reference_adain |
|
self.fusion_blocks = fusion_blocks |
|
self.reference_attention_weight = reference_attention_weight |
|
self.audio_attention_weight = audio_attention_weight |
|
self.register_reference_hooks( |
|
mode, |
|
do_classifier_free_guidance, |
|
attention_auto_machine_weight, |
|
gn_auto_machine_weight, |
|
style_fidelity, |
|
reference_attn, |
|
reference_adain, |
|
fusion_blocks, |
|
batch_size=batch_size, |
|
) |
|
|
|
def register_reference_hooks( |
|
self, |
|
mode, |
|
do_classifier_free_guidance, |
|
attention_auto_machine_weight, |
|
gn_auto_machine_weight, |
|
style_fidelity, |
|
reference_attn, |
|
reference_adain, |
|
dtype=torch.float16, |
|
batch_size=1, |
|
num_images_per_prompt=1, |
|
device=torch.device("cpu"), |
|
fusion_blocks="midup", |
|
): |
|
MODE = mode |
|
do_classifier_free_guidance = do_classifier_free_guidance |
|
attention_auto_machine_weight = attention_auto_machine_weight |
|
gn_auto_machine_weight = gn_auto_machine_weight |
|
style_fidelity = style_fidelity |
|
reference_attn = reference_attn |
|
reference_adain = reference_adain |
|
fusion_blocks = fusion_blocks |
|
num_images_per_prompt = num_images_per_prompt |
|
reference_attention_weight = self.reference_attention_weight |
|
audio_attention_weight = self.audio_attention_weight |
|
dtype = dtype |
|
if do_classifier_free_guidance: |
|
uc_mask = ( |
|
torch.Tensor( |
|
[1] * batch_size * num_images_per_prompt * 16 |
|
+ [0] * batch_size * num_images_per_prompt * 16 |
|
) |
|
.to(device) |
|
.bool() |
|
) |
|
else: |
|
uc_mask = ( |
|
torch.Tensor([0] * batch_size * num_images_per_prompt * 2) |
|
.to(device) |
|
.bool() |
|
) |
|
|
|
def hacked_basic_transformer_inner_forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
class_labels: Optional[torch.LongTensor] = None, |
|
video_length=None, |
|
): |
|
if self.use_ada_layer_norm: |
|
norm_hidden_states = self.norm1(hidden_states, timestep) |
|
elif self.use_ada_layer_norm_zero: |
|
( |
|
norm_hidden_states, |
|
gate_msa, |
|
shift_mlp, |
|
scale_mlp, |
|
gate_mlp, |
|
) = self.norm1( |
|
hidden_states, |
|
timestep, |
|
class_labels, |
|
hidden_dtype=hidden_states.dtype, |
|
) |
|
else: |
|
norm_hidden_states = self.norm1(hidden_states) |
|
|
|
|
|
|
|
cross_attention_kwargs = ( |
|
cross_attention_kwargs if cross_attention_kwargs is not None else {} |
|
) |
|
if self.only_cross_attention: |
|
attn_output = self.attn1( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states |
|
if self.only_cross_attention |
|
else None, |
|
attention_mask=attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
else: |
|
if MODE == "write": |
|
attn_output = self.attn1( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states |
|
if self.only_cross_attention |
|
else None, |
|
attention_mask=attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
|
|
if self.use_ada_layer_norm_zero: |
|
attn_output = gate_msa.unsqueeze(1) * attn_output |
|
hidden_states = attn_output + hidden_states |
|
|
|
if self.attn2 is not None: |
|
norm_hidden_states = ( |
|
self.norm2(hidden_states, timestep) |
|
if self.use_ada_layer_norm |
|
else self.norm2(hidden_states) |
|
) |
|
self.bank.append(norm_hidden_states.clone()) |
|
|
|
|
|
attn_output = self.attn2( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
hidden_states = attn_output + hidden_states |
|
|
|
if MODE == "read": |
|
hidden_states = ( |
|
self.attn1( |
|
norm_hidden_states, |
|
encoder_hidden_states=norm_hidden_states, |
|
attention_mask=attention_mask, |
|
) |
|
+ hidden_states |
|
) |
|
|
|
if self.use_ada_layer_norm: |
|
norm_hidden_states = self.norm1_5(hidden_states, timestep) |
|
elif self.use_ada_layer_norm_zero: |
|
( |
|
norm_hidden_states, |
|
gate_msa, |
|
shift_mlp, |
|
scale_mlp, |
|
gate_mlp, |
|
) = self.norm1_5( |
|
hidden_states, |
|
timestep, |
|
class_labels, |
|
hidden_dtype=hidden_states.dtype, |
|
) |
|
else: |
|
norm_hidden_states = self.norm1_5(hidden_states) |
|
|
|
bank_fea = [] |
|
for d in self.bank: |
|
if len(d.shape) == 3: |
|
d = d.unsqueeze(1).repeat(1, video_length, 1, 1) |
|
bank_fea.append(rearrange(d, "b t l c -> (b t) l c")) |
|
|
|
attn_hidden_states = self.attn1_5( |
|
norm_hidden_states, |
|
encoder_hidden_states=bank_fea[0], |
|
attention_mask=attention_mask, |
|
) |
|
|
|
if reference_attention_weight != 1.: |
|
attn_hidden_states *= reference_attention_weight |
|
|
|
hidden_states = (attn_hidden_states + hidden_states) |
|
|
|
|
|
if self.attn2 is not None: |
|
|
|
norm_hidden_states = ( |
|
self.norm2(hidden_states, timestep) |
|
if self.use_ada_layer_norm |
|
else self.norm2(hidden_states) |
|
) |
|
|
|
attn_hidden_states = self.attn2( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
if audio_attention_weight != 1.: |
|
attn_hidden_states *= audio_attention_weight |
|
|
|
hidden_states = (attn_hidden_states + hidden_states) |
|
|
|
|
|
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states |
|
|
|
|
|
if self.unet_use_temporal_attention: |
|
d = hidden_states.shape[1] |
|
hidden_states = rearrange( |
|
hidden_states, "(b f) d c -> (b d) f c", f=video_length |
|
) |
|
norm_hidden_states = ( |
|
self.norm_temp(hidden_states, timestep) |
|
if self.use_ada_layer_norm |
|
else self.norm_temp(hidden_states) |
|
) |
|
hidden_states = ( |
|
self.attn_temp(norm_hidden_states) + hidden_states |
|
) |
|
hidden_states = rearrange( |
|
hidden_states, "(b d) f c -> (b f) d c", d=d |
|
) |
|
|
|
return hidden_states |
|
|
|
|
|
norm_hidden_states = self.norm3(hidden_states) |
|
|
|
if self.use_ada_layer_norm_zero: |
|
norm_hidden_states = ( |
|
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
) |
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
if self.use_ada_layer_norm_zero: |
|
ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
|
|
hidden_states = ff_output + hidden_states |
|
|
|
return hidden_states |
|
|
|
if self.reference_attn: |
|
if self.fusion_blocks == "midup": |
|
attn_modules = [ |
|
module |
|
for module in ( |
|
torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) |
|
) |
|
if isinstance(module, BasicTransformerBlock) |
|
or isinstance(module, TemporalBasicTransformerBlock) |
|
] |
|
elif self.fusion_blocks == "full": |
|
attn_modules = [ |
|
module |
|
for module in torch_dfs(self.unet) |
|
if isinstance(module, BasicTransformerBlock) |
|
or isinstance(module, TemporalBasicTransformerBlock) |
|
] |
|
attn_modules = sorted( |
|
attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
|
) |
|
|
|
for i, module in enumerate(attn_modules): |
|
module._original_inner_forward = module.forward |
|
if isinstance(module, BasicTransformerBlock): |
|
module.forward = hacked_basic_transformer_inner_forward.__get__( |
|
module, BasicTransformerBlock |
|
) |
|
if isinstance(module, TemporalBasicTransformerBlock): |
|
module.forward = hacked_basic_transformer_inner_forward.__get__( |
|
module, TemporalBasicTransformerBlock |
|
) |
|
|
|
module.bank = [] |
|
module.attn_weight = float(i) / float(len(attn_modules)) |
|
|
|
def update( |
|
self, |
|
writer, |
|
do_classifier_free_guidance=True, |
|
dtype=torch.float16, |
|
): |
|
if self.reference_attn: |
|
if self.fusion_blocks == "midup": |
|
reader_attn_modules = [ |
|
module |
|
for module in (torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)) |
|
if isinstance(module, TemporalBasicTransformerBlock) |
|
] |
|
writer_attn_modules = [ |
|
module |
|
for module in (torch_dfs(writer.unet.mid_block) + torch_dfs(writer.unet.up_blocks)) |
|
if isinstance(module, BasicTransformerBlock) |
|
] |
|
elif self.fusion_blocks == "full": |
|
reader_attn_modules = [ |
|
module |
|
for module in torch_dfs(self.unet) |
|
if isinstance(module, TemporalBasicTransformerBlock) |
|
] |
|
writer_attn_modules = [ |
|
module |
|
for module in torch_dfs(writer.unet) |
|
if isinstance(module, BasicTransformerBlock) |
|
] |
|
reader_attn_modules = sorted( |
|
reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
|
) |
|
writer_attn_modules = sorted( |
|
writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
|
) |
|
for r, w in zip(reader_attn_modules, writer_attn_modules): |
|
if do_classifier_free_guidance: |
|
r.bank = [torch.cat([torch.zeros_like(v), v]).to(dtype) for v in w.bank] |
|
else: |
|
r.bank = [v.clone().to(dtype) for v in w.bank] |
|
|
|
def clear(self): |
|
if self.reference_attn: |
|
if self.fusion_blocks == "midup": |
|
reader_attn_modules = [ |
|
module |
|
for module in ( |
|
torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) |
|
) |
|
if isinstance(module, BasicTransformerBlock) |
|
or isinstance(module, TemporalBasicTransformerBlock) |
|
] |
|
elif self.fusion_blocks == "full": |
|
reader_attn_modules = [ |
|
module |
|
for module in torch_dfs(self.unet) |
|
if isinstance(module, BasicTransformerBlock) |
|
or isinstance(module, TemporalBasicTransformerBlock) |
|
] |
|
reader_attn_modules = sorted( |
|
reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
|
) |
|
for r in reader_attn_modules: |
|
r.bank.clear() |
|
|