Spaces:
Build error
Build error
import math | |
from typing import Any | |
from einops import rearrange | |
import torch | |
from diffusers.models.attention_processor import Attention | |
# flash attention forwards and backwards | |
# https://arxiv.org/abs/2205.14135 | |
EPSILON = 1e-6 | |
class FlashAttentionFunction(torch.autograd.function.Function): | |
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): | |
"""Algorithm 2 in the paper""" | |
device = q.device | |
dtype = q.dtype | |
max_neg_value = -torch.finfo(q.dtype).max | |
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) | |
o = torch.zeros_like(q) | |
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) | |
all_row_maxes = torch.full( | |
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device | |
) | |
scale = q.shape[-1] ** -0.5 | |
if mask is None: | |
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) | |
else: | |
mask = rearrange(mask, "b n -> b 1 1 n") | |
mask = mask.split(q_bucket_size, dim=-1) | |
row_splits = zip( | |
q.split(q_bucket_size, dim=-2), | |
o.split(q_bucket_size, dim=-2), | |
mask, | |
all_row_sums.split(q_bucket_size, dim=-2), | |
all_row_maxes.split(q_bucket_size, dim=-2), | |
) | |
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): | |
q_start_index = ind * q_bucket_size - qk_len_diff | |
col_splits = zip( | |
k.split(k_bucket_size, dim=-2), | |
v.split(k_bucket_size, dim=-2), | |
) | |
for k_ind, (kc, vc) in enumerate(col_splits): | |
k_start_index = k_ind * k_bucket_size | |
attn_weights = ( | |
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale | |
) | |
if row_mask is not None: | |
attn_weights.masked_fill_(~row_mask, max_neg_value) | |
if causal and q_start_index < (k_start_index + k_bucket_size - 1): | |
causal_mask = torch.ones( | |
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device | |
).triu(q_start_index - k_start_index + 1) | |
attn_weights.masked_fill_(causal_mask, max_neg_value) | |
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) | |
attn_weights -= block_row_maxes | |
exp_weights = torch.exp(attn_weights) | |
if row_mask is not None: | |
exp_weights.masked_fill_(~row_mask, 0.0) | |
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( | |
min=EPSILON | |
) | |
new_row_maxes = torch.maximum(block_row_maxes, row_maxes) | |
exp_values = torch.einsum( | |
"... i j, ... j d -> ... i d", exp_weights, vc | |
) | |
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) | |
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) | |
new_row_sums = ( | |
exp_row_max_diff * row_sums | |
+ exp_block_row_max_diff * block_row_sums | |
) | |
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( | |
(exp_block_row_max_diff / new_row_sums) * exp_values | |
) | |
row_maxes.copy_(new_row_maxes) | |
row_sums.copy_(new_row_sums) | |
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) | |
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) | |
return o | |
def backward(ctx, do): | |
"""Algorithm 4 in the paper""" | |
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args | |
q, k, v, o, l, m = ctx.saved_tensors | |
device = q.device | |
max_neg_value = -torch.finfo(q.dtype).max | |
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) | |
dq = torch.zeros_like(q) | |
dk = torch.zeros_like(k) | |
dv = torch.zeros_like(v) | |
row_splits = zip( | |
q.split(q_bucket_size, dim=-2), | |
o.split(q_bucket_size, dim=-2), | |
do.split(q_bucket_size, dim=-2), | |
mask, | |
l.split(q_bucket_size, dim=-2), | |
m.split(q_bucket_size, dim=-2), | |
dq.split(q_bucket_size, dim=-2), | |
) | |
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): | |
q_start_index = ind * q_bucket_size - qk_len_diff | |
col_splits = zip( | |
k.split(k_bucket_size, dim=-2), | |
v.split(k_bucket_size, dim=-2), | |
dk.split(k_bucket_size, dim=-2), | |
dv.split(k_bucket_size, dim=-2), | |
) | |
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): | |
k_start_index = k_ind * k_bucket_size | |
attn_weights = ( | |
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale | |
) | |
if causal and q_start_index < (k_start_index + k_bucket_size - 1): | |
causal_mask = torch.ones( | |
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device | |
).triu(q_start_index - k_start_index + 1) | |
attn_weights.masked_fill_(causal_mask, max_neg_value) | |
exp_attn_weights = torch.exp(attn_weights - mc) | |
if row_mask is not None: | |
exp_attn_weights.masked_fill_(~row_mask, 0.0) | |
p = exp_attn_weights / lc | |
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) | |
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) | |
D = (doc * oc).sum(dim=-1, keepdims=True) | |
ds = p * scale * (dp - D) | |
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) | |
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) | |
dqc.add_(dq_chunk) | |
dkc.add_(dk_chunk) | |
dvc.add_(dv_chunk) | |
return dq, dk, dv, None, None, None, None | |
class FlashAttnProcessor: | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
) -> Any: | |
q_bucket_size = 512 | |
k_bucket_size = 1024 | |
h = attn.heads | |
q = attn.to_q(hidden_states) | |
encoder_hidden_states = ( | |
encoder_hidden_states | |
if encoder_hidden_states is not None | |
else hidden_states | |
) | |
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) | |
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None: | |
context_k, context_v = attn.hypernetwork.forward( | |
hidden_states, encoder_hidden_states | |
) | |
context_k = context_k.to(hidden_states.dtype) | |
context_v = context_v.to(hidden_states.dtype) | |
else: | |
context_k = encoder_hidden_states | |
context_v = encoder_hidden_states | |
k = attn.to_k(context_k) | |
v = attn.to_v(context_v) | |
del encoder_hidden_states, hidden_states | |
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) | |
out = FlashAttentionFunction.apply( | |
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size | |
) | |
out = rearrange(out, "b h n d -> b n (h d)") | |
out = attn.to_out[0](out) | |
out = attn.to_out[1](out) | |
return out | |