File size: 2,983 Bytes
bc3753a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from torch.nn.functional import *
from torch.nn.functional import (
    _mha_shape_check,
    _canonical_mask,
    _none_or_dtype,
    _in_projection_packed,
)

def multi_head_attention_forward_patched(
    query,
    key,
    value,
    embed_dim_to_check: int,
    num_heads: int,
    in_proj_weight,
    in_proj_bias: Optional[Tensor],
    bias_k: Optional[Tensor],
    bias_v: Optional[Tensor],
    add_zero_attn: bool,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Optional[Tensor],
    training: bool = True,
    key_padding_mask: Optional[Tensor] = None,
    need_weights: bool = True,
    attn_mask: Optional[Tensor] = None,
    use_separate_proj_weight: bool = False,
    q_proj_weight: Optional[Tensor] = None,
    k_proj_weight: Optional[Tensor] = None,
    v_proj_weight: Optional[Tensor] = None,
    static_k: Optional[Tensor] = None,
    static_v: Optional[Tensor] = None,
    average_attn_weights: bool = True,
    is_causal: bool = False,
    cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:

    # set up shape vars
    _, _, embed_dim = query.shape
    attn_mask = _canonical_mask(
        mask=attn_mask,
        mask_name="attn_mask",
        other_type=None,
        other_name="",
        target_type=query.dtype,
        check_other=False,
    )
    head_dim = embed_dim // num_heads

    proj_qkv = linear(query, in_proj_weight, in_proj_bias)
    proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
    q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]

    if cache["first_infer"] == 1:
        cache["k"][cache["stage"]] = k
        cache["v"][cache["stage"]] = v
    else:
        cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
        cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
        k = cache["k"][cache["stage"]]
        v = cache["v"][cache["stage"]]
    cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]

    attn_mask = _canonical_mask(
        mask=attn_mask,
        mask_name="attn_mask",
        other_type=None,
        other_name="",
        target_type=q.dtype,
        check_other=False,
    )
    attn_mask = attn_mask.unsqueeze(0)

    q = q.view(-1, num_heads, head_dim).transpose(0, 1)
    k = k.view(-1, num_heads, head_dim).transpose(0, 1)
    v = v.view(-1, num_heads, head_dim).transpose(0, 1)

    dropout_p = 0.0
    attn_mask = attn_mask.unsqueeze(0)
    q = q.view(num_heads, -1, head_dim).unsqueeze(0)
    k = k.view(num_heads, -1, head_dim).unsqueeze(0)
    v = v.view(num_heads, -1, head_dim).unsqueeze(0)
    attn_output = scaled_dot_product_attention(
        q, k, v, attn_mask, dropout_p, is_causal
    )
    attn_output = (
        attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
    )
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
    attn_output = attn_output.view(-1, 1, attn_output.size(1))

    return attn_output