XuyaoWang commited on
Commit
2010bc5
1 Parent(s): 500568f

Update any_model.py

Browse files
Files changed (1) hide show
  1. any_model.py +2 -12
any_model.py CHANGED
@@ -667,11 +667,7 @@ class AnyModelForConditionalGeneration(AnyModelPreTrainedModel):
667
  # 2. Merge text and images
668
  if pixel_values_1 is not None and pixel_values_1 is not None and input_ids.shape[1] != 1:
669
  assert modality is not None, "modality must be provided when pixel_values is not None"
670
- '''
671
- if isinstance(modality, list):
672
- assert len(set(modality)) == 1, "only one kind modality can be provided in a batch"
673
- modality = modality[0]
674
- '''
675
  for i in range(2):
676
  pixel_values = pixel_values_1 if i == 0 else pixel_values_2
677
  if modality[0][i] == ModalityType.IMAGE:
@@ -698,12 +694,6 @@ class AnyModelForConditionalGeneration(AnyModelPreTrainedModel):
698
  raise ValueError(f"modality {modality[i]} is not supported")
699
 
700
  inputs_embeds = inputs_embeds.to(features.dtype)
701
- '''
702
- print('+++'*10)
703
- print(input_ids)
704
- print(torch.sum(input_ids == self.config.audio_token_index, dim=-1))
705
- print('+++'*10)
706
- '''
707
  inputs_embeds, attention_mask, labels, position_ids = self.merge_input_ids_with_other_features(
708
  features, inputs_embeds, input_ids, attention_mask, labels
709
  )
@@ -832,7 +822,7 @@ class AnyRewardModel(AnyModelForConditionalGeneration):
832
  attention_mask: torch.Tensor | None = None,
833
  **kwargs,
834
  ) -> torch.Tensor:
835
- outputs = self.model(
836
  input_ids,
837
  attention_mask=attention_mask,
838
  output_hidden_states=True,
 
667
  # 2. Merge text and images
668
  if pixel_values_1 is not None and pixel_values_1 is not None and input_ids.shape[1] != 1:
669
  assert modality is not None, "modality must be provided when pixel_values is not None"
670
+
 
 
 
 
671
  for i in range(2):
672
  pixel_values = pixel_values_1 if i == 0 else pixel_values_2
673
  if modality[0][i] == ModalityType.IMAGE:
 
694
  raise ValueError(f"modality {modality[i]} is not supported")
695
 
696
  inputs_embeds = inputs_embeds.to(features.dtype)
 
 
 
 
 
 
697
  inputs_embeds, attention_mask, labels, position_ids = self.merge_input_ids_with_other_features(
698
  features, inputs_embeds, input_ids, attention_mask, labels
699
  )
 
822
  attention_mask: torch.Tensor | None = None,
823
  **kwargs,
824
  ) -> torch.Tensor:
825
+ outputs = super().forward(
826
  input_ids,
827
  attention_mask=attention_mask,
828
  output_hidden_states=True,