Update any_model.py
Browse files- any_model.py +2 -12
any_model.py
CHANGED
@@ -667,11 +667,7 @@ class AnyModelForConditionalGeneration(AnyModelPreTrainedModel):
|
|
667 |
# 2. Merge text and images
|
668 |
if pixel_values_1 is not None and pixel_values_1 is not None and input_ids.shape[1] != 1:
|
669 |
assert modality is not None, "modality must be provided when pixel_values is not None"
|
670 |
-
|
671 |
-
if isinstance(modality, list):
|
672 |
-
assert len(set(modality)) == 1, "only one kind modality can be provided in a batch"
|
673 |
-
modality = modality[0]
|
674 |
-
'''
|
675 |
for i in range(2):
|
676 |
pixel_values = pixel_values_1 if i == 0 else pixel_values_2
|
677 |
if modality[0][i] == ModalityType.IMAGE:
|
@@ -698,12 +694,6 @@ class AnyModelForConditionalGeneration(AnyModelPreTrainedModel):
|
|
698 |
raise ValueError(f"modality {modality[i]} is not supported")
|
699 |
|
700 |
inputs_embeds = inputs_embeds.to(features.dtype)
|
701 |
-
'''
|
702 |
-
print('+++'*10)
|
703 |
-
print(input_ids)
|
704 |
-
print(torch.sum(input_ids == self.config.audio_token_index, dim=-1))
|
705 |
-
print('+++'*10)
|
706 |
-
'''
|
707 |
inputs_embeds, attention_mask, labels, position_ids = self.merge_input_ids_with_other_features(
|
708 |
features, inputs_embeds, input_ids, attention_mask, labels
|
709 |
)
|
@@ -832,7 +822,7 @@ class AnyRewardModel(AnyModelForConditionalGeneration):
|
|
832 |
attention_mask: torch.Tensor | None = None,
|
833 |
**kwargs,
|
834 |
) -> torch.Tensor:
|
835 |
-
outputs =
|
836 |
input_ids,
|
837 |
attention_mask=attention_mask,
|
838 |
output_hidden_states=True,
|
|
|
667 |
# 2. Merge text and images
|
668 |
if pixel_values_1 is not None and pixel_values_1 is not None and input_ids.shape[1] != 1:
|
669 |
assert modality is not None, "modality must be provided when pixel_values is not None"
|
670 |
+
|
|
|
|
|
|
|
|
|
671 |
for i in range(2):
|
672 |
pixel_values = pixel_values_1 if i == 0 else pixel_values_2
|
673 |
if modality[0][i] == ModalityType.IMAGE:
|
|
|
694 |
raise ValueError(f"modality {modality[i]} is not supported")
|
695 |
|
696 |
inputs_embeds = inputs_embeds.to(features.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
697 |
inputs_embeds, attention_mask, labels, position_ids = self.merge_input_ids_with_other_features(
|
698 |
features, inputs_embeds, input_ids, attention_mask, labels
|
699 |
)
|
|
|
822 |
attention_mask: torch.Tensor | None = None,
|
823 |
**kwargs,
|
824 |
) -> torch.Tensor:
|
825 |
+
outputs = super().forward(
|
826 |
input_ids,
|
827 |
attention_mask=attention_mask,
|
828 |
output_hidden_states=True,
|