Upload FlaxTransformerLMForCausalLM
Browse files- flax_model.msgpack +1 -1
- modeling_transformerlm_flax.py +3 -0
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 524522413
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f43dc830c806b64d6a77027a61d16bd2fcbe896c799d5dbba0a81b9e7f26fc8b
|
3 |
size 524522413
|
modeling_transformerlm_flax.py
CHANGED
@@ -404,6 +404,9 @@ class FlaxTransformerLMPreTrainedModel(FlaxPreTrainedModel):
|
|
404 |
last_logits, last_cache = last
|
405 |
lm_logits = jnp.reshape(all_logits, (1, seq_length, vcab_size))
|
406 |
|
|
|
|
|
|
|
407 |
if not return_dict:
|
408 |
outputs = (lm_logits,) + (last_cache,)
|
409 |
else:
|
|
|
404 |
last_logits, last_cache = last
|
405 |
lm_logits = jnp.reshape(all_logits, (1, seq_length, vcab_size))
|
406 |
|
407 |
+
if input_ids.shape[1] > 1:
|
408 |
+
lm_logits = lm_logits[:, 1:, :] # Ignore leading zeros in prompts
|
409 |
+
|
410 |
if not return_dict:
|
411 |
outputs = (lm_logits,) + (last_cache,)
|
412 |
else:
|