Crystalcareai commited on
Commit
6adf5e4
·
verified ·
1 Parent(s): 2e51e15

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +45 -25
modeling_gemmoe.py CHANGED
@@ -55,6 +55,7 @@ if is_flash_attn_2_available():
55
 
56
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
57
  # It means that the function will not be traced through and simply appear as a node in the graph.
 
58
  if is_torch_fx_available():
59
  if not is_torch_greater_or_equal_than_1_13:
60
  import torch.fx
@@ -166,42 +167,52 @@ class GemmoeRMSNorm(nn.Module):
166
  self.weight = nn.Parameter(torch.zeros(dim))
167
 
168
  def _norm(self, x):
169
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 
 
 
 
170
 
171
  def forward(self, x):
172
- output = self._norm(x.float()).type_as(x)
173
- return output * (self.weight + 1)
 
 
174
 
175
  ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
176
 
177
  class GemmoeRotaryEmbedding(nn.Module):
178
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
179
  super().__init__()
180
-
181
  self.dim = dim
182
  self.max_position_embeddings = max_position_embeddings
183
  self.base = base
184
- self.register_buffer("inv_freq", None, persistent=False)
185
-
186
- @torch.no_grad()
187
- def forward(self, x, position_ids, seq_len=None):
188
- # x: [bs, num_attention_heads, seq_len, head_size]
189
- if self.inv_freq is None:
190
- self.inv_freq = 1.0 / (
191
- self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
192
- )
193
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
194
- position_ids_expanded = position_ids[:, None, :].float()
195
- # Force float32 since bfloat16 loses precision on long contexts
196
- # See https://github.com/huggingface/transformers/pull/29285
197
- device_type = x.device.type
198
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
199
- with torch.autocast(device_type=device_type, enabled=False):
200
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
201
- emb = torch.cat((freqs, freqs), dim=-1)
202
- cos = emb.cos()
203
- sin = emb.sin()
204
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 
 
 
 
 
205
 
206
  # Copied from transformers.models.llama.modeling_llama.rotate_half
207
  def rotate_half(x):
@@ -1034,6 +1045,15 @@ class GemmoeModel(GemmoePreTrainedModel):
1034
  if inputs_embeds is None:
1035
  inputs_embeds = self.embed_tokens(input_ids)
1036
 
 
 
 
 
 
 
 
 
 
1037
  past_seen_tokens = 0
1038
  if use_cache: # kept for BC (cache positions)
1039
  if not isinstance(past_key_values, StaticCache):
 
55
 
56
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
57
  # It means that the function will not be traced through and simply appear as a node in the graph.
58
+
59
  if is_torch_fx_available():
60
  if not is_torch_greater_or_equal_than_1_13:
61
  import torch.fx
 
167
  self.weight = nn.Parameter(torch.zeros(dim))
168
 
169
  def _norm(self, x):
170
+ # Ensure the entire normalization is done in float32
171
+ x_float = x.float() # upcast to float32
172
+ mean = x_float.pow(2).mean(-1, keepdim=True)
173
+ normed_x = x_float * torch.rsqrt(mean + self.eps)
174
+ return normed_x
175
 
176
  def forward(self, x):
177
+ normed_x = self._norm(x)
178
+ # Downcast the result to the original dtype at the end
179
+ normed_x = normed_x.type_as(x)
180
+ return normed_x * (self.weight + 1)
181
 
182
  ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
183
 
184
  class GemmoeRotaryEmbedding(nn.Module):
185
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
186
  super().__init__()
 
187
  self.dim = dim
188
  self.max_position_embeddings = max_position_embeddings
189
  self.base = base
190
+ self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
191
+
192
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
193
+ self.max_seq_len_cached = seq_len
194
+ freq_exponents = (2.0 / self.dim) * (
195
+ torch.arange(self.dim // 2, dtype=torch.int64, device="cpu").float()
196
+ )
197
+ timescale = self.base ** freq_exponents
198
+ positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.int64).float()
199
+ radians_new = positions[..., None] / timescale[None, None, :]
200
+ radians_new = radians_new.squeeze(0)
201
+ emb = torch.cat((radians_new, radians_new), dim=-1)
202
+ cos = emb.cos().to(device=device, non_blocking=True)
203
+ sin = emb.sin().to(device=device, non_blocking=True)
204
+ self.register_buffer("cos_cached", cos, persistent=False)
205
+ self.register_buffer("sin_cached", sin, persistent=False)
206
+
207
+ def forward(self, x, position_ids=None, seq_len=None):
208
+ if seq_len is None:
209
+ seq_len = x.size(2)
210
+ if seq_len > self.max_seq_len_cached:
211
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
212
+ return (
213
+ self.cos_cached[:seq_len],
214
+ self.sin_cached[:seq_len],
215
+ )
216
 
217
  # Copied from transformers.models.llama.modeling_llama.rotate_half
218
  def rotate_half(x):
 
1045
  if inputs_embeds is None:
1046
  inputs_embeds = self.embed_tokens(input_ids)
1047
 
1048
+ # Scale embeddings
1049
+ # Fix for precision issue when casting to bfloat16
1050
+ hidden_size_sqrt = math.sqrt(self.config.hidden_size)
1051
+ if inputs_embeds.dtype == torch.bfloat16:
1052
+ # Use float32 for sqrt calculation to avoid precision loss
1053
+ hidden_size_sqrt = hidden_size_sqrt.astype(torch.float32)
1054
+
1055
+ hidden_states = inputs_embeds * hidden_size_sqrt
1056
+
1057
  past_seen_tokens = 0
1058
  if use_cache: # kept for BC (cache positions)
1059
  if not isinstance(past_key_values, StaticCache):