ybelkada commited on
Commit
4f95d51
1 Parent(s): a5c24a9

Update llama_xformers_attention.py

Browse files
Files changed (1) hide show
  1. llama_xformers_attention.py +4 -27
llama_xformers_attention.py CHANGED
@@ -23,27 +23,9 @@ class LlamaXFormersAttention(LlamaAttention):
23
 
24
  bsz, q_len, _ = hidden_states.size()
25
 
26
- if self.config.pretraining_tp > 1:
27
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
28
- query_slices = self.q_proj.weight.split(
29
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
30
- )
31
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
32
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
33
-
34
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
35
- query_states = torch.cat(query_states, dim=-1)
36
-
37
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
38
- key_states = torch.cat(key_states, dim=-1)
39
-
40
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
41
- value_states = torch.cat(value_states, dim=-1)
42
-
43
- else:
44
- query_states = self.q_proj(hidden_states)
45
- key_states = self.k_proj(hidden_states)
46
- value_states = self.v_proj(hidden_states)
47
 
48
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
49
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -95,12 +77,7 @@ class LlamaXFormersAttention(LlamaAttention):
95
 
96
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
97
 
98
- if self.config.pretraining_tp > 1:
99
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
100
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
101
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
102
- else:
103
- attn_output = self.o_proj(attn_output)
104
 
105
  if not output_attentions:
106
  attn_weights = None
 
23
 
24
  bsz, q_len, _ = hidden_states.size()
25
 
26
+ query_states = self.q_proj(hidden_states)
27
+ key_states = self.k_proj(hidden_states)
28
+ value_states = self.v_proj(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
31
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
77
 
78
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
79
 
80
+ attn_output = self.o_proj(attn_output)
 
 
 
 
 
81
 
82
  if not output_attentions:
83
  attn_weights = None