katuni4ka commited on
Commit
1952dda
·
verified ·
1 Parent(s): 0d670f1

Update modeling_baichuan.py

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +16 -14
modeling_baichuan.py CHANGED
@@ -59,7 +59,7 @@ def _make_causal_mask(
59
  Make causal mask used for bi-directional self-attention.
60
  """
61
  bsz, tgt_len = input_ids_shape
62
- mask = torch.full((tgt_len, tgt_len), float("-inf"), device=device)
63
  mask_cond = torch.arange(mask.size(-1), device=device)
64
  mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
65
  mask = mask.to(dtype)
@@ -109,15 +109,14 @@ class RMSNorm(nn.Module):
109
  class RotaryEmbedding(torch.nn.Module):
110
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
111
  super().__init__()
112
- self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
113
  self.max_seq_len_cached = max_position_embeddings
114
  t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
115
  freqs = torch.outer(t, self.inv_freq)
116
  emb = torch.cat((freqs, freqs), dim=-1)
117
- self.register_buffer("cos_cached", emb.cos().to(self.inv_freq.device)[None, None, :, :], persistent=False)
118
- self.register_buffer("sin_cached", emb.sin().to(self.inv_freq.device)[None, None, :, :], persistent=False)
119
-
120
- def forward(self, x, seq_len):
121
  # x: [bs, num_attention_heads, seq_len, head_size]
122
  # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
123
  if seq_len > self.max_seq_len_cached:
@@ -125,11 +124,14 @@ class RotaryEmbedding(torch.nn.Module):
125
  t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
126
  freqs = torch.outer(t, self.inv_freq)
127
  emb = torch.cat((freqs, freqs), dim=-1)
128
- self.register_buffer("cos_cached", emb.cos().to(self.inv_freq.device)[None, None, :, :], persistent=False)
129
- self.register_buffer("sin_cached", emb.sin().to(self.inv_freq.device)[None, None, :, :], persistent=False)
 
 
 
130
  return (
131
- self.cos_cached[:, :, :seq_len, :].to(x.device),
132
- self.sin_cached[:, :, :seq_len, :].to(x.device),
133
  )
134
 
135
 
@@ -208,7 +210,7 @@ class Attention(nn.Module):
208
 
209
  kv_seq_len = key_states.shape[-2]
210
  if past_key_value is not None:
211
- kv_seq_len = key_states.shape[-2] + past_key_value[0].shape[-2]
212
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
213
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
214
  # [bsz, nh, t, hd]
@@ -228,8 +230,8 @@ class Attention(nn.Module):
228
  query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
229
  )
230
  else:
231
- #with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
232
- attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
233
  attn_output = attn_output.transpose(1, 2)
234
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
235
  attn_output = self.o_proj(attn_output)
@@ -701,4 +703,4 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
701
  else:
702
  outputs = self.generate(input_ids, generation_config=generation_config)
703
  response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
704
- return response
 
59
  Make causal mask used for bi-directional self-attention.
60
  """
61
  bsz, tgt_len = input_ids_shape
62
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
63
  mask_cond = torch.arange(mask.size(-1), device=device)
64
  mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
65
  mask = mask.to(dtype)
 
109
  class RotaryEmbedding(torch.nn.Module):
110
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
111
  super().__init__()
112
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
113
  self.max_seq_len_cached = max_position_embeddings
114
  t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
115
  freqs = torch.outer(t, self.inv_freq)
116
  emb = torch.cat((freqs, freqs), dim=-1)
117
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
118
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
119
+ def forward(self, x, seq_len=None):
 
120
  # x: [bs, num_attention_heads, seq_len, head_size]
121
  # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
122
  if seq_len > self.max_seq_len_cached:
 
124
  t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
125
  freqs = torch.outer(t, self.inv_freq)
126
  emb = torch.cat((freqs, freqs), dim=-1)
127
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
128
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
129
+ elif self.cos_cached.device != x.device:
130
+ self.cos_cached = self.cos_cached.to(x.device)
131
+ self.sin_cached = self.sin_cached.to(x.device)
132
  return (
133
+ self.cos_cached[:, :, :seq_len, ...],
134
+ self.sin_cached[:, :, :seq_len, ...],
135
  )
136
 
137
 
 
210
 
211
  kv_seq_len = key_states.shape[-2]
212
  if past_key_value is not None:
213
+ kv_seq_len += past_key_value[0].shape[-2]
214
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
215
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
216
  # [bsz, nh, t, hd]
 
230
  query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
231
  )
232
  else:
233
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
234
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
235
  attn_output = attn_output.transpose(1, 2)
236
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
237
  attn_output = self.o_proj(attn_output)
 
703
  else:
704
  outputs = self.generate(input_ids, generation_config=generation_config)
705
  response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
706
+ return response