Update modeling_prismatic.py to account for the case where `input_ids` is `None `
Browse filesInput Ids and Input Embeds are both marked `Optional[torch.LongTensor] = None,` however failing to pass in `input_ids` into the `forward()` method results in an error in the first block, since the code automatically checks if `input_ids.shape[1] == 1` without first checking to see if `input_ids is not None`.
This pull request updates the logic to allow for this case in Generation with Cache and Multimodal Forward.
- modeling_prismatic.py +5 -2
modeling_prismatic.py
CHANGED
@@ -322,7 +322,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
322 |
# => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
|
323 |
|
324 |
# === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
|
325 |
-
if input_ids.shape[1] == 1:
|
326 |
assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
|
327 |
assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
|
328 |
assert labels is None, "Unexpected key `labels` provided during cached generation!"
|
@@ -359,7 +359,10 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
359 |
)
|
360 |
|
361 |
# === Handle Multimodal Forward ===
|
362 |
-
elif (
|
|
|
|
|
|
|
363 |
assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
|
364 |
|
365 |
# Visual Feature Extraction
|
|
|
322 |
# => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
|
323 |
|
324 |
# === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
|
325 |
+
if input_ids is not None and input_ids.shape[1] == 1:
|
326 |
assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
|
327 |
assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
|
328 |
assert labels is None, "Unexpected key `labels` provided during cached generation!"
|
|
|
359 |
)
|
360 |
|
361 |
# === Handle Multimodal Forward ===
|
362 |
+
elif (
|
363 |
+
(input_ids is not None and input_ids.shape[0] == pixel_values.shape[0]) or
|
364 |
+
(inputs_embeds is not None and inputs_embeds.shape[0] == pixel_values.shape[0])
|
365 |
+
):
|
366 |
assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
|
367 |
|
368 |
# Visual Feature Extraction
|