Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +5 -0
modeling_gemmoe.py
CHANGED
@@ -1215,6 +1215,10 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1215 |
)
|
1216 |
|
1217 |
hidden_states = outputs[0]
|
|
|
|
|
|
|
|
|
1218 |
logits = self.lm_head(hidden_states)
|
1219 |
logits = logits.float()
|
1220 |
|
@@ -1332,6 +1336,7 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1332 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1333 |
)
|
1334 |
return reordered_past
|
|
|
1335 |
@add_start_docstrings(
|
1336 |
"""
|
1337 |
The Gemmoe Model transformer with a sequence classification head on top (linear layer).
|
|
|
1215 |
)
|
1216 |
|
1217 |
hidden_states = outputs[0]
|
1218 |
+
|
1219 |
+
# Ensure hidden_states and lm_head have compatible dtypes
|
1220 |
+
hidden_states = hidden_states.to(dtype=self.lm_head.weight.dtype)
|
1221 |
+
|
1222 |
logits = self.lm_head(hidden_states)
|
1223 |
logits = logits.float()
|
1224 |
|
|
|
1336 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1337 |
)
|
1338 |
return reordered_past
|
1339 |
+
|
1340 |
@add_start_docstrings(
|
1341 |
"""
|
1342 |
The Gemmoe Model transformer with a sequence classification head on top (linear layer).
|