File size: 3,672 Bytes
b9d6819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn as nn
import torch.utils.checkpoint
import einops
from einops import rearrange, repeat
from inspect import isfunction
from .rotary import RotaryEmbedding

if hasattr(nn.functional, 'scaled_dot_product_attention'):
    ATTENTION_MODE = 'flash'
else:
    ATTENTION_MODE = 'math'
print(f'attention mode is {ATTENTION_MODE}')


def add_mask(sim, mask):
    b, ndim = sim.shape[0], mask.ndim
    if ndim == 3:
        mask = rearrange(mask, "b n m -> b 1 n m")
    if ndim == 2:
        mask = repeat(mask, "n m -> b 1 n m", b=b)
    max_neg_value = -torch.finfo(sim.dtype).max
    sim = sim.masked_fill(~mask, max_neg_value)
    return sim


def create_mask(q, k, q_mask=None, k_mask=None):
    def default(val, d):
        return val if val is not None else (d() if isfunction(d) else d)

    b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
    q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
    k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
    attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
    return attn_mask


class Attention(nn.Module):
    def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None,

                 attn_drop=0., proj_drop=0., use_rope=False):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        context_dim = dim if context_dim is None else context_dim

        self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
        self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
        self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
        self.attn_drop_p = attn_drop
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.use_rope = use_rope
        if self.use_rope:
            self.rotary = RotaryEmbedding(dim=head_dim)

    def forward(self, x, context=None, context_mask=None):
        B, L, C = x.shape
        q = self.to_q(x)
        if context is None:
            context = x
        else:
            assert self.use_rope is False

        k = self.to_k(context)
        v = self.to_v(context)

        if context_mask is not None:
            mask_binary = create_mask(x, context, None, context_mask)
        else:
            mask_binary = None

        q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads).float()
        k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads).float()
        v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads).float()

        if self.use_rope:
            q, k = self.rotary(q=q, k=k)

        if ATTENTION_MODE == 'flash':
            x = torch.nn.functional.scaled_dot_product_attention(q, k, v,
                                                                 dropout_p=self.attn_drop_p,
                                                                 attn_mask=mask_binary)
            x = einops.rearrange(x, 'B H L D -> B L (H D)')
        elif ATTENTION_MODE == 'math':
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2).reshape(B, L, C)
        else:
            raise NotImplementedError

        x = self.proj(x)
        x = self.proj_drop(x)
        return x