eliotj commited on
Commit
258bd1d
1 Parent(s): 31f090d

Update modeling_prismatic.py to account for the case where `input_ids` is `None `

Browse files

Input Ids and Input Embeds are both marked `Optional[torch.LongTensor] = None,` however failing to pass in `input_ids` into the `forward()` method results in an error in the first block, since the code automatically checks if `input_ids.shape[1] == 1` without first checking to see if `input_ids is not None`.

This pull request updates the logic to allow for this case in Generation with Cache and Multimodal Forward.

Files changed (1) hide show
  1. modeling_prismatic.py +5 -2
modeling_prismatic.py CHANGED
@@ -322,7 +322,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
322
  # => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
323
 
324
  # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
325
- if input_ids.shape[1] == 1:
326
  assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
327
  assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
328
  assert labels is None, "Unexpected key `labels` provided during cached generation!"
@@ -359,7 +359,10 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
359
  )
360
 
361
  # === Handle Multimodal Forward ===
362
- elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
 
 
 
363
  assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
364
 
365
  # Visual Feature Extraction
 
322
  # => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
323
 
324
  # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
325
+ if input_ids is not None and input_ids.shape[1] == 1:
326
  assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
327
  assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
328
  assert labels is None, "Unexpected key `labels` provided during cached generation!"
 
359
  )
360
 
361
  # === Handle Multimodal Forward ===
362
+ elif (
363
+ (input_ids is not None and input_ids.shape[0] == pixel_values.shape[0]) or
364
+ (inputs_embeds is not None and inputs_embeds.shape[0] == pixel_values.shape[0])
365
+ ):
366
  assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
367
 
368
  # Visual Feature Extraction