Crystalcareai commited on
Commit
9e9951f
·
verified ·
1 Parent(s): e8a1698

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +4 -14
modeling_gemmoe.py CHANGED
@@ -1220,14 +1220,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1220
  hidden_states = hidden_states.to(dtype=self.lm_head.weight.dtype)
1221
 
1222
  logits = self.lm_head(hidden_states)
1223
- logits = logits.float()
1224
-
1225
- # Handle unused parameters
1226
- if self.training:
1227
- for expert in self.model.layers[-1].block_sparse_moe.experts:
1228
- for param in expert.parameters():
1229
- if param.requires_grad and param.grad is None:
1230
- param.grad = torch.zeros_like(param)
1231
 
1232
  loss = None
1233
  if labels is not None:
@@ -1306,8 +1298,8 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1306
  past_length = 0
1307
  else:
1308
  past_length = cache_position[-1] + 1
1309
- input_ids = input_ids[:, past_length:]
1310
- position_ids = position_ids[:, past_length:]
1311
 
1312
  cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
1313
 
@@ -1426,10 +1418,8 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1426
  sequence_lengths = -1
1427
  else:
1428
  if input_ids is not None:
1429
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1430
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1431
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1432
- sequence_lengths = sequence_lengths.to(logits.device)
1433
  else:
1434
  sequence_lengths = -1
1435
 
 
1220
  hidden_states = hidden_states.to(dtype=self.lm_head.weight.dtype)
1221
 
1222
  logits = self.lm_head(hidden_states)
 
 
 
 
 
 
 
 
1223
 
1224
  loss = None
1225
  if labels is not None:
 
1298
  past_length = 0
1299
  else:
1300
  past_length = cache_position[-1] + 1
1301
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1302
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1303
 
1304
  cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
1305
 
 
1418
  sequence_lengths = -1
1419
  else:
1420
  if input_ids is not None:
1421
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1422
+ sequence_lengths = sequence_lengths.clamp(min=0).to(logits.device)
 
 
1423
  else:
1424
  sequence_lengths = -1
1425