Crystalcareai commited on
Commit
e8a1698
·
verified ·
1 Parent(s): 1b5a82b

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +5 -0
modeling_gemmoe.py CHANGED
@@ -1215,6 +1215,10 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1215
  )
1216
 
1217
  hidden_states = outputs[0]
 
 
 
 
1218
  logits = self.lm_head(hidden_states)
1219
  logits = logits.float()
1220
 
@@ -1332,6 +1336,7 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1332
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1333
  )
1334
  return reordered_past
 
1335
  @add_start_docstrings(
1336
  """
1337
  The Gemmoe Model transformer with a sequence classification head on top (linear layer).
 
1215
  )
1216
 
1217
  hidden_states = outputs[0]
1218
+
1219
+ # Ensure hidden_states and lm_head have compatible dtypes
1220
+ hidden_states = hidden_states.to(dtype=self.lm_head.weight.dtype)
1221
+
1222
  logits = self.lm_head(hidden_states)
1223
  logits = logits.float()
1224
 
 
1336
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1337
  )
1338
  return reordered_past
1339
+
1340
  @add_start_docstrings(
1341
  """
1342
  The Gemmoe Model transformer with a sequence classification head on top (linear layer).