Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +4 -14
modeling_gemmoe.py
CHANGED
@@ -1220,14 +1220,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1220 |
hidden_states = hidden_states.to(dtype=self.lm_head.weight.dtype)
|
1221 |
|
1222 |
logits = self.lm_head(hidden_states)
|
1223 |
-
logits = logits.float()
|
1224 |
-
|
1225 |
-
# Handle unused parameters
|
1226 |
-
if self.training:
|
1227 |
-
for expert in self.model.layers[-1].block_sparse_moe.experts:
|
1228 |
-
for param in expert.parameters():
|
1229 |
-
if param.requires_grad and param.grad is None:
|
1230 |
-
param.grad = torch.zeros_like(param)
|
1231 |
|
1232 |
loss = None
|
1233 |
if labels is not None:
|
@@ -1306,8 +1298,8 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1306 |
past_length = 0
|
1307 |
else:
|
1308 |
past_length = cache_position[-1] + 1
|
1309 |
-
input_ids = input_ids[:,
|
1310 |
-
position_ids = position_ids[:,
|
1311 |
|
1312 |
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
|
1313 |
|
@@ -1426,10 +1418,8 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
|
|
1426 |
sequence_lengths = -1
|
1427 |
else:
|
1428 |
if input_ids is not None:
|
1429 |
-
|
1430 |
-
sequence_lengths =
|
1431 |
-
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1432 |
-
sequence_lengths = sequence_lengths.to(logits.device)
|
1433 |
else:
|
1434 |
sequence_lengths = -1
|
1435 |
|
|
|
1220 |
hidden_states = hidden_states.to(dtype=self.lm_head.weight.dtype)
|
1221 |
|
1222 |
logits = self.lm_head(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1223 |
|
1224 |
loss = None
|
1225 |
if labels is not None:
|
|
|
1298 |
past_length = 0
|
1299 |
else:
|
1300 |
past_length = cache_position[-1] + 1
|
1301 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
1302 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
1303 |
|
1304 |
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
|
1305 |
|
|
|
1418 |
sequence_lengths = -1
|
1419 |
else:
|
1420 |
if input_ids is not None:
|
1421 |
+
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
1422 |
+
sequence_lengths = sequence_lengths.clamp(min=0).to(logits.device)
|
|
|
|
|
1423 |
else:
|
1424 |
sequence_lengths = -1
|
1425 |
|