Upload FlaxTransformerLMForCausalLM
Browse files
modeling_transformerlm_flax.py
CHANGED
@@ -426,6 +426,7 @@ class FlaxTransformerLMPreTrainedModel(FlaxPreTrainedModel):
|
|
426 |
mutable=mutable,
|
427 |
)
|
428 |
lm_logits = output.logits
|
|
|
429 |
if input_ids.shape[1] > 1:
|
430 |
lm_logits = lm_logits[:, 1:, :] # Ignore leading zeros in prompts
|
431 |
|
|
|
426 |
mutable=mutable,
|
427 |
)
|
428 |
lm_logits = output.logits
|
429 |
+
|
430 |
if input_ids.shape[1] > 1:
|
431 |
lm_logits = lm_logits[:, 1:, :] # Ignore leading zeros in prompts
|
432 |
|