Update modeling_prismatic.py to account for the case where `input_ids` is `None `

#5
by eliotj - opened
Files changed (1) hide show
  1. modeling_prismatic.py +5 -2
modeling_prismatic.py CHANGED
@@ -322,7 +322,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
322
  # => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
323
 
324
  # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
325
- if input_ids.shape[1] == 1:
326
  assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
327
  assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
328
  assert labels is None, "Unexpected key `labels` provided during cached generation!"
@@ -359,7 +359,10 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
359
  )
360
 
361
  # === Handle Multimodal Forward ===
362
- elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
 
 
 
363
  assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
364
 
365
  # Visual Feature Extraction
 
322
  # => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
323
 
324
  # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
325
+ if input_ids is not None and input_ids.shape[1] == 1:
326
  assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
327
  assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
328
  assert labels is None, "Unexpected key `labels` provided during cached generation!"
 
359
  )
360
 
361
  # === Handle Multimodal Forward ===
362
+ elif (
363
+ (input_ids is not None and input_ids.shape[0] == pixel_values.shape[0]) or
364
+ (inputs_embeds is not None and inputs_embeds.shape[0] == pixel_values.shape[0])
365
+ ):
366
  assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
367
 
368
  # Visual Feature Extraction