ybelkada commited on
Commit
aecb5f1
1 Parent(s): 95d3a98

Update llama_xformers_attention.py

Browse files
Files changed (1) hide show
  1. llama_xformers_attention.py +108 -0
llama_xformers_attention.py CHANGED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ from transformers.models.llama.modeling_llama import LlamaAttention
7
+
8
+ class LlamaXFormersAttention(LlamaAttention):
9
+ def forward(
10
+ self,
11
+ hidden_states: torch.Tensor,
12
+ attention_mask: Optional[torch.Tensor] = None,
13
+ position_ids: Optional[torch.LongTensor] = None,
14
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
15
+ output_attentions: bool = False,
16
+ use_cache: bool = False,
17
+ **kwargs,
18
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
19
+ if "padding_mask" in kwargs:
20
+ warnings.warn(
21
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
22
+ )
23
+
24
+ bsz, q_len, _ = hidden_states.size()
25
+
26
+ if self.config.pretraining_tp > 1:
27
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
28
+ query_slices = self.q_proj.weight.split(
29
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
30
+ )
31
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
32
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
33
+
34
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
35
+ query_states = torch.cat(query_states, dim=-1)
36
+
37
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
38
+ key_states = torch.cat(key_states, dim=-1)
39
+
40
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
41
+ value_states = torch.cat(value_states, dim=-1)
42
+
43
+ else:
44
+ query_states = self.q_proj(hidden_states)
45
+ key_states = self.k_proj(hidden_states)
46
+ value_states = self.v_proj(hidden_states)
47
+
48
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
49
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
50
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
51
+
52
+ kv_seq_len = key_states.shape[-2]
53
+ if past_key_value is not None:
54
+ kv_seq_len += past_key_value[0].shape[-2]
55
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
56
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
57
+
58
+ if past_key_value is not None:
59
+ # reuse k, v, self_attention
60
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
61
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
62
+
63
+ past_key_value = (key_states, value_states) if use_cache else None
64
+
65
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
66
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
67
+
68
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
69
+
70
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
71
+ raise ValueError(
72
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
73
+ f" {attn_weights.size()}"
74
+ )
75
+
76
+ if attention_mask is not None:
77
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
78
+ raise ValueError(
79
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
80
+ )
81
+ attn_weights = attn_weights + attention_mask
82
+
83
+ # upcast attention to fp32
84
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
85
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
86
+ attn_output = torch.matmul(attn_weights, value_states)
87
+
88
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
89
+ raise ValueError(
90
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
91
+ f" {attn_output.size()}"
92
+ )
93
+
94
+ attn_output = attn_output.transpose(1, 2).contiguous()
95
+
96
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
97
+
98
+ if self.config.pretraining_tp > 1:
99
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
100
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
101
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
102
+ else:
103
+ attn_output = self.o_proj(attn_output)
104
+
105
+ if not output_attentions:
106
+ attn_weights = None
107
+
108
+ return attn_output, attn_weights, past_key_value