Update modeling_baichuan.py
Browse files- modeling_baichuan.py +16 -14
modeling_baichuan.py
CHANGED
@@ -59,7 +59,7 @@ def _make_causal_mask(
|
|
59 |
Make causal mask used for bi-directional self-attention.
|
60 |
"""
|
61 |
bsz, tgt_len = input_ids_shape
|
62 |
-
mask = torch.full((tgt_len, tgt_len),
|
63 |
mask_cond = torch.arange(mask.size(-1), device=device)
|
64 |
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
65 |
mask = mask.to(dtype)
|
@@ -109,15 +109,14 @@ class RMSNorm(nn.Module):
|
|
109 |
class RotaryEmbedding(torch.nn.Module):
|
110 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
111 |
super().__init__()
|
112 |
-
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
113 |
self.max_seq_len_cached = max_position_embeddings
|
114 |
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
115 |
freqs = torch.outer(t, self.inv_freq)
|
116 |
emb = torch.cat((freqs, freqs), dim=-1)
|
117 |
-
self.
|
118 |
-
self.
|
119 |
-
|
120 |
-
def forward(self, x, seq_len):
|
121 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
122 |
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
123 |
if seq_len > self.max_seq_len_cached:
|
@@ -125,11 +124,14 @@ class RotaryEmbedding(torch.nn.Module):
|
|
125 |
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
126 |
freqs = torch.outer(t, self.inv_freq)
|
127 |
emb = torch.cat((freqs, freqs), dim=-1)
|
128 |
-
self.
|
129 |
-
self.
|
|
|
|
|
|
|
130 |
return (
|
131 |
-
self.cos_cached[:, :, :seq_len,
|
132 |
-
self.sin_cached[:, :, :seq_len,
|
133 |
)
|
134 |
|
135 |
|
@@ -208,7 +210,7 @@ class Attention(nn.Module):
|
|
208 |
|
209 |
kv_seq_len = key_states.shape[-2]
|
210 |
if past_key_value is not None:
|
211 |
-
kv_seq_len
|
212 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
213 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
214 |
# [bsz, nh, t, hd]
|
@@ -228,8 +230,8 @@ class Attention(nn.Module):
|
|
228 |
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
|
229 |
)
|
230 |
else:
|
231 |
-
|
232 |
-
|
233 |
attn_output = attn_output.transpose(1, 2)
|
234 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
235 |
attn_output = self.o_proj(attn_output)
|
@@ -701,4 +703,4 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
701 |
else:
|
702 |
outputs = self.generate(input_ids, generation_config=generation_config)
|
703 |
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
704 |
-
return response
|
|
|
59 |
Make causal mask used for bi-directional self-attention.
|
60 |
"""
|
61 |
bsz, tgt_len = input_ids_shape
|
62 |
+
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
|
63 |
mask_cond = torch.arange(mask.size(-1), device=device)
|
64 |
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
65 |
mask = mask.to(dtype)
|
|
|
109 |
class RotaryEmbedding(torch.nn.Module):
|
110 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
111 |
super().__init__()
|
112 |
+
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
113 |
self.max_seq_len_cached = max_position_embeddings
|
114 |
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
115 |
freqs = torch.outer(t, self.inv_freq)
|
116 |
emb = torch.cat((freqs, freqs), dim=-1)
|
117 |
+
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
|
118 |
+
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
|
119 |
+
def forward(self, x, seq_len=None):
|
|
|
120 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
121 |
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
122 |
if seq_len > self.max_seq_len_cached:
|
|
|
124 |
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
125 |
freqs = torch.outer(t, self.inv_freq)
|
126 |
emb = torch.cat((freqs, freqs), dim=-1)
|
127 |
+
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
|
128 |
+
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
|
129 |
+
elif self.cos_cached.device != x.device:
|
130 |
+
self.cos_cached = self.cos_cached.to(x.device)
|
131 |
+
self.sin_cached = self.sin_cached.to(x.device)
|
132 |
return (
|
133 |
+
self.cos_cached[:, :, :seq_len, ...],
|
134 |
+
self.sin_cached[:, :, :seq_len, ...],
|
135 |
)
|
136 |
|
137 |
|
|
|
210 |
|
211 |
kv_seq_len = key_states.shape[-2]
|
212 |
if past_key_value is not None:
|
213 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
214 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
215 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
216 |
# [bsz, nh, t, hd]
|
|
|
230 |
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
|
231 |
)
|
232 |
else:
|
233 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
234 |
+
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
|
235 |
attn_output = attn_output.transpose(1, 2)
|
236 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
237 |
attn_output = self.o_proj(attn_output)
|
|
|
703 |
else:
|
704 |
outputs = self.generate(input_ids, generation_config=generation_config)
|
705 |
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
706 |
+
return response
|