Spaces:
Starting
on
L4
Starting
on
L4
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
from .attention import CrossAttention | |
from .positional_encoding import PositionalEncoding | |
class StreamTemporalAttention(CrossAttention): | |
""" | |
* window_size: The max length of attention window. | |
* sink_size: The number sink token. | |
* positional_rule: absolute, relative | |
Therefore, the seq length of temporal self-attention will be: | |
sink_length + cache_size | |
""" | |
def __init__( | |
self, | |
attention_mode=None, | |
cross_frame_attention_mode=None, | |
temporal_position_encoding=False, | |
temporal_position_encoding_max_len=32, | |
window_size=8, | |
sink_size=0, | |
*args, | |
**kwargs, | |
): | |
super().__init__(*args, **kwargs) | |
self.attention_mode = self._orig_attention_mode = attention_mode | |
self.is_cross_attention = kwargs["cross_attention_dim"] is not None | |
self.pos_encoder = PositionalEncoding( | |
kwargs["query_dim"], | |
dropout=0.0, | |
max_len=temporal_position_encoding_max_len, | |
) | |
self.window_size = window_size | |
self.sink_size = sink_size | |
self.cache_size = self.window_size - self.sink_size | |
assert self.cache_size >= 0, ( | |
"cache_size must be greater or equal to 0. Please check your configuration. " | |
f"window_size: {window_size}, sink_size: {sink_size}, " | |
f"cache_size: {self.cache_size}" | |
) | |
self.motion_module_idx = None | |
def set_index(self, idx): | |
self.motion_module_idx = idx | |
def set_cache(self, denoising_steps_num: int): | |
""" | |
larger buffer index means cleaner latent | |
""" | |
device = next(self.parameters()).device | |
dtype = next(self.parameters()).dtype | |
# [t, 2, hw, L, c], 2 means k and v | |
kv_cache = torch.zeros( | |
denoising_steps_num, | |
2, | |
self.h * self.w, | |
self.window_size, | |
self.kv_channels, | |
device=device, | |
dtype=dtype, | |
) | |
self.denoising_steps_num = denoising_steps_num | |
return kv_cache | |
def prepare_pe_buffer(self): | |
"""In AnimateDiff, Temporal Self-attention use absolute positional encoding: | |
q = w_q * (x + pe) + bias | |
k = w_k * (x + pe) + bias | |
v = w_v * (x + pe) + bias | |
If we want to conduct relative positional encoding with kv-cache, we should pre-calcute | |
`w_q/k/v * pe` and then cache `w_q/k/v * x + bias` | |
""" | |
pe_list = self.pos_encoder.pe[:, : self.window_size] # [1, window_size, ch] | |
q_pe = F.linear(pe_list, self.to_q.weight) | |
k_pe = F.linear(pe_list, self.to_k.weight) | |
v_pe = F.linear(pe_list, self.to_v.weight) | |
self.register_buffer("q_pe", q_pe) | |
self.register_buffer("k_pe", k_pe) | |
self.register_buffer("v_pe", v_pe) | |
def prepare_qkv_full_and_cache(self, hidden_states, kv_cache, pe_idx, update_idx): | |
""" | |
hidden_states: [(N * bhw), F, c], | |
kv_cache: [2, N, hw, L, c] | |
* for warmup case: `N` should be 1 and `F` should be warmup_size (`sink_size`) | |
* for streaming case: `N` should be `denoising_steps_num` and `F` should be `chunk_size` | |
""" | |
q_layer = self.to_q(hidden_states) | |
k_layer = self.to_k(hidden_states) | |
v_layer = self.to_v(hidden_states) | |
q_layer = rearrange(q_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) | |
k_layer = rearrange(k_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) | |
v_layer = rearrange(v_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) | |
# onnx & trt friendly indexing | |
for idx in range(self.denoising_steps_num): | |
kv_cache[idx, 0, :, update_idx[idx]] = k_layer[idx, :, 0] | |
kv_cache[idx, 1, :, update_idx[idx]] = v_layer[idx, :, 0] | |
k_full = kv_cache[:, 0] | |
v_full = kv_cache[:, 1] | |
kv_idx = pe_idx | |
q_idx = torch.stack([kv_idx[idx, update_idx[idx]] for idx in range(self.denoising_steps_num)]).unsqueeze_( | |
1 | |
) # [timesteps, 1] | |
pe_k = torch.cat( | |
[self.k_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 | |
) # [n, window_size, c] | |
pe_v = torch.cat( | |
[self.v_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 | |
) # [n, window_size, c] | |
pe_q = torch.cat( | |
[self.q_pe.index_select(1, q_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 | |
) # [n, window_size, c] | |
q_layer = q_layer + pe_q.unsqueeze(1) | |
k_full = k_full + pe_k.unsqueeze(1) | |
v_full = v_full + pe_v.unsqueeze(1) | |
q_layer = rearrange(q_layer, "n bhw f c -> (n bhw) f c") | |
k_full = rearrange(k_full, "n bhw f c -> (n bhw) f c") | |
v_full = rearrange(v_full, "n bhw f c -> (n bhw) f c") | |
return q_layer, k_full, v_full | |
def forward( | |
self, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
video_length=None, | |
temporal_attention_mask=None, | |
kv_cache=None, | |
pe_idx=None, | |
update_idx=None, | |
*args, | |
**kwargs, | |
): | |
""" | |
temporal_attention_mask: attention mask specific for the temporal self-attention. | |
""" | |
d = hidden_states.shape[1] | |
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) | |
if self.group_norm is not None: | |
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query_layer, key_full, value_full = self.prepare_qkv_full_and_cache( | |
hidden_states, kv_cache, pe_idx, update_idx | |
) | |
# [(n * hw * b), f, c] -> [(n * hw * b * head), f, c // head] | |
query_layer = self.reshape_heads_to_batch_dim(query_layer) | |
key_full = self.reshape_heads_to_batch_dim(key_full) | |
value_full = self.reshape_heads_to_batch_dim(value_full) | |
if temporal_attention_mask is not None: | |
q_size = query_layer.shape[1] | |
# [n, self.window_size] -> [n, hw, q_size, window_size] | |
temporal_attention_mask_ = temporal_attention_mask[:, None, None, :].repeat(1, self.h * self.w, q_size, 1) | |
temporal_attention_mask_ = rearrange(temporal_attention_mask_, "n hw Q KV -> (n hw) Q KV") | |
temporal_attention_mask_ = temporal_attention_mask_.repeat_interleave(self.heads, dim=0) | |
else: | |
temporal_attention_mask_ = None | |
# attention, what we cannot get enough of | |
if hasattr(F, "scaled_dot_product_attention"): | |
hidden_states = self._memory_efficient_attention_pt20( | |
query_layer, key_full, value_full, attention_mask=temporal_attention_mask_ | |
) | |
elif self._use_memory_efficient_attention_xformers: | |
hidden_states = self._memory_efficient_attention_xformers( | |
query_layer, key_full, value_full, attention_mask=temporal_attention_mask_ | |
) | |
# Some versions of xformers return output in fp32, cast it back to the dtype of the input | |
hidden_states = hidden_states.to(query_layer.dtype) | |
else: | |
hidden_states = self._attention(query_layer, key_full, value_full, temporal_attention_mask_) | |
# linear proj | |
hidden_states = self.to_out[0](hidden_states) | |
# dropout | |
hidden_states = self.to_out[1](hidden_states) | |
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) | |
return hidden_states | |