File size: 7,770 Bytes
d16b52d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import torch
import torch.nn.functional as F
from einops import rearrange

from .attention import CrossAttention
from .positional_encoding import PositionalEncoding


class StreamTemporalAttention(CrossAttention):
    """

    * window_size: The max length of attention window.
    * sink_size: The number sink token.
    * positional_rule: absolute, relative

    Therefore, the seq length of temporal self-attention will be:
        sink_length + cache_size

    """

    def __init__(
        self,
        attention_mode=None,
        cross_frame_attention_mode=None,
        temporal_position_encoding=False,
        temporal_position_encoding_max_len=32,
        window_size=8,
        sink_size=0,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.attention_mode = self._orig_attention_mode = attention_mode
        self.is_cross_attention = kwargs["cross_attention_dim"] is not None

        self.pos_encoder = PositionalEncoding(
            kwargs["query_dim"],
            dropout=0.0,
            max_len=temporal_position_encoding_max_len,
        )

        self.window_size = window_size
        self.sink_size = sink_size
        self.cache_size = self.window_size - self.sink_size
        assert self.cache_size >= 0, (
            "cache_size must be greater or equal to 0. Please check your configuration. "
            f"window_size: {window_size}, sink_size: {sink_size}, "
            f"cache_size: {self.cache_size}"
        )

        self.motion_module_idx = None

    def set_index(self, idx):
        self.motion_module_idx = idx

    @torch.no_grad()
    def set_cache(self, denoising_steps_num: int):
        """
        larger buffer index means cleaner latent
        """
        device = next(self.parameters()).device
        dtype = next(self.parameters()).dtype

        # [t, 2, hw, L, c], 2 means k and v
        kv_cache = torch.zeros(
            denoising_steps_num,
            2,
            self.h * self.w,
            self.window_size,
            self.kv_channels,
            device=device,
            dtype=dtype,
        )
        self.denoising_steps_num = denoising_steps_num

        return kv_cache

    @torch.no_grad()
    def prepare_pe_buffer(self):
        """In AnimateDiff, Temporal Self-attention use absolute positional encoding:
        q = w_q * (x + pe) + bias
        k = w_k * (x + pe) + bias
        v = w_v * (x + pe) + bias

        If we want to conduct relative positional encoding with kv-cache, we should pre-calcute
        `w_q/k/v * pe` and then cache `w_q/k/v * x + bias`
        """

        pe_list = self.pos_encoder.pe[:, : self.window_size]  # [1, window_size, ch]
        q_pe = F.linear(pe_list, self.to_q.weight)
        k_pe = F.linear(pe_list, self.to_k.weight)
        v_pe = F.linear(pe_list, self.to_v.weight)

        self.register_buffer("q_pe", q_pe)
        self.register_buffer("k_pe", k_pe)
        self.register_buffer("v_pe", v_pe)

    def prepare_qkv_full_and_cache(self, hidden_states, kv_cache, pe_idx, update_idx):
        """
        hidden_states: [(N * bhw), F, c],
        kv_cache: [2, N, hw, L, c]

        * for warmup case: `N` should be 1 and `F` should be warmup_size (`sink_size`)
        * for streaming case: `N` should be `denoising_steps_num` and `F` should be `chunk_size`

        """
        q_layer = self.to_q(hidden_states)
        k_layer = self.to_k(hidden_states)
        v_layer = self.to_v(hidden_states)

        q_layer = rearrange(q_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num)
        k_layer = rearrange(k_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num)
        v_layer = rearrange(v_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num)

        # onnx & trt friendly indexing
        for idx in range(self.denoising_steps_num):
            kv_cache[idx, 0, :, update_idx[idx]] = k_layer[idx, :, 0]
            kv_cache[idx, 1, :, update_idx[idx]] = v_layer[idx, :, 0]

        k_full = kv_cache[:, 0]
        v_full = kv_cache[:, 1]

        kv_idx = pe_idx
        q_idx = torch.stack([kv_idx[idx, update_idx[idx]] for idx in range(self.denoising_steps_num)]).unsqueeze_(
            1
        )  # [timesteps, 1]

        pe_k = torch.cat(
            [self.k_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0
        )  # [n, window_size, c]
        pe_v = torch.cat(
            [self.v_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0
        )  # [n, window_size, c]
        pe_q = torch.cat(
            [self.q_pe.index_select(1, q_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0
        )  # [n, window_size, c]

        q_layer = q_layer + pe_q.unsqueeze(1)
        k_full = k_full + pe_k.unsqueeze(1)
        v_full = v_full + pe_v.unsqueeze(1)

        q_layer = rearrange(q_layer, "n bhw f c -> (n bhw) f c")
        k_full = rearrange(k_full, "n bhw f c -> (n bhw) f c")
        v_full = rearrange(v_full, "n bhw f c -> (n bhw) f c")

        return q_layer, k_full, v_full

    def forward(
        self,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        video_length=None,
        temporal_attention_mask=None,
        kv_cache=None,
        pe_idx=None,
        update_idx=None,
        *args,
        **kwargs,
    ):
        """
        temporal_attention_mask: attention mask specific for the temporal self-attention.
        """

        d = hidden_states.shape[1]
        hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)

        if self.group_norm is not None:
            hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query_layer, key_full, value_full = self.prepare_qkv_full_and_cache(
            hidden_states, kv_cache, pe_idx, update_idx
        )

        # [(n * hw * b), f, c] -> [(n * hw * b * head), f, c // head]
        query_layer = self.reshape_heads_to_batch_dim(query_layer)
        key_full = self.reshape_heads_to_batch_dim(key_full)
        value_full = self.reshape_heads_to_batch_dim(value_full)

        if temporal_attention_mask is not None:
            q_size = query_layer.shape[1]
            # [n, self.window_size] -> [n, hw, q_size, window_size]
            temporal_attention_mask_ = temporal_attention_mask[:, None, None, :].repeat(1, self.h * self.w, q_size, 1)
            temporal_attention_mask_ = rearrange(temporal_attention_mask_, "n hw Q KV -> (n hw) Q KV")
            temporal_attention_mask_ = temporal_attention_mask_.repeat_interleave(self.heads, dim=0)
        else:
            temporal_attention_mask_ = None

        # attention, what we cannot get enough of
        if hasattr(F, "scaled_dot_product_attention"):
            hidden_states = self._memory_efficient_attention_pt20(
                query_layer, key_full, value_full, attention_mask=temporal_attention_mask_
            )

        elif self._use_memory_efficient_attention_xformers:
            hidden_states = self._memory_efficient_attention_xformers(
                query_layer, key_full, value_full, attention_mask=temporal_attention_mask_
            )
            # Some versions of xformers return output in fp32, cast it back to the dtype of the input
            hidden_states = hidden_states.to(query_layer.dtype)
        else:
            hidden_states = self._attention(query_layer, key_full, value_full, temporal_attention_mask_)

        # linear proj
        hidden_states = self.to_out[0](hidden_states)

        # dropout
        hidden_states = self.to_out[1](hidden_states)

        hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)

        return hidden_states