farzadab commited on
Commit
6510504
·
verified ·
1 Parent(s): 1e13c4c

Upload 4 files

Browse files
preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "WhisperFeatureExtractor",
4
+ "feature_size": 80,
5
+ "hop_length": 160,
6
+ "n_fft": 400,
7
+ "n_samples": 480000,
8
+ "nb_max_frames": 3000,
9
+ "padding_side": "right",
10
+ "padding_value": 0.0,
11
+ "processor_class": "WhisperProcessor",
12
+ "return_attention_mask": false,
13
+ "sampling_rate": 16000
14
+ }
processor_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "audio_context_size": 3000,
3
+ "audio_padding": "longest",
4
+ "audio_placeholder": "<|audio|>",
5
+ "auto_map": {
6
+ "AutoProcessor": "ultravox_processing.UltravoxProcessor"
7
+ },
8
+ "encoder_ds_factor": 2,
9
+ "processor_class": "UltravoxProcessor",
10
+ "stack_factor": 8
11
+ }
ultravox_model.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
- from typing import Any, Dict, Optional, Set, Tuple, Union
 
3
 
4
  import peft
5
  import torch
@@ -9,6 +10,7 @@ import transformers
9
  import transformers.activations
10
  import transformers.modeling_outputs
11
  import transformers.models
 
12
  from transformers.models.whisper import modeling_whisper as whisper
13
 
14
  # We must use relative import in this directory to allow uploading to HF Hub
@@ -18,7 +20,7 @@ from .ultravox_config import LossFunction
18
  from .ultravox_config import UltravoxConfig
19
 
20
 
21
- class UltravoxModel(transformers.LlamaPreTrainedModel):
22
  """
23
  The Ultravox model which consists of an audio encoder and a language model.
24
 
@@ -36,6 +38,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
36
  config: UltravoxConfig # for type hinting
37
  # Usually we load encoder and LLM weights from a pretrained model separately, so they are allowed to be missing
38
  _keys_to_ignore_on_load_missing = ["audio_tower.*", "language_model.*"]
 
 
 
39
 
40
  def __init__(self, config: UltravoxConfig):
41
  super().__init__(config)
@@ -45,15 +50,16 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
45
  self.vocab_size = config.vocab_size
46
 
47
  self.audio_tower = self._create_audio_tower(config)
 
 
 
48
  self.multi_modal_projector = self._create_multi_modal_projector(config)
49
  self.language_model = self._create_language_model(config)
50
 
51
  # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
52
  # FSDP throws an error if some of the layer types are not found in the model.
53
- # This would be something like ["LlamaDecoderLayer", "WhisperEncoderLayer"]
54
- self._no_split_modules = (self.language_model._no_split_modules or []) + (
55
- self.audio_tower._no_split_modules or []
56
- )
57
 
58
  self.loss_config = LossConfig()
59
  self.post_init()
@@ -140,6 +146,24 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
140
  )
141
  return {"loss": kl_loss}
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def forward(
144
  self,
145
  input_ids: torch.Tensor,
@@ -148,8 +172,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
148
  labels: Optional[torch.Tensor] = None,
149
  attention_mask: Optional[torch.Tensor] = None,
150
  audio_token_start_idx: Optional[torch.Tensor] = None,
151
- audio_len: Optional[torch.Tensor] = None,
152
  audio_token_len: Optional[torch.Tensor] = None,
 
153
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
154
  # the alt_* fields are needed for KL divergence loss
155
  alt_input_ids: Optional[torch.Tensor] = None,
@@ -180,29 +205,37 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
180
  # B x T -> B x T x D
181
  inputs_embeds = self.get_input_embeddings().forward(input_ids)
182
 
183
- if audio_values is not None:
184
  assert (
185
- audio_token_start_idx is not None and audio_token_len is not None
186
- ), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided."
 
 
 
187
  assert (
188
- len(audio_token_start_idx) == len(audio_token_len) == len(audio_values)
189
- ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size."
190
-
191
- # B x A/3200 x D
 
 
 
 
 
 
192
  audio_tower_output = self.audio_tower.forward(
193
  audio_values.to(self.audio_tower.dtype),
194
- audio_len = audio_len
195
  ).last_hidden_state
196
  audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
197
-
198
  audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
199
 
200
  # combine audio and text embeddings
201
- for i, (audio, start, length) in enumerate(
202
- zip(audio_embeds, audio_token_start_idx, audio_token_len)
203
- ):
204
- length = min(length, audio.shape[0])
205
- inputs_embeds[i, start : start + length] = audio[:length]
206
 
207
  lm_output = self.language_model.forward(
208
  inputs_embeds=inputs_embeds,
@@ -237,7 +270,8 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
237
  audio_values: Optional[torch.FloatTensor] = None,
238
  audio_token_start_idx: Optional[torch.Tensor] = None,
239
  audio_token_len: Optional[torch.Tensor] = None,
240
- audio_len: Optional[torch.Tensor] = None,
 
241
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
242
  attention_mask: Optional[torch.Tensor] = None,
243
  inputs_embeds: Optional[torch.Tensor] = None,
@@ -266,7 +300,8 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
266
  audio_token_start_idx - prefill_start_idx
267
  )
268
  model_input["audio_token_len"] = audio_token_len
269
- model_input["audio_len"] = audio_len
 
270
 
271
  return model_input
272
 
@@ -283,18 +318,32 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
283
  cls, config: UltravoxConfig
284
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
285
  if config.audio_model_id is not None:
286
- if "whisper" in config.audio_model_id is not None:
287
  audio_tower = ModifiedWhisperEncoder.from_pretrained(
288
  config.audio_model_id, torch_dtype=config.torch_dtype
289
  )
 
 
 
290
  else:
 
 
 
 
291
  audio_tower = transformers.AutoModel.from_pretrained(
292
  config.audio_model_id, torch_dtype=config.torch_dtype
293
  )
294
  else:
295
- if "whisper" in config.audio_config._name_or_path:
296
  audio_tower = ModifiedWhisperEncoder(config.audio_config)
 
 
 
297
  else:
 
 
 
 
298
  with transformers.modeling_utils.no_init_weights():
299
  # we only ever use from_config if the weights are retrained, hence initializing is not
300
  # required. This makes the model quite creation faster since init on CPU is quite slow.
@@ -370,24 +419,34 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
370
 
371
  def push_to_hub(self, *args, **kwargs):
372
  self.merge_and_unload()
373
- self.to(self.language_model.dtype)
374
  return super().push_to_hub(*args, **kwargs)
375
 
376
- def save_pretrained(
377
- self, *args, state_dict: Optional[Dict[str, Any]] = None, **kwargs
378
- ):
379
  if state_dict is None:
380
  state_dict = super().state_dict()
381
 
382
- named_params = dict(self.named_parameters())
 
 
 
 
 
383
 
384
  state_dict = {
385
  k: v
386
  for k, v in state_dict.items()
387
- if k in self.keep_params
388
- or (k in named_params and named_params[k].requires_grad)
389
  }
390
 
 
 
 
 
 
 
 
391
  super().save_pretrained(*args, state_dict=state_dict, **kwargs)
392
 
393
  def _pre_load_state_dict_hook(self, state_dict: Dict[str, Any], *args, **kwargs):
@@ -422,8 +481,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
422
  )
423
 
424
 
 
425
  def is_cache_empty(
426
- past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]]
427
  ) -> bool:
428
  """
429
  Check if the cache is empty.
@@ -439,12 +499,18 @@ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
439
  """
440
  Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
441
  """
 
442
  lora_config = peft.LoraConfig(**lora_config or {})
443
 
444
  if lora_config.r == 0:
445
- # freeze the model entirely
446
- for param in model.parameters():
447
- param.requires_grad = False
 
 
 
 
 
448
  else:
449
  model = peft.get_peft_model(model, lora_config)
450
 
@@ -453,12 +519,8 @@ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
453
 
454
  class StackAudioFrames(nn.Module):
455
  """
456
- Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`.
457
-
458
- The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames.
459
- NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor,
460
- we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings.
461
- In most cases this extra padding will get removed in the model's forward function so it has no effect.
462
  """
463
 
464
  def __init__(self, stack_factor: int = 8):
@@ -468,7 +530,7 @@ class StackAudioFrames(nn.Module):
468
  def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
469
  B, T, C = audio_embeds.shape
470
  T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
471
- audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor))
472
  B, T, C = audio_embeds.shape
473
  audio_embeds = audio_embeds.view(
474
  B, T // self.stack_factor, C * self.stack_factor
@@ -488,31 +550,43 @@ class SwiGLU(nn.Module):
488
  return F.silu(gate) * x
489
 
490
 
491
- class UltravoxProjector(nn.Sequential):
492
  def __init__(self, config: UltravoxConfig):
493
  super().__init__()
494
  self.hidden_dim = config.hidden_size
495
  self._pad_and_stack = StackAudioFrames(config.stack_factor)
496
- dim = config.audio_config.hidden_size * config.stack_factor
497
- self.ln_pre = RMSNorm(dim, init=config.norm_init)
498
- self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
499
- dim = self.hidden_dim
500
  self.act = transformers.activations.get_activation(config.projector_act)
501
- dim = dim // 2 if config.projector_act == "swiglu" else dim
502
- self.linear_2 = nn.Linear(dim, config.text_config.hidden_size, bias=False)
503
- self.ln_post = RMSNorm(config.text_config.hidden_size, init=config.norm_init)
 
 
 
 
 
 
 
 
 
504
 
505
  def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
506
  audio_features = self._pad_and_stack(audio_features)
507
  audio_features = self.ln_pre(audio_features)
508
  hidden_states = self.linear_1(audio_features)
509
  hidden_states = self.act(hidden_states)
 
510
  hidden_states = self.linear_2(hidden_states)
511
  hidden_states = self.ln_post(hidden_states)
512
  return hidden_states
513
 
514
 
515
- class ModifiedWhisperEncoder(whisper.WhisperEncoder, transformers.modeling_utils.ModuleUtilsMixin):
 
 
516
  """
517
  Encoder portion of OpenAI's Whisper model.
518
 
@@ -528,6 +602,47 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder, transformers.modeling_utils
528
  base_model_prefix = "model.encoder"
529
  _no_split_modules = ["WhisperEncoderLayer"]
530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  def forward(
532
  self,
533
  input_features,
@@ -537,11 +652,7 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder, transformers.modeling_utils
537
  output_hidden_states=None,
538
  return_dict=None,
539
  ):
540
- expected_seq_length = (
541
- self.config.max_source_positions
542
- * self.conv1.stride[0]
543
- * self.conv2.stride[0]
544
- )
545
  if input_features.shape[-1] > expected_seq_length:
546
  raise ValueError(
547
  f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
@@ -574,23 +685,37 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder, transformers.modeling_utils
574
  encoder_states = () if output_hidden_states else None
575
  all_attentions = () if output_attentions else None
576
 
 
 
 
 
 
 
 
 
577
  attention_mask = None
578
  if audio_len != None:
579
  audio_feature_len = self._get_feat_extract_output_lengths(audio_len)
580
- batch_size = hidden_states.shape[0]
581
  max_seq_len = hidden_states.shape[1]
582
- attention_mask = (
583
- torch.arange(max_seq_len, device=hidden_states.device)[None, :]
584
- .expand(batch_size, -1)
585
- .lt(audio_feature_len.view(batch_size, 1))
586
- )
587
  attention_mask = self.get_extended_attention_mask(
588
  attention_mask,
589
  None,
590
- device=hidden_states.device,
591
  dtype=hidden_states.dtype,
592
  )
593
 
 
 
 
 
 
 
 
 
 
 
594
  # check if head_mask has a correct number of layers specified if desired
595
  if head_mask is not None:
596
  assert head_mask.size()[0] == (
 
1
  import logging
2
+ import re
3
+ from typing import Any, Dict, Generator, Optional, Set, Tuple, Union
4
 
5
  import peft
6
  import torch
 
10
  import transformers.activations
11
  import transformers.modeling_outputs
12
  import transformers.models
13
+ from transformers.generation.utils import GenerationMixin
14
  from transformers.models.whisper import modeling_whisper as whisper
15
 
16
  # We must use relative import in this directory to allow uploading to HF Hub
 
20
  from .ultravox_config import UltravoxConfig
21
 
22
 
23
+ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
24
  """
25
  The Ultravox model which consists of an audio encoder and a language model.
26
 
 
38
  config: UltravoxConfig # for type hinting
39
  # Usually we load encoder and LLM weights from a pretrained model separately, so they are allowed to be missing
40
  _keys_to_ignore_on_load_missing = ["audio_tower.*", "language_model.*"]
41
+ # Since we have kwargs in forward, we need to set this to False, otherwise grad_accum_steps will cause incorrect train loss to be reported
42
+ # see https://github.com/huggingface/transformers/issues/35856 and https://github.com/huggingface/trl/pull/2615/files
43
+ accepts_loss_kwargs = False
44
 
45
  def __init__(self, config: UltravoxConfig):
46
  super().__init__(config)
 
50
  self.vocab_size = config.vocab_size
51
 
52
  self.audio_tower = self._create_audio_tower(config)
53
+ self.audio_tower_context_length: Optional[int] = None
54
+ self.audio_tower_context_length = self.audio_tower.max_context_length
55
+
56
  self.multi_modal_projector = self._create_multi_modal_projector(config)
57
  self.language_model = self._create_language_model(config)
58
 
59
  # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
60
  # FSDP throws an error if some of the layer types are not found in the model.
61
+ # This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
62
+ self._no_split_modules = self.language_model._no_split_modules
 
 
63
 
64
  self.loss_config = LossConfig()
65
  self.post_init()
 
146
  )
147
  return {"loss": kl_loss}
148
 
149
+ def _audio_iter(
150
+ self, audio_batch_size: torch.Tensor
151
+ ) -> Generator[Tuple[int, int], None, None]:
152
+ """
153
+ Iterate over the audio batch size and yield the batch index and audio index of each audio item.
154
+
155
+ Args:
156
+ audio_batch_size: A tensor of shape (B,) where B is the batch size.
157
+
158
+ Returns:
159
+ A generator that yields a tuple of (start index, length) for each audio item.
160
+ """
161
+ audio_index = 0
162
+ for i_b, batch_count in enumerate(audio_batch_size):
163
+ for _ in range(batch_count):
164
+ yield i_b, audio_index
165
+ audio_index += 1
166
+
167
  def forward(
168
  self,
169
  input_ids: torch.Tensor,
 
172
  labels: Optional[torch.Tensor] = None,
173
  attention_mask: Optional[torch.Tensor] = None,
174
  audio_token_start_idx: Optional[torch.Tensor] = None,
175
+ audio_lens: Optional[torch.Tensor] = None,
176
  audio_token_len: Optional[torch.Tensor] = None,
177
+ audio_batch_size: Optional[torch.Tensor] = None,
178
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
179
  # the alt_* fields are needed for KL divergence loss
180
  alt_input_ids: Optional[torch.Tensor] = None,
 
205
  # B x T -> B x T x D
206
  inputs_embeds = self.get_input_embeddings().forward(input_ids)
207
 
208
+ if audio_values is not None and len(audio_values) > 0:
209
  assert (
210
+ audio_token_start_idx is not None
211
+ and audio_token_len is not None
212
+ and audio_lens is not None
213
+ and audio_batch_size is not None
214
+ ), "audio_token_start_idx/audio_token_len/audio_lens must be provided if audio_values are provided."
215
  assert (
216
+ len(audio_token_start_idx)
217
+ == len(audio_token_len)
218
+ == len(audio_lens)
219
+ == len(audio_values)
220
+ ), "audio_token_start_idx/audio_token_len/audio_lens/audio_values must have the same batch size."
221
+ assert len(audio_batch_size) == len(
222
+ inputs_embeds
223
+ ), "audio_batch_size and inputs_embeds must have the same batch size."
224
+
225
+ # B x A/3200 x (D=max-audio-length-in-batch)
226
  audio_tower_output = self.audio_tower.forward(
227
  audio_values.to(self.audio_tower.dtype),
228
+ audio_len=audio_lens,
229
  ).last_hidden_state
230
  audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
 
231
  audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
232
 
233
  # combine audio and text embeddings
234
+ for i_b, i_a in self._audio_iter(audio_batch_size):
235
+ start_idx = audio_token_start_idx[i_a]
236
+ token_len = audio_token_len[i_a]
237
+ item_embedding = audio_embeds[i_a][:token_len]
238
+ inputs_embeds[i_b][start_idx : start_idx + token_len] = item_embedding
239
 
240
  lm_output = self.language_model.forward(
241
  inputs_embeds=inputs_embeds,
 
270
  audio_values: Optional[torch.FloatTensor] = None,
271
  audio_token_start_idx: Optional[torch.Tensor] = None,
272
  audio_token_len: Optional[torch.Tensor] = None,
273
+ audio_lens: Optional[torch.Tensor] = None,
274
+ audio_batch_size: Optional[torch.Tensor] = None,
275
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
276
  attention_mask: Optional[torch.Tensor] = None,
277
  inputs_embeds: Optional[torch.Tensor] = None,
 
300
  audio_token_start_idx - prefill_start_idx
301
  )
302
  model_input["audio_token_len"] = audio_token_len
303
+ model_input["audio_batch_size"] = audio_batch_size
304
+ model_input["audio_lens"] = audio_lens
305
 
306
  return model_input
307
 
 
318
  cls, config: UltravoxConfig
319
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
320
  if config.audio_model_id is not None:
321
+ if "whisper" in config.audio_model_id.lower():
322
  audio_tower = ModifiedWhisperEncoder.from_pretrained(
323
  config.audio_model_id, torch_dtype=config.torch_dtype
324
  )
325
+ audio_tower.init_latency_mask(
326
+ config.audio_latency_block_size, dtype=config.torch_dtype
327
+ )
328
  else:
329
+ assert config.audio_latency_block_size in (
330
+ None,
331
+ 0,
332
+ ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
333
  audio_tower = transformers.AutoModel.from_pretrained(
334
  config.audio_model_id, torch_dtype=config.torch_dtype
335
  )
336
  else:
337
+ if "whisper" in config.audio_config._name_or_path.lower():
338
  audio_tower = ModifiedWhisperEncoder(config.audio_config)
339
+ audio_tower.init_latency_mask(
340
+ config.audio_latency_block_size, dtype=config.torch_dtype
341
+ )
342
  else:
343
+ assert config.audio_latency_block_size in (
344
+ None,
345
+ 0,
346
+ ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
347
  with transformers.modeling_utils.no_init_weights():
348
  # we only ever use from_config if the weights are retrained, hence initializing is not
349
  # required. This makes the model quite creation faster since init on CPU is quite slow.
 
419
 
420
  def push_to_hub(self, *args, **kwargs):
421
  self.merge_and_unload()
 
422
  return super().push_to_hub(*args, **kwargs)
423
 
424
+ def diff_state_dict(
425
+ self, state_dict: Optional[Dict[str, Any]] = None
426
+ ) -> Dict[str, Any]:
427
  if state_dict is None:
428
  state_dict = super().state_dict()
429
 
430
+ trainable_params = {k for k, v in self.named_parameters() if v.requires_grad}
431
+ # normalize the keys to match the original model
432
+ # Example: audio_tower.base_model.model.layers.0._fsdp_wrapped_module.self_attn.k_proj.lora_B.default.weight
433
+ trainable_params = {
434
+ k.replace("_fsdp_wrapped_module.", "") for k in trainable_params
435
+ }
436
 
437
  state_dict = {
438
  k: v
439
  for k, v in state_dict.items()
440
+ if k in self.keep_params or k in trainable_params
 
441
  }
442
 
443
+ return state_dict
444
+
445
+ def save_pretrained(
446
+ self, *args, state_dict: Optional[Dict[str, Any]] = None, **kwargs
447
+ ):
448
+ state_dict = self.diff_state_dict(state_dict)
449
+
450
  super().save_pretrained(*args, state_dict=state_dict, **kwargs)
451
 
452
  def _pre_load_state_dict_hook(self, state_dict: Dict[str, Any], *args, **kwargs):
 
481
  )
482
 
483
 
484
+ # TODO: refactor common parts to a shared module
485
  def is_cache_empty(
486
+ past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]],
487
  ) -> bool:
488
  """
489
  Check if the cache is empty.
 
499
  """
500
  Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
501
  """
502
+ unfreeze_layers = lora_config.pop("unfreeze_layers", None)
503
  lora_config = peft.LoraConfig(**lora_config or {})
504
 
505
  if lora_config.r == 0:
506
+ # freeze the model entirely, except for the specified layers
507
+ for name, param in model.named_parameters():
508
+ if not unfreeze_layers or not any(
509
+ re.match(layer, name) for layer in unfreeze_layers
510
+ ):
511
+ param.requires_grad = False
512
+ else:
513
+ logging.info(f"Unfreezing layer: {name} with #{param.numel()} params")
514
  else:
515
  model = peft.get_peft_model(model, lora_config)
516
 
 
519
 
520
  class StackAudioFrames(nn.Module):
521
  """
522
+ Stack the audio embedding frames to reduce the sequence length by a factor
523
+ of `stack_factor`.
 
 
 
 
524
  """
525
 
526
  def __init__(self, stack_factor: int = 8):
 
530
  def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
531
  B, T, C = audio_embeds.shape
532
  T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
533
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
534
  B, T, C = audio_embeds.shape
535
  audio_embeds = audio_embeds.view(
536
  B, T // self.stack_factor, C * self.stack_factor
 
550
  return F.silu(gate) * x
551
 
552
 
553
+ class UltravoxProjector(nn.Module):
554
  def __init__(self, config: UltravoxConfig):
555
  super().__init__()
556
  self.hidden_dim = config.hidden_size
557
  self._pad_and_stack = StackAudioFrames(config.stack_factor)
558
+ dim_in = config.audio_config.hidden_size * config.stack_factor
559
+ self.ln_pre = RMSNorm(dim_in, init=config.norm_init)
560
+ self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
561
+ dim_mid = self.hidden_dim
562
  self.act = transformers.activations.get_activation(config.projector_act)
563
+ dim_mid = dim_mid // 2 if config.projector_act == "swiglu" else dim_mid
564
+ dim_out = config.text_config.hidden_size
565
+ self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
566
+
567
+ # Ultravox v0.4.1 and below uses layer_norm after the second linear layer,
568
+ # while v0.5.0 and above uses layer_norm after the first linear layer.
569
+ if config.projector_ln_mid:
570
+ self.ln_mid: nn.Module = RMSNorm(dim_mid, init=config.norm_init)
571
+ self.ln_post: nn.Module = nn.Identity()
572
+ else:
573
+ self.ln_mid = nn.Identity()
574
+ self.ln_post = RMSNorm(dim_out, init=config.norm_init)
575
 
576
  def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
577
  audio_features = self._pad_and_stack(audio_features)
578
  audio_features = self.ln_pre(audio_features)
579
  hidden_states = self.linear_1(audio_features)
580
  hidden_states = self.act(hidden_states)
581
+ hidden_states = self.ln_mid(hidden_states)
582
  hidden_states = self.linear_2(hidden_states)
583
  hidden_states = self.ln_post(hidden_states)
584
  return hidden_states
585
 
586
 
587
+ class ModifiedWhisperEncoder(
588
+ whisper.WhisperEncoder, transformers.modeling_utils.ModuleUtilsMixin
589
+ ):
590
  """
591
  Encoder portion of OpenAI's Whisper model.
592
 
 
602
  base_model_prefix = "model.encoder"
603
  _no_split_modules = ["WhisperEncoderLayer"]
604
 
605
+ def __init__(self, config: transformers.WhisperConfig):
606
+ super().__init__(config)
607
+ self.config.is_decoder = False
608
+
609
+ @property
610
+ def max_context_length(self):
611
+ return (
612
+ self.config.max_source_positions
613
+ * self.conv1.stride[0]
614
+ * self.conv2.stride[0]
615
+ )
616
+
617
+ def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.dtype):
618
+ if audio_latency_block_size is None:
619
+ self.audio_streaming_mask = None
620
+ return
621
+
622
+ # Use max_context_length directly in the calculation
623
+ max_seqlen = self.max_context_length
624
+ assert (
625
+ max_seqlen > 0
626
+ ), f"maximum sequence length must be positive, got {max_seqlen}"
627
+ assert (
628
+ max_seqlen % audio_latency_block_size == 0
629
+ ), f"audio_latency_block_size {audio_latency_block_size} must divide {max_seqlen} evenly."
630
+ # Given the block size, we calculate number of blocks.
631
+ audio_latency_nblocks = max_seqlen // audio_latency_block_size
632
+ audio_streaming_mask = (
633
+ torch.tril(
634
+ torch.ones(audio_latency_nblocks, audio_latency_nblocks),
635
+ diagonal=0,
636
+ )
637
+ .repeat_interleave(audio_latency_block_size, dim=0)
638
+ .repeat_interleave(audio_latency_block_size, dim=1)
639
+ )
640
+ audio_streaming_mask = (1.0 - audio_streaming_mask) * torch.finfo(dtype).min
641
+ audio_streaming_mask = audio_streaming_mask[None, None, :, :]
642
+ self.register_buffer(
643
+ "audio_streaming_mask", audio_streaming_mask, persistent=False
644
+ )
645
+
646
  def forward(
647
  self,
648
  input_features,
 
652
  output_hidden_states=None,
653
  return_dict=None,
654
  ):
655
+ expected_seq_length = self.max_context_length
 
 
 
 
656
  if input_features.shape[-1] > expected_seq_length:
657
  raise ValueError(
658
  f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
 
685
  encoder_states = () if output_hidden_states else None
686
  all_attentions = () if output_attentions else None
687
 
688
+ # Create attention mask based on audio lengths to mask out padding tokens
689
+ # For each sample in batch:
690
+ # - Convert raw audio length to feature length after convolutions
691
+ # - Create boolean mask that is True for valid positions and False for padding
692
+ # - Convert to extended attention mask format expected by transformer layers
693
+ # (1.0 for positions to attend to, large negative for positions to ignore)
694
+ # This masking ensures consistent behavior between training and inference
695
+ # by preventing the model from attending to padding tokens in both cases
696
  attention_mask = None
697
  if audio_len != None:
698
  audio_feature_len = self._get_feat_extract_output_lengths(audio_len)
 
699
  max_seq_len = hidden_states.shape[1]
700
+ attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
701
+ None, :
702
+ ].lt(audio_feature_len.view(-1, 1))
 
 
703
  attention_mask = self.get_extended_attention_mask(
704
  attention_mask,
705
  None,
 
706
  dtype=hidden_states.dtype,
707
  )
708
 
709
+ if self.audio_streaming_mask is not None:
710
+ seqlen = hidden_states.size(-2)
711
+ if attention_mask is not None:
712
+ attention_mask = torch.minimum(
713
+ self.audio_streaming_mask[:, :, :seqlen, :seqlen], attention_mask
714
+ ) # merge
715
+ else:
716
+ attention_mask = self.audio_streaming_mask[:, :, :seqlen, :seqlen]
717
+ attention_mask = attention_mask.to(hidden_states.dtype)
718
+
719
  # check if head_mask has a correct number of layers specified if desired
720
  if head_mask is not None:
721
  assert head_mask.size()[0] == (
ultravox_processing.py CHANGED
@@ -1,12 +1,69 @@
1
- from typing import Optional, Union
 
2
 
3
  import numpy as np
4
  import torch
 
5
  import transformers
6
 
7
  from .ultravox_config import UltravoxConfig
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class UltravoxProcessor(transformers.ProcessorMixin):
11
  """
12
  Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
@@ -17,11 +74,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
17
  """
18
 
19
  attributes = ["audio_processor", "tokenizer"]
20
- audio_processor_class = (
21
- "Wav2Vec2Processor",
22
- "SeamlessM4TFeatureExtractor",
23
- "WhisperProcessor",
24
- )
25
  tokenizer_class = (
26
  "PreTrainedTokenizer",
27
  "PreTrainedTokenizerFast",
@@ -35,27 +88,32 @@ class UltravoxProcessor(transformers.ProcessorMixin):
35
  audio_processor=None,
36
  tokenizer=None,
37
  audio_padding: str = "longest",
38
- encoder_ds_factor: int = 320,
39
  stack_factor: int = 8,
40
  audio_placeholder: str = "<|audio|>",
 
 
41
  ):
42
  """
43
  Args:
44
  audio_processor: The audio processor for the audio encoder.
45
  tokenizer: The tokenizer for the language model.
46
  audio_padding: The padding strategy for the audio encoder.
47
- encoder_ds_factor: The downsample factor of the audio encoder.
48
  stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
 
49
  audio_placeholder: The placeholder for the audio in the text.
 
50
  """
51
  self.audio_padding = audio_padding
52
  self.encoder_ds_factor = encoder_ds_factor
53
  self.stack_factor = stack_factor
54
  self.audio_placeholder = audio_placeholder
55
- self.audio_token_replacement = tokenizer.eos_token
56
  assert (
57
- self.audio_token_replacement is not None
58
  ), "The tokenizer has no EOS token. Cannot recover."
 
 
59
  if tokenizer.pad_token_id is None:
60
  tokenizer.pad_token_id = tokenizer.eos_token_id
61
 
@@ -69,7 +127,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
69
  audio_processor = transformers.AutoProcessor.from_pretrained(
70
  config.audio_model_id
71
  or config.audio_config._name_or_path
72
- or "facebook/wav2vec2-base-960h"
73
  )
74
 
75
  tokenizer = transformers.AutoTokenizer.from_pretrained(
@@ -84,30 +142,100 @@ class UltravoxProcessor(transformers.ProcessorMixin):
84
  stack_factor=config.stack_factor,
85
  )
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def __call__(
88
  self,
89
  text: Optional[str] = None,
90
  audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
 
 
 
 
 
91
  sampling_rate: Optional[int] = None,
92
  return_tensors: Optional[
93
  Union[str, transformers.TensorType]
94
  ] = transformers.TensorType.PYTORCH,
 
95
  **kwargs,
96
  ) -> transformers.BatchFeature:
97
  """
98
  Main method to prepare for the model one text sequence and audio. This method forwards the `text`
99
  and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
100
  the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
101
- audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
102
  of the above two methods for more information.
103
 
104
  Args:
105
  text (`str`, `List[str]`):
106
  The sequence to be encoded. Sequence can be a string or (pretokenized string).
107
  audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
108
- The audio to be prepared. Audio can be NumPy array or PyTorch tensor. In case of a
109
- NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the
110
- sample length of the audio.
111
  sampling_rate (`int`, *optional*, defaults to 16000):
112
  Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
113
  you are doing.
@@ -131,69 +259,105 @@ class UltravoxProcessor(transformers.ProcessorMixin):
131
  Returned when `audio` is not `None`.
132
  - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
133
  """
134
- # TODO: Add support for multiple audio and text inputs.
 
 
 
 
 
 
 
135
  data = {}
136
- audio_embed_frames = 0
137
- if audio is not None and len(audio) > 0:
138
- if self.audio_padding == "max_length":
139
- # 30 seconds is the expected length for Whisper
140
- assert sampling_rate is not None, "Sampling rate must be provided."
141
- audio_len = 30 * sampling_rate
142
- else:
143
- audio_len = audio.shape[-1]
144
- # It's guaranteed that the number of frames is less than or equal to this amount.
145
- # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
146
- # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
147
- nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
148
- audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
149
- data["audio_token_len"] = [audio_embed_frames]
150
 
151
  # Main audio processing. The processor is model-specific.
152
- x = self.audio_processor(
153
- audio,
154
  sampling_rate=sampling_rate,
155
  padding="longest",
156
- max_length=audio_len,
 
157
  return_attention_mask=True,
158
  **kwargs,
159
  )
160
- if "input_features" in x:
161
- data["audio_values"] = x.input_features
162
- else:
163
- data["audio_values"] = x.input_values
164
- if self.audio_padding == "max_length":
165
- data["audio_len"] = x.attention_mask.sum(-1) - 1
166
- else:
167
- data["audio_len"] = [data["audio_values"].shape[-1]]
168
 
169
- if text is not None:
170
- assert isinstance(
171
- text, str
172
- ), "Text must be a string. Batch mode not supported yet."
173
- if self.audio_placeholder in text:
174
- if "audio_token_len" not in data:
175
- raise ValueError(
176
- f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
177
- )
178
-
179
- start_idx = len(
180
- self.tokenizer.encode(
181
- text[: text.index(self.audio_placeholder)],
182
- add_special_tokens=False,
183
- )
184
- )
185
- data["audio_token_start_idx"] = [start_idx]
186
-
187
- # Replace the audio placeholder with the audio token.
188
- # e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
189
- # where the number of </s> is the number of audio frames.
190
- text = text.replace(
191
- self.audio_placeholder,
192
- self.audio_token_replacement * audio_embed_frames,
193
  )
 
 
 
 
 
 
 
 
 
 
194
 
195
  # Special tokens like BOS should already have been added by the caller.
196
- data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
199
 
 
1
+ import dataclasses
2
+ from typing import Any, Dict, List, Optional, Union
3
 
4
  import numpy as np
5
  import torch
6
+ import torch.nn.functional as F
7
  import transformers
8
 
9
  from .ultravox_config import UltravoxConfig
10
 
11
 
12
+ @dataclasses.dataclass
13
+ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
14
+ # when enabled, the alt_input_ids, alt_attention_mask, and alt_labels fields are used for computing the KL loss in UltravoxModel
15
+ include_alt_fields: bool = False
16
+
17
+ def __call__(self, features, *args, **kwargs):
18
+ audio_values = [x for f in features for x in f.pop("audio_values", [])]
19
+ audio_lens = [x for f in features for x in f.pop("audio_lens", [])]
20
+ audio_token_len = [x for f in features for x in f.pop("audio_token_len", [])]
21
+ audio_token_start_idx = [
22
+ x for f in features for x in f.pop("audio_token_start_idx", [])
23
+ ]
24
+
25
+ if self.include_alt_fields:
26
+ # these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
27
+ alt_features = [
28
+ {
29
+ "input_ids": f.pop("alt_input_ids"),
30
+ "attention_mask": f.pop("alt_attention_mask"),
31
+ "labels": f.pop("alt_labels"),
32
+ }
33
+ for f in features
34
+ ]
35
+
36
+ batch = super().__call__(features, *args, **kwargs)
37
+ if self.include_alt_fields:
38
+ alt_batch = super().__call__(alt_features, *args, **kwargs)
39
+ batch["alt_input_ids"] = alt_batch["input_ids"]
40
+ batch["alt_attention_mask"] = alt_batch["attention_mask"]
41
+ batch["alt_labels"] = alt_batch["labels"]
42
+
43
+ batch["audio_token_start_idx"] = torch.stack(audio_token_start_idx)
44
+ batch["audio_lens"] = torch.stack(audio_lens)
45
+ batch["audio_token_len"] = torch.stack(audio_token_len)
46
+
47
+ # Pad the last dimension of all audio_values to the same length, with 0s on the right.
48
+ if audio_values:
49
+ max_len = max([x.shape[-1] for x in audio_values])
50
+ batch["audio_values"] = torch.stack(
51
+ [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
52
+ )
53
+ if self.tokenizer.padding_side == "left":
54
+ input_ids_lens = torch.LongTensor(
55
+ [f["input_ids"].shape[-1] for f in features]
56
+ )
57
+ displacement = batch["input_ids"].shape[-1] - input_ids_lens
58
+ displacement = displacement.repeat_interleave(
59
+ batch["audio_batch_size"].squeeze(-1)
60
+ )
61
+ batch["audio_token_start_idx"] += displacement.to(
62
+ batch["audio_token_start_idx"].device
63
+ )
64
+ return batch
65
+
66
+
67
  class UltravoxProcessor(transformers.ProcessorMixin):
68
  """
69
  Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
 
74
  """
75
 
76
  attributes = ["audio_processor", "tokenizer"]
77
+ audio_processor_class = ("WhisperProcessor",)
 
 
 
 
78
  tokenizer_class = (
79
  "PreTrainedTokenizer",
80
  "PreTrainedTokenizerFast",
 
88
  audio_processor=None,
89
  tokenizer=None,
90
  audio_padding: str = "longest",
91
+ encoder_ds_factor: int = 2,
92
  stack_factor: int = 8,
93
  audio_placeholder: str = "<|audio|>",
94
+ # Defaults to whisper encoder context size
95
+ audio_context_size: Optional[int] = 3000,
96
  ):
97
  """
98
  Args:
99
  audio_processor: The audio processor for the audio encoder.
100
  tokenizer: The tokenizer for the language model.
101
  audio_padding: The padding strategy for the audio encoder.
 
102
  stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
103
+ encoder_ds_factor: The downsampling factor of the audio encoder.
104
  audio_placeholder: The placeholder for the audio in the text.
105
+ audio_context_size: The maximum number of frames that the audio encoder can handle.
106
  """
107
  self.audio_padding = audio_padding
108
  self.encoder_ds_factor = encoder_ds_factor
109
  self.stack_factor = stack_factor
110
  self.audio_placeholder = audio_placeholder
111
+ self.audio_context_size = audio_context_size
112
  assert (
113
+ tokenizer.eos_token is not None
114
  ), "The tokenizer has no EOS token. Cannot recover."
115
+ self.vocab = tokenizer.get_vocab()
116
+ self.audio_token_replacement = tokenizer.eos_token
117
  if tokenizer.pad_token_id is None:
118
  tokenizer.pad_token_id = tokenizer.eos_token_id
119
 
 
127
  audio_processor = transformers.AutoProcessor.from_pretrained(
128
  config.audio_model_id
129
  or config.audio_config._name_or_path
130
+ or "openai/whisper-tiny"
131
  )
132
 
133
  tokenizer = transformers.AutoTokenizer.from_pretrained(
 
142
  stack_factor=config.stack_factor,
143
  )
144
 
145
+ def _chunk_and_pad_audio(
146
+ self,
147
+ audio_values: torch.Tensor,
148
+ audio_lens: torch.Tensor,
149
+ include_audio_num_chunks: bool = False,
150
+ ) -> Dict[str, Any]:
151
+ """
152
+ Processes the audio batch by chunking any items in the batch according to the audio_context_size,
153
+ padding the last chunk if needed, and returns a dictionary with updated audio data.
154
+
155
+ Args:
156
+ audio_values (torch.Tensor): A tensor of audio values (e.g., in B, D, T format).
157
+ audio_lens (torch.Tensor): A tensor of audio lengths.
158
+
159
+ Returns:
160
+ Dict[str, Any]: Dictionary with the following keys:
161
+ - "audio_values": The concatenated audio tensor after chunking and padding.
162
+ - "audio_lens": Tensor of lengths for each chunk.
163
+ - "audio_is_continuation": Tensor of booleans indicating if the chunk is a continuation of the previous chunk.
164
+ - "audio_batch_size": A Tensor with one integer representing the number of chunks.
165
+
166
+ """
167
+ chunked_audio_values: List[torch.Tensor] = []
168
+ chunked_audio_lens: List[int] = []
169
+ is_continuation_list: List[bool] = []
170
+ num_chunks: List[int] = []
171
+ context_size = self.audio_context_size or audio_values.shape[-1]
172
+
173
+ for i in range(audio_values.shape[0]): # iterate over the batch
174
+ num_chunks.append(int(np.ceil(audio_lens[i] / context_size)))
175
+ for offset in range(0, audio_lens[i], context_size):
176
+ is_continuation = offset > 0
177
+ chunk = audio_values[i, :, offset : offset + context_size]
178
+ if is_continuation and chunk.shape[-1] < context_size:
179
+ # N.B. We only need to pad continuation chunks. If none of the samples require chunking, the
180
+ # batch might not (need to) be padded all the way to the audio_context_size, in which case
181
+ # we've already included the padding above. On the other hand, if we have any continuation
182
+ # chunks we know that the batch needs to be padded to audio_context_size because that's what
183
+ # we're slicing to.
184
+ chunk = F.pad(chunk, (0, context_size - chunk.shape[-1]))
185
+ chunked_audio_values.append(chunk)
186
+ chunked_audio_lens.append(
187
+ min(int(audio_lens[i].item()) - offset, context_size)
188
+ )
189
+ is_continuation_list.append(is_continuation)
190
+
191
+ data = {
192
+ "audio_values": torch.stack(chunked_audio_values, dim=0),
193
+ "audio_lens": torch.tensor(
194
+ chunked_audio_lens, dtype=torch.int64, device=audio_values.device
195
+ ),
196
+ "audio_is_continuation": torch.tensor(
197
+ is_continuation_list, dtype=torch.bool, device=audio_values.device
198
+ ),
199
+ "audio_batch_size": torch.tensor(
200
+ [len(chunked_audio_values)], device=audio_values.device
201
+ ),
202
+ }
203
+ if include_audio_num_chunks:
204
+ data["audio_num_chunks"] = torch.tensor(
205
+ num_chunks, dtype=torch.int64, device=audio_values.device
206
+ )
207
+ return data
208
+
209
  def __call__(
210
  self,
211
  text: Optional[str] = None,
212
  audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
213
+ audios: Optional[
214
+ Union[
215
+ List[Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor]
216
+ ]
217
+ ] = None,
218
  sampling_rate: Optional[int] = None,
219
  return_tensors: Optional[
220
  Union[str, transformers.TensorType]
221
  ] = transformers.TensorType.PYTORCH,
222
+ include_audio_num_chunks: bool = False,
223
  **kwargs,
224
  ) -> transformers.BatchFeature:
225
  """
226
  Main method to prepare for the model one text sequence and audio. This method forwards the `text`
227
  and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
228
  the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
229
+ audio processor's [`~WhisperProcessor.__call__`] if `audio` is not `None`. Please refer to the docstring
230
  of the above two methods for more information.
231
 
232
  Args:
233
  text (`str`, `List[str]`):
234
  The sequence to be encoded. Sequence can be a string or (pretokenized string).
235
  audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
236
+ The audio to be prepared. Audio can be a single-channel (1-dimensional) NumPy array or PyTorch tensor.
237
+ audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
238
+ A list or two dimensional array of audio to be prepared.
239
  sampling_rate (`int`, *optional*, defaults to 16000):
240
  Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
241
  you are doing.
 
259
  Returned when `audio` is not `None`.
260
  - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
261
  """
262
+ # TODO: Add support for multiple text inputs.
263
+ if audio is not None and audios is not None:
264
+ raise ValueError("Only one of `audio` or `audios` should be provided.")
265
+ elif audio is not None:
266
+ audios = audio if isinstance(audio, list) or audio.ndim == 2 else [audio]
267
+ elif audios is None:
268
+ audios = []
269
+
270
  data = {}
271
+ audio_is_continuation = []
272
+ if len(audios) > 0:
273
+ audios = [x.numpy() if isinstance(x, torch.Tensor) else x for x in audios]
274
+
275
+ # Pad out each audio to at least 2 hops (the minimum required by the processor).
276
+ hop_length = self.audio_processor.feature_extractor.hop_length
277
+ audios = [
278
+ (
279
+ np.pad(x, (0, 2 * hop_length - len(x)), mode="constant")
280
+ if len(x) < 2 * hop_length
281
+ else x
282
+ )
283
+ for x in audios
284
+ ]
285
 
286
  # Main audio processing. The processor is model-specific.
287
+ x: transformers.BatchFeature = self.audio_processor(
288
+ audios,
289
  sampling_rate=sampling_rate,
290
  padding="longest",
291
+ pad_to_multiple_of=hop_length, # The attention mask effectively gets padded to the hop length, so pad the audio to be consistent.
292
+ truncation=False,
293
  return_attention_mask=True,
294
  **kwargs,
295
  )
 
 
 
 
 
 
 
 
296
 
297
+ data.update(
298
+ self._chunk_and_pad_audio(
299
+ audio_values=torch.as_tensor(
300
+ x.input_features if "input_features" in x else x.input_values
301
+ ),
302
+ audio_lens=torch.as_tensor(x.attention_mask).sum(-1),
303
+ include_audio_num_chunks=include_audio_num_chunks,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  )
305
+ )
306
+
307
+ audio_is_continuation = data.pop("audio_is_continuation")
308
+ data["audio_token_len"] = torch.ceil(
309
+ data["audio_lens"] / (self.encoder_ds_factor * self.stack_factor)
310
+ ).to(dtype=torch.int)
311
+
312
+ if text is not None:
313
+ if not isinstance(text, str):
314
+ raise ValueError("Text must be a string. Batch mode not supported yet.")
315
 
316
  # Special tokens like BOS should already have been added by the caller.
317
+ tokenized_parts = self.tokenizer(
318
+ text.split(
319
+ "<|audio|>" # The placeholder isn't part of the vocabulary, so split the text around it.
320
+ ),
321
+ add_special_tokens=False,
322
+ **kwargs,
323
+ )
324
+
325
+ audio_token_start_idx = []
326
+ placeholder_index = -1
327
+ split_input_ids = tokenized_parts["input_ids"]
328
+ input_ids: List[int] = []
329
+
330
+ audio_token_replacement_token_id = self.vocab[self.audio_token_replacement]
331
+
332
+ for i, token_len in enumerate(data.get("audio_token_len", [])):
333
+ if not audio_is_continuation[i]:
334
+ placeholder_index += 1
335
+ if placeholder_index >= len(split_input_ids):
336
+ raise ValueError(
337
+ f"Text contains too few audio placeholders. (Expected {len(audios)} placeholders)"
338
+ )
339
+
340
+ input_ids.extend(split_input_ids[placeholder_index])
341
+
342
+ audio_token_start_idx.append(len(input_ids))
343
+
344
+ input_ids.extend([audio_token_replacement_token_id] * token_len)
345
+
346
+ # Include any tokens after the last audio.
347
+ placeholder_index += 1
348
+ if placeholder_index != len(split_input_ids) - 1:
349
+ raise ValueError(
350
+ f"Text contains too many audio placeholders. (Expected {len(audios)} placeholders)"
351
+ )
352
+ input_ids.extend(split_input_ids[placeholder_index])
353
+
354
+ if "audio_token_len" in data:
355
+ data["audio_token_start_idx"] = torch.as_tensor(audio_token_start_idx)
356
+
357
+ data["input_ids"] = [input_ids]
358
+ data["attention_mask"] = [[1] * len(input_ids)]
359
+
360
+ # Ensure that there are no audio placeholders after the last audio.
361
 
362
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
363