fukugawa commited on
Commit
fe82d0f
·
verified ·
1 Parent(s): e06470c

Upload FlaxTransformerLMForCausalLM

Browse files
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f146d3ff30cefcdecb23e15f175395be65b13de37821c4d78b32feb8415f3666
3
  size 524522413
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f43dc830c806b64d6a77027a61d16bd2fcbe896c799d5dbba0a81b9e7f26fc8b
3
  size 524522413
modeling_transformerlm_flax.py CHANGED
@@ -404,6 +404,9 @@ class FlaxTransformerLMPreTrainedModel(FlaxPreTrainedModel):
404
  last_logits, last_cache = last
405
  lm_logits = jnp.reshape(all_logits, (1, seq_length, vcab_size))
406
 
 
 
 
407
  if not return_dict:
408
  outputs = (lm_logits,) + (last_cache,)
409
  else:
 
404
  last_logits, last_cache = last
405
  lm_logits = jnp.reshape(all_logits, (1, seq_length, vcab_size))
406
 
407
+ if input_ids.shape[1] > 1:
408
+ lm_logits = lm_logits[:, 1:, :] # Ignore leading zeros in prompts
409
+
410
  if not return_dict:
411
  outputs = (lm_logits,) + (last_cache,)
412
  else: