This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. configuration_whisper (1).py +344 -0
  3. configuration_whisper.cpython-312 (1).pyc +0 -0
  4. configuration_whisper.cpython-312.pyc +0 -0
  5. configuration_whisper.py +344 -0
  6. feature_extraction_whisper (1).py +324 -0
  7. feature_extraction_whisper.cpython-312 (1).pyc +0 -0
  8. feature_extraction_whisper.cpython-312.pyc +0 -0
  9. feature_extraction_whisper.py +324 -0
  10. find-corrupt-whisper-files (1).py +81 -0
  11. find-corrupt-whisper-files (2).py +81 -0
  12. find-corrupt-whisper-files.py +81 -0
  13. generation_whisper (1).py +1881 -0
  14. generation_whisper.cpython-312 (1).pyc +0 -0
  15. generation_whisper.cpython-312.pyc +0 -0
  16. generation_whisper.py +1881 -0
  17. modeling_flax_whisper (1).py +1696 -0
  18. modeling_flax_whisper.cpython-312 (1).pyc +0 -0
  19. modeling_flax_whisper.cpython-312.pyc +0 -0
  20. modeling_flax_whisper.py +1696 -0
  21. modeling_tf_whisper (1).py +1758 -0
  22. modeling_tf_whisper.cpython-312 (1).pyc +0 -0
  23. modeling_tf_whisper.cpython-312.pyc +0 -0
  24. modeling_tf_whisper.py +1758 -0
  25. modeling_whisper (1).py +0 -0
  26. modeling_whisper.cpython-312 (1).pyc +3 -0
  27. modeling_whisper.cpython-312.pyc +3 -0
  28. modeling_whisper.py +0 -0
  29. processing_whisper (1).py +97 -0
  30. processing_whisper.cpython-312 (1).pyc +0 -0
  31. processing_whisper.cpython-312.pyc +0 -0
  32. processing_whisper.py +97 -0
  33. realtime-whisper-webgpu/.eslintrc.cjs +21 -0
  34. realtime-whisper-webgpu/.gitignore +24 -0
  35. realtime-whisper-webgpu/README.md +8 -0
  36. realtime-whisper-webgpu/index.html +13 -0
  37. realtime-whisper-webgpu/package-lock.json +0 -0
  38. realtime-whisper-webgpu/package.json +30 -0
  39. realtime-whisper-webgpu/postcss.config.js +6 -0
  40. realtime-whisper-webgpu/public/banner.png +3 -0
  41. realtime-whisper-webgpu/public/logo.png +3 -0
  42. realtime-whisper-webgpu/public/realtime-whisper-webgpu.mp4 +3 -0
  43. realtime-whisper-webgpu/src/App.jsx +321 -0
  44. realtime-whisper-webgpu/src/components/AudioVisualizer.jsx +57 -0
  45. realtime-whisper-webgpu/src/components/LanguageSelector.jsx +134 -0
  46. realtime-whisper-webgpu/src/components/Progress.jsx +22 -0
  47. realtime-whisper-webgpu/src/index.css +32 -0
  48. realtime-whisper-webgpu/src/main.jsx +10 -0
  49. realtime-whisper-webgpu/src/worker.js +143 -0
  50. realtime-whisper-webgpu/tailwind.config.js +8 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ modeling_whisper.cpython-312[[:space:]](1).pyc filter=lfs diff=lfs merge=lfs -text
37
+ modeling_whisper.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
38
+ realtime-whisper-webgpu/public/banner.png filter=lfs diff=lfs merge=lfs -text
39
+ realtime-whisper-webgpu/public/logo.png filter=lfs diff=lfs merge=lfs -text
40
+ realtime-whisper-webgpu/public/realtime-whisper-webgpu.mp4 filter=lfs diff=lfs merge=lfs -text
configuration_whisper (1).py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Whisper model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
22
+ from ...utils import logging
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from ...feature_extraction_utils import FeatureExtractionMixin
27
+ from ...tokenization_utils_base import PreTrainedTokenizerBase
28
+ from ...utils import TensorType
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ # fmt: off
34
+ NON_SPEECH_TOKENS = [
35
+ 1, 2, 7, 8, 9, 10, 14, 25,
36
+ 26, 27, 28, 29, 31, 58, 59, 60, 61, 62,
37
+ 63, 90, 91, 92, 93, 357, 366, 438, 532, 685,
38
+ 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377,
39
+ 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211,
40
+ 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786,
41
+ 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791,
42
+ 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409,
43
+ 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50359, 50360, 50361
44
+ ]
45
+ NON_SPEECH_TOKENS_MULTI = [
46
+ 1, 2, 7, 8, 9, 10, 14, 25,
47
+ 26, 27, 28, 29, 31, 58, 59, 60, 61, 62,
48
+ 63, 90, 91, 92, 93, 359, 503, 522, 542, 873,
49
+ 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627,
50
+ 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647,
51
+ 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793,
52
+ 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675,
53
+ 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865,
54
+ 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362
55
+ ]
56
+ # fmt: on
57
+
58
+
59
+ class WhisperConfig(PretrainedConfig):
60
+ r"""
61
+ This is the configuration class to store the configuration of a [`WhisperModel`]. It is used to instantiate a
62
+ Whisper model according to the specified arguments, defining the model architecture. Instantiating a configuration
63
+ with the defaults will yield a similar configuration to that of the Whisper
64
+ [openai/whisper-tiny](https://huggingface.co/openai/whisper-tiny) architecture.
65
+
66
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
67
+ documentation from [`PretrainedConfig`] for more information.
68
+
69
+
70
+ Args:
71
+ vocab_size (`int`, *optional*, defaults to 51865):
72
+ Vocabulary size of the Whisper model. Defines the number of different tokens that can be represented by the
73
+ `decoder_input_ids` passed when calling [`WhisperModel`]
74
+ num_mel_bins (`int`, *optional*, defaults to 80):
75
+ Number of mel features used per input features. Should correspond to the value used in the
76
+ `WhisperProcessor` class.
77
+ encoder_layers (`int`, *optional*, defaults to 4):
78
+ Number of encoder layers.
79
+ decoder_layers (`int`, *optional*, defaults to 4):
80
+ Number of decoder layers.
81
+ encoder_attention_heads (`int`, *optional*, defaults to 6):
82
+ Number of attention heads for each attention layer in the Transformer encoder.
83
+ decoder_attention_heads (`int`, *optional*, defaults to 6):
84
+ Number of attention heads for each attention layer in the Transformer decoder.
85
+ encoder_ffn_dim (`int`, *optional*, defaults to 1536):
86
+ Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
87
+ decoder_ffn_dim (`int`, *optional*, defaults to 1536):
88
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
89
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
90
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
91
+ for more details.
92
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
93
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
94
+ for more details.
95
+ decoder_start_token_id (`int`, *optional*, defaults to 50257):
96
+ Corresponds to the "<|startoftranscript|>" token, which is automatically used when no `decoder_input_ids`
97
+ are provided to the `generate` function. It is used to guide the model`s generation process depending on
98
+ the task.
99
+ use_cache (`bool`, *optional*, defaults to `True`):
100
+ Whether or not the model should return the last key/values attentions (not used by all models).
101
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
102
+ Whether the model is used as an encoder/decoder or not.
103
+ activation_function (`str`, *optional*, defaults to `"gelu"`):
104
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
105
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
106
+ d_model (`int`, *optional*, defaults to 384):
107
+ Dimensionality of the layers.
108
+ dropout (`float`, *optional*, defaults to 0.1):
109
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
110
+ attention_dropout (`float`, *optional*, defaults to 0.0):
111
+ The dropout ratio for the attention probabilities.
112
+ activation_dropout (`float`, *optional*, defaults to 0.0):
113
+ The dropout ratio for activations inside the fully connected layer.
114
+ init_std (`float`, *optional*, defaults to 0.02):
115
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
116
+ scale_embedding (`bool`, *optional*, defaults to False):
117
+ Scale embeddings by diving by sqrt(d_model).
118
+ max_source_positions (`int`, *optional*, defaults to 1500):
119
+ The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
120
+ max_target_positions (`int`, *optional*, defaults to 448):
121
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
122
+ just in case (e.g., 512 or 1024 or 2048).
123
+ pad_token_id (`int`, *optional*, defaults to 50256):
124
+ Padding token id.
125
+ bos_token_id (`int`, *optional*, defaults to 50256):
126
+ Begin of stream token id.
127
+ eos_token_id (`int`, *optional*, defaults to 50256):
128
+ End of stream token id.
129
+ suppress_tokens (`List[int]`, *optional*):
130
+ A list containing the non-speech tokens that will be used by the logit processor in the `generate`
131
+ function. NON_SPEECH_TOKENS and NON_SPEECH_TOKENS_MULTI each correspond to the `english-only` and the
132
+ `multilingual` model.
133
+ begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`):
134
+ A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as
135
+ the token for `" "` (`blank_token_id`) and the `eos_token_id`
136
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
137
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
138
+ instance of [`WhisperForAudioClassification`].
139
+ classifier_proj_size (`int`, *optional*, defaults to 256):
140
+ Dimensionality of the projection before token mean-pooling for classification. Only relevant when using an
141
+ instance of [`WhisperForAudioClassification`].
142
+ apply_spec_augment (`bool`, *optional*, defaults to `False`):
143
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
144
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
145
+ Recognition](https://arxiv.org/abs/1904.08779).
146
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
147
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
148
+ procecure generates `mask_time_prob*len(time_axis)/mask_time_length` independent masks over the axis. If
149
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
150
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
151
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment == True`.
152
+ mask_time_length (`int`, *optional*, defaults to 10):
153
+ Length of vector span along the time axis.
154
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
155
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
156
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
157
+ mask_time_min_masks''
158
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
159
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
160
+ masking procecure generates `mask_feature_prob*len(feature_axis)/mask_time_length` independent masks over
161
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
162
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
163
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
164
+ True`.
165
+ mask_feature_length (`int`, *optional*, defaults to 10):
166
+ Length of vector span along the feature axis.
167
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
168
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
169
+ step, irrespectively of `mask_feature_prob`. Only relevant if
170
+ `mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`.
171
+ median_filter_width (`int`, *optional*, defaults to 7):
172
+ Width of the median filter used to smoothen to cross-attention outputs when computing token timestamps.
173
+ Should be an odd number.
174
+
175
+ Example:
176
+
177
+ ```python
178
+ >>> from transformers import WhisperConfig, WhisperModel
179
+
180
+ >>> # Initializing a Whisper tiny style configuration
181
+ >>> configuration = WhisperConfig()
182
+
183
+ >>> # Initializing a model (with random weights) from the tiny style configuration
184
+ >>> model = WhisperModel(configuration)
185
+
186
+ >>> # Accessing the model configuration
187
+ >>> configuration = model.config
188
+ ```"""
189
+
190
+ model_type = "whisper"
191
+ keys_to_ignore_at_inference = ["past_key_values"]
192
+ attribute_map = {
193
+ "num_key_value_heads": "encoder_attention_heads",
194
+ "num_attention_heads": "encoder_attention_heads",
195
+ "hidden_size": "d_model",
196
+ }
197
+
198
+ def __init__(
199
+ self,
200
+ vocab_size=51865,
201
+ num_mel_bins=80,
202
+ encoder_layers=4,
203
+ encoder_attention_heads=6,
204
+ decoder_layers=4,
205
+ decoder_attention_heads=6,
206
+ decoder_ffn_dim=1536,
207
+ encoder_ffn_dim=1536,
208
+ encoder_layerdrop=0.0,
209
+ decoder_layerdrop=0.0,
210
+ decoder_start_token_id=50257,
211
+ use_cache=True,
212
+ is_encoder_decoder=True,
213
+ activation_function="gelu",
214
+ d_model=384,
215
+ dropout=0.0,
216
+ attention_dropout=0.0,
217
+ activation_dropout=0.0,
218
+ init_std=0.02,
219
+ scale_embedding=False,
220
+ max_source_positions=1500,
221
+ max_target_positions=448,
222
+ pad_token_id=50256,
223
+ bos_token_id=50256,
224
+ eos_token_id=50256,
225
+ suppress_tokens=None,
226
+ begin_suppress_tokens=[220, 50256],
227
+ use_weighted_layer_sum=False,
228
+ classifier_proj_size=256,
229
+ apply_spec_augment=False,
230
+ mask_time_prob=0.05,
231
+ mask_time_length=10,
232
+ mask_time_min_masks=2,
233
+ mask_feature_prob=0.0,
234
+ mask_feature_length=10,
235
+ mask_feature_min_masks=0,
236
+ median_filter_width=7,
237
+ **kwargs,
238
+ ):
239
+ self.vocab_size = vocab_size
240
+ self.num_mel_bins = num_mel_bins
241
+ self.d_model = d_model
242
+ self.encoder_layers = encoder_layers
243
+ self.encoder_attention_heads = encoder_attention_heads
244
+ self.decoder_layers = decoder_layers
245
+ self.decoder_attention_heads = decoder_attention_heads
246
+ self.decoder_ffn_dim = decoder_ffn_dim
247
+ self.encoder_ffn_dim = encoder_ffn_dim
248
+ self.dropout = dropout
249
+ self.attention_dropout = attention_dropout
250
+ self.activation_dropout = activation_dropout
251
+ self.activation_function = activation_function
252
+ self.init_std = init_std
253
+ self.encoder_layerdrop = encoder_layerdrop
254
+ self.decoder_layerdrop = decoder_layerdrop
255
+ self.use_cache = use_cache
256
+ self.num_hidden_layers = encoder_layers
257
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
258
+ self.max_source_positions = max_source_positions
259
+ self.max_target_positions = max_target_positions
260
+
261
+ # Audio Classification-specific parameters. Feel free to ignore for other classes.
262
+ self.classifier_proj_size = classifier_proj_size
263
+ self.use_weighted_layer_sum = use_weighted_layer_sum
264
+
265
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
266
+ self.apply_spec_augment = apply_spec_augment
267
+ self.mask_time_prob = mask_time_prob
268
+ self.mask_time_length = mask_time_length
269
+ self.mask_time_min_masks = mask_time_min_masks
270
+ self.mask_feature_prob = mask_feature_prob
271
+ self.mask_feature_length = mask_feature_length
272
+ self.mask_feature_min_masks = mask_feature_min_masks
273
+
274
+ self.median_filter_width = median_filter_width
275
+
276
+ super().__init__(
277
+ pad_token_id=pad_token_id,
278
+ bos_token_id=bos_token_id,
279
+ eos_token_id=eos_token_id,
280
+ is_encoder_decoder=is_encoder_decoder,
281
+ decoder_start_token_id=decoder_start_token_id,
282
+ suppress_tokens=suppress_tokens,
283
+ begin_suppress_tokens=begin_suppress_tokens,
284
+ **kwargs,
285
+ )
286
+
287
+
288
+ class WhisperOnnxConfig(OnnxSeq2SeqConfigWithPast):
289
+ @property
290
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
291
+ common_inputs = OrderedDict(
292
+ [
293
+ ("input_features", {0: "batch", 1: "feature_size", 2: "encoder_sequence"}),
294
+ ]
295
+ )
296
+ if self.use_past:
297
+ common_inputs["decoder_input_ids"] = {0: "batch"}
298
+ else:
299
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
300
+
301
+ if self.use_past:
302
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
303
+
304
+ return common_inputs
305
+
306
+ def generate_dummy_inputs(
307
+ self,
308
+ preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
309
+ batch_size: int = -1,
310
+ seq_length: int = -1,
311
+ is_pair: bool = False,
312
+ framework: Optional["TensorType"] = None,
313
+ sampling_rate: int = 22050,
314
+ time_duration: float = 5.0,
315
+ frequency: int = 220,
316
+ ) -> Mapping[str, Any]:
317
+ dummy_inputs = OrderedDict()
318
+ encoder_inputs = OnnxConfig.generate_dummy_inputs(
319
+ self,
320
+ preprocessor=preprocessor.feature_extractor,
321
+ batch_size=batch_size,
322
+ framework=framework,
323
+ sampling_rate=sampling_rate,
324
+ time_duration=time_duration,
325
+ frequency=frequency,
326
+ )
327
+ encoder_sequence_length = encoder_inputs["input_features"].shape[2]
328
+ seq_length = encoder_sequence_length // 2 if self.use_past else seq_length
329
+
330
+ decoder_inputs = super().generate_dummy_inputs(
331
+ preprocessor.tokenizer, batch_size, seq_length, is_pair, framework
332
+ )
333
+
334
+ dummy_inputs["input_features"] = encoder_inputs.pop("input_features")
335
+ dummy_inputs["decoder_input_ids"] = decoder_inputs.pop("decoder_input_ids")
336
+
337
+ if "past_key_values" in decoder_inputs:
338
+ dummy_inputs["past_key_values"] = decoder_inputs.pop("past_key_values")
339
+
340
+ return dummy_inputs
341
+
342
+ @property
343
+ def atol_for_validation(self) -> float:
344
+ return 1e-3
configuration_whisper.cpython-312 (1).pyc ADDED
Binary file (15.8 kB). View file
 
configuration_whisper.cpython-312.pyc ADDED
Binary file (15.8 kB). View file
 
configuration_whisper.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Whisper model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
22
+ from ...utils import logging
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from ...feature_extraction_utils import FeatureExtractionMixin
27
+ from ...tokenization_utils_base import PreTrainedTokenizerBase
28
+ from ...utils import TensorType
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ # fmt: off
34
+ NON_SPEECH_TOKENS = [
35
+ 1, 2, 7, 8, 9, 10, 14, 25,
36
+ 26, 27, 28, 29, 31, 58, 59, 60, 61, 62,
37
+ 63, 90, 91, 92, 93, 357, 366, 438, 532, 685,
38
+ 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377,
39
+ 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211,
40
+ 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786,
41
+ 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791,
42
+ 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409,
43
+ 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50359, 50360, 50361
44
+ ]
45
+ NON_SPEECH_TOKENS_MULTI = [
46
+ 1, 2, 7, 8, 9, 10, 14, 25,
47
+ 26, 27, 28, 29, 31, 58, 59, 60, 61, 62,
48
+ 63, 90, 91, 92, 93, 359, 503, 522, 542, 873,
49
+ 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627,
50
+ 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647,
51
+ 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793,
52
+ 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675,
53
+ 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865,
54
+ 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362
55
+ ]
56
+ # fmt: on
57
+
58
+
59
+ class WhisperConfig(PretrainedConfig):
60
+ r"""
61
+ This is the configuration class to store the configuration of a [`WhisperModel`]. It is used to instantiate a
62
+ Whisper model according to the specified arguments, defining the model architecture. Instantiating a configuration
63
+ with the defaults will yield a similar configuration to that of the Whisper
64
+ [openai/whisper-tiny](https://huggingface.co/openai/whisper-tiny) architecture.
65
+
66
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
67
+ documentation from [`PretrainedConfig`] for more information.
68
+
69
+
70
+ Args:
71
+ vocab_size (`int`, *optional*, defaults to 51865):
72
+ Vocabulary size of the Whisper model. Defines the number of different tokens that can be represented by the
73
+ `decoder_input_ids` passed when calling [`WhisperModel`]
74
+ num_mel_bins (`int`, *optional*, defaults to 80):
75
+ Number of mel features used per input features. Should correspond to the value used in the
76
+ `WhisperProcessor` class.
77
+ encoder_layers (`int`, *optional*, defaults to 4):
78
+ Number of encoder layers.
79
+ decoder_layers (`int`, *optional*, defaults to 4):
80
+ Number of decoder layers.
81
+ encoder_attention_heads (`int`, *optional*, defaults to 6):
82
+ Number of attention heads for each attention layer in the Transformer encoder.
83
+ decoder_attention_heads (`int`, *optional*, defaults to 6):
84
+ Number of attention heads for each attention layer in the Transformer decoder.
85
+ encoder_ffn_dim (`int`, *optional*, defaults to 1536):
86
+ Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
87
+ decoder_ffn_dim (`int`, *optional*, defaults to 1536):
88
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
89
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
90
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
91
+ for more details.
92
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
93
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
94
+ for more details.
95
+ decoder_start_token_id (`int`, *optional*, defaults to 50257):
96
+ Corresponds to the "<|startoftranscript|>" token, which is automatically used when no `decoder_input_ids`
97
+ are provided to the `generate` function. It is used to guide the model`s generation process depending on
98
+ the task.
99
+ use_cache (`bool`, *optional*, defaults to `True`):
100
+ Whether or not the model should return the last key/values attentions (not used by all models).
101
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
102
+ Whether the model is used as an encoder/decoder or not.
103
+ activation_function (`str`, *optional*, defaults to `"gelu"`):
104
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
105
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
106
+ d_model (`int`, *optional*, defaults to 384):
107
+ Dimensionality of the layers.
108
+ dropout (`float`, *optional*, defaults to 0.1):
109
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
110
+ attention_dropout (`float`, *optional*, defaults to 0.0):
111
+ The dropout ratio for the attention probabilities.
112
+ activation_dropout (`float`, *optional*, defaults to 0.0):
113
+ The dropout ratio for activations inside the fully connected layer.
114
+ init_std (`float`, *optional*, defaults to 0.02):
115
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
116
+ scale_embedding (`bool`, *optional*, defaults to False):
117
+ Scale embeddings by diving by sqrt(d_model).
118
+ max_source_positions (`int`, *optional*, defaults to 1500):
119
+ The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
120
+ max_target_positions (`int`, *optional*, defaults to 448):
121
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
122
+ just in case (e.g., 512 or 1024 or 2048).
123
+ pad_token_id (`int`, *optional*, defaults to 50256):
124
+ Padding token id.
125
+ bos_token_id (`int`, *optional*, defaults to 50256):
126
+ Begin of stream token id.
127
+ eos_token_id (`int`, *optional*, defaults to 50256):
128
+ End of stream token id.
129
+ suppress_tokens (`List[int]`, *optional*):
130
+ A list containing the non-speech tokens that will be used by the logit processor in the `generate`
131
+ function. NON_SPEECH_TOKENS and NON_SPEECH_TOKENS_MULTI each correspond to the `english-only` and the
132
+ `multilingual` model.
133
+ begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`):
134
+ A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as
135
+ the token for `" "` (`blank_token_id`) and the `eos_token_id`
136
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
137
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
138
+ instance of [`WhisperForAudioClassification`].
139
+ classifier_proj_size (`int`, *optional*, defaults to 256):
140
+ Dimensionality of the projection before token mean-pooling for classification. Only relevant when using an
141
+ instance of [`WhisperForAudioClassification`].
142
+ apply_spec_augment (`bool`, *optional*, defaults to `False`):
143
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
144
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
145
+ Recognition](https://arxiv.org/abs/1904.08779).
146
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
147
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
148
+ procecure generates `mask_time_prob*len(time_axis)/mask_time_length` independent masks over the axis. If
149
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
150
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
151
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment == True`.
152
+ mask_time_length (`int`, *optional*, defaults to 10):
153
+ Length of vector span along the time axis.
154
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
155
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
156
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
157
+ mask_time_min_masks''
158
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
159
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
160
+ masking procecure generates `mask_feature_prob*len(feature_axis)/mask_time_length` independent masks over
161
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
162
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
163
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
164
+ True`.
165
+ mask_feature_length (`int`, *optional*, defaults to 10):
166
+ Length of vector span along the feature axis.
167
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
168
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
169
+ step, irrespectively of `mask_feature_prob`. Only relevant if
170
+ `mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`.
171
+ median_filter_width (`int`, *optional*, defaults to 7):
172
+ Width of the median filter used to smoothen to cross-attention outputs when computing token timestamps.
173
+ Should be an odd number.
174
+
175
+ Example:
176
+
177
+ ```python
178
+ >>> from transformers import WhisperConfig, WhisperModel
179
+
180
+ >>> # Initializing a Whisper tiny style configuration
181
+ >>> configuration = WhisperConfig()
182
+
183
+ >>> # Initializing a model (with random weights) from the tiny style configuration
184
+ >>> model = WhisperModel(configuration)
185
+
186
+ >>> # Accessing the model configuration
187
+ >>> configuration = model.config
188
+ ```"""
189
+
190
+ model_type = "whisper"
191
+ keys_to_ignore_at_inference = ["past_key_values"]
192
+ attribute_map = {
193
+ "num_key_value_heads": "encoder_attention_heads",
194
+ "num_attention_heads": "encoder_attention_heads",
195
+ "hidden_size": "d_model",
196
+ }
197
+
198
+ def __init__(
199
+ self,
200
+ vocab_size=51865,
201
+ num_mel_bins=80,
202
+ encoder_layers=4,
203
+ encoder_attention_heads=6,
204
+ decoder_layers=4,
205
+ decoder_attention_heads=6,
206
+ decoder_ffn_dim=1536,
207
+ encoder_ffn_dim=1536,
208
+ encoder_layerdrop=0.0,
209
+ decoder_layerdrop=0.0,
210
+ decoder_start_token_id=50257,
211
+ use_cache=True,
212
+ is_encoder_decoder=True,
213
+ activation_function="gelu",
214
+ d_model=384,
215
+ dropout=0.0,
216
+ attention_dropout=0.0,
217
+ activation_dropout=0.0,
218
+ init_std=0.02,
219
+ scale_embedding=False,
220
+ max_source_positions=1500,
221
+ max_target_positions=448,
222
+ pad_token_id=50256,
223
+ bos_token_id=50256,
224
+ eos_token_id=50256,
225
+ suppress_tokens=None,
226
+ begin_suppress_tokens=[220, 50256],
227
+ use_weighted_layer_sum=False,
228
+ classifier_proj_size=256,
229
+ apply_spec_augment=False,
230
+ mask_time_prob=0.05,
231
+ mask_time_length=10,
232
+ mask_time_min_masks=2,
233
+ mask_feature_prob=0.0,
234
+ mask_feature_length=10,
235
+ mask_feature_min_masks=0,
236
+ median_filter_width=7,
237
+ **kwargs,
238
+ ):
239
+ self.vocab_size = vocab_size
240
+ self.num_mel_bins = num_mel_bins
241
+ self.d_model = d_model
242
+ self.encoder_layers = encoder_layers
243
+ self.encoder_attention_heads = encoder_attention_heads
244
+ self.decoder_layers = decoder_layers
245
+ self.decoder_attention_heads = decoder_attention_heads
246
+ self.decoder_ffn_dim = decoder_ffn_dim
247
+ self.encoder_ffn_dim = encoder_ffn_dim
248
+ self.dropout = dropout
249
+ self.attention_dropout = attention_dropout
250
+ self.activation_dropout = activation_dropout
251
+ self.activation_function = activation_function
252
+ self.init_std = init_std
253
+ self.encoder_layerdrop = encoder_layerdrop
254
+ self.decoder_layerdrop = decoder_layerdrop
255
+ self.use_cache = use_cache
256
+ self.num_hidden_layers = encoder_layers
257
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
258
+ self.max_source_positions = max_source_positions
259
+ self.max_target_positions = max_target_positions
260
+
261
+ # Audio Classification-specific parameters. Feel free to ignore for other classes.
262
+ self.classifier_proj_size = classifier_proj_size
263
+ self.use_weighted_layer_sum = use_weighted_layer_sum
264
+
265
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
266
+ self.apply_spec_augment = apply_spec_augment
267
+ self.mask_time_prob = mask_time_prob
268
+ self.mask_time_length = mask_time_length
269
+ self.mask_time_min_masks = mask_time_min_masks
270
+ self.mask_feature_prob = mask_feature_prob
271
+ self.mask_feature_length = mask_feature_length
272
+ self.mask_feature_min_masks = mask_feature_min_masks
273
+
274
+ self.median_filter_width = median_filter_width
275
+
276
+ super().__init__(
277
+ pad_token_id=pad_token_id,
278
+ bos_token_id=bos_token_id,
279
+ eos_token_id=eos_token_id,
280
+ is_encoder_decoder=is_encoder_decoder,
281
+ decoder_start_token_id=decoder_start_token_id,
282
+ suppress_tokens=suppress_tokens,
283
+ begin_suppress_tokens=begin_suppress_tokens,
284
+ **kwargs,
285
+ )
286
+
287
+
288
+ class WhisperOnnxConfig(OnnxSeq2SeqConfigWithPast):
289
+ @property
290
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
291
+ common_inputs = OrderedDict(
292
+ [
293
+ ("input_features", {0: "batch", 1: "feature_size", 2: "encoder_sequence"}),
294
+ ]
295
+ )
296
+ if self.use_past:
297
+ common_inputs["decoder_input_ids"] = {0: "batch"}
298
+ else:
299
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
300
+
301
+ if self.use_past:
302
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
303
+
304
+ return common_inputs
305
+
306
+ def generate_dummy_inputs(
307
+ self,
308
+ preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
309
+ batch_size: int = -1,
310
+ seq_length: int = -1,
311
+ is_pair: bool = False,
312
+ framework: Optional["TensorType"] = None,
313
+ sampling_rate: int = 22050,
314
+ time_duration: float = 5.0,
315
+ frequency: int = 220,
316
+ ) -> Mapping[str, Any]:
317
+ dummy_inputs = OrderedDict()
318
+ encoder_inputs = OnnxConfig.generate_dummy_inputs(
319
+ self,
320
+ preprocessor=preprocessor.feature_extractor,
321
+ batch_size=batch_size,
322
+ framework=framework,
323
+ sampling_rate=sampling_rate,
324
+ time_duration=time_duration,
325
+ frequency=frequency,
326
+ )
327
+ encoder_sequence_length = encoder_inputs["input_features"].shape[2]
328
+ seq_length = encoder_sequence_length // 2 if self.use_past else seq_length
329
+
330
+ decoder_inputs = super().generate_dummy_inputs(
331
+ preprocessor.tokenizer, batch_size, seq_length, is_pair, framework
332
+ )
333
+
334
+ dummy_inputs["input_features"] = encoder_inputs.pop("input_features")
335
+ dummy_inputs["decoder_input_ids"] = decoder_inputs.pop("decoder_input_ids")
336
+
337
+ if "past_key_values" in decoder_inputs:
338
+ dummy_inputs["past_key_values"] = decoder_inputs.pop("past_key_values")
339
+
340
+ return dummy_inputs
341
+
342
+ @property
343
+ def atol_for_validation(self) -> float:
344
+ return 1e-3
feature_extraction_whisper (1).py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Feature extractor class for Whisper
17
+ """
18
+
19
+ from typing import List, Optional, Union
20
+
21
+ import numpy as np
22
+
23
+ from ... import is_torch_available
24
+ from ...audio_utils import mel_filter_bank, spectrogram, window_function
25
+ from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
26
+ from ...feature_extraction_utils import BatchFeature
27
+ from ...utils import TensorType, logging
28
+
29
+
30
+ if is_torch_available():
31
+ import torch
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class WhisperFeatureExtractor(SequenceFeatureExtractor):
37
+ r"""
38
+ Constructs a Whisper feature extractor.
39
+
40
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
41
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
42
+
43
+ This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time
44
+ Fourier Transform` which should match pytorch's `torch.stft` equivalent.
45
+
46
+ Args:
47
+ feature_size (`int`, *optional*, defaults to 80):
48
+ The feature dimension of the extracted features.
49
+ sampling_rate (`int`, *optional*, defaults to 16000):
50
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
51
+ hop_length (`int`, *optional*, defaults to 160):
52
+ Length of the overlaping windows for the STFT used to obtain the Mel Frequency coefficients.
53
+ chunk_length (`int`, *optional*, defaults to 30):
54
+ The maximum number of chuncks of `sampling_rate` samples used to trim and pad longer or shorter audio
55
+ sequences.
56
+ n_fft (`int`, *optional*, defaults to 400):
57
+ Size of the Fourier transform.
58
+ padding_value (`float`, *optional*, defaults to 0.0):
59
+ Padding value used to pad the audio. Should correspond to silences.
60
+ """
61
+
62
+ model_input_names = ["input_features"]
63
+
64
+ def __init__(
65
+ self,
66
+ feature_size=80,
67
+ sampling_rate=16000,
68
+ hop_length=160,
69
+ chunk_length=30,
70
+ n_fft=400,
71
+ padding_value=0.0,
72
+ return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
73
+ **kwargs,
74
+ ):
75
+ super().__init__(
76
+ feature_size=feature_size,
77
+ sampling_rate=sampling_rate,
78
+ padding_value=padding_value,
79
+ return_attention_mask=return_attention_mask,
80
+ **kwargs,
81
+ )
82
+ self.n_fft = n_fft
83
+ self.hop_length = hop_length
84
+ self.chunk_length = chunk_length
85
+ self.n_samples = chunk_length * sampling_rate
86
+ self.nb_max_frames = self.n_samples // hop_length
87
+ self.sampling_rate = sampling_rate
88
+ self.mel_filters = mel_filter_bank(
89
+ num_frequency_bins=1 + n_fft // 2,
90
+ num_mel_filters=feature_size,
91
+ min_frequency=0.0,
92
+ max_frequency=8000.0,
93
+ sampling_rate=sampling_rate,
94
+ norm="slaney",
95
+ mel_scale="slaney",
96
+ )
97
+
98
+ def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> np.ndarray:
99
+ """
100
+ Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
101
+ implementation with 1e-5 tolerance.
102
+ """
103
+ if device != "cpu":
104
+ raise ValueError(
105
+ f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
106
+ "devices requires torch, which is not installed. Either set `device='cpu'`, or "
107
+ "install torch according to the official instructions: https://pytorch.org/get-started/locally/"
108
+ )
109
+ log_spec_batch = []
110
+ for waveform in waveform_batch:
111
+ log_spec = spectrogram(
112
+ waveform,
113
+ window_function(self.n_fft, "hann"),
114
+ frame_length=self.n_fft,
115
+ hop_length=self.hop_length,
116
+ power=2.0,
117
+ mel_filters=self.mel_filters,
118
+ log_mel="log10",
119
+ )
120
+ log_spec = log_spec[:, :-1]
121
+ log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
122
+ log_spec = (log_spec + 4.0) / 4.0
123
+ log_spec_batch.append(log_spec)
124
+ log_spec_batch = np.array(log_spec_batch)
125
+ return log_spec_batch
126
+
127
+ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") -> np.ndarray:
128
+ """
129
+ Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
130
+ yielding results similar to cpu computing with 1e-5 tolerance.
131
+ """
132
+ waveform = torch.from_numpy(waveform).type(torch.float32)
133
+
134
+ window = torch.hann_window(self.n_fft)
135
+ if device != "cpu":
136
+ waveform = waveform.to(device)
137
+ window = window.to(device)
138
+ stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
139
+ magnitudes = stft[..., :-1].abs() ** 2
140
+
141
+ mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
142
+ if device != "cpu":
143
+ mel_filters = mel_filters.to(device)
144
+ mel_spec = mel_filters.T @ magnitudes
145
+
146
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
147
+ if waveform.dim() == 2:
148
+ max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
149
+ log_spec = torch.maximum(log_spec, max_val - 8.0)
150
+ else:
151
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
152
+ log_spec = (log_spec + 4.0) / 4.0
153
+ if device != "cpu":
154
+ log_spec = log_spec.detach().cpu()
155
+ return log_spec.numpy()
156
+
157
+ @staticmethod
158
+ # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
159
+ def zero_mean_unit_var_norm(
160
+ input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
161
+ ) -> List[np.ndarray]:
162
+ """
163
+ Every array in the list is normalized to have zero mean and unit variance
164
+ """
165
+ if attention_mask is not None:
166
+ attention_mask = np.array(attention_mask, np.int32)
167
+ normed_input_values = []
168
+
169
+ for vector, length in zip(input_values, attention_mask.sum(-1)):
170
+ normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
171
+ if length < normed_slice.shape[0]:
172
+ normed_slice[length:] = padding_value
173
+
174
+ normed_input_values.append(normed_slice)
175
+ else:
176
+ normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
177
+
178
+ return normed_input_values
179
+
180
+ def __call__(
181
+ self,
182
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
183
+ truncation: bool = True,
184
+ pad_to_multiple_of: Optional[int] = None,
185
+ return_tensors: Optional[Union[str, TensorType]] = None,
186
+ return_attention_mask: Optional[bool] = None,
187
+ padding: Optional[str] = "max_length",
188
+ max_length: Optional[int] = None,
189
+ sampling_rate: Optional[int] = None,
190
+ do_normalize: Optional[bool] = None,
191
+ device: Optional[str] = "cpu",
192
+ return_token_timestamps: Optional[bool] = None,
193
+ **kwargs,
194
+ ) -> BatchFeature:
195
+ """
196
+ Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
197
+ the STFT computation if available, otherwise a slower NumPy based one.
198
+
199
+ Args:
200
+ raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
201
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
202
+ values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
203
+ stereo, i.e. single float per timestep.
204
+ truncation (`bool`, *optional*, default to `True`):
205
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
206
+ pad_to_multiple_of (`int`, *optional*, defaults to None):
207
+ If set will pad the sequence to a multiple of the provided value.
208
+
209
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
210
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
211
+ return_attention_mask (`bool`, *optional*):
212
+ Whether to return the attention mask. If left to the default, will return the attention mask according
213
+ to the specific feature_extractor's default.
214
+
215
+ [What are attention masks?](../glossary#attention-mask)
216
+
217
+ <Tip>
218
+
219
+ For Whisper models, `attention_mask` should always be passed for batched inference, to avoid subtle
220
+ bugs.
221
+
222
+ </Tip>
223
+
224
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
225
+ If set, will return tensors instead of list of python integers. Acceptable values are:
226
+
227
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
228
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
229
+ - `'np'`: Return Numpy `np.ndarray` objects.
230
+ sampling_rate (`int`, *optional*):
231
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
232
+ `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
233
+ pipeline.
234
+ padding_value (`float`, *optional*, defaults to 0.0):
235
+ The value that is used to fill the padding values / vectors.
236
+ do_normalize (`bool`, *optional*, defaults to `False`):
237
+ Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
238
+ improve the performance of the model.
239
+ device (`str`, *optional*, defaults to `'cpu'`):
240
+ Specifies the device for computation of the log-mel spectrogram of audio signals in the
241
+ `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
242
+ return_token_timestamps (`bool`, *optional*, defaults to `None`):
243
+ Whether or not to return the number of frames of the input raw_speech.
244
+ These num_frames can be used by the model to compute word level timestamps.
245
+ """
246
+
247
+ if sampling_rate is not None:
248
+ if sampling_rate != self.sampling_rate:
249
+ raise ValueError(
250
+ f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
251
+ f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
252
+ f" was sampled with {self.sampling_rate} and not {sampling_rate}."
253
+ )
254
+ else:
255
+ logger.warning(
256
+ "It is strongly recommended to pass the `sampling_rate` argument to this function. "
257
+ "Failing to do so can result in silent errors that might be hard to debug."
258
+ )
259
+
260
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
261
+ if is_batched_numpy and len(raw_speech.shape) > 2:
262
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
263
+ is_batched = is_batched_numpy or (
264
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
265
+ )
266
+
267
+ if is_batched:
268
+ raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
269
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
270
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
271
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
272
+ raw_speech = raw_speech.astype(np.float32)
273
+
274
+ # always return batch
275
+ if not is_batched:
276
+ raw_speech = [np.asarray([raw_speech]).T]
277
+
278
+ batched_speech = BatchFeature({"input_features": raw_speech})
279
+
280
+ # convert into correct format for padding
281
+
282
+ padded_inputs = self.pad(
283
+ batched_speech,
284
+ padding=padding,
285
+ max_length=max_length if max_length else self.n_samples,
286
+ truncation=truncation,
287
+ pad_to_multiple_of=pad_to_multiple_of,
288
+ return_attention_mask=return_attention_mask or do_normalize,
289
+ )
290
+
291
+ # zero-mean and unit-variance normalization
292
+ if do_normalize:
293
+ padded_inputs["input_features"] = self.zero_mean_unit_var_norm(
294
+ padded_inputs["input_features"],
295
+ attention_mask=padded_inputs["attention_mask"],
296
+ padding_value=self.padding_value,
297
+ )
298
+ padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0)
299
+
300
+ # make sure list is in array format
301
+ input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
302
+
303
+ extract_fbank_features = (
304
+ self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
305
+ )
306
+ input_features = extract_fbank_features(input_features[0], device)
307
+
308
+ if isinstance(input_features[0], List):
309
+ padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
310
+
311
+ else:
312
+ padded_inputs["input_features"] = input_features
313
+
314
+ if return_attention_mask:
315
+ # rescale from sample (48000) to feature (3000)
316
+ padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
317
+
318
+ if return_token_timestamps is not None:
319
+ padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]
320
+
321
+ if return_tensors is not None:
322
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
323
+
324
+ return padded_inputs
feature_extraction_whisper.cpython-312 (1).pyc ADDED
Binary file (16.2 kB). View file
 
feature_extraction_whisper.cpython-312.pyc ADDED
Binary file (16.2 kB). View file
 
feature_extraction_whisper.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Feature extractor class for Whisper
17
+ """
18
+
19
+ from typing import List, Optional, Union
20
+
21
+ import numpy as np
22
+
23
+ from ... import is_torch_available
24
+ from ...audio_utils import mel_filter_bank, spectrogram, window_function
25
+ from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
26
+ from ...feature_extraction_utils import BatchFeature
27
+ from ...utils import TensorType, logging
28
+
29
+
30
+ if is_torch_available():
31
+ import torch
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class WhisperFeatureExtractor(SequenceFeatureExtractor):
37
+ r"""
38
+ Constructs a Whisper feature extractor.
39
+
40
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
41
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
42
+
43
+ This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time
44
+ Fourier Transform` which should match pytorch's `torch.stft` equivalent.
45
+
46
+ Args:
47
+ feature_size (`int`, *optional*, defaults to 80):
48
+ The feature dimension of the extracted features.
49
+ sampling_rate (`int`, *optional*, defaults to 16000):
50
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
51
+ hop_length (`int`, *optional*, defaults to 160):
52
+ Length of the overlaping windows for the STFT used to obtain the Mel Frequency coefficients.
53
+ chunk_length (`int`, *optional*, defaults to 30):
54
+ The maximum number of chuncks of `sampling_rate` samples used to trim and pad longer or shorter audio
55
+ sequences.
56
+ n_fft (`int`, *optional*, defaults to 400):
57
+ Size of the Fourier transform.
58
+ padding_value (`float`, *optional*, defaults to 0.0):
59
+ Padding value used to pad the audio. Should correspond to silences.
60
+ """
61
+
62
+ model_input_names = ["input_features"]
63
+
64
+ def __init__(
65
+ self,
66
+ feature_size=80,
67
+ sampling_rate=16000,
68
+ hop_length=160,
69
+ chunk_length=30,
70
+ n_fft=400,
71
+ padding_value=0.0,
72
+ return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
73
+ **kwargs,
74
+ ):
75
+ super().__init__(
76
+ feature_size=feature_size,
77
+ sampling_rate=sampling_rate,
78
+ padding_value=padding_value,
79
+ return_attention_mask=return_attention_mask,
80
+ **kwargs,
81
+ )
82
+ self.n_fft = n_fft
83
+ self.hop_length = hop_length
84
+ self.chunk_length = chunk_length
85
+ self.n_samples = chunk_length * sampling_rate
86
+ self.nb_max_frames = self.n_samples // hop_length
87
+ self.sampling_rate = sampling_rate
88
+ self.mel_filters = mel_filter_bank(
89
+ num_frequency_bins=1 + n_fft // 2,
90
+ num_mel_filters=feature_size,
91
+ min_frequency=0.0,
92
+ max_frequency=8000.0,
93
+ sampling_rate=sampling_rate,
94
+ norm="slaney",
95
+ mel_scale="slaney",
96
+ )
97
+
98
+ def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> np.ndarray:
99
+ """
100
+ Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
101
+ implementation with 1e-5 tolerance.
102
+ """
103
+ if device != "cpu":
104
+ raise ValueError(
105
+ f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
106
+ "devices requires torch, which is not installed. Either set `device='cpu'`, or "
107
+ "install torch according to the official instructions: https://pytorch.org/get-started/locally/"
108
+ )
109
+ log_spec_batch = []
110
+ for waveform in waveform_batch:
111
+ log_spec = spectrogram(
112
+ waveform,
113
+ window_function(self.n_fft, "hann"),
114
+ frame_length=self.n_fft,
115
+ hop_length=self.hop_length,
116
+ power=2.0,
117
+ mel_filters=self.mel_filters,
118
+ log_mel="log10",
119
+ )
120
+ log_spec = log_spec[:, :-1]
121
+ log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
122
+ log_spec = (log_spec + 4.0) / 4.0
123
+ log_spec_batch.append(log_spec)
124
+ log_spec_batch = np.array(log_spec_batch)
125
+ return log_spec_batch
126
+
127
+ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") -> np.ndarray:
128
+ """
129
+ Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
130
+ yielding results similar to cpu computing with 1e-5 tolerance.
131
+ """
132
+ waveform = torch.from_numpy(waveform).type(torch.float32)
133
+
134
+ window = torch.hann_window(self.n_fft)
135
+ if device != "cpu":
136
+ waveform = waveform.to(device)
137
+ window = window.to(device)
138
+ stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
139
+ magnitudes = stft[..., :-1].abs() ** 2
140
+
141
+ mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
142
+ if device != "cpu":
143
+ mel_filters = mel_filters.to(device)
144
+ mel_spec = mel_filters.T @ magnitudes
145
+
146
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
147
+ if waveform.dim() == 2:
148
+ max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
149
+ log_spec = torch.maximum(log_spec, max_val - 8.0)
150
+ else:
151
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
152
+ log_spec = (log_spec + 4.0) / 4.0
153
+ if device != "cpu":
154
+ log_spec = log_spec.detach().cpu()
155
+ return log_spec.numpy()
156
+
157
+ @staticmethod
158
+ # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
159
+ def zero_mean_unit_var_norm(
160
+ input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
161
+ ) -> List[np.ndarray]:
162
+ """
163
+ Every array in the list is normalized to have zero mean and unit variance
164
+ """
165
+ if attention_mask is not None:
166
+ attention_mask = np.array(attention_mask, np.int32)
167
+ normed_input_values = []
168
+
169
+ for vector, length in zip(input_values, attention_mask.sum(-1)):
170
+ normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
171
+ if length < normed_slice.shape[0]:
172
+ normed_slice[length:] = padding_value
173
+
174
+ normed_input_values.append(normed_slice)
175
+ else:
176
+ normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
177
+
178
+ return normed_input_values
179
+
180
+ def __call__(
181
+ self,
182
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
183
+ truncation: bool = True,
184
+ pad_to_multiple_of: Optional[int] = None,
185
+ return_tensors: Optional[Union[str, TensorType]] = None,
186
+ return_attention_mask: Optional[bool] = None,
187
+ padding: Optional[str] = "max_length",
188
+ max_length: Optional[int] = None,
189
+ sampling_rate: Optional[int] = None,
190
+ do_normalize: Optional[bool] = None,
191
+ device: Optional[str] = "cpu",
192
+ return_token_timestamps: Optional[bool] = None,
193
+ **kwargs,
194
+ ) -> BatchFeature:
195
+ """
196
+ Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
197
+ the STFT computation if available, otherwise a slower NumPy based one.
198
+
199
+ Args:
200
+ raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
201
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
202
+ values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
203
+ stereo, i.e. single float per timestep.
204
+ truncation (`bool`, *optional*, default to `True`):
205
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
206
+ pad_to_multiple_of (`int`, *optional*, defaults to None):
207
+ If set will pad the sequence to a multiple of the provided value.
208
+
209
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
210
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
211
+ return_attention_mask (`bool`, *optional*):
212
+ Whether to return the attention mask. If left to the default, will return the attention mask according
213
+ to the specific feature_extractor's default.
214
+
215
+ [What are attention masks?](../glossary#attention-mask)
216
+
217
+ <Tip>
218
+
219
+ For Whisper models, `attention_mask` should always be passed for batched inference, to avoid subtle
220
+ bugs.
221
+
222
+ </Tip>
223
+
224
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
225
+ If set, will return tensors instead of list of python integers. Acceptable values are:
226
+
227
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
228
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
229
+ - `'np'`: Return Numpy `np.ndarray` objects.
230
+ sampling_rate (`int`, *optional*):
231
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
232
+ `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
233
+ pipeline.
234
+ padding_value (`float`, *optional*, defaults to 0.0):
235
+ The value that is used to fill the padding values / vectors.
236
+ do_normalize (`bool`, *optional*, defaults to `False`):
237
+ Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
238
+ improve the performance of the model.
239
+ device (`str`, *optional*, defaults to `'cpu'`):
240
+ Specifies the device for computation of the log-mel spectrogram of audio signals in the
241
+ `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
242
+ return_token_timestamps (`bool`, *optional*, defaults to `None`):
243
+ Whether or not to return the number of frames of the input raw_speech.
244
+ These num_frames can be used by the model to compute word level timestamps.
245
+ """
246
+
247
+ if sampling_rate is not None:
248
+ if sampling_rate != self.sampling_rate:
249
+ raise ValueError(
250
+ f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
251
+ f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
252
+ f" was sampled with {self.sampling_rate} and not {sampling_rate}."
253
+ )
254
+ else:
255
+ logger.warning(
256
+ "It is strongly recommended to pass the `sampling_rate` argument to this function. "
257
+ "Failing to do so can result in silent errors that might be hard to debug."
258
+ )
259
+
260
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
261
+ if is_batched_numpy and len(raw_speech.shape) > 2:
262
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
263
+ is_batched = is_batched_numpy or (
264
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
265
+ )
266
+
267
+ if is_batched:
268
+ raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
269
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
270
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
271
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
272
+ raw_speech = raw_speech.astype(np.float32)
273
+
274
+ # always return batch
275
+ if not is_batched:
276
+ raw_speech = [np.asarray([raw_speech]).T]
277
+
278
+ batched_speech = BatchFeature({"input_features": raw_speech})
279
+
280
+ # convert into correct format for padding
281
+
282
+ padded_inputs = self.pad(
283
+ batched_speech,
284
+ padding=padding,
285
+ max_length=max_length if max_length else self.n_samples,
286
+ truncation=truncation,
287
+ pad_to_multiple_of=pad_to_multiple_of,
288
+ return_attention_mask=return_attention_mask or do_normalize,
289
+ )
290
+
291
+ # zero-mean and unit-variance normalization
292
+ if do_normalize:
293
+ padded_inputs["input_features"] = self.zero_mean_unit_var_norm(
294
+ padded_inputs["input_features"],
295
+ attention_mask=padded_inputs["attention_mask"],
296
+ padding_value=self.padding_value,
297
+ )
298
+ padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0)
299
+
300
+ # make sure list is in array format
301
+ input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
302
+
303
+ extract_fbank_features = (
304
+ self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
305
+ )
306
+ input_features = extract_fbank_features(input_features[0], device)
307
+
308
+ if isinstance(input_features[0], List):
309
+ padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
310
+
311
+ else:
312
+ padded_inputs["input_features"] = input_features
313
+
314
+ if return_attention_mask:
315
+ # rescale from sample (48000) to feature (3000)
316
+ padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
317
+
318
+ if return_token_timestamps is not None:
319
+ padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]
320
+
321
+ if return_tensors is not None:
322
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
323
+
324
+ return padded_inputs
find-corrupt-whisper-files (1).py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!C:\Python312\python.exe
2
+ # encoding: utf-8
3
+ """Find and (optionally) delete corrupt Whisper data files"""
4
+
5
+ import argparse
6
+ import os
7
+ import sys
8
+ import logging
9
+
10
+ try:
11
+ import whisper
12
+ except ImportError:
13
+ raise SystemExit("[ERROR] Please make sure Whisper is installed properly")
14
+
15
+
16
+ def setup_logging(verbose=False):
17
+ """Configure logging."""
18
+ logging.basicConfig(
19
+ level=logging.DEBUG if verbose else logging.INFO,
20
+ format="%(asctime)s [%(levelname)s]: %(message)s",
21
+ datefmt="%Y-%m-%d %H:%M:%S",
22
+ )
23
+
24
+
25
+ def walk_dir(base_dir, delete_corrupt=False, backup_corrupt=False):
26
+ """Walk through directories to find and handle corrupt Whisper files."""
27
+ total_files = 0
28
+ corrupt_files = 0
29
+ deleted_files = 0
30
+
31
+ for dirpath, _, filenames in os.walk(base_dir):
32
+ logging.info("Scanning %s...", dirpath)
33
+
34
+ whisper_files = (os.path.join(dirpath, f) for f in filenames if f.endswith(".wsp"))
35
+ for f in whisper_files:
36
+ total_files += 1
37
+ try:
38
+ info = whisper.info(f)
39
+ logging.debug("%s: %d points", f, sum(i["points"] for i in info.get("archives", {})))
40
+ except whisper.CorruptWhisperFile:
41
+ corrupt_files += 1
42
+ if backup_corrupt:
43
+ backup_path = f + ".bak"
44
+ try:
45
+ os.rename(f, backup_path)
46
+ logging.warning("Backed up corrupt file: %s -> %s", f, backup_path)
47
+ except OSError as e:
48
+ logging.error("Failed to back up %s: %s", f, e)
49
+ continue
50
+
51
+ if delete_corrupt:
52
+ try:
53
+ os.unlink(f)
54
+ deleted_files += 1
55
+ logging.warning("Deleted corrupt file: %s", f)
56
+ except OSError as e:
57
+ logging.error("Failed to delete %s: %s", f, e)
58
+ else:
59
+ logging.error("Corrupt Whisper file: %s", f)
60
+
61
+ logging.info("Summary: Scanned %d files, Found %d corrupt, Deleted %d", total_files, corrupt_files, deleted_files)
62
+ return total_files, corrupt_files, deleted_files
63
+
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser(description=__doc__.strip())
67
+ parser.add_argument("--delete-corrupt", action="store_true", help="Delete reported corrupt files")
68
+ parser.add_argument("--backup-corrupt", action="store_true", help="Back up corrupt files before deletion")
69
+ parser.add_argument("--verbose", action="store_true", help="Display detailed progress")
70
+ parser.add_argument("directories", type=str, nargs="+", metavar="WHISPER_DIR", help="Directory containing Whisper files")
71
+ args = parser.parse_args()
72
+
73
+ setup_logging(verbose=args.verbose)
74
+
75
+ for d in args.directories:
76
+ d = os.path.realpath(d)
77
+ if not os.path.isdir(d):
78
+ logging.error("%s is not a directory!", d)
79
+ continue
80
+
81
+ walk_dir(d, delete_corrupt=args.delete_corrupt, backup_corrupt=args.backup_corrupt)
find-corrupt-whisper-files (2).py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!C:\Python312\python.exe
2
+ # encoding: utf-8
3
+ """Find and (optionally) delete corrupt Whisper data files"""
4
+
5
+ import argparse
6
+ import os
7
+ import sys
8
+ import logging
9
+
10
+ try:
11
+ import whisper
12
+ except ImportError:
13
+ raise SystemExit("[ERROR] Please make sure Whisper is installed properly")
14
+
15
+
16
+ def setup_logging(verbose=False):
17
+ """Configure logging."""
18
+ logging.basicConfig(
19
+ level=logging.DEBUG if verbose else logging.INFO,
20
+ format="%(asctime)s [%(levelname)s]: %(message)s",
21
+ datefmt="%Y-%m-%d %H:%M:%S",
22
+ )
23
+
24
+
25
+ def walk_dir(base_dir, delete_corrupt=False, backup_corrupt=False):
26
+ """Walk through directories to find and handle corrupt Whisper files."""
27
+ total_files = 0
28
+ corrupt_files = 0
29
+ deleted_files = 0
30
+
31
+ for dirpath, _, filenames in os.walk(base_dir):
32
+ logging.info("Scanning %s...", dirpath)
33
+
34
+ whisper_files = (os.path.join(dirpath, f) for f in filenames if f.endswith(".wsp"))
35
+ for f in whisper_files:
36
+ total_files += 1
37
+ try:
38
+ info = whisper.info(f)
39
+ logging.debug("%s: %d points", f, sum(i["points"] for i in info.get("archives", {})))
40
+ except whisper.CorruptWhisperFile:
41
+ corrupt_files += 1
42
+ if backup_corrupt:
43
+ backup_path = f + ".bak"
44
+ try:
45
+ os.rename(f, backup_path)
46
+ logging.warning("Backed up corrupt file: %s -> %s", f, backup_path)
47
+ except OSError as e:
48
+ logging.error("Failed to back up %s: %s", f, e)
49
+ continue
50
+
51
+ if delete_corrupt:
52
+ try:
53
+ os.unlink(f)
54
+ deleted_files += 1
55
+ logging.warning("Deleted corrupt file: %s", f)
56
+ except OSError as e:
57
+ logging.error("Failed to delete %s: %s", f, e)
58
+ else:
59
+ logging.error("Corrupt Whisper file: %s", f)
60
+
61
+ logging.info("Summary: Scanned %d files, Found %d corrupt, Deleted %d", total_files, corrupt_files, deleted_files)
62
+ return total_files, corrupt_files, deleted_files
63
+
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser(description=__doc__.strip())
67
+ parser.add_argument("--delete-corrupt", action="store_true", help="Delete reported corrupt files")
68
+ parser.add_argument("--backup-corrupt", action="store_true", help="Back up corrupt files before deletion")
69
+ parser.add_argument("--verbose", action="store_true", help="Display detailed progress")
70
+ parser.add_argument("directories", type=str, nargs="+", metavar="WHISPER_DIR", help="Directory containing Whisper files")
71
+ args = parser.parse_args()
72
+
73
+ setup_logging(verbose=args.verbose)
74
+
75
+ for d in args.directories:
76
+ d = os.path.realpath(d)
77
+ if not os.path.isdir(d):
78
+ logging.error("%s is not a directory!", d)
79
+ continue
80
+
81
+ walk_dir(d, delete_corrupt=args.delete_corrupt, backup_corrupt=args.backup_corrupt)
find-corrupt-whisper-files.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!C:\Python312\python.exe
2
+ # encoding: utf-8
3
+ """Find and (optionally) delete corrupt Whisper data files"""
4
+
5
+ import argparse
6
+ import os
7
+ import sys
8
+ import logging
9
+
10
+ try:
11
+ import whisper
12
+ except ImportError:
13
+ raise SystemExit("[ERROR] Please make sure Whisper is installed properly")
14
+
15
+
16
+ def setup_logging(verbose=False):
17
+ """Configure logging."""
18
+ logging.basicConfig(
19
+ level=logging.DEBUG if verbose else logging.INFO,
20
+ format="%(asctime)s [%(levelname)s]: %(message)s",
21
+ datefmt="%Y-%m-%d %H:%M:%S",
22
+ )
23
+
24
+
25
+ def walk_dir(base_dir, delete_corrupt=False, backup_corrupt=False):
26
+ """Walk through directories to find and handle corrupt Whisper files."""
27
+ total_files = 0
28
+ corrupt_files = 0
29
+ deleted_files = 0
30
+
31
+ for dirpath, _, filenames in os.walk(base_dir):
32
+ logging.info("Scanning %s...", dirpath)
33
+
34
+ whisper_files = (os.path.join(dirpath, f) for f in filenames if f.endswith(".wsp"))
35
+ for f in whisper_files:
36
+ total_files += 1
37
+ try:
38
+ info = whisper.info(f)
39
+ logging.debug("%s: %d points", f, sum(i["points"] for i in info.get("archives", {})))
40
+ except whisper.CorruptWhisperFile:
41
+ corrupt_files += 1
42
+ if backup_corrupt:
43
+ backup_path = f + ".bak"
44
+ try:
45
+ os.rename(f, backup_path)
46
+ logging.warning("Backed up corrupt file: %s -> %s", f, backup_path)
47
+ except OSError as e:
48
+ logging.error("Failed to back up %s: %s", f, e)
49
+ continue
50
+
51
+ if delete_corrupt:
52
+ try:
53
+ os.unlink(f)
54
+ deleted_files += 1
55
+ logging.warning("Deleted corrupt file: %s", f)
56
+ except OSError as e:
57
+ logging.error("Failed to delete %s: %s", f, e)
58
+ else:
59
+ logging.error("Corrupt Whisper file: %s", f)
60
+
61
+ logging.info("Summary: Scanned %d files, Found %d corrupt, Deleted %d", total_files, corrupt_files, deleted_files)
62
+ return total_files, corrupt_files, deleted_files
63
+
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser(description=__doc__.strip())
67
+ parser.add_argument("--delete-corrupt", action="store_true", help="Delete reported corrupt files")
68
+ parser.add_argument("--backup-corrupt", action="store_true", help="Back up corrupt files before deletion")
69
+ parser.add_argument("--verbose", action="store_true", help="Display detailed progress")
70
+ parser.add_argument("directories", type=str, nargs="+", metavar="WHISPER_DIR", help="Directory containing Whisper files")
71
+ args = parser.parse_args()
72
+
73
+ setup_logging(verbose=args.verbose)
74
+
75
+ for d in args.directories:
76
+ d = os.path.realpath(d)
77
+ if not os.path.isdir(d):
78
+ logging.error("%s is not a directory!", d)
79
+ continue
80
+
81
+ walk_dir(d, delete_corrupt=args.delete_corrupt, backup_corrupt=args.backup_corrupt)
generation_whisper (1).py ADDED
@@ -0,0 +1,1881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import copy
16
+ import math
17
+ import warnings
18
+ import zlib
19
+ from typing import Callable, Iterator, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import nn
25
+
26
+ from transformers.cache_utils import EncoderDecoderCache
27
+
28
+ from ...generation import GenerationConfig, GenerationMixin
29
+ from ...generation.logits_process import (
30
+ LogitsProcessorList,
31
+ SuppressTokensAtBeginLogitsProcessor,
32
+ SuppressTokensLogitsProcessor,
33
+ WhisperNoSpeechDetection,
34
+ WhisperTimeStampLogitsProcessor,
35
+ )
36
+ from ...generation.stopping_criteria import StoppingCriteriaList
37
+ from ...modeling_outputs import BaseModelOutput
38
+ from ...utils import logging
39
+ from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+
45
+ def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
46
+ """
47
+ Applies a median filter of width `filter_width` along the last dimension of the input.
48
+
49
+ The `inputs` tensor is assumed to be 3- or 4-dimensional.
50
+ """
51
+ if filter_width <= 0 or filter_width % 2 != 1:
52
+ raise ValueError("`filter_width` should be an odd number")
53
+
54
+ pad_width = filter_width // 2
55
+ if inputs.shape[-1] <= pad_width:
56
+ return inputs
57
+
58
+ # Pad the left and right edges.
59
+ inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
60
+
61
+ # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
62
+ result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
63
+ return result
64
+
65
+
66
+ def _dynamic_time_warping(matrix: np.ndarray):
67
+ """
68
+ Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
69
+ token-level timestamps.
70
+ """
71
+ output_length, input_length = matrix.shape
72
+ cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
73
+ trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
74
+
75
+ cost[0, 0] = 0
76
+ for j in range(1, input_length + 1):
77
+ for i in range(1, output_length + 1):
78
+ c0 = cost[i - 1, j - 1]
79
+ c1 = cost[i - 1, j]
80
+ c2 = cost[i, j - 1]
81
+
82
+ if c0 < c1 and c0 < c2:
83
+ c, t = c0, 0
84
+ elif c1 < c0 and c1 < c2:
85
+ c, t = c1, 1
86
+ else:
87
+ c, t = c2, 2
88
+
89
+ cost[i, j] = matrix[i - 1, j - 1] + c
90
+ trace[i, j] = t
91
+
92
+ # backtrace
93
+ i = trace.shape[0] - 1
94
+ j = trace.shape[1] - 1
95
+ trace[0, :] = 2
96
+ trace[:, 0] = 1
97
+
98
+ text_indices = []
99
+ time_indices = []
100
+ while i > 0 or j > 0:
101
+ text_indices.append(i - 1)
102
+ time_indices.append(j - 1)
103
+ if trace[i, j] == 0:
104
+ i -= 1
105
+ j -= 1
106
+ elif trace[i, j] == 1:
107
+ i -= 1
108
+ elif trace[i, j] == 2:
109
+ j -= 1
110
+ else:
111
+ raise RuntimeError(
112
+ f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
113
+ )
114
+
115
+ text_indices = np.array(text_indices)[::-1]
116
+ time_indices = np.array(time_indices)[::-1]
117
+ return text_indices, time_indices
118
+
119
+
120
+ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
121
+ if logits_processor is not None:
122
+ logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
123
+ if logit_processor:
124
+ return getattr(logit_processor, attribute_name, None)
125
+ return None
126
+
127
+
128
+ def _pad_to_max_length(
129
+ current_segments,
130
+ pad_token_id,
131
+ device,
132
+ padding_side="right",
133
+ padding="longest",
134
+ bos_token_tensor=None,
135
+ cut_off_length=None,
136
+ ):
137
+ max_total_length = 0
138
+ sequences = []
139
+
140
+ if padding_side not in ["right", "left"]:
141
+ raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
142
+
143
+ if padding not in ["longest", "max_length"]:
144
+ raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
145
+ elif padding == "max_length" and cut_off_length is None:
146
+ raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
147
+
148
+ for current_segment_list in current_segments:
149
+ if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
150
+ sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
151
+
152
+ if cut_off_length is not None:
153
+ sequence = sequence[-cut_off_length:]
154
+
155
+ if bos_token_tensor is not None:
156
+ sequence = torch.cat([bos_token_tensor, sequence])
157
+
158
+ sequences.append(sequence)
159
+ max_total_length = max(max_total_length, len(sequences[-1]))
160
+ elif bos_token_tensor is not None:
161
+ sequences.append(bos_token_tensor)
162
+ else:
163
+ sequences.append(torch.tensor([], device=device))
164
+
165
+ max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
166
+ for i in range(len(current_segments)):
167
+ pad_length = max_total_length - len(sequences[i])
168
+ pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
169
+ sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
170
+
171
+ sequences = torch.stack(sequences, dim=0)
172
+ return sequences
173
+
174
+
175
+ class WhisperGenerationMixin(GenerationMixin):
176
+ def _extract_token_timestamps(
177
+ self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None, num_input_ids=None
178
+ ):
179
+ """
180
+ Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
181
+ map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
182
+ cross-attentions will be cropped before applying DTW.
183
+
184
+ Returns:
185
+ tensor containing the timestamps in seconds for each predicted token
186
+ """
187
+ # Create a list with `decoder_layers` elements, each a tensor of shape
188
+ # (batch size, attention_heads, output length, input length).
189
+ cross_attentions = []
190
+ for i in range(self.config.decoder_layers):
191
+ cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
192
+
193
+ # Select specific cross-attention layers and heads. This is a tensor
194
+ # of shape (batch size, num selected, output length, input length).
195
+ weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
196
+ weights = weights.permute([1, 0, 2, 3])
197
+
198
+ weight_length = None
199
+
200
+ if "beam_indices" in generate_outputs:
201
+ # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
202
+ # since the beam search strategy chooses the most probable sequences at the end of the search.
203
+ # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
204
+ weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
205
+ weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids
206
+
207
+ # beam search takes `decoder_input_ids` into account in the `beam_indices` length
208
+ # but forgot to shift the beam_indices by the number of `decoder_input_ids`
209
+ beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length])
210
+ # we actually shif the beam indices here
211
+ beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids]
212
+
213
+ weights = weights[:, :, :weight_length]
214
+
215
+ # If beam index is still -1, it means that the associated token id is EOS
216
+ # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
217
+ beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
218
+
219
+ # Select the cross attention from the right beam for each output sequences
220
+ weights = torch.stack(
221
+ [
222
+ torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
223
+ for i in range(beam_indices.shape[1])
224
+ ],
225
+ dim=2,
226
+ )
227
+
228
+ # make sure timestamps are as long as weights
229
+ input_length = weight_length or cross_attentions[0].shape[2]
230
+ batch_size = generate_outputs.sequences.shape[0]
231
+ timestamps = torch.zeros(
232
+ (batch_size, input_length + 1), dtype=torch.float32, device=generate_outputs.sequences.device
233
+ )
234
+
235
+ if num_frames is not None:
236
+ # two cases:
237
+ # 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel
238
+ # 2. num_frames is different, compute the DTW matrix for each sample sequentially
239
+
240
+ # we're using np.unique because num_frames can be int/list/tuple
241
+ if isinstance(num_frames, int):
242
+ weights = weights[..., : num_frames // 2]
243
+
244
+ elif isinstance(num_frames, (list, tuple, np.ndarray)) and len(np.unique(num_frames)) == 1:
245
+ weights = weights[..., : num_frames[0] // 2]
246
+
247
+ elif isinstance(num_frames, (torch.Tensor)) and len(torch.unique(num_frames)) == 1:
248
+ weights = weights[..., : num_frames[0] // 2]
249
+
250
+ else:
251
+ # num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
252
+ repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
253
+ num_frames = num_frames.cpu() if isinstance(num_frames, (torch.Tensor)) else num_frames
254
+ num_frames = np.repeat(num_frames, repeat_time)
255
+
256
+ if num_frames is None or isinstance(num_frames, int):
257
+ # Normalize and smoothen the weights.
258
+ std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
259
+ mean = torch.mean(weights, dim=-2, keepdim=True)
260
+ weights = (weights - mean) / std
261
+ weights = _median_filter(weights, self.config.median_filter_width)
262
+
263
+ # Average the different cross-attention heads.
264
+ weights = weights.mean(dim=1)
265
+
266
+ # Perform dynamic time warping on each element of the batch.
267
+ for batch_idx in range(batch_size):
268
+ if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray, torch.Tensor)):
269
+ matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]
270
+
271
+ # Normalize and smoothen the weights.
272
+ std = torch.std(matrix, dim=-2, keepdim=True, unbiased=False)
273
+ mean = torch.mean(matrix, dim=-2, keepdim=True)
274
+ matrix = (matrix - mean) / std
275
+ matrix = _median_filter(matrix, self.config.median_filter_width)
276
+
277
+ # Average the different cross-attention heads.
278
+ matrix = matrix.mean(dim=0)
279
+ else:
280
+ matrix = weights[batch_idx]
281
+
282
+ text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy())
283
+ jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
284
+ jump_times = time_indices[jumps] * time_precision
285
+ timestamps[batch_idx, 1:] = torch.tensor(jump_times)
286
+
287
+ return timestamps
288
+
289
+ def generate(
290
+ self,
291
+ input_features: Optional[torch.Tensor] = None,
292
+ generation_config: Optional[GenerationConfig] = None,
293
+ logits_processor: Optional[LogitsProcessorList] = None,
294
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
295
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
296
+ synced_gpus: bool = False,
297
+ return_timestamps: Optional[bool] = None,
298
+ task: Optional[str] = None,
299
+ language: Optional[Union[str, List[str]]] = None,
300
+ is_multilingual: Optional[bool] = None,
301
+ prompt_ids: Optional[torch.Tensor] = None,
302
+ prompt_condition_type: Optional[str] = None, # first-segment, all-segments
303
+ condition_on_prev_tokens: Optional[bool] = None,
304
+ temperature: Optional[Union[float, Tuple[float, ...]]] = None,
305
+ compression_ratio_threshold: Optional[float] = None,
306
+ logprob_threshold: Optional[float] = None,
307
+ no_speech_threshold: Optional[float] = None,
308
+ num_segment_frames: Optional[int] = None,
309
+ attention_mask: Optional[torch.Tensor] = None,
310
+ time_precision: float = 0.02,
311
+ time_precision_features: float = 0.01,
312
+ return_token_timestamps: Optional[bool] = None,
313
+ return_segments: bool = False,
314
+ return_dict_in_generate: Optional[bool] = None,
315
+ **kwargs,
316
+ ):
317
+ """
318
+ Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
319
+
320
+ <Tip warning={true}>
321
+
322
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
323
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
324
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
325
+
326
+ For an overview of generation strategies and code examples, check out the [following
327
+ guide](./generation_strategies).
328
+
329
+ </Tip>
330
+
331
+ Parameters:
332
+ input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
333
+ Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
334
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
335
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
336
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
337
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
338
+ generation_config (`~generation.GenerationConfig`, *optional*):
339
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
340
+ passed to generate matching the attributes of `generation_config` will override them. If
341
+ `generation_config` is not provided, the default will be used, which had the following loading
342
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
343
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
344
+ default values, whose documentation should be checked to parameterize generation.
345
+ logits_processor (`LogitsProcessorList`, *optional*):
346
+ Custom logits processors that complement the default logits processors built from arguments and
347
+ generation config. If a logit processor is passed that is already created with the arguments or a
348
+ generation config an error is thrown. This feature is intended for advanced users.
349
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
350
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
351
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
352
+ generation config an error is thrown. This feature is intended for advanced users.
353
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
354
+ If provided, this function constraints the beam search to allowed tokens only at each step. If not
355
+ provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
356
+ `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
357
+ on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
358
+ for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
359
+ Retrieval](https://arxiv.org/abs/2010.00904).
360
+ synced_gpus (`bool`, *optional*, defaults to `False`):
361
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
362
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
363
+ return_timestamps (`bool`, *optional*):
364
+ Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
365
+ task (`str`, *optional*):
366
+ Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
367
+ will be updated accordingly.
368
+ language (`str` or list of `str`, *optional*):
369
+ Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. For
370
+ batched generation, a list of language tokens can be passed. You can find all the possible language
371
+ tokens in the `model.generation_config.lang_to_id` dictionary.
372
+ is_multilingual (`bool`, *optional*):
373
+ Whether or not the model is multilingual.
374
+ prompt_ids (`torch.Tensor`, *optional*):
375
+ Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
376
+ provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
377
+ transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
378
+ correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
379
+ prompt_condition_type (`str`, *optional*):
380
+ Only relevant for long-form transcription. Condition type of `prompt_ids`. 'first-segment' means only the first segment is conditioned on `prompt_ids`. 'all-segments' means each segment is conditioned on `prompt_ids`. Make sure to enable `condition_on_prev_tokens` for 'all-segments'.
381
+ Defaults to 'first-segment'. For short-term transcription only 'first-segment' is possible.
382
+ condition_on_prev_tokens (`bool`, *optional*):
383
+ Only relevant for long-form transcription. Whether to condition each segment on the previous segment.
384
+ As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
385
+ performance.
386
+ temperature (`float` or list of `float`, *optional*):
387
+ The temperature to be used for generation. Passing a single `float` value and `do_sample=True` activates
388
+ generation using sampling. For long-form transcription, temperature fallback can be activated by passing
389
+ a list of float values such as (0.0, 0.2, 0.4, 0.6, 0.8, 1.0). As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
390
+ performance.
391
+ compression_ratio_threshold (`float`, *optional*):
392
+ Only relevant for long-form transcription. If defined, the zlib compression rate of each segment will be computed. If the compression rate of
393
+ a segment is higher than `compression_ratio_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is
394
+ repeated using a higher temperature. The intuition behind this feature is that segments with very high compression rates
395
+ suffer from a lot of repetition. The unwanted repetition can be reduced by injecting more randomness by increasing the temperature. If `compression_ratio_threshold` is defined
396
+ make sure that `temperature` is a list of values. A common value for `compression_ratio_threshold` is 1.35.
397
+ As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
398
+ performance.
399
+ logprob_threshold (`float`, *optional*):
400
+ Only relevant for long-form transcription. If defined, the average log-probability of each segment will be computed. If the log-probability of
401
+ a given segment is lower than `logprob_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is
402
+ repeated using a higher temperature. The intuition behind this feature is that segments of low log-probability
403
+ can be improved by injecting more randomness by increasing the temperature. If `logprob_threshold` is defined
404
+ make sure that `temperature` is a list of values. A common value for `logprob_threshold` is -1.0.
405
+ As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
406
+ performance.
407
+ no_speech_threshold (`float`, *optional*):
408
+ Only relevant for long-form transcription. If defined, the "no-speech" token combined with the `logprob_threshold`
409
+ is used to determine whether a segment contains only silence. In this case, the transcription for this segment
410
+ is skipped.
411
+ As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
412
+ performance.
413
+ num_segment_frames (`int`, *optional*):
414
+ The number of frames a single segment is made of. If not defined, `num_segment_frames` defaults to the model's stride
415
+ times the maximum input length.
416
+ attention_mask (`torch.Tensor`, *optional*):
417
+ `attention_mask` needs to be passed when doing long-form transcription using a batch size > 1.
418
+ time_precision (`int`, *optional*, defaults to 0.02):
419
+ The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
420
+ for 20 ms.
421
+ time_precision_features (`int`, *optional*, defaults to 0.01):
422
+ The duration represented by a feature frame in seconds.
423
+ return_token_timestamps (`bool`, *optional*):
424
+ Whether to return token-level timestamps with the text. This can be used with or without the
425
+ `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
426
+ words.
427
+ return_segments (`bool`, *optional*, defaults to `False`):
428
+ Whether to additionally return a list of all segments. Note that this option can only be enabled
429
+ when doing long-form transcription.
430
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
431
+ Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens.
432
+ Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
433
+ `return_segments` is set True. In this case the generation outputs of each segment is added to each
434
+ segment.
435
+ kwargs (`Dict[str, Any]`, *optional*):
436
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
437
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
438
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
439
+
440
+ Return:
441
+ [`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
442
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
443
+
444
+ If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
445
+
446
+ else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
447
+
448
+ - [`~generation.GenerateEncoderDecoderOutput`],
449
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
450
+
451
+ else only the generated output sequence ids are returned.
452
+
453
+ Example:
454
+
455
+ - *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
456
+
457
+ ```python
458
+ >>> import torch
459
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
460
+ >>> from datasets import load_dataset, Audio
461
+
462
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
463
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
464
+ >>> model.cuda() # doctest: +IGNORE_RESULT
465
+
466
+ >>> # load audios > 30 seconds
467
+ >>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
468
+ >>> # resample to 16kHz
469
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=16000))
470
+ >>> # take first 8 audios and retrieve array
471
+ >>> audio = ds[:8]["audio"]
472
+ >>> audio = [x["array"] for x in audio]
473
+
474
+ >>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio
475
+ >>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
476
+ >>> inputs = inputs.to("cuda", torch.float32)
477
+
478
+ >>> # transcribe audio to ids
479
+ >>> generated_ids = model.generate(**inputs)
480
+
481
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
482
+ >>> transcription[0]
483
+ " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
484
+ ```
485
+
486
+ - *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
487
+
488
+ ```python
489
+ >>> import torch
490
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
491
+ >>> from datasets import load_dataset
492
+
493
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
494
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
495
+
496
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
497
+
498
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
499
+ >>> input_features = inputs.input_features
500
+
501
+ >>> generated_ids = model.generate(inputs=input_features)
502
+
503
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
504
+ >>> transcription
505
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
506
+ ```
507
+
508
+ """
509
+ # 0. deprecate old inputs
510
+ if "inputs" in kwargs:
511
+ input_features = kwargs.pop("inputs")
512
+ warnings.warn(
513
+ "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
514
+ FutureWarning,
515
+ )
516
+
517
+ # 1. prepare generation config
518
+ generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
519
+
520
+ # 2. set global generate variables
521
+ input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
522
+ num_segment_frames = input_stride * self.config.max_source_positions
523
+ batch_size, total_input_frames = self._retrieve_total_input_frames(
524
+ input_features=input_features, input_stride=input_stride, kwargs=kwargs
525
+ )
526
+ is_shortform = total_input_frames <= num_segment_frames
527
+
528
+ # 3. Make sure generation config is correctly set
529
+ # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
530
+ return_dict_in_generate = self._set_return_outputs(
531
+ return_dict_in_generate=return_dict_in_generate,
532
+ return_token_timestamps=return_token_timestamps,
533
+ logprob_threshold=logprob_threshold,
534
+ generation_config=generation_config,
535
+ )
536
+ timestamp_begin = self._set_return_timestamps(
537
+ return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config
538
+ )
539
+ self._set_language_and_task(
540
+ language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
541
+ )
542
+ self._set_num_frames(
543
+ return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
544
+ )
545
+ self._set_thresholds_and_condition(
546
+ generation_config=generation_config,
547
+ logprob_threshold=logprob_threshold,
548
+ compression_ratio_threshold=compression_ratio_threshold,
549
+ no_speech_threshold=no_speech_threshold,
550
+ condition_on_prev_tokens=condition_on_prev_tokens,
551
+ )
552
+ self._set_prompt_condition_type(
553
+ generation_config=generation_config,
554
+ prompt_condition_type=prompt_condition_type,
555
+ )
556
+
557
+ # pass self.config for backward compatibility
558
+ init_tokens = self._retrieve_init_tokens(
559
+ input_features,
560
+ batch_size=batch_size,
561
+ generation_config=generation_config,
562
+ config=self.config,
563
+ num_segment_frames=num_segment_frames,
564
+ kwargs=kwargs,
565
+ )
566
+ # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
567
+ # where the input ids are handled explicitly by the generate method
568
+ self._check_decoder_input_ids(kwargs=kwargs)
569
+
570
+ # 3. Retrieve logits processors
571
+ device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
572
+ begin_index = init_tokens.shape[1]
573
+ logits_processor = self._retrieve_logit_processors(
574
+ generation_config=generation_config,
575
+ logits_processor=logits_processor,
576
+ begin_index=begin_index, # begin index is index of first generated decoder token
577
+ num_beams=kwargs.get("num_beams", 1),
578
+ device=device,
579
+ )
580
+
581
+ # 4 Set and retrieve global generation variables
582
+ self._set_condition_on_prev_tokens(
583
+ condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
584
+ )
585
+
586
+ temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
587
+ temperature = temperatures[0]
588
+
589
+ max_frames, seek = self._retrieve_max_frames_and_seek(
590
+ batch_size=batch_size,
591
+ attention_mask=attention_mask,
592
+ total_input_frames=total_input_frames,
593
+ is_shortform=is_shortform,
594
+ )
595
+
596
+ # 5 Prepare running variables, list for generation
597
+ num_return_sequences = generation_config.num_return_sequences
598
+ (
599
+ batch_idx_map,
600
+ cur_bsz,
601
+ input_features,
602
+ seek,
603
+ max_frames,
604
+ init_tokens,
605
+ do_condition_on_prev_tokens,
606
+ ) = self._expand_variables_for_generation(
607
+ input_features=input_features,
608
+ seek=seek,
609
+ max_frames=max_frames,
610
+ init_tokens=init_tokens,
611
+ batch_size=batch_size,
612
+ condition_on_prev_tokens=condition_on_prev_tokens,
613
+ generation_config=generation_config,
614
+ )
615
+
616
+ current_segments = self._prepare_segments(
617
+ prompt_ids=prompt_ids,
618
+ batch_size=cur_bsz,
619
+ generation_config=generation_config,
620
+ )
621
+
622
+ # 6 Transcribe audio until we reach the end of all input audios
623
+ while (seek < max_frames).any():
624
+ # 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
625
+ # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
626
+ # to know which original audio is being decoded
627
+ # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
628
+ input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(
629
+ input_features=input_features,
630
+ seek=seek,
631
+ max_frames=max_frames,
632
+ cur_bsz=cur_bsz,
633
+ batch_idx_map=batch_idx_map,
634
+ )
635
+ time_offset = (
636
+ seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
637
+ )
638
+ seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
639
+
640
+ # 6.2 cut out next 30s segment from input features
641
+ segment_input = self._get_input_segment(
642
+ input_features=input_features,
643
+ seek=seek,
644
+ seek_num_frames=seek_num_frames,
645
+ num_segment_frames=num_segment_frames,
646
+ cur_bsz=cur_bsz,
647
+ batch_idx_map=batch_idx_map,
648
+ )
649
+
650
+ # 6.3 prepare decoder input ids
651
+ suppress_tokens = _get_attr_from_logit_processors(
652
+ logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
653
+ )
654
+
655
+ decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
656
+ cur_bsz=cur_bsz,
657
+ init_tokens=init_tokens,
658
+ current_segments=current_segments,
659
+ batch_idx_map=batch_idx_map,
660
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
661
+ prompt_ids=prompt_ids,
662
+ generation_config=generation_config,
663
+ config=self.config,
664
+ device=init_tokens.device,
665
+ suppress_tokens=suppress_tokens,
666
+ timestamp_begin=timestamp_begin,
667
+ kwargs=kwargs,
668
+ )
669
+
670
+ # 6.4 set max new tokens or max length
671
+ self._set_max_new_tokens_and_length(
672
+ config=self.config,
673
+ decoder_input_ids=decoder_input_ids,
674
+ generation_config=generation_config,
675
+ )
676
+
677
+ # 6.5 Set current `begin_index` for all logit processors
678
+ if logits_processor is not None:
679
+ for proc in logits_processor:
680
+ if hasattr(proc, "set_begin_index"):
681
+ proc.set_begin_index(decoder_input_ids.shape[-1])
682
+
683
+ # 6.6 Run generate with fallback
684
+ (
685
+ seek_sequences,
686
+ seek_outputs,
687
+ should_skip,
688
+ do_condition_on_prev_tokens,
689
+ model_output_type,
690
+ ) = self.generate_with_fallback(
691
+ segment_input=segment_input,
692
+ decoder_input_ids=decoder_input_ids,
693
+ cur_bsz=cur_bsz,
694
+ batch_idx_map=batch_idx_map,
695
+ seek=seek,
696
+ num_segment_frames=num_segment_frames,
697
+ max_frames=max_frames,
698
+ temperatures=temperatures,
699
+ generation_config=generation_config,
700
+ logits_processor=logits_processor,
701
+ stopping_criteria=stopping_criteria,
702
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
703
+ synced_gpus=synced_gpus,
704
+ return_token_timestamps=return_token_timestamps,
705
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
706
+ is_shortform=is_shortform,
707
+ batch_size=batch_size,
708
+ attention_mask=attention_mask,
709
+ kwargs=kwargs,
710
+ )
711
+
712
+ # 6.7 In every generated sequence, split by timestamp tokens and extract segments
713
+ for i, seek_sequence in enumerate(seek_sequences):
714
+ prev_i = batch_idx_map[i]
715
+
716
+ if should_skip[i]:
717
+ seek[prev_i] += seek_num_frames[prev_i]
718
+ continue
719
+
720
+ segments, segment_offset = self._retrieve_segment(
721
+ seek_sequence=seek_sequence,
722
+ seek_outputs=seek_outputs,
723
+ time_offset=time_offset,
724
+ timestamp_begin=timestamp_begin,
725
+ seek_num_frames=seek_num_frames,
726
+ time_precision=time_precision,
727
+ time_precision_features=time_precision_features,
728
+ input_stride=input_stride,
729
+ prev_idx=prev_i,
730
+ idx=i,
731
+ return_token_timestamps=return_token_timestamps,
732
+ )
733
+
734
+ current_segments[prev_i] += segments
735
+
736
+ if is_shortform:
737
+ seek[prev_i] += max_frames[i]
738
+ else:
739
+ seek[prev_i] += segment_offset
740
+
741
+ # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
742
+ # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
743
+ final_segments = (
744
+ [x[1:] for x in current_segments]
745
+ if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
746
+ else current_segments
747
+ )
748
+
749
+ sequences = _pad_to_max_length(
750
+ final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
751
+ )
752
+
753
+ # 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
754
+ if return_segments:
755
+ return {"sequences": sequences, "segments": final_segments}
756
+
757
+ if is_shortform:
758
+ # add eos token:
759
+ if generation_config.max_new_tokens is None and generation_config.max_length is None:
760
+ eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id)
761
+ sequences = torch.cat([sequences, eos_tokens], dim=-1)
762
+
763
+ if return_token_timestamps:
764
+ outputs = {}
765
+ outputs["sequences"] = sequences
766
+ outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0)
767
+ else:
768
+ outputs = sequences
769
+
770
+ if return_dict_in_generate and generation_config.return_dict_in_generate:
771
+ dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
772
+
773
+ if num_return_sequences > 1:
774
+ if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None:
775
+ dict_outputs.encoder_attentions = tuple(
776
+ dict_outputs.encoder_attentions[i][::num_return_sequences]
777
+ for i in range(len(dict_outputs.encoder_attentions))
778
+ )
779
+ if (
780
+ hasattr(dict_outputs, "encoder_hidden_states")
781
+ and dict_outputs.encoder_hidden_states is not None
782
+ ):
783
+ dict_outputs.encoder_hidden_states = tuple(
784
+ dict_outputs.encoder_hidden_states[i][::num_return_sequences]
785
+ for i in range(len(dict_outputs.encoder_hidden_states))
786
+ )
787
+ if return_token_timestamps:
788
+ dict_outputs["token_timestamps"] = outputs["token_timestamps"]
789
+ return dict_outputs
790
+
791
+ return outputs
792
+
793
+ return sequences
794
+
795
+ def generate_with_fallback(
796
+ self,
797
+ segment_input,
798
+ decoder_input_ids,
799
+ cur_bsz,
800
+ batch_idx_map,
801
+ seek,
802
+ num_segment_frames,
803
+ max_frames,
804
+ temperatures,
805
+ generation_config,
806
+ logits_processor,
807
+ stopping_criteria,
808
+ prefix_allowed_tokens_fn,
809
+ synced_gpus,
810
+ return_token_timestamps,
811
+ do_condition_on_prev_tokens,
812
+ is_shortform,
813
+ batch_size,
814
+ attention_mask,
815
+ kwargs,
816
+ ):
817
+ kwargs = copy.copy(kwargs)
818
+
819
+ # 6.6 Batch generate current chunk
820
+ seek_sequence_list = [None for _ in range(cur_bsz)]
821
+ seek_outputs_list = [None for _ in range(cur_bsz)]
822
+ needs_fallback = [False for _ in range(cur_bsz)]
823
+ should_skip = [False for _ in range(cur_bsz)]
824
+ fallback_index_map = list(range(cur_bsz))
825
+ if generation_config.no_speech_threshold is not None:
826
+ self._setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs)
827
+
828
+ for fallback_idx, temperature in enumerate(temperatures):
829
+ generation_config.do_sample = temperature is not None and temperature > 0.0
830
+ generation_config.temperature = temperature if generation_config.do_sample else 1.0
831
+ if generation_config.do_sample:
832
+ generation_config.num_beams = 1
833
+
834
+ generate_kwargs = copy.copy(kwargs)
835
+ for key in ["do_sample", "temperature", "num_beams"]:
836
+ if key in generate_kwargs:
837
+ del generate_kwargs[key]
838
+
839
+ cur_bsz = decoder_input_ids.shape[0]
840
+ if generation_config.cache_implementation == "static" and cur_bsz < batch_size:
841
+ segment_input = F.pad(segment_input, (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0)
842
+ decoder_input_ids = F.pad(
843
+ decoder_input_ids, (0, 0, 0, batch_size - cur_bsz), value=generation_config.pad_token_id
844
+ )
845
+ if generate_kwargs.get("decoder_attention_mask") is not None:
846
+ generate_kwargs["decoder_attention_mask"] = F.pad(
847
+ generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True
848
+ )
849
+ if generate_kwargs.get("encoder_outputs") is not None:
850
+ generate_kwargs["encoder_outputs"] = F.pad(
851
+ generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0
852
+ )
853
+
854
+ seek_outputs = super().generate(
855
+ segment_input,
856
+ generation_config=generation_config,
857
+ logits_processor=logits_processor,
858
+ stopping_criteria=stopping_criteria,
859
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
860
+ synced_gpus=synced_gpus,
861
+ decoder_input_ids=decoder_input_ids,
862
+ attention_mask=attention_mask,
863
+ **generate_kwargs,
864
+ )
865
+
866
+ model_output_type = type(seek_outputs)
867
+
868
+ # post-process sequence tokens and outputs to be in list form
869
+ seek_sequences, seek_outputs = self._postprocess_outputs(
870
+ seek_outputs=seek_outputs,
871
+ decoder_input_ids=decoder_input_ids,
872
+ return_token_timestamps=return_token_timestamps,
873
+ generation_config=generation_config,
874
+ is_shortform=is_shortform,
875
+ )
876
+
877
+ if cur_bsz < batch_size:
878
+ seek_sequences = seek_sequences[:cur_bsz]
879
+ seek_outputs = seek_outputs[:cur_bsz]
880
+
881
+ # 6.7 Extract cut sequences from every sequence and check if fallback should be applied
882
+ # Loop over each decoded audio individually as each decoding can be of a different length
883
+ new_fallback_index_map = []
884
+ new_segment_input = []
885
+ new_decoder_input_ids = []
886
+ new_decoder_attention_mask = []
887
+
888
+ for i, seek_sequence in enumerate(seek_sequences):
889
+ # make sure we cut a predicted EOS token if we are not finished with the generation yet
890
+ prev_i = batch_idx_map[fallback_index_map[i]]
891
+ is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
892
+
893
+ # remove eos token id
894
+ if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
895
+ seek_sequence = seek_sequence[:-1]
896
+ if return_token_timestamps and not is_shortform:
897
+ seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
898
+
899
+ # remove all padding tokens
900
+ if seek_sequence[-1] == generation_config.pad_token_id:
901
+ num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
902
+ seek_sequence = seek_sequence[:-num_paddings]
903
+ if return_token_timestamps and not is_shortform:
904
+ seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
905
+
906
+ # check which sequences in batch need fallback & which should be skipped
907
+ needs_fallback[i], should_skip[i] = self._need_fallback(
908
+ seek_sequence,
909
+ seek_outputs,
910
+ i,
911
+ logits_processor,
912
+ generation_config,
913
+ self.config.vocab_size,
914
+ temperature,
915
+ )
916
+
917
+ seek_sequence_list[fallback_index_map[i]] = seek_sequence
918
+ seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
919
+ is_low_temperature = temperature is None or temperature < 0.5
920
+ do_condition_on_prev_tokens[fallback_index_map[i]] = (
921
+ generation_config.condition_on_prev_tokens and is_low_temperature
922
+ )
923
+
924
+ if needs_fallback[i]:
925
+ new_fallback_index_map.append(fallback_index_map[i])
926
+ new_segment_input.append(segment_input[i])
927
+ new_decoder_input_ids.append(decoder_input_ids[i])
928
+ if "decoder_attention_mask" in kwargs:
929
+ new_decoder_attention_mask.append(kwargs["decoder_attention_mask"][i])
930
+
931
+ fallback_index_map = new_fallback_index_map
932
+
933
+ # if no sequence needs to be run with temperature fallback, we're finished
934
+ if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1:
935
+ seek_sequences = seek_sequence_list
936
+ seek_outputs = seek_outputs_list
937
+ break
938
+
939
+ # if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors
940
+ decoder_input_ids = torch.stack(new_decoder_input_ids)
941
+ segment_input = torch.stack(new_segment_input)
942
+ if "decoder_attention_mask" in kwargs:
943
+ kwargs["decoder_attention_mask"] = torch.stack(new_decoder_attention_mask)
944
+
945
+ return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type
946
+
947
+ @staticmethod
948
+ def _prepare_segments(prompt_ids, batch_size, generation_config):
949
+ if prompt_ids is not None and generation_config.prompt_condition_type == "first-segment":
950
+ prev_sot_token_id = getattr(generation_config, "prev_sot_token_id", None)
951
+ prompt_ids = prompt_ids[1:] if prompt_ids[0] == prev_sot_token_id else prompt_ids
952
+ current_segments = [[{"tokens": prompt_ids}] for _ in range(batch_size)]
953
+ else:
954
+ current_segments = [[] for _ in range(batch_size)]
955
+
956
+ return current_segments
957
+
958
+ def _postprocess_outputs(
959
+ self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform
960
+ ):
961
+ # remove all previously passed decoder input ids
962
+ start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
963
+
964
+ if isinstance(seek_outputs, torch.Tensor):
965
+ seek_outputs = seek_outputs[:, start_idx:]
966
+ return seek_outputs, seek_outputs
967
+
968
+ if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
969
+ num_frames = getattr(generation_config, "num_frames", None)
970
+ seek_outputs["token_timestamps"] = self._extract_token_timestamps(
971
+ seek_outputs,
972
+ generation_config.alignment_heads,
973
+ num_frames=num_frames,
974
+ num_input_ids=decoder_input_ids.shape[-1],
975
+ )
976
+ seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
977
+
978
+ seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
979
+
980
+ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None):
981
+ if beam_indices is not None and key == "scores":
982
+ return [v[beam_idx].cpu() for (v, beam_idx) in zip(values, beam_indices[batch_idx][: len(values)])]
983
+ if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
984
+ return [v[batch_idx].cpu() for v in values]
985
+ if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
986
+ return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
987
+ elif key == "past_key_values":
988
+ if not is_shortform:
989
+ # we don't save `past_key_values` as this is too costly for longform
990
+ return None
991
+ elif isinstance(values, EncoderDecoderCache):
992
+ all_past_key_values = []
993
+ for layer_idx in range(self.config.decoder_layers):
994
+ layer_past_key_values = []
995
+ for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
996
+ for v in [cache_cls.key_cache, cache_cls.value_cache]:
997
+ layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu())
998
+ all_past_key_values.append(tuple(layer_past_key_values))
999
+ return tuple(all_past_key_values)
1000
+ else:
1001
+ all_past_key_values = []
1002
+ for v in range(len(values)):
1003
+ layer_past_key_values = []
1004
+ for w in values[v]:
1005
+ if len(w) != 0:
1006
+ layer_past_key_values.append(w[batch_idx][None].cpu())
1007
+ else:
1008
+ layer_past_key_values.append(w)
1009
+ all_past_key_values.append(tuple(layer_past_key_values))
1010
+ return tuple(all_past_key_values)
1011
+
1012
+ return values[batch_idx].cpu()
1013
+
1014
+ sequence_tokens = seek_outputs["sequences"]
1015
+ seek_outputs = [
1016
+ {
1017
+ k: split_by_batch_index(v, k, i, is_shortform, beam_indices=seek_outputs.get("beam_indices"))
1018
+ for k, v in seek_outputs.items()
1019
+ }
1020
+ for i in range(sequence_tokens.shape[0])
1021
+ ]
1022
+
1023
+ return sequence_tokens, seek_outputs
1024
+
1025
+ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs):
1026
+ # Stack back seek_outputs tensors after splitting them with the split_by_batch_index method
1027
+ outputs = {}
1028
+ for key in seek_outputs[0].keys():
1029
+ if key in ["sequences", "beam_indices"]:
1030
+ outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
1031
+ elif key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
1032
+ outputs[key] = tuple(
1033
+ torch.stack([v[key][i] for v in seek_outputs]).to(device) for i in range(len(seek_outputs[0][key]))
1034
+ )
1035
+ elif key == "sequences_scores":
1036
+ outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
1037
+ elif key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
1038
+ outputs[key] = tuple(
1039
+ tuple(
1040
+ torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
1041
+ for j in range(len(seek_outputs[0][key][0]))
1042
+ )
1043
+ for i in range(len(seek_outputs[0][key]))
1044
+ )
1045
+ elif key == "past_key_values":
1046
+ past_key_value_type = kwargs.get("past_key_values")
1047
+ if seek_outputs[0][key] is not None:
1048
+ outputs[key] = tuple(
1049
+ tuple(
1050
+ torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
1051
+ for j in range(len(seek_outputs[0][key][0]))
1052
+ )
1053
+ for i in range(len(seek_outputs[0][key]))
1054
+ )
1055
+ if past_key_value_type is not None and isinstance(past_key_value_type, EncoderDecoderCache):
1056
+ outputs[key] = past_key_value_type.from_legacy_cache(outputs[key])
1057
+ else:
1058
+ outputs[key] = None
1059
+
1060
+ return model_output_type(**outputs)
1061
+
1062
+ def _need_fallback(
1063
+ self,
1064
+ seek_sequence,
1065
+ seek_outputs,
1066
+ index,
1067
+ logits_processor,
1068
+ generation_config,
1069
+ vocab_size,
1070
+ temperature,
1071
+ ):
1072
+ needs_fallback = False
1073
+ should_skip = False
1074
+ if generation_config.compression_ratio_threshold is not None:
1075
+ compression_ratio = self._retrieve_compression_ratio(seek_sequence, vocab_size)
1076
+
1077
+ if compression_ratio > generation_config.compression_ratio_threshold:
1078
+ needs_fallback = True
1079
+
1080
+ if generation_config.logprob_threshold is not None:
1081
+ if hasattr(seek_outputs[0], "sequences_scores"):
1082
+ logprobs = [s["sequences_scores"] for s in seek_outputs][index]
1083
+ else:
1084
+ scores = seek_outputs[index]["scores"]
1085
+ logprobs = self._retrieve_avg_logprobs(
1086
+ scores, seek_sequence, generation_config.eos_token_id, temperature
1087
+ )
1088
+
1089
+ if logprobs < generation_config.logprob_threshold:
1090
+ needs_fallback = True
1091
+
1092
+ if generation_config.no_speech_threshold is not None:
1093
+ no_speech_prob = _get_attr_from_logit_processors(
1094
+ logits_processor, WhisperNoSpeechDetection, "no_speech_prob"
1095
+ )
1096
+
1097
+ if (
1098
+ logprobs < generation_config.logprob_threshold
1099
+ and no_speech_prob[index] > generation_config.no_speech_threshold
1100
+ ):
1101
+ needs_fallback = False
1102
+ should_skip = True
1103
+
1104
+ return needs_fallback, should_skip
1105
+
1106
+ def _expand_variables_for_generation(
1107
+ self, input_features, seek, max_frames, init_tokens, batch_size, condition_on_prev_tokens, generation_config
1108
+ ):
1109
+ if generation_config.num_return_sequences is not None and generation_config.num_return_sequences > 1:
1110
+ batch_idx_map = list(range(batch_size * generation_config.num_return_sequences))
1111
+ cur_bsz = len(batch_idx_map)
1112
+ do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(len(batch_idx_map))]
1113
+ input_features = input_features.repeat_interleave(generation_config.num_return_sequences, dim=0)
1114
+ seek = seek.repeat_interleave(generation_config.num_return_sequences, dim=0)
1115
+ max_frames = max_frames.repeat_interleave(generation_config.num_return_sequences, dim=0)
1116
+ init_tokens = init_tokens.repeat_interleave(generation_config.num_return_sequences, dim=0)
1117
+ generation_config.num_return_sequences = 1
1118
+ else:
1119
+ cur_bsz = batch_size
1120
+ batch_idx_map = list(range(cur_bsz))
1121
+ do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(cur_bsz)]
1122
+
1123
+ return (
1124
+ batch_idx_map,
1125
+ cur_bsz,
1126
+ input_features,
1127
+ seek,
1128
+ max_frames,
1129
+ init_tokens,
1130
+ do_condition_on_prev_tokens,
1131
+ )
1132
+
1133
+ @staticmethod
1134
+ def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs):
1135
+ set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs")
1136
+ extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)}
1137
+ set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs})
1138
+
1139
+ @staticmethod
1140
+ def _retrieve_total_input_frames(input_features, input_stride, kwargs):
1141
+ if input_features is not None:
1142
+ return input_features.shape[0], input_features.shape[-1]
1143
+
1144
+ if "encoder_outputs" in kwargs:
1145
+ encoder_outputs_shape = (
1146
+ kwargs["encoder_outputs"][0].shape
1147
+ if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
1148
+ else kwargs["encoder_outputs"].shape
1149
+ )
1150
+ return encoder_outputs_shape[0], encoder_outputs_shape[1] * input_stride
1151
+
1152
+ raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
1153
+
1154
+ @staticmethod
1155
+ def _maybe_warn_unused_inputs(
1156
+ condition_on_prev_tokens,
1157
+ temperature,
1158
+ compression_ratio_threshold,
1159
+ logprob_threshold,
1160
+ no_speech_threshold,
1161
+ total_input_frames,
1162
+ ):
1163
+ warning_prefix = (
1164
+ f"Audio input consists of only {total_input_frames}. "
1165
+ "Short-form transcription is activated."
1166
+ "{}, but will be ignored."
1167
+ )
1168
+ if condition_on_prev_tokens is not None:
1169
+ logger.warning(warning_prefix.format(f"condition_on_prev_tokens is set to {condition_on_prev_tokens}"))
1170
+
1171
+ if compression_ratio_threshold is not None:
1172
+ logger.warning(
1173
+ warning_prefix.format(f"compression_ratio_threshold is set to {compression_ratio_threshold}")
1174
+ )
1175
+
1176
+ if logprob_threshold is not None:
1177
+ logger.warning(warning_prefix.format(f"logprob_threshold is set to {logprob_threshold}"))
1178
+
1179
+ if no_speech_threshold is not None:
1180
+ logger.warning(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}"))
1181
+
1182
+ # when passing temperature as a list it cannot just be ignored => throw error in this case
1183
+ if isinstance(temperature, (list, tuple)):
1184
+ raise ValueError(
1185
+ f"Audio input consists of only {total_input_frames}. Short-form transcription is activated."
1186
+ f"temperature cannot be set to {temperature} which can only be used for temperature fallback for long-form generation. Make sure to set `temperature` to a float value or `None` for short-form generation."
1187
+ )
1188
+
1189
+ @staticmethod
1190
+ def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
1191
+ if return_dict_in_generate is None:
1192
+ return_dict_in_generate = generation_config.return_dict_in_generate
1193
+ else:
1194
+ generation_config.return_dict_in_generate = return_dict_in_generate
1195
+
1196
+ generation_config.return_token_timestamps = return_token_timestamps
1197
+ if return_token_timestamps:
1198
+ generation_config.return_dict_in_generate = True
1199
+ generation_config.output_attentions = True
1200
+ generation_config.output_scores = True
1201
+
1202
+ if logprob_threshold is not None:
1203
+ generation_config.return_dict_in_generate = True
1204
+ generation_config.output_scores = True
1205
+
1206
+ return return_dict_in_generate
1207
+
1208
+ def _set_return_timestamps(self, return_timestamps, is_shortform, generation_config):
1209
+ if return_timestamps is None and hasattr(generation_config, "return_timestamps"):
1210
+ return_timestamps = generation_config.return_timestamps
1211
+
1212
+ if not is_shortform:
1213
+ if return_timestamps is False:
1214
+ raise ValueError(
1215
+ "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
1216
+ "requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
1217
+ )
1218
+
1219
+ logger.info("Setting `return_timestamps=True` for long-form generation.")
1220
+ return_timestamps = True
1221
+
1222
+ if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
1223
+ raise ValueError(
1224
+ "You are trying to return timestamps, but the generation config is not properly set. "
1225
+ "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
1226
+ "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
1227
+ )
1228
+
1229
+ generation_config.return_timestamps = return_timestamps
1230
+
1231
+ if hasattr(generation_config, "no_timestamps_token_id"):
1232
+ timestamp_begin = generation_config.no_timestamps_token_id + 1
1233
+ else:
1234
+ # BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form with no timestamps
1235
+ # We set the timestamp begin token larger than the vocab size, such that the timestamp condition is never met in the decoding loop
1236
+ timestamp_begin = self.config.vocab_size + 1
1237
+
1238
+ return timestamp_begin
1239
+
1240
+ @staticmethod
1241
+ def _set_language_and_task(language, task, is_multilingual, generation_config):
1242
+ if is_multilingual is not None:
1243
+ if not hasattr(generation_config, "is_multilingual"):
1244
+ raise ValueError(
1245
+ "The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
1246
+ "to `generate`. Please update the generation config as per the instructions "
1247
+ "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
1248
+ )
1249
+ generation_config.is_multilingual = is_multilingual
1250
+
1251
+ if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
1252
+ if task is not None or language is not None:
1253
+ raise ValueError(
1254
+ "Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
1255
+ "multilingual, pass `is_multilingual=True` to generate, or update the generation config."
1256
+ )
1257
+
1258
+ if language is not None:
1259
+ if not hasattr(generation_config, "lang_to_id"):
1260
+ raise ValueError(
1261
+ "The generation config is outdated and is thus not compatible with the `language` argument "
1262
+ "to `generate`. Either set the language using the `forced_decoder_ids` in the model config, "
1263
+ "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
1264
+ )
1265
+ generation_config.language = language
1266
+
1267
+ if task is not None:
1268
+ if not hasattr(generation_config, "task_to_id"):
1269
+ raise ValueError(
1270
+ "The generation config is outdated and is thus not compatible with the `task` argument "
1271
+ "to `generate`. Either set the task using the `forced_decoder_ids` in the model config, "
1272
+ "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
1273
+ )
1274
+ generation_config.task = task
1275
+
1276
+ def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
1277
+ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
1278
+ """short function to replace num with a itr in lst"""
1279
+ found = any(i in lst for i in itr)
1280
+ if found:
1281
+ lst = [num if i in itr else i for i in lst]
1282
+ else:
1283
+ lst.append(num)
1284
+ return lst
1285
+
1286
+ def language_to_id(language: str) -> int:
1287
+ language = language.lower()
1288
+ if language in generation_config.lang_to_id.keys():
1289
+ language_token = language
1290
+ elif language in TO_LANGUAGE_CODE.keys():
1291
+ language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
1292
+ elif language in TO_LANGUAGE_CODE.values():
1293
+ language_token = f"<|{language}|>"
1294
+ else:
1295
+ is_language_code = len(language) == 2
1296
+ raise ValueError(
1297
+ f"Unsupported language: {language}. Language should be one of:"
1298
+ f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
1299
+ )
1300
+ if language_token not in generation_config.lang_to_id:
1301
+ raise ValueError(
1302
+ f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
1303
+ "(You should just add it to the generation config)"
1304
+ )
1305
+
1306
+ return generation_config.lang_to_id[language_token]
1307
+
1308
+ task = getattr(generation_config, "task", None)
1309
+ language = getattr(generation_config, "language", None)
1310
+
1311
+ forced_decoder_ids = generation_config.forced_decoder_ids
1312
+ if forced_decoder_ids is not None:
1313
+ if language is None and task is None and forced_decoder_ids[0][1] is None:
1314
+ logger.warning_once(
1315
+ "Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
1316
+ "This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
1317
+ )
1318
+ elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
1319
+ forced_decoder_ids = config.forced_decoder_ids
1320
+
1321
+ if forced_decoder_ids is not None and task is not None:
1322
+ logger.warning_once(
1323
+ f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}."
1324
+ )
1325
+ forced_decoder_ids = None
1326
+ elif forced_decoder_ids is not None and language is not None:
1327
+ logger.warning_once(
1328
+ f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
1329
+ )
1330
+ forced_decoder_ids = None
1331
+
1332
+ init_tokens = [generation_config.decoder_start_token_id]
1333
+ if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
1334
+ i = 1
1335
+ while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
1336
+ init_tokens += [forced_decoder_ids[0][1]]
1337
+ forced_decoder_ids = forced_decoder_ids[1:]
1338
+ i += 1
1339
+
1340
+ if len(forced_decoder_ids) > 0:
1341
+ raise ValueError(
1342
+ f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.",
1343
+ )
1344
+
1345
+ # from v4.39 the forced decoder ids are always None in favour of decoder input ids
1346
+ generation_config.forced_decoder_ids = None
1347
+
1348
+ is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
1349
+
1350
+ # Make sure language is a list of strings of the correct length
1351
+ if isinstance(language, (list, tuple)):
1352
+ if any(l is None for l in language):
1353
+ raise TypeError(
1354
+ "Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list containing `None`."
1355
+ )
1356
+ if len(language) != batch_size:
1357
+ raise ValueError(
1358
+ "When passing a list of languages, the length of the list must match the batch size. "
1359
+ f"Expected length of {batch_size}, but got {len(language)} languages."
1360
+ )
1361
+ languages = language
1362
+ elif language is None:
1363
+ # Language will be detected for each item in batch
1364
+ languages = [None] * batch_size
1365
+ else:
1366
+ languages = [language] # Use a length-1 list now, broadcast later
1367
+
1368
+ # Separate init_tokens for each language
1369
+ init_tokens = [copy.copy(init_tokens) for _ in languages]
1370
+
1371
+ # Update init_tokens with languages
1372
+ lang_ids = None
1373
+ if language is not None:
1374
+ lang_ids = [language_to_id(l) for l in languages]
1375
+ elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined:
1376
+ # language is not defined or intentially set to `None` to trigger language detection
1377
+ lang_ids = self.detect_language(
1378
+ input_features=input_features,
1379
+ encoder_outputs=kwargs.get("encoder_outputs", None),
1380
+ generation_config=generation_config,
1381
+ num_segment_frames=num_segment_frames,
1382
+ ).tolist()
1383
+ if lang_ids is not None:
1384
+ # append or replace lang_ids to init_tokens
1385
+ for i in range(len(init_tokens)):
1386
+ if len(init_tokens[i]) > 1:
1387
+ init_tokens[i][1] = lang_ids[i]
1388
+ else:
1389
+ init_tokens[i].append(lang_ids[i])
1390
+ del languages
1391
+
1392
+ # Update init_tokens with task
1393
+ for i in range(len(init_tokens)):
1394
+ if task is not None:
1395
+ if task in TASK_IDS:
1396
+ init_tokens[i].append(generation_config.task_to_id[generation_config.task])
1397
+ task_id = generation_config.task_to_id[generation_config.task]
1398
+
1399
+ # if task is defined it'll overwrite task ids that might have already been defined via the generation_config
1400
+ replace_or_add(init_tokens[i], task_id, generation_config.task_to_id.values())
1401
+ else:
1402
+ raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`")
1403
+ elif language is not None and hasattr(generation_config, "task_to_id"):
1404
+ # if language is defined, but no task id is in `init_tokens`, default to transcribe
1405
+ if not any(ti in init_tokens[i] for ti in generation_config.task_to_id.values()):
1406
+ init_tokens[i].append(generation_config.task_to_id["transcribe"])
1407
+
1408
+ if (
1409
+ not generation_config.return_timestamps
1410
+ and hasattr(generation_config, "no_timestamps_token_id")
1411
+ and init_tokens[i][-1] != generation_config.no_timestamps_token_id
1412
+ ):
1413
+ init_tokens[i].append(generation_config.no_timestamps_token_id)
1414
+ elif (
1415
+ generation_config.return_timestamps and init_tokens[i][-1] == generation_config.no_timestamps_token_id
1416
+ ):
1417
+ logger.info(
1418
+ "<|notimestamps|> prompt token is removed from generation_config since `return_timestamps` is set to `'True'`."
1419
+ )
1420
+ init_tokens[i] = init_tokens[i][:-1]
1421
+
1422
+ # let's make sure we don't pass `None` tokens as prompt tokens
1423
+ init_tokens[i] = [t for t in init_tokens[i] if t is not None]
1424
+
1425
+ return torch.as_tensor(init_tokens, dtype=torch.long, device=self.device).expand(batch_size, -1)
1426
+
1427
+ def detect_language(
1428
+ self,
1429
+ input_features: Optional[torch.FloatTensor] = None,
1430
+ encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
1431
+ generation_config: Optional[GenerationConfig] = None,
1432
+ num_segment_frames: int = 3000,
1433
+ ) -> torch.Tensor:
1434
+ """
1435
+ Detects language from log-mel input features or encoder_outputs
1436
+
1437
+ Parameters:
1438
+ input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
1439
+ Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
1440
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
1441
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
1442
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
1443
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
1444
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1445
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
1446
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
1447
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1448
+ generation_config (`~generation.GenerationConfig`, *optional*):
1449
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
1450
+ passed to generate matching the attributes of `generation_config` will override them. If
1451
+ `generation_config` is not provided, the default will be used, which had the following loading
1452
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
1453
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
1454
+ default values, whose documentation should be checked to parameterize generation.
1455
+ num_segment_frames (`int`, *optional*, defaults to 3000):
1456
+ The number of log-mel frames the model expects
1457
+
1458
+ Return:
1459
+ A `torch.LongTensor` representing the detected language ids.
1460
+ """
1461
+ if input_features is None and encoder_outputs is None:
1462
+ raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
1463
+ elif input_features is not None and encoder_outputs is not None:
1464
+ raise ValueError("Make sure to specificy only one of `input_features` or `encoder_outputs` - not both!")
1465
+ elif input_features is not None:
1466
+ inputs = {"input_features": input_features[:, :, :num_segment_frames]}
1467
+ batch_size = input_features.shape[0]
1468
+ elif encoder_outputs is not None:
1469
+ inputs = {"encoder_outputs": encoder_outputs}
1470
+ batch_size = (
1471
+ encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
1472
+ )
1473
+
1474
+ generation_config = generation_config or self.generation_config
1475
+ decoder_input_ids = (
1476
+ torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
1477
+ * generation_config.decoder_start_token_id
1478
+ )
1479
+
1480
+ with torch.no_grad():
1481
+ logits = self(**inputs, decoder_input_ids=decoder_input_ids).logits[:, -1]
1482
+
1483
+ non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
1484
+ non_lang_mask[list(generation_config.lang_to_id.values())] = False
1485
+
1486
+ logits[:, non_lang_mask] = -np.inf
1487
+
1488
+ lang_ids = logits.argmax(-1)
1489
+
1490
+ return lang_ids
1491
+
1492
+ @staticmethod
1493
+ def _check_decoder_input_ids(kwargs):
1494
+ decoder_input_ids = kwargs.get("decoder_input_ids", None)
1495
+ assistant_model = kwargs.get("assistant_model", None)
1496
+ if decoder_input_ids is not None and assistant_model is not None:
1497
+ raise ValueError(
1498
+ "Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
1499
+ )
1500
+
1501
+ @staticmethod
1502
+ def _set_num_frames(return_token_timestamps, generation_config, kwargs):
1503
+ if return_token_timestamps:
1504
+ if getattr(generation_config, "task", None) == "translate":
1505
+ logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
1506
+ if not hasattr(generation_config, "alignment_heads"):
1507
+ raise ValueError(
1508
+ "Model generation config has no `alignment_heads`, token-level timestamps not available. "
1509
+ "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
1510
+ )
1511
+ generation_config.num_frames = kwargs.pop("num_frames", None)
1512
+
1513
+ @staticmethod
1514
+ def _set_thresholds_and_condition(
1515
+ generation_config,
1516
+ logprob_threshold,
1517
+ compression_ratio_threshold,
1518
+ no_speech_threshold,
1519
+ condition_on_prev_tokens,
1520
+ ):
1521
+ generation_config.logprob_threshold = (
1522
+ logprob_threshold
1523
+ if logprob_threshold is not None
1524
+ else getattr(generation_config, "logprob_threshold", None)
1525
+ )
1526
+ generation_config.compression_ratio_threshold = (
1527
+ compression_ratio_threshold
1528
+ if compression_ratio_threshold is not None
1529
+ else getattr(generation_config, "compression_ratio_threshold", None)
1530
+ )
1531
+ generation_config.no_speech_threshold = (
1532
+ no_speech_threshold
1533
+ if no_speech_threshold is not None
1534
+ else getattr(generation_config, "no_speech_threshold", None)
1535
+ )
1536
+ generation_config.condition_on_prev_tokens = (
1537
+ condition_on_prev_tokens
1538
+ if condition_on_prev_tokens is not None
1539
+ else getattr(generation_config, "condition_on_prev_tokens", None)
1540
+ )
1541
+
1542
+ @staticmethod
1543
+ def _set_prompt_condition_type(generation_config, prompt_condition_type):
1544
+ allowed_cond_types = ["first-segment", "all-segments"]
1545
+
1546
+ # default to "first-segment"
1547
+ prompt_condition_type = prompt_condition_type or allowed_cond_types[0]
1548
+
1549
+ if prompt_condition_type not in allowed_cond_types:
1550
+ raise ValueError(
1551
+ f"`prompt_condition_type={prompt_condition_type} does not exist. Make sure to set `prompt_condition_type` to one of {', '.join(allowed_cond_types)}"
1552
+ )
1553
+
1554
+ if generation_config.condition_on_prev_tokens is not True and prompt_condition_type == "all-segments":
1555
+ raise ValueError(
1556
+ "Make sure to set `condition_on_prev_tokens=True` when setting `prompt_condition_type='all-segments'`."
1557
+ )
1558
+
1559
+ generation_config.prompt_condition_type = prompt_condition_type
1560
+
1561
+ @staticmethod
1562
+ def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config):
1563
+ condition_on_prev_tokens = (
1564
+ condition_on_prev_tokens
1565
+ if condition_on_prev_tokens is not None
1566
+ else getattr(generation_config, "condition_on_prev_tokens", False)
1567
+ )
1568
+ generation_config.condition_on_prev_tokens = condition_on_prev_tokens
1569
+
1570
+ @staticmethod
1571
+ def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames, is_shortform):
1572
+ if batch_size > 1 and not is_shortform and attention_mask is None:
1573
+ raise ValueError(
1574
+ "When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
1575
+ )
1576
+ elif batch_size > 1 and not is_shortform:
1577
+ max_frames = attention_mask.sum(-1).cpu().to(torch.long)
1578
+ seek = torch.zeros((batch_size,), dtype=torch.long)
1579
+ else:
1580
+ max_frames = torch.ones((batch_size,), dtype=torch.long) * total_input_frames
1581
+ seek = torch.zeros((batch_size,), dtype=torch.long)
1582
+
1583
+ return max_frames, seek
1584
+
1585
+ def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, num_beams, device):
1586
+ if generation_config.return_timestamps is True:
1587
+ timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
1588
+ logits_processor = (
1589
+ [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor
1590
+ )
1591
+
1592
+ if generation_config.suppress_tokens is not None:
1593
+ suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
1594
+ logits_processor = (
1595
+ [suppress_tokens_processor]
1596
+ if logits_processor is None
1597
+ else [suppress_tokens_processor] + logits_processor
1598
+ )
1599
+ generation_config.suppress_tokens = None
1600
+
1601
+ if generation_config.begin_suppress_tokens is not None:
1602
+ begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
1603
+ generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
1604
+ )
1605
+ logits_processor = (
1606
+ [begin_suppress_processor]
1607
+ if logits_processor is None
1608
+ else [begin_suppress_processor] + logits_processor
1609
+ )
1610
+ generation_config.begin_suppress_tokens = None
1611
+
1612
+ if generation_config.no_speech_threshold is not None:
1613
+ no_speech_detector = WhisperNoSpeechDetection(
1614
+ no_speech_token=generation_config.no_timestamps_token_id - 1,
1615
+ begin_index=begin_index,
1616
+ scores_is_logprobs=num_beams > 1,
1617
+ )
1618
+ logits_processor = (
1619
+ [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor
1620
+ )
1621
+ no_speech_detector.set_model(self)
1622
+
1623
+ return logits_processor
1624
+
1625
+ @staticmethod
1626
+ def _maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map):
1627
+ prev_bsz = cur_bsz
1628
+ new_batch_idx_map = []
1629
+ for i in range(prev_bsz):
1630
+ prev_i = batch_idx_map[i]
1631
+ if seek[prev_i] >= max_frames[prev_i]:
1632
+ cut_index = i + (cur_bsz - prev_bsz)
1633
+ cur_bsz -= 1
1634
+ input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0)
1635
+ else:
1636
+ # cut out index that goes away
1637
+ new_batch_idx_map.append(prev_i)
1638
+
1639
+ return input_features, cur_bsz, new_batch_idx_map
1640
+
1641
+ @staticmethod
1642
+ def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map):
1643
+ if input_features is None:
1644
+ return None
1645
+
1646
+ segment_input = []
1647
+ for i in range(cur_bsz):
1648
+ prev_i = batch_idx_map[i]
1649
+ segment_input_slice = input_features[i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]]
1650
+
1651
+ if segment_input_slice.shape[-1] < num_segment_frames:
1652
+ # pad to 3000 if necessary
1653
+ segment_input_slice = F.pad(
1654
+ segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
1655
+ )
1656
+
1657
+ segment_input.append(segment_input_slice)
1658
+
1659
+ segment_input = torch.cat(segment_input, dim=0)
1660
+
1661
+ return segment_input
1662
+
1663
+ @staticmethod
1664
+ def _prepare_decoder_input_ids(
1665
+ cur_bsz,
1666
+ init_tokens,
1667
+ current_segments,
1668
+ batch_idx_map,
1669
+ do_condition_on_prev_tokens,
1670
+ prompt_ids,
1671
+ generation_config,
1672
+ config,
1673
+ device,
1674
+ suppress_tokens,
1675
+ timestamp_begin,
1676
+ kwargs,
1677
+ ):
1678
+ if "decoder_input_ids" in kwargs:
1679
+ decoder_input_ids = kwargs.pop("decoder_input_ids")
1680
+
1681
+ return decoder_input_ids, kwargs
1682
+
1683
+ cut_off_length = config.max_target_positions // 2 - 1
1684
+
1685
+ decoder_input_ids = init_tokens[batch_idx_map]
1686
+
1687
+ prev_start_of_text = getattr(generation_config, "prev_sot_token_id", None)
1688
+ if prev_start_of_text is None:
1689
+ prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None
1690
+
1691
+ if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0:
1692
+ # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
1693
+ active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
1694
+
1695
+ for segments in active_segments:
1696
+ for seg in segments:
1697
+ if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin:
1698
+ # the segment finishes with two timestamp tokens
1699
+ # we need to ignore the last timestamp token
1700
+ # see https://github.com/huggingface/transformers/pull/34537
1701
+ seg["tokens"] = seg["tokens"][:-1]
1702
+
1703
+ if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
1704
+ prev_ids = prompt_ids
1705
+ else:
1706
+ one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
1707
+ prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None
1708
+
1709
+ padding = "max_length" if generation_config.cache_implementation == "static" else "longest"
1710
+
1711
+ prev_tokens = _pad_to_max_length(
1712
+ active_segments,
1713
+ generation_config.pad_token_id,
1714
+ device=device,
1715
+ padding_side="left",
1716
+ padding=padding,
1717
+ bos_token_tensor=prev_ids,
1718
+ cut_off_length=cut_off_length,
1719
+ )
1720
+ decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
1721
+
1722
+ kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id
1723
+ elif prompt_ids is not None:
1724
+ prev_tokens = prompt_ids[None].repeat(decoder_input_ids.shape[0], 1)
1725
+ decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
1726
+ # make sure `"decoder_attention_mask"` is not passed to forward
1727
+ kwargs.pop("decoder_attention_mask", None)
1728
+ else:
1729
+ # make sure `"decoder_attention_mask"` is not passed to forward
1730
+ kwargs.pop("decoder_attention_mask", None)
1731
+
1732
+ return decoder_input_ids, kwargs
1733
+
1734
+ def _set_max_new_tokens_and_length(self, config, decoder_input_ids, generation_config):
1735
+ max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
1736
+ if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
1737
+ raise ValueError(
1738
+ f"The length of `decoder_input_ids`, including special start tokens, prompt tokens, and previous tokens, is {decoder_input_ids.shape[-1]}, "
1739
+ f" and `max_new_tokens` is {max_new_tokens}. Thus, the combined length of "
1740
+ f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
1741
+ f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
1742
+ "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
1743
+ f"so that their combined length is less than {self.config.max_target_positions}."
1744
+ )
1745
+
1746
+ num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)
1747
+
1748
+ # Make sure we don't get larger than `max_length`
1749
+ if generation_config.max_length is not None and generation_config.max_new_tokens is None:
1750
+ max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
1751
+ logger.info(
1752
+ f"Increase max_length from {generation_config.max_length} to {max_length} since input is conditioned on previous segment."
1753
+ )
1754
+ elif (
1755
+ generation_config.max_new_tokens is not None
1756
+ and generation_config.max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
1757
+ ):
1758
+ max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
1759
+ generation_config.max_new_tokens = max_new_tokens
1760
+
1761
+ @staticmethod
1762
+ def _retrieve_compression_ratio(tokens, vocab_size):
1763
+ """Compute byte length of zlib compressed token bytes vs. byte length of raw token bytes"""
1764
+ length = int(math.log2(vocab_size) / 8) + 1
1765
+ token_bytes = b"".join([t.to_bytes(length, "little") for t in tokens.tolist()])
1766
+ compression_ratio = len(token_bytes) / len(zlib.compress(token_bytes))
1767
+
1768
+ return compression_ratio
1769
+
1770
+ @staticmethod
1771
+ def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature):
1772
+ rescale_temperature = temperature if temperature > 0.0 else 1
1773
+ scores = torch.stack(scores).to(tokens.device)
1774
+
1775
+ if scores.shape[0] > tokens.shape[0]:
1776
+ scores = scores[: tokens.shape[0]]
1777
+ else:
1778
+ tokens = tokens[-scores.shape[0] :]
1779
+
1780
+ logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype)
1781
+
1782
+ # retrieve logprob of selected tokens and sum
1783
+ sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0]))
1784
+ length = (tokens != eos_token_id).sum(-1) if eos_token_id is not None else tokens.shape[0]
1785
+
1786
+ avg_logprobs = sum_logprobs / (length + 1)
1787
+ return avg_logprobs
1788
+
1789
+ @staticmethod
1790
+ def _retrieve_segment(
1791
+ seek_sequence,
1792
+ seek_outputs,
1793
+ time_offset,
1794
+ timestamp_begin,
1795
+ seek_num_frames,
1796
+ time_precision,
1797
+ time_precision_features,
1798
+ input_stride,
1799
+ prev_idx,
1800
+ idx,
1801
+ return_token_timestamps,
1802
+ ):
1803
+ # find the predicted "end of segment" predictions of Whisper
1804
+ # "end of segment" predictions occur whenever Whisper predicts a timestamp token
1805
+ timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
1806
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
1807
+ timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
1808
+ timestamp_segment_indices.add_(1)
1809
+ token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
1810
+ device = seek_sequence.device
1811
+
1812
+ # If whisper predicted a "end of segment" via a timestep token, let's go ever each
1813
+ # "end of segment" prediction and slice the decoding into segments accordingly
1814
+ if len(timestamp_segment_indices) > 0:
1815
+ # if the output contains two consecutive timestamp tokens
1816
+ slices = timestamp_segment_indices.tolist()
1817
+ segments = []
1818
+ if single_timestamp_ending:
1819
+ slices.append(len(seek_sequence))
1820
+ else:
1821
+ # we want to include the last timestamp token in the last segment to know it was no single ending
1822
+ slices[-1] += 1
1823
+
1824
+ last_slice = 0
1825
+ # Add each segment to list of all segments
1826
+ for i, current_slice in enumerate(slices):
1827
+ is_last_slice = i == len(slices) - 1
1828
+ sliced_tokens = seek_sequence[last_slice:current_slice]
1829
+ start_timestamp_pos = sliced_tokens[0] - timestamp_begin
1830
+ idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2
1831
+ end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
1832
+ segments.append(
1833
+ {
1834
+ "start": time_offset[prev_idx]
1835
+ + start_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
1836
+ * time_precision,
1837
+ "end": time_offset[prev_idx]
1838
+ + end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
1839
+ * time_precision,
1840
+ "tokens": sliced_tokens,
1841
+ "result": seek_outputs[idx],
1842
+ }
1843
+ )
1844
+ if return_token_timestamps:
1845
+ segments[-1]["token_timestamps"] = (
1846
+ token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
1847
+ )
1848
+ last_slice = current_slice
1849
+
1850
+ if single_timestamp_ending:
1851
+ # single timestamp at the end means no speech after the last timestamp.
1852
+ segment_offset = seek_num_frames[prev_idx]
1853
+ else:
1854
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
1855
+ # here we throw away all predictions after the last predicted "end of segment"
1856
+ # since we are cutting right in the middle of an audio
1857
+ last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin
1858
+ segment_offset = last_timestamp_pos * input_stride
1859
+ else:
1860
+ # If whisper does not predict any "end of segment" token, then
1861
+ # the whole decoding is considered a segment and we add it to the list of segments
1862
+ timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
1863
+ last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
1864
+ if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin:
1865
+ # no consecutive timestamps but it has a timestamp; use the last one.
1866
+ last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(
1867
+ torch.float32 if device.type == "mps" else torch.float64
1868
+ )
1869
+ segments = [
1870
+ {
1871
+ "start": time_offset[prev_idx],
1872
+ "end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
1873
+ "tokens": seek_sequence,
1874
+ "result": seek_outputs[idx],
1875
+ }
1876
+ ]
1877
+ if return_token_timestamps:
1878
+ segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
1879
+ segment_offset = seek_num_frames[prev_idx]
1880
+
1881
+ return segments, segment_offset
generation_whisper.cpython-312 (1).pyc ADDED
Binary file (87.3 kB). View file
 
generation_whisper.cpython-312.pyc ADDED
Binary file (87.3 kB). View file
 
generation_whisper.py ADDED
@@ -0,0 +1,1881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import copy
16
+ import math
17
+ import warnings
18
+ import zlib
19
+ from typing import Callable, Iterator, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import nn
25
+
26
+ from transformers.cache_utils import EncoderDecoderCache
27
+
28
+ from ...generation import GenerationConfig, GenerationMixin
29
+ from ...generation.logits_process import (
30
+ LogitsProcessorList,
31
+ SuppressTokensAtBeginLogitsProcessor,
32
+ SuppressTokensLogitsProcessor,
33
+ WhisperNoSpeechDetection,
34
+ WhisperTimeStampLogitsProcessor,
35
+ )
36
+ from ...generation.stopping_criteria import StoppingCriteriaList
37
+ from ...modeling_outputs import BaseModelOutput
38
+ from ...utils import logging
39
+ from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+
45
+ def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
46
+ """
47
+ Applies a median filter of width `filter_width` along the last dimension of the input.
48
+
49
+ The `inputs` tensor is assumed to be 3- or 4-dimensional.
50
+ """
51
+ if filter_width <= 0 or filter_width % 2 != 1:
52
+ raise ValueError("`filter_width` should be an odd number")
53
+
54
+ pad_width = filter_width // 2
55
+ if inputs.shape[-1] <= pad_width:
56
+ return inputs
57
+
58
+ # Pad the left and right edges.
59
+ inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
60
+
61
+ # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
62
+ result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
63
+ return result
64
+
65
+
66
+ def _dynamic_time_warping(matrix: np.ndarray):
67
+ """
68
+ Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
69
+ token-level timestamps.
70
+ """
71
+ output_length, input_length = matrix.shape
72
+ cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
73
+ trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
74
+
75
+ cost[0, 0] = 0
76
+ for j in range(1, input_length + 1):
77
+ for i in range(1, output_length + 1):
78
+ c0 = cost[i - 1, j - 1]
79
+ c1 = cost[i - 1, j]
80
+ c2 = cost[i, j - 1]
81
+
82
+ if c0 < c1 and c0 < c2:
83
+ c, t = c0, 0
84
+ elif c1 < c0 and c1 < c2:
85
+ c, t = c1, 1
86
+ else:
87
+ c, t = c2, 2
88
+
89
+ cost[i, j] = matrix[i - 1, j - 1] + c
90
+ trace[i, j] = t
91
+
92
+ # backtrace
93
+ i = trace.shape[0] - 1
94
+ j = trace.shape[1] - 1
95
+ trace[0, :] = 2
96
+ trace[:, 0] = 1
97
+
98
+ text_indices = []
99
+ time_indices = []
100
+ while i > 0 or j > 0:
101
+ text_indices.append(i - 1)
102
+ time_indices.append(j - 1)
103
+ if trace[i, j] == 0:
104
+ i -= 1
105
+ j -= 1
106
+ elif trace[i, j] == 1:
107
+ i -= 1
108
+ elif trace[i, j] == 2:
109
+ j -= 1
110
+ else:
111
+ raise RuntimeError(
112
+ f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
113
+ )
114
+
115
+ text_indices = np.array(text_indices)[::-1]
116
+ time_indices = np.array(time_indices)[::-1]
117
+ return text_indices, time_indices
118
+
119
+
120
+ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
121
+ if logits_processor is not None:
122
+ logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
123
+ if logit_processor:
124
+ return getattr(logit_processor, attribute_name, None)
125
+ return None
126
+
127
+
128
+ def _pad_to_max_length(
129
+ current_segments,
130
+ pad_token_id,
131
+ device,
132
+ padding_side="right",
133
+ padding="longest",
134
+ bos_token_tensor=None,
135
+ cut_off_length=None,
136
+ ):
137
+ max_total_length = 0
138
+ sequences = []
139
+
140
+ if padding_side not in ["right", "left"]:
141
+ raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
142
+
143
+ if padding not in ["longest", "max_length"]:
144
+ raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
145
+ elif padding == "max_length" and cut_off_length is None:
146
+ raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
147
+
148
+ for current_segment_list in current_segments:
149
+ if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
150
+ sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
151
+
152
+ if cut_off_length is not None:
153
+ sequence = sequence[-cut_off_length:]
154
+
155
+ if bos_token_tensor is not None:
156
+ sequence = torch.cat([bos_token_tensor, sequence])
157
+
158
+ sequences.append(sequence)
159
+ max_total_length = max(max_total_length, len(sequences[-1]))
160
+ elif bos_token_tensor is not None:
161
+ sequences.append(bos_token_tensor)
162
+ else:
163
+ sequences.append(torch.tensor([], device=device))
164
+
165
+ max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
166
+ for i in range(len(current_segments)):
167
+ pad_length = max_total_length - len(sequences[i])
168
+ pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
169
+ sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
170
+
171
+ sequences = torch.stack(sequences, dim=0)
172
+ return sequences
173
+
174
+
175
+ class WhisperGenerationMixin(GenerationMixin):
176
+ def _extract_token_timestamps(
177
+ self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None, num_input_ids=None
178
+ ):
179
+ """
180
+ Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
181
+ map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
182
+ cross-attentions will be cropped before applying DTW.
183
+
184
+ Returns:
185
+ tensor containing the timestamps in seconds for each predicted token
186
+ """
187
+ # Create a list with `decoder_layers` elements, each a tensor of shape
188
+ # (batch size, attention_heads, output length, input length).
189
+ cross_attentions = []
190
+ for i in range(self.config.decoder_layers):
191
+ cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
192
+
193
+ # Select specific cross-attention layers and heads. This is a tensor
194
+ # of shape (batch size, num selected, output length, input length).
195
+ weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
196
+ weights = weights.permute([1, 0, 2, 3])
197
+
198
+ weight_length = None
199
+
200
+ if "beam_indices" in generate_outputs:
201
+ # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
202
+ # since the beam search strategy chooses the most probable sequences at the end of the search.
203
+ # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
204
+ weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
205
+ weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids
206
+
207
+ # beam search takes `decoder_input_ids` into account in the `beam_indices` length
208
+ # but forgot to shift the beam_indices by the number of `decoder_input_ids`
209
+ beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length])
210
+ # we actually shif the beam indices here
211
+ beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids]
212
+
213
+ weights = weights[:, :, :weight_length]
214
+
215
+ # If beam index is still -1, it means that the associated token id is EOS
216
+ # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
217
+ beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
218
+
219
+ # Select the cross attention from the right beam for each output sequences
220
+ weights = torch.stack(
221
+ [
222
+ torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
223
+ for i in range(beam_indices.shape[1])
224
+ ],
225
+ dim=2,
226
+ )
227
+
228
+ # make sure timestamps are as long as weights
229
+ input_length = weight_length or cross_attentions[0].shape[2]
230
+ batch_size = generate_outputs.sequences.shape[0]
231
+ timestamps = torch.zeros(
232
+ (batch_size, input_length + 1), dtype=torch.float32, device=generate_outputs.sequences.device
233
+ )
234
+
235
+ if num_frames is not None:
236
+ # two cases:
237
+ # 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel
238
+ # 2. num_frames is different, compute the DTW matrix for each sample sequentially
239
+
240
+ # we're using np.unique because num_frames can be int/list/tuple
241
+ if isinstance(num_frames, int):
242
+ weights = weights[..., : num_frames // 2]
243
+
244
+ elif isinstance(num_frames, (list, tuple, np.ndarray)) and len(np.unique(num_frames)) == 1:
245
+ weights = weights[..., : num_frames[0] // 2]
246
+
247
+ elif isinstance(num_frames, (torch.Tensor)) and len(torch.unique(num_frames)) == 1:
248
+ weights = weights[..., : num_frames[0] // 2]
249
+
250
+ else:
251
+ # num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
252
+ repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
253
+ num_frames = num_frames.cpu() if isinstance(num_frames, (torch.Tensor)) else num_frames
254
+ num_frames = np.repeat(num_frames, repeat_time)
255
+
256
+ if num_frames is None or isinstance(num_frames, int):
257
+ # Normalize and smoothen the weights.
258
+ std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
259
+ mean = torch.mean(weights, dim=-2, keepdim=True)
260
+ weights = (weights - mean) / std
261
+ weights = _median_filter(weights, self.config.median_filter_width)
262
+
263
+ # Average the different cross-attention heads.
264
+ weights = weights.mean(dim=1)
265
+
266
+ # Perform dynamic time warping on each element of the batch.
267
+ for batch_idx in range(batch_size):
268
+ if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray, torch.Tensor)):
269
+ matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]
270
+
271
+ # Normalize and smoothen the weights.
272
+ std = torch.std(matrix, dim=-2, keepdim=True, unbiased=False)
273
+ mean = torch.mean(matrix, dim=-2, keepdim=True)
274
+ matrix = (matrix - mean) / std
275
+ matrix = _median_filter(matrix, self.config.median_filter_width)
276
+
277
+ # Average the different cross-attention heads.
278
+ matrix = matrix.mean(dim=0)
279
+ else:
280
+ matrix = weights[batch_idx]
281
+
282
+ text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy())
283
+ jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
284
+ jump_times = time_indices[jumps] * time_precision
285
+ timestamps[batch_idx, 1:] = torch.tensor(jump_times)
286
+
287
+ return timestamps
288
+
289
+ def generate(
290
+ self,
291
+ input_features: Optional[torch.Tensor] = None,
292
+ generation_config: Optional[GenerationConfig] = None,
293
+ logits_processor: Optional[LogitsProcessorList] = None,
294
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
295
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
296
+ synced_gpus: bool = False,
297
+ return_timestamps: Optional[bool] = None,
298
+ task: Optional[str] = None,
299
+ language: Optional[Union[str, List[str]]] = None,
300
+ is_multilingual: Optional[bool] = None,
301
+ prompt_ids: Optional[torch.Tensor] = None,
302
+ prompt_condition_type: Optional[str] = None, # first-segment, all-segments
303
+ condition_on_prev_tokens: Optional[bool] = None,
304
+ temperature: Optional[Union[float, Tuple[float, ...]]] = None,
305
+ compression_ratio_threshold: Optional[float] = None,
306
+ logprob_threshold: Optional[float] = None,
307
+ no_speech_threshold: Optional[float] = None,
308
+ num_segment_frames: Optional[int] = None,
309
+ attention_mask: Optional[torch.Tensor] = None,
310
+ time_precision: float = 0.02,
311
+ time_precision_features: float = 0.01,
312
+ return_token_timestamps: Optional[bool] = None,
313
+ return_segments: bool = False,
314
+ return_dict_in_generate: Optional[bool] = None,
315
+ **kwargs,
316
+ ):
317
+ """
318
+ Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
319
+
320
+ <Tip warning={true}>
321
+
322
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
323
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
324
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
325
+
326
+ For an overview of generation strategies and code examples, check out the [following
327
+ guide](./generation_strategies).
328
+
329
+ </Tip>
330
+
331
+ Parameters:
332
+ input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
333
+ Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
334
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
335
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
336
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
337
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
338
+ generation_config (`~generation.GenerationConfig`, *optional*):
339
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
340
+ passed to generate matching the attributes of `generation_config` will override them. If
341
+ `generation_config` is not provided, the default will be used, which had the following loading
342
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
343
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
344
+ default values, whose documentation should be checked to parameterize generation.
345
+ logits_processor (`LogitsProcessorList`, *optional*):
346
+ Custom logits processors that complement the default logits processors built from arguments and
347
+ generation config. If a logit processor is passed that is already created with the arguments or a
348
+ generation config an error is thrown. This feature is intended for advanced users.
349
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
350
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
351
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
352
+ generation config an error is thrown. This feature is intended for advanced users.
353
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
354
+ If provided, this function constraints the beam search to allowed tokens only at each step. If not
355
+ provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
356
+ `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
357
+ on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
358
+ for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
359
+ Retrieval](https://arxiv.org/abs/2010.00904).
360
+ synced_gpus (`bool`, *optional*, defaults to `False`):
361
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
362
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
363
+ return_timestamps (`bool`, *optional*):
364
+ Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
365
+ task (`str`, *optional*):
366
+ Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
367
+ will be updated accordingly.
368
+ language (`str` or list of `str`, *optional*):
369
+ Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. For
370
+ batched generation, a list of language tokens can be passed. You can find all the possible language
371
+ tokens in the `model.generation_config.lang_to_id` dictionary.
372
+ is_multilingual (`bool`, *optional*):
373
+ Whether or not the model is multilingual.
374
+ prompt_ids (`torch.Tensor`, *optional*):
375
+ Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
376
+ provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
377
+ transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
378
+ correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
379
+ prompt_condition_type (`str`, *optional*):
380
+ Only relevant for long-form transcription. Condition type of `prompt_ids`. 'first-segment' means only the first segment is conditioned on `prompt_ids`. 'all-segments' means each segment is conditioned on `prompt_ids`. Make sure to enable `condition_on_prev_tokens` for 'all-segments'.
381
+ Defaults to 'first-segment'. For short-term transcription only 'first-segment' is possible.
382
+ condition_on_prev_tokens (`bool`, *optional*):
383
+ Only relevant for long-form transcription. Whether to condition each segment on the previous segment.
384
+ As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
385
+ performance.
386
+ temperature (`float` or list of `float`, *optional*):
387
+ The temperature to be used for generation. Passing a single `float` value and `do_sample=True` activates
388
+ generation using sampling. For long-form transcription, temperature fallback can be activated by passing
389
+ a list of float values such as (0.0, 0.2, 0.4, 0.6, 0.8, 1.0). As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
390
+ performance.
391
+ compression_ratio_threshold (`float`, *optional*):
392
+ Only relevant for long-form transcription. If defined, the zlib compression rate of each segment will be computed. If the compression rate of
393
+ a segment is higher than `compression_ratio_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is
394
+ repeated using a higher temperature. The intuition behind this feature is that segments with very high compression rates
395
+ suffer from a lot of repetition. The unwanted repetition can be reduced by injecting more randomness by increasing the temperature. If `compression_ratio_threshold` is defined
396
+ make sure that `temperature` is a list of values. A common value for `compression_ratio_threshold` is 1.35.
397
+ As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
398
+ performance.
399
+ logprob_threshold (`float`, *optional*):
400
+ Only relevant for long-form transcription. If defined, the average log-probability of each segment will be computed. If the log-probability of
401
+ a given segment is lower than `logprob_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is
402
+ repeated using a higher temperature. The intuition behind this feature is that segments of low log-probability
403
+ can be improved by injecting more randomness by increasing the temperature. If `logprob_threshold` is defined
404
+ make sure that `temperature` is a list of values. A common value for `logprob_threshold` is -1.0.
405
+ As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
406
+ performance.
407
+ no_speech_threshold (`float`, *optional*):
408
+ Only relevant for long-form transcription. If defined, the "no-speech" token combined with the `logprob_threshold`
409
+ is used to determine whether a segment contains only silence. In this case, the transcription for this segment
410
+ is skipped.
411
+ As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
412
+ performance.
413
+ num_segment_frames (`int`, *optional*):
414
+ The number of frames a single segment is made of. If not defined, `num_segment_frames` defaults to the model's stride
415
+ times the maximum input length.
416
+ attention_mask (`torch.Tensor`, *optional*):
417
+ `attention_mask` needs to be passed when doing long-form transcription using a batch size > 1.
418
+ time_precision (`int`, *optional*, defaults to 0.02):
419
+ The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
420
+ for 20 ms.
421
+ time_precision_features (`int`, *optional*, defaults to 0.01):
422
+ The duration represented by a feature frame in seconds.
423
+ return_token_timestamps (`bool`, *optional*):
424
+ Whether to return token-level timestamps with the text. This can be used with or without the
425
+ `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
426
+ words.
427
+ return_segments (`bool`, *optional*, defaults to `False`):
428
+ Whether to additionally return a list of all segments. Note that this option can only be enabled
429
+ when doing long-form transcription.
430
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
431
+ Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens.
432
+ Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
433
+ `return_segments` is set True. In this case the generation outputs of each segment is added to each
434
+ segment.
435
+ kwargs (`Dict[str, Any]`, *optional*):
436
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
437
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
438
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
439
+
440
+ Return:
441
+ [`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
442
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
443
+
444
+ If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
445
+
446
+ else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
447
+
448
+ - [`~generation.GenerateEncoderDecoderOutput`],
449
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
450
+
451
+ else only the generated output sequence ids are returned.
452
+
453
+ Example:
454
+
455
+ - *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
456
+
457
+ ```python
458
+ >>> import torch
459
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
460
+ >>> from datasets import load_dataset, Audio
461
+
462
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
463
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
464
+ >>> model.cuda() # doctest: +IGNORE_RESULT
465
+
466
+ >>> # load audios > 30 seconds
467
+ >>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
468
+ >>> # resample to 16kHz
469
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=16000))
470
+ >>> # take first 8 audios and retrieve array
471
+ >>> audio = ds[:8]["audio"]
472
+ >>> audio = [x["array"] for x in audio]
473
+
474
+ >>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio
475
+ >>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
476
+ >>> inputs = inputs.to("cuda", torch.float32)
477
+
478
+ >>> # transcribe audio to ids
479
+ >>> generated_ids = model.generate(**inputs)
480
+
481
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
482
+ >>> transcription[0]
483
+ " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
484
+ ```
485
+
486
+ - *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
487
+
488
+ ```python
489
+ >>> import torch
490
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
491
+ >>> from datasets import load_dataset
492
+
493
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
494
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
495
+
496
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
497
+
498
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
499
+ >>> input_features = inputs.input_features
500
+
501
+ >>> generated_ids = model.generate(inputs=input_features)
502
+
503
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
504
+ >>> transcription
505
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
506
+ ```
507
+
508
+ """
509
+ # 0. deprecate old inputs
510
+ if "inputs" in kwargs:
511
+ input_features = kwargs.pop("inputs")
512
+ warnings.warn(
513
+ "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
514
+ FutureWarning,
515
+ )
516
+
517
+ # 1. prepare generation config
518
+ generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
519
+
520
+ # 2. set global generate variables
521
+ input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
522
+ num_segment_frames = input_stride * self.config.max_source_positions
523
+ batch_size, total_input_frames = self._retrieve_total_input_frames(
524
+ input_features=input_features, input_stride=input_stride, kwargs=kwargs
525
+ )
526
+ is_shortform = total_input_frames <= num_segment_frames
527
+
528
+ # 3. Make sure generation config is correctly set
529
+ # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
530
+ return_dict_in_generate = self._set_return_outputs(
531
+ return_dict_in_generate=return_dict_in_generate,
532
+ return_token_timestamps=return_token_timestamps,
533
+ logprob_threshold=logprob_threshold,
534
+ generation_config=generation_config,
535
+ )
536
+ timestamp_begin = self._set_return_timestamps(
537
+ return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config
538
+ )
539
+ self._set_language_and_task(
540
+ language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
541
+ )
542
+ self._set_num_frames(
543
+ return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
544
+ )
545
+ self._set_thresholds_and_condition(
546
+ generation_config=generation_config,
547
+ logprob_threshold=logprob_threshold,
548
+ compression_ratio_threshold=compression_ratio_threshold,
549
+ no_speech_threshold=no_speech_threshold,
550
+ condition_on_prev_tokens=condition_on_prev_tokens,
551
+ )
552
+ self._set_prompt_condition_type(
553
+ generation_config=generation_config,
554
+ prompt_condition_type=prompt_condition_type,
555
+ )
556
+
557
+ # pass self.config for backward compatibility
558
+ init_tokens = self._retrieve_init_tokens(
559
+ input_features,
560
+ batch_size=batch_size,
561
+ generation_config=generation_config,
562
+ config=self.config,
563
+ num_segment_frames=num_segment_frames,
564
+ kwargs=kwargs,
565
+ )
566
+ # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
567
+ # where the input ids are handled explicitly by the generate method
568
+ self._check_decoder_input_ids(kwargs=kwargs)
569
+
570
+ # 3. Retrieve logits processors
571
+ device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
572
+ begin_index = init_tokens.shape[1]
573
+ logits_processor = self._retrieve_logit_processors(
574
+ generation_config=generation_config,
575
+ logits_processor=logits_processor,
576
+ begin_index=begin_index, # begin index is index of first generated decoder token
577
+ num_beams=kwargs.get("num_beams", 1),
578
+ device=device,
579
+ )
580
+
581
+ # 4 Set and retrieve global generation variables
582
+ self._set_condition_on_prev_tokens(
583
+ condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
584
+ )
585
+
586
+ temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
587
+ temperature = temperatures[0]
588
+
589
+ max_frames, seek = self._retrieve_max_frames_and_seek(
590
+ batch_size=batch_size,
591
+ attention_mask=attention_mask,
592
+ total_input_frames=total_input_frames,
593
+ is_shortform=is_shortform,
594
+ )
595
+
596
+ # 5 Prepare running variables, list for generation
597
+ num_return_sequences = generation_config.num_return_sequences
598
+ (
599
+ batch_idx_map,
600
+ cur_bsz,
601
+ input_features,
602
+ seek,
603
+ max_frames,
604
+ init_tokens,
605
+ do_condition_on_prev_tokens,
606
+ ) = self._expand_variables_for_generation(
607
+ input_features=input_features,
608
+ seek=seek,
609
+ max_frames=max_frames,
610
+ init_tokens=init_tokens,
611
+ batch_size=batch_size,
612
+ condition_on_prev_tokens=condition_on_prev_tokens,
613
+ generation_config=generation_config,
614
+ )
615
+
616
+ current_segments = self._prepare_segments(
617
+ prompt_ids=prompt_ids,
618
+ batch_size=cur_bsz,
619
+ generation_config=generation_config,
620
+ )
621
+
622
+ # 6 Transcribe audio until we reach the end of all input audios
623
+ while (seek < max_frames).any():
624
+ # 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
625
+ # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
626
+ # to know which original audio is being decoded
627
+ # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
628
+ input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(
629
+ input_features=input_features,
630
+ seek=seek,
631
+ max_frames=max_frames,
632
+ cur_bsz=cur_bsz,
633
+ batch_idx_map=batch_idx_map,
634
+ )
635
+ time_offset = (
636
+ seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
637
+ )
638
+ seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
639
+
640
+ # 6.2 cut out next 30s segment from input features
641
+ segment_input = self._get_input_segment(
642
+ input_features=input_features,
643
+ seek=seek,
644
+ seek_num_frames=seek_num_frames,
645
+ num_segment_frames=num_segment_frames,
646
+ cur_bsz=cur_bsz,
647
+ batch_idx_map=batch_idx_map,
648
+ )
649
+
650
+ # 6.3 prepare decoder input ids
651
+ suppress_tokens = _get_attr_from_logit_processors(
652
+ logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
653
+ )
654
+
655
+ decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
656
+ cur_bsz=cur_bsz,
657
+ init_tokens=init_tokens,
658
+ current_segments=current_segments,
659
+ batch_idx_map=batch_idx_map,
660
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
661
+ prompt_ids=prompt_ids,
662
+ generation_config=generation_config,
663
+ config=self.config,
664
+ device=init_tokens.device,
665
+ suppress_tokens=suppress_tokens,
666
+ timestamp_begin=timestamp_begin,
667
+ kwargs=kwargs,
668
+ )
669
+
670
+ # 6.4 set max new tokens or max length
671
+ self._set_max_new_tokens_and_length(
672
+ config=self.config,
673
+ decoder_input_ids=decoder_input_ids,
674
+ generation_config=generation_config,
675
+ )
676
+
677
+ # 6.5 Set current `begin_index` for all logit processors
678
+ if logits_processor is not None:
679
+ for proc in logits_processor:
680
+ if hasattr(proc, "set_begin_index"):
681
+ proc.set_begin_index(decoder_input_ids.shape[-1])
682
+
683
+ # 6.6 Run generate with fallback
684
+ (
685
+ seek_sequences,
686
+ seek_outputs,
687
+ should_skip,
688
+ do_condition_on_prev_tokens,
689
+ model_output_type,
690
+ ) = self.generate_with_fallback(
691
+ segment_input=segment_input,
692
+ decoder_input_ids=decoder_input_ids,
693
+ cur_bsz=cur_bsz,
694
+ batch_idx_map=batch_idx_map,
695
+ seek=seek,
696
+ num_segment_frames=num_segment_frames,
697
+ max_frames=max_frames,
698
+ temperatures=temperatures,
699
+ generation_config=generation_config,
700
+ logits_processor=logits_processor,
701
+ stopping_criteria=stopping_criteria,
702
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
703
+ synced_gpus=synced_gpus,
704
+ return_token_timestamps=return_token_timestamps,
705
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
706
+ is_shortform=is_shortform,
707
+ batch_size=batch_size,
708
+ attention_mask=attention_mask,
709
+ kwargs=kwargs,
710
+ )
711
+
712
+ # 6.7 In every generated sequence, split by timestamp tokens and extract segments
713
+ for i, seek_sequence in enumerate(seek_sequences):
714
+ prev_i = batch_idx_map[i]
715
+
716
+ if should_skip[i]:
717
+ seek[prev_i] += seek_num_frames[prev_i]
718
+ continue
719
+
720
+ segments, segment_offset = self._retrieve_segment(
721
+ seek_sequence=seek_sequence,
722
+ seek_outputs=seek_outputs,
723
+ time_offset=time_offset,
724
+ timestamp_begin=timestamp_begin,
725
+ seek_num_frames=seek_num_frames,
726
+ time_precision=time_precision,
727
+ time_precision_features=time_precision_features,
728
+ input_stride=input_stride,
729
+ prev_idx=prev_i,
730
+ idx=i,
731
+ return_token_timestamps=return_token_timestamps,
732
+ )
733
+
734
+ current_segments[prev_i] += segments
735
+
736
+ if is_shortform:
737
+ seek[prev_i] += max_frames[i]
738
+ else:
739
+ seek[prev_i] += segment_offset
740
+
741
+ # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
742
+ # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
743
+ final_segments = (
744
+ [x[1:] for x in current_segments]
745
+ if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
746
+ else current_segments
747
+ )
748
+
749
+ sequences = _pad_to_max_length(
750
+ final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
751
+ )
752
+
753
+ # 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
754
+ if return_segments:
755
+ return {"sequences": sequences, "segments": final_segments}
756
+
757
+ if is_shortform:
758
+ # add eos token:
759
+ if generation_config.max_new_tokens is None and generation_config.max_length is None:
760
+ eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id)
761
+ sequences = torch.cat([sequences, eos_tokens], dim=-1)
762
+
763
+ if return_token_timestamps:
764
+ outputs = {}
765
+ outputs["sequences"] = sequences
766
+ outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0)
767
+ else:
768
+ outputs = sequences
769
+
770
+ if return_dict_in_generate and generation_config.return_dict_in_generate:
771
+ dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
772
+
773
+ if num_return_sequences > 1:
774
+ if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None:
775
+ dict_outputs.encoder_attentions = tuple(
776
+ dict_outputs.encoder_attentions[i][::num_return_sequences]
777
+ for i in range(len(dict_outputs.encoder_attentions))
778
+ )
779
+ if (
780
+ hasattr(dict_outputs, "encoder_hidden_states")
781
+ and dict_outputs.encoder_hidden_states is not None
782
+ ):
783
+ dict_outputs.encoder_hidden_states = tuple(
784
+ dict_outputs.encoder_hidden_states[i][::num_return_sequences]
785
+ for i in range(len(dict_outputs.encoder_hidden_states))
786
+ )
787
+ if return_token_timestamps:
788
+ dict_outputs["token_timestamps"] = outputs["token_timestamps"]
789
+ return dict_outputs
790
+
791
+ return outputs
792
+
793
+ return sequences
794
+
795
+ def generate_with_fallback(
796
+ self,
797
+ segment_input,
798
+ decoder_input_ids,
799
+ cur_bsz,
800
+ batch_idx_map,
801
+ seek,
802
+ num_segment_frames,
803
+ max_frames,
804
+ temperatures,
805
+ generation_config,
806
+ logits_processor,
807
+ stopping_criteria,
808
+ prefix_allowed_tokens_fn,
809
+ synced_gpus,
810
+ return_token_timestamps,
811
+ do_condition_on_prev_tokens,
812
+ is_shortform,
813
+ batch_size,
814
+ attention_mask,
815
+ kwargs,
816
+ ):
817
+ kwargs = copy.copy(kwargs)
818
+
819
+ # 6.6 Batch generate current chunk
820
+ seek_sequence_list = [None for _ in range(cur_bsz)]
821
+ seek_outputs_list = [None for _ in range(cur_bsz)]
822
+ needs_fallback = [False for _ in range(cur_bsz)]
823
+ should_skip = [False for _ in range(cur_bsz)]
824
+ fallback_index_map = list(range(cur_bsz))
825
+ if generation_config.no_speech_threshold is not None:
826
+ self._setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs)
827
+
828
+ for fallback_idx, temperature in enumerate(temperatures):
829
+ generation_config.do_sample = temperature is not None and temperature > 0.0
830
+ generation_config.temperature = temperature if generation_config.do_sample else 1.0
831
+ if generation_config.do_sample:
832
+ generation_config.num_beams = 1
833
+
834
+ generate_kwargs = copy.copy(kwargs)
835
+ for key in ["do_sample", "temperature", "num_beams"]:
836
+ if key in generate_kwargs:
837
+ del generate_kwargs[key]
838
+
839
+ cur_bsz = decoder_input_ids.shape[0]
840
+ if generation_config.cache_implementation == "static" and cur_bsz < batch_size:
841
+ segment_input = F.pad(segment_input, (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0)
842
+ decoder_input_ids = F.pad(
843
+ decoder_input_ids, (0, 0, 0, batch_size - cur_bsz), value=generation_config.pad_token_id
844
+ )
845
+ if generate_kwargs.get("decoder_attention_mask") is not None:
846
+ generate_kwargs["decoder_attention_mask"] = F.pad(
847
+ generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True
848
+ )
849
+ if generate_kwargs.get("encoder_outputs") is not None:
850
+ generate_kwargs["encoder_outputs"] = F.pad(
851
+ generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0
852
+ )
853
+
854
+ seek_outputs = super().generate(
855
+ segment_input,
856
+ generation_config=generation_config,
857
+ logits_processor=logits_processor,
858
+ stopping_criteria=stopping_criteria,
859
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
860
+ synced_gpus=synced_gpus,
861
+ decoder_input_ids=decoder_input_ids,
862
+ attention_mask=attention_mask,
863
+ **generate_kwargs,
864
+ )
865
+
866
+ model_output_type = type(seek_outputs)
867
+
868
+ # post-process sequence tokens and outputs to be in list form
869
+ seek_sequences, seek_outputs = self._postprocess_outputs(
870
+ seek_outputs=seek_outputs,
871
+ decoder_input_ids=decoder_input_ids,
872
+ return_token_timestamps=return_token_timestamps,
873
+ generation_config=generation_config,
874
+ is_shortform=is_shortform,
875
+ )
876
+
877
+ if cur_bsz < batch_size:
878
+ seek_sequences = seek_sequences[:cur_bsz]
879
+ seek_outputs = seek_outputs[:cur_bsz]
880
+
881
+ # 6.7 Extract cut sequences from every sequence and check if fallback should be applied
882
+ # Loop over each decoded audio individually as each decoding can be of a different length
883
+ new_fallback_index_map = []
884
+ new_segment_input = []
885
+ new_decoder_input_ids = []
886
+ new_decoder_attention_mask = []
887
+
888
+ for i, seek_sequence in enumerate(seek_sequences):
889
+ # make sure we cut a predicted EOS token if we are not finished with the generation yet
890
+ prev_i = batch_idx_map[fallback_index_map[i]]
891
+ is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
892
+
893
+ # remove eos token id
894
+ if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
895
+ seek_sequence = seek_sequence[:-1]
896
+ if return_token_timestamps and not is_shortform:
897
+ seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
898
+
899
+ # remove all padding tokens
900
+ if seek_sequence[-1] == generation_config.pad_token_id:
901
+ num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
902
+ seek_sequence = seek_sequence[:-num_paddings]
903
+ if return_token_timestamps and not is_shortform:
904
+ seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
905
+
906
+ # check which sequences in batch need fallback & which should be skipped
907
+ needs_fallback[i], should_skip[i] = self._need_fallback(
908
+ seek_sequence,
909
+ seek_outputs,
910
+ i,
911
+ logits_processor,
912
+ generation_config,
913
+ self.config.vocab_size,
914
+ temperature,
915
+ )
916
+
917
+ seek_sequence_list[fallback_index_map[i]] = seek_sequence
918
+ seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
919
+ is_low_temperature = temperature is None or temperature < 0.5
920
+ do_condition_on_prev_tokens[fallback_index_map[i]] = (
921
+ generation_config.condition_on_prev_tokens and is_low_temperature
922
+ )
923
+
924
+ if needs_fallback[i]:
925
+ new_fallback_index_map.append(fallback_index_map[i])
926
+ new_segment_input.append(segment_input[i])
927
+ new_decoder_input_ids.append(decoder_input_ids[i])
928
+ if "decoder_attention_mask" in kwargs:
929
+ new_decoder_attention_mask.append(kwargs["decoder_attention_mask"][i])
930
+
931
+ fallback_index_map = new_fallback_index_map
932
+
933
+ # if no sequence needs to be run with temperature fallback, we're finished
934
+ if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1:
935
+ seek_sequences = seek_sequence_list
936
+ seek_outputs = seek_outputs_list
937
+ break
938
+
939
+ # if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors
940
+ decoder_input_ids = torch.stack(new_decoder_input_ids)
941
+ segment_input = torch.stack(new_segment_input)
942
+ if "decoder_attention_mask" in kwargs:
943
+ kwargs["decoder_attention_mask"] = torch.stack(new_decoder_attention_mask)
944
+
945
+ return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type
946
+
947
+ @staticmethod
948
+ def _prepare_segments(prompt_ids, batch_size, generation_config):
949
+ if prompt_ids is not None and generation_config.prompt_condition_type == "first-segment":
950
+ prev_sot_token_id = getattr(generation_config, "prev_sot_token_id", None)
951
+ prompt_ids = prompt_ids[1:] if prompt_ids[0] == prev_sot_token_id else prompt_ids
952
+ current_segments = [[{"tokens": prompt_ids}] for _ in range(batch_size)]
953
+ else:
954
+ current_segments = [[] for _ in range(batch_size)]
955
+
956
+ return current_segments
957
+
958
+ def _postprocess_outputs(
959
+ self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform
960
+ ):
961
+ # remove all previously passed decoder input ids
962
+ start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
963
+
964
+ if isinstance(seek_outputs, torch.Tensor):
965
+ seek_outputs = seek_outputs[:, start_idx:]
966
+ return seek_outputs, seek_outputs
967
+
968
+ if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
969
+ num_frames = getattr(generation_config, "num_frames", None)
970
+ seek_outputs["token_timestamps"] = self._extract_token_timestamps(
971
+ seek_outputs,
972
+ generation_config.alignment_heads,
973
+ num_frames=num_frames,
974
+ num_input_ids=decoder_input_ids.shape[-1],
975
+ )
976
+ seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
977
+
978
+ seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
979
+
980
+ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None):
981
+ if beam_indices is not None and key == "scores":
982
+ return [v[beam_idx].cpu() for (v, beam_idx) in zip(values, beam_indices[batch_idx][: len(values)])]
983
+ if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
984
+ return [v[batch_idx].cpu() for v in values]
985
+ if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
986
+ return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
987
+ elif key == "past_key_values":
988
+ if not is_shortform:
989
+ # we don't save `past_key_values` as this is too costly for longform
990
+ return None
991
+ elif isinstance(values, EncoderDecoderCache):
992
+ all_past_key_values = []
993
+ for layer_idx in range(self.config.decoder_layers):
994
+ layer_past_key_values = []
995
+ for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
996
+ for v in [cache_cls.key_cache, cache_cls.value_cache]:
997
+ layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu())
998
+ all_past_key_values.append(tuple(layer_past_key_values))
999
+ return tuple(all_past_key_values)
1000
+ else:
1001
+ all_past_key_values = []
1002
+ for v in range(len(values)):
1003
+ layer_past_key_values = []
1004
+ for w in values[v]:
1005
+ if len(w) != 0:
1006
+ layer_past_key_values.append(w[batch_idx][None].cpu())
1007
+ else:
1008
+ layer_past_key_values.append(w)
1009
+ all_past_key_values.append(tuple(layer_past_key_values))
1010
+ return tuple(all_past_key_values)
1011
+
1012
+ return values[batch_idx].cpu()
1013
+
1014
+ sequence_tokens = seek_outputs["sequences"]
1015
+ seek_outputs = [
1016
+ {
1017
+ k: split_by_batch_index(v, k, i, is_shortform, beam_indices=seek_outputs.get("beam_indices"))
1018
+ for k, v in seek_outputs.items()
1019
+ }
1020
+ for i in range(sequence_tokens.shape[0])
1021
+ ]
1022
+
1023
+ return sequence_tokens, seek_outputs
1024
+
1025
+ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs):
1026
+ # Stack back seek_outputs tensors after splitting them with the split_by_batch_index method
1027
+ outputs = {}
1028
+ for key in seek_outputs[0].keys():
1029
+ if key in ["sequences", "beam_indices"]:
1030
+ outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
1031
+ elif key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
1032
+ outputs[key] = tuple(
1033
+ torch.stack([v[key][i] for v in seek_outputs]).to(device) for i in range(len(seek_outputs[0][key]))
1034
+ )
1035
+ elif key == "sequences_scores":
1036
+ outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
1037
+ elif key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
1038
+ outputs[key] = tuple(
1039
+ tuple(
1040
+ torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
1041
+ for j in range(len(seek_outputs[0][key][0]))
1042
+ )
1043
+ for i in range(len(seek_outputs[0][key]))
1044
+ )
1045
+ elif key == "past_key_values":
1046
+ past_key_value_type = kwargs.get("past_key_values")
1047
+ if seek_outputs[0][key] is not None:
1048
+ outputs[key] = tuple(
1049
+ tuple(
1050
+ torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
1051
+ for j in range(len(seek_outputs[0][key][0]))
1052
+ )
1053
+ for i in range(len(seek_outputs[0][key]))
1054
+ )
1055
+ if past_key_value_type is not None and isinstance(past_key_value_type, EncoderDecoderCache):
1056
+ outputs[key] = past_key_value_type.from_legacy_cache(outputs[key])
1057
+ else:
1058
+ outputs[key] = None
1059
+
1060
+ return model_output_type(**outputs)
1061
+
1062
+ def _need_fallback(
1063
+ self,
1064
+ seek_sequence,
1065
+ seek_outputs,
1066
+ index,
1067
+ logits_processor,
1068
+ generation_config,
1069
+ vocab_size,
1070
+ temperature,
1071
+ ):
1072
+ needs_fallback = False
1073
+ should_skip = False
1074
+ if generation_config.compression_ratio_threshold is not None:
1075
+ compression_ratio = self._retrieve_compression_ratio(seek_sequence, vocab_size)
1076
+
1077
+ if compression_ratio > generation_config.compression_ratio_threshold:
1078
+ needs_fallback = True
1079
+
1080
+ if generation_config.logprob_threshold is not None:
1081
+ if hasattr(seek_outputs[0], "sequences_scores"):
1082
+ logprobs = [s["sequences_scores"] for s in seek_outputs][index]
1083
+ else:
1084
+ scores = seek_outputs[index]["scores"]
1085
+ logprobs = self._retrieve_avg_logprobs(
1086
+ scores, seek_sequence, generation_config.eos_token_id, temperature
1087
+ )
1088
+
1089
+ if logprobs < generation_config.logprob_threshold:
1090
+ needs_fallback = True
1091
+
1092
+ if generation_config.no_speech_threshold is not None:
1093
+ no_speech_prob = _get_attr_from_logit_processors(
1094
+ logits_processor, WhisperNoSpeechDetection, "no_speech_prob"
1095
+ )
1096
+
1097
+ if (
1098
+ logprobs < generation_config.logprob_threshold
1099
+ and no_speech_prob[index] > generation_config.no_speech_threshold
1100
+ ):
1101
+ needs_fallback = False
1102
+ should_skip = True
1103
+
1104
+ return needs_fallback, should_skip
1105
+
1106
+ def _expand_variables_for_generation(
1107
+ self, input_features, seek, max_frames, init_tokens, batch_size, condition_on_prev_tokens, generation_config
1108
+ ):
1109
+ if generation_config.num_return_sequences is not None and generation_config.num_return_sequences > 1:
1110
+ batch_idx_map = list(range(batch_size * generation_config.num_return_sequences))
1111
+ cur_bsz = len(batch_idx_map)
1112
+ do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(len(batch_idx_map))]
1113
+ input_features = input_features.repeat_interleave(generation_config.num_return_sequences, dim=0)
1114
+ seek = seek.repeat_interleave(generation_config.num_return_sequences, dim=0)
1115
+ max_frames = max_frames.repeat_interleave(generation_config.num_return_sequences, dim=0)
1116
+ init_tokens = init_tokens.repeat_interleave(generation_config.num_return_sequences, dim=0)
1117
+ generation_config.num_return_sequences = 1
1118
+ else:
1119
+ cur_bsz = batch_size
1120
+ batch_idx_map = list(range(cur_bsz))
1121
+ do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(cur_bsz)]
1122
+
1123
+ return (
1124
+ batch_idx_map,
1125
+ cur_bsz,
1126
+ input_features,
1127
+ seek,
1128
+ max_frames,
1129
+ init_tokens,
1130
+ do_condition_on_prev_tokens,
1131
+ )
1132
+
1133
+ @staticmethod
1134
+ def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs):
1135
+ set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs")
1136
+ extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)}
1137
+ set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs})
1138
+
1139
+ @staticmethod
1140
+ def _retrieve_total_input_frames(input_features, input_stride, kwargs):
1141
+ if input_features is not None:
1142
+ return input_features.shape[0], input_features.shape[-1]
1143
+
1144
+ if "encoder_outputs" in kwargs:
1145
+ encoder_outputs_shape = (
1146
+ kwargs["encoder_outputs"][0].shape
1147
+ if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
1148
+ else kwargs["encoder_outputs"].shape
1149
+ )
1150
+ return encoder_outputs_shape[0], encoder_outputs_shape[1] * input_stride
1151
+
1152
+ raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
1153
+
1154
+ @staticmethod
1155
+ def _maybe_warn_unused_inputs(
1156
+ condition_on_prev_tokens,
1157
+ temperature,
1158
+ compression_ratio_threshold,
1159
+ logprob_threshold,
1160
+ no_speech_threshold,
1161
+ total_input_frames,
1162
+ ):
1163
+ warning_prefix = (
1164
+ f"Audio input consists of only {total_input_frames}. "
1165
+ "Short-form transcription is activated."
1166
+ "{}, but will be ignored."
1167
+ )
1168
+ if condition_on_prev_tokens is not None:
1169
+ logger.warning(warning_prefix.format(f"condition_on_prev_tokens is set to {condition_on_prev_tokens}"))
1170
+
1171
+ if compression_ratio_threshold is not None:
1172
+ logger.warning(
1173
+ warning_prefix.format(f"compression_ratio_threshold is set to {compression_ratio_threshold}")
1174
+ )
1175
+
1176
+ if logprob_threshold is not None:
1177
+ logger.warning(warning_prefix.format(f"logprob_threshold is set to {logprob_threshold}"))
1178
+
1179
+ if no_speech_threshold is not None:
1180
+ logger.warning(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}"))
1181
+
1182
+ # when passing temperature as a list it cannot just be ignored => throw error in this case
1183
+ if isinstance(temperature, (list, tuple)):
1184
+ raise ValueError(
1185
+ f"Audio input consists of only {total_input_frames}. Short-form transcription is activated."
1186
+ f"temperature cannot be set to {temperature} which can only be used for temperature fallback for long-form generation. Make sure to set `temperature` to a float value or `None` for short-form generation."
1187
+ )
1188
+
1189
+ @staticmethod
1190
+ def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
1191
+ if return_dict_in_generate is None:
1192
+ return_dict_in_generate = generation_config.return_dict_in_generate
1193
+ else:
1194
+ generation_config.return_dict_in_generate = return_dict_in_generate
1195
+
1196
+ generation_config.return_token_timestamps = return_token_timestamps
1197
+ if return_token_timestamps:
1198
+ generation_config.return_dict_in_generate = True
1199
+ generation_config.output_attentions = True
1200
+ generation_config.output_scores = True
1201
+
1202
+ if logprob_threshold is not None:
1203
+ generation_config.return_dict_in_generate = True
1204
+ generation_config.output_scores = True
1205
+
1206
+ return return_dict_in_generate
1207
+
1208
+ def _set_return_timestamps(self, return_timestamps, is_shortform, generation_config):
1209
+ if return_timestamps is None and hasattr(generation_config, "return_timestamps"):
1210
+ return_timestamps = generation_config.return_timestamps
1211
+
1212
+ if not is_shortform:
1213
+ if return_timestamps is False:
1214
+ raise ValueError(
1215
+ "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
1216
+ "requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
1217
+ )
1218
+
1219
+ logger.info("Setting `return_timestamps=True` for long-form generation.")
1220
+ return_timestamps = True
1221
+
1222
+ if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
1223
+ raise ValueError(
1224
+ "You are trying to return timestamps, but the generation config is not properly set. "
1225
+ "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
1226
+ "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
1227
+ )
1228
+
1229
+ generation_config.return_timestamps = return_timestamps
1230
+
1231
+ if hasattr(generation_config, "no_timestamps_token_id"):
1232
+ timestamp_begin = generation_config.no_timestamps_token_id + 1
1233
+ else:
1234
+ # BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form with no timestamps
1235
+ # We set the timestamp begin token larger than the vocab size, such that the timestamp condition is never met in the decoding loop
1236
+ timestamp_begin = self.config.vocab_size + 1
1237
+
1238
+ return timestamp_begin
1239
+
1240
+ @staticmethod
1241
+ def _set_language_and_task(language, task, is_multilingual, generation_config):
1242
+ if is_multilingual is not None:
1243
+ if not hasattr(generation_config, "is_multilingual"):
1244
+ raise ValueError(
1245
+ "The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
1246
+ "to `generate`. Please update the generation config as per the instructions "
1247
+ "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
1248
+ )
1249
+ generation_config.is_multilingual = is_multilingual
1250
+
1251
+ if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
1252
+ if task is not None or language is not None:
1253
+ raise ValueError(
1254
+ "Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
1255
+ "multilingual, pass `is_multilingual=True` to generate, or update the generation config."
1256
+ )
1257
+
1258
+ if language is not None:
1259
+ if not hasattr(generation_config, "lang_to_id"):
1260
+ raise ValueError(
1261
+ "The generation config is outdated and is thus not compatible with the `language` argument "
1262
+ "to `generate`. Either set the language using the `forced_decoder_ids` in the model config, "
1263
+ "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
1264
+ )
1265
+ generation_config.language = language
1266
+
1267
+ if task is not None:
1268
+ if not hasattr(generation_config, "task_to_id"):
1269
+ raise ValueError(
1270
+ "The generation config is outdated and is thus not compatible with the `task` argument "
1271
+ "to `generate`. Either set the task using the `forced_decoder_ids` in the model config, "
1272
+ "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
1273
+ )
1274
+ generation_config.task = task
1275
+
1276
+ def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
1277
+ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
1278
+ """short function to replace num with a itr in lst"""
1279
+ found = any(i in lst for i in itr)
1280
+ if found:
1281
+ lst = [num if i in itr else i for i in lst]
1282
+ else:
1283
+ lst.append(num)
1284
+ return lst
1285
+
1286
+ def language_to_id(language: str) -> int:
1287
+ language = language.lower()
1288
+ if language in generation_config.lang_to_id.keys():
1289
+ language_token = language
1290
+ elif language in TO_LANGUAGE_CODE.keys():
1291
+ language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
1292
+ elif language in TO_LANGUAGE_CODE.values():
1293
+ language_token = f"<|{language}|>"
1294
+ else:
1295
+ is_language_code = len(language) == 2
1296
+ raise ValueError(
1297
+ f"Unsupported language: {language}. Language should be one of:"
1298
+ f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
1299
+ )
1300
+ if language_token not in generation_config.lang_to_id:
1301
+ raise ValueError(
1302
+ f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
1303
+ "(You should just add it to the generation config)"
1304
+ )
1305
+
1306
+ return generation_config.lang_to_id[language_token]
1307
+
1308
+ task = getattr(generation_config, "task", None)
1309
+ language = getattr(generation_config, "language", None)
1310
+
1311
+ forced_decoder_ids = generation_config.forced_decoder_ids
1312
+ if forced_decoder_ids is not None:
1313
+ if language is None and task is None and forced_decoder_ids[0][1] is None:
1314
+ logger.warning_once(
1315
+ "Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
1316
+ "This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
1317
+ )
1318
+ elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
1319
+ forced_decoder_ids = config.forced_decoder_ids
1320
+
1321
+ if forced_decoder_ids is not None and task is not None:
1322
+ logger.warning_once(
1323
+ f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}."
1324
+ )
1325
+ forced_decoder_ids = None
1326
+ elif forced_decoder_ids is not None and language is not None:
1327
+ logger.warning_once(
1328
+ f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
1329
+ )
1330
+ forced_decoder_ids = None
1331
+
1332
+ init_tokens = [generation_config.decoder_start_token_id]
1333
+ if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
1334
+ i = 1
1335
+ while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
1336
+ init_tokens += [forced_decoder_ids[0][1]]
1337
+ forced_decoder_ids = forced_decoder_ids[1:]
1338
+ i += 1
1339
+
1340
+ if len(forced_decoder_ids) > 0:
1341
+ raise ValueError(
1342
+ f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.",
1343
+ )
1344
+
1345
+ # from v4.39 the forced decoder ids are always None in favour of decoder input ids
1346
+ generation_config.forced_decoder_ids = None
1347
+
1348
+ is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
1349
+
1350
+ # Make sure language is a list of strings of the correct length
1351
+ if isinstance(language, (list, tuple)):
1352
+ if any(l is None for l in language):
1353
+ raise TypeError(
1354
+ "Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list containing `None`."
1355
+ )
1356
+ if len(language) != batch_size:
1357
+ raise ValueError(
1358
+ "When passing a list of languages, the length of the list must match the batch size. "
1359
+ f"Expected length of {batch_size}, but got {len(language)} languages."
1360
+ )
1361
+ languages = language
1362
+ elif language is None:
1363
+ # Language will be detected for each item in batch
1364
+ languages = [None] * batch_size
1365
+ else:
1366
+ languages = [language] # Use a length-1 list now, broadcast later
1367
+
1368
+ # Separate init_tokens for each language
1369
+ init_tokens = [copy.copy(init_tokens) for _ in languages]
1370
+
1371
+ # Update init_tokens with languages
1372
+ lang_ids = None
1373
+ if language is not None:
1374
+ lang_ids = [language_to_id(l) for l in languages]
1375
+ elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined:
1376
+ # language is not defined or intentially set to `None` to trigger language detection
1377
+ lang_ids = self.detect_language(
1378
+ input_features=input_features,
1379
+ encoder_outputs=kwargs.get("encoder_outputs", None),
1380
+ generation_config=generation_config,
1381
+ num_segment_frames=num_segment_frames,
1382
+ ).tolist()
1383
+ if lang_ids is not None:
1384
+ # append or replace lang_ids to init_tokens
1385
+ for i in range(len(init_tokens)):
1386
+ if len(init_tokens[i]) > 1:
1387
+ init_tokens[i][1] = lang_ids[i]
1388
+ else:
1389
+ init_tokens[i].append(lang_ids[i])
1390
+ del languages
1391
+
1392
+ # Update init_tokens with task
1393
+ for i in range(len(init_tokens)):
1394
+ if task is not None:
1395
+ if task in TASK_IDS:
1396
+ init_tokens[i].append(generation_config.task_to_id[generation_config.task])
1397
+ task_id = generation_config.task_to_id[generation_config.task]
1398
+
1399
+ # if task is defined it'll overwrite task ids that might have already been defined via the generation_config
1400
+ replace_or_add(init_tokens[i], task_id, generation_config.task_to_id.values())
1401
+ else:
1402
+ raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`")
1403
+ elif language is not None and hasattr(generation_config, "task_to_id"):
1404
+ # if language is defined, but no task id is in `init_tokens`, default to transcribe
1405
+ if not any(ti in init_tokens[i] for ti in generation_config.task_to_id.values()):
1406
+ init_tokens[i].append(generation_config.task_to_id["transcribe"])
1407
+
1408
+ if (
1409
+ not generation_config.return_timestamps
1410
+ and hasattr(generation_config, "no_timestamps_token_id")
1411
+ and init_tokens[i][-1] != generation_config.no_timestamps_token_id
1412
+ ):
1413
+ init_tokens[i].append(generation_config.no_timestamps_token_id)
1414
+ elif (
1415
+ generation_config.return_timestamps and init_tokens[i][-1] == generation_config.no_timestamps_token_id
1416
+ ):
1417
+ logger.info(
1418
+ "<|notimestamps|> prompt token is removed from generation_config since `return_timestamps` is set to `'True'`."
1419
+ )
1420
+ init_tokens[i] = init_tokens[i][:-1]
1421
+
1422
+ # let's make sure we don't pass `None` tokens as prompt tokens
1423
+ init_tokens[i] = [t for t in init_tokens[i] if t is not None]
1424
+
1425
+ return torch.as_tensor(init_tokens, dtype=torch.long, device=self.device).expand(batch_size, -1)
1426
+
1427
+ def detect_language(
1428
+ self,
1429
+ input_features: Optional[torch.FloatTensor] = None,
1430
+ encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
1431
+ generation_config: Optional[GenerationConfig] = None,
1432
+ num_segment_frames: int = 3000,
1433
+ ) -> torch.Tensor:
1434
+ """
1435
+ Detects language from log-mel input features or encoder_outputs
1436
+
1437
+ Parameters:
1438
+ input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
1439
+ Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
1440
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
1441
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
1442
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
1443
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
1444
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1445
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
1446
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
1447
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1448
+ generation_config (`~generation.GenerationConfig`, *optional*):
1449
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
1450
+ passed to generate matching the attributes of `generation_config` will override them. If
1451
+ `generation_config` is not provided, the default will be used, which had the following loading
1452
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
1453
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
1454
+ default values, whose documentation should be checked to parameterize generation.
1455
+ num_segment_frames (`int`, *optional*, defaults to 3000):
1456
+ The number of log-mel frames the model expects
1457
+
1458
+ Return:
1459
+ A `torch.LongTensor` representing the detected language ids.
1460
+ """
1461
+ if input_features is None and encoder_outputs is None:
1462
+ raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
1463
+ elif input_features is not None and encoder_outputs is not None:
1464
+ raise ValueError("Make sure to specificy only one of `input_features` or `encoder_outputs` - not both!")
1465
+ elif input_features is not None:
1466
+ inputs = {"input_features": input_features[:, :, :num_segment_frames]}
1467
+ batch_size = input_features.shape[0]
1468
+ elif encoder_outputs is not None:
1469
+ inputs = {"encoder_outputs": encoder_outputs}
1470
+ batch_size = (
1471
+ encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
1472
+ )
1473
+
1474
+ generation_config = generation_config or self.generation_config
1475
+ decoder_input_ids = (
1476
+ torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
1477
+ * generation_config.decoder_start_token_id
1478
+ )
1479
+
1480
+ with torch.no_grad():
1481
+ logits = self(**inputs, decoder_input_ids=decoder_input_ids).logits[:, -1]
1482
+
1483
+ non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
1484
+ non_lang_mask[list(generation_config.lang_to_id.values())] = False
1485
+
1486
+ logits[:, non_lang_mask] = -np.inf
1487
+
1488
+ lang_ids = logits.argmax(-1)
1489
+
1490
+ return lang_ids
1491
+
1492
+ @staticmethod
1493
+ def _check_decoder_input_ids(kwargs):
1494
+ decoder_input_ids = kwargs.get("decoder_input_ids", None)
1495
+ assistant_model = kwargs.get("assistant_model", None)
1496
+ if decoder_input_ids is not None and assistant_model is not None:
1497
+ raise ValueError(
1498
+ "Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
1499
+ )
1500
+
1501
+ @staticmethod
1502
+ def _set_num_frames(return_token_timestamps, generation_config, kwargs):
1503
+ if return_token_timestamps:
1504
+ if getattr(generation_config, "task", None) == "translate":
1505
+ logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
1506
+ if not hasattr(generation_config, "alignment_heads"):
1507
+ raise ValueError(
1508
+ "Model generation config has no `alignment_heads`, token-level timestamps not available. "
1509
+ "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
1510
+ )
1511
+ generation_config.num_frames = kwargs.pop("num_frames", None)
1512
+
1513
+ @staticmethod
1514
+ def _set_thresholds_and_condition(
1515
+ generation_config,
1516
+ logprob_threshold,
1517
+ compression_ratio_threshold,
1518
+ no_speech_threshold,
1519
+ condition_on_prev_tokens,
1520
+ ):
1521
+ generation_config.logprob_threshold = (
1522
+ logprob_threshold
1523
+ if logprob_threshold is not None
1524
+ else getattr(generation_config, "logprob_threshold", None)
1525
+ )
1526
+ generation_config.compression_ratio_threshold = (
1527
+ compression_ratio_threshold
1528
+ if compression_ratio_threshold is not None
1529
+ else getattr(generation_config, "compression_ratio_threshold", None)
1530
+ )
1531
+ generation_config.no_speech_threshold = (
1532
+ no_speech_threshold
1533
+ if no_speech_threshold is not None
1534
+ else getattr(generation_config, "no_speech_threshold", None)
1535
+ )
1536
+ generation_config.condition_on_prev_tokens = (
1537
+ condition_on_prev_tokens
1538
+ if condition_on_prev_tokens is not None
1539
+ else getattr(generation_config, "condition_on_prev_tokens", None)
1540
+ )
1541
+
1542
+ @staticmethod
1543
+ def _set_prompt_condition_type(generation_config, prompt_condition_type):
1544
+ allowed_cond_types = ["first-segment", "all-segments"]
1545
+
1546
+ # default to "first-segment"
1547
+ prompt_condition_type = prompt_condition_type or allowed_cond_types[0]
1548
+
1549
+ if prompt_condition_type not in allowed_cond_types:
1550
+ raise ValueError(
1551
+ f"`prompt_condition_type={prompt_condition_type} does not exist. Make sure to set `prompt_condition_type` to one of {', '.join(allowed_cond_types)}"
1552
+ )
1553
+
1554
+ if generation_config.condition_on_prev_tokens is not True and prompt_condition_type == "all-segments":
1555
+ raise ValueError(
1556
+ "Make sure to set `condition_on_prev_tokens=True` when setting `prompt_condition_type='all-segments'`."
1557
+ )
1558
+
1559
+ generation_config.prompt_condition_type = prompt_condition_type
1560
+
1561
+ @staticmethod
1562
+ def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config):
1563
+ condition_on_prev_tokens = (
1564
+ condition_on_prev_tokens
1565
+ if condition_on_prev_tokens is not None
1566
+ else getattr(generation_config, "condition_on_prev_tokens", False)
1567
+ )
1568
+ generation_config.condition_on_prev_tokens = condition_on_prev_tokens
1569
+
1570
+ @staticmethod
1571
+ def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames, is_shortform):
1572
+ if batch_size > 1 and not is_shortform and attention_mask is None:
1573
+ raise ValueError(
1574
+ "When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
1575
+ )
1576
+ elif batch_size > 1 and not is_shortform:
1577
+ max_frames = attention_mask.sum(-1).cpu().to(torch.long)
1578
+ seek = torch.zeros((batch_size,), dtype=torch.long)
1579
+ else:
1580
+ max_frames = torch.ones((batch_size,), dtype=torch.long) * total_input_frames
1581
+ seek = torch.zeros((batch_size,), dtype=torch.long)
1582
+
1583
+ return max_frames, seek
1584
+
1585
+ def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, num_beams, device):
1586
+ if generation_config.return_timestamps is True:
1587
+ timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
1588
+ logits_processor = (
1589
+ [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor
1590
+ )
1591
+
1592
+ if generation_config.suppress_tokens is not None:
1593
+ suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
1594
+ logits_processor = (
1595
+ [suppress_tokens_processor]
1596
+ if logits_processor is None
1597
+ else [suppress_tokens_processor] + logits_processor
1598
+ )
1599
+ generation_config.suppress_tokens = None
1600
+
1601
+ if generation_config.begin_suppress_tokens is not None:
1602
+ begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
1603
+ generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
1604
+ )
1605
+ logits_processor = (
1606
+ [begin_suppress_processor]
1607
+ if logits_processor is None
1608
+ else [begin_suppress_processor] + logits_processor
1609
+ )
1610
+ generation_config.begin_suppress_tokens = None
1611
+
1612
+ if generation_config.no_speech_threshold is not None:
1613
+ no_speech_detector = WhisperNoSpeechDetection(
1614
+ no_speech_token=generation_config.no_timestamps_token_id - 1,
1615
+ begin_index=begin_index,
1616
+ scores_is_logprobs=num_beams > 1,
1617
+ )
1618
+ logits_processor = (
1619
+ [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor
1620
+ )
1621
+ no_speech_detector.set_model(self)
1622
+
1623
+ return logits_processor
1624
+
1625
+ @staticmethod
1626
+ def _maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map):
1627
+ prev_bsz = cur_bsz
1628
+ new_batch_idx_map = []
1629
+ for i in range(prev_bsz):
1630
+ prev_i = batch_idx_map[i]
1631
+ if seek[prev_i] >= max_frames[prev_i]:
1632
+ cut_index = i + (cur_bsz - prev_bsz)
1633
+ cur_bsz -= 1
1634
+ input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0)
1635
+ else:
1636
+ # cut out index that goes away
1637
+ new_batch_idx_map.append(prev_i)
1638
+
1639
+ return input_features, cur_bsz, new_batch_idx_map
1640
+
1641
+ @staticmethod
1642
+ def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map):
1643
+ if input_features is None:
1644
+ return None
1645
+
1646
+ segment_input = []
1647
+ for i in range(cur_bsz):
1648
+ prev_i = batch_idx_map[i]
1649
+ segment_input_slice = input_features[i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]]
1650
+
1651
+ if segment_input_slice.shape[-1] < num_segment_frames:
1652
+ # pad to 3000 if necessary
1653
+ segment_input_slice = F.pad(
1654
+ segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
1655
+ )
1656
+
1657
+ segment_input.append(segment_input_slice)
1658
+
1659
+ segment_input = torch.cat(segment_input, dim=0)
1660
+
1661
+ return segment_input
1662
+
1663
+ @staticmethod
1664
+ def _prepare_decoder_input_ids(
1665
+ cur_bsz,
1666
+ init_tokens,
1667
+ current_segments,
1668
+ batch_idx_map,
1669
+ do_condition_on_prev_tokens,
1670
+ prompt_ids,
1671
+ generation_config,
1672
+ config,
1673
+ device,
1674
+ suppress_tokens,
1675
+ timestamp_begin,
1676
+ kwargs,
1677
+ ):
1678
+ if "decoder_input_ids" in kwargs:
1679
+ decoder_input_ids = kwargs.pop("decoder_input_ids")
1680
+
1681
+ return decoder_input_ids, kwargs
1682
+
1683
+ cut_off_length = config.max_target_positions // 2 - 1
1684
+
1685
+ decoder_input_ids = init_tokens[batch_idx_map]
1686
+
1687
+ prev_start_of_text = getattr(generation_config, "prev_sot_token_id", None)
1688
+ if prev_start_of_text is None:
1689
+ prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None
1690
+
1691
+ if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0:
1692
+ # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
1693
+ active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
1694
+
1695
+ for segments in active_segments:
1696
+ for seg in segments:
1697
+ if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin:
1698
+ # the segment finishes with two timestamp tokens
1699
+ # we need to ignore the last timestamp token
1700
+ # see https://github.com/huggingface/transformers/pull/34537
1701
+ seg["tokens"] = seg["tokens"][:-1]
1702
+
1703
+ if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
1704
+ prev_ids = prompt_ids
1705
+ else:
1706
+ one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
1707
+ prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None
1708
+
1709
+ padding = "max_length" if generation_config.cache_implementation == "static" else "longest"
1710
+
1711
+ prev_tokens = _pad_to_max_length(
1712
+ active_segments,
1713
+ generation_config.pad_token_id,
1714
+ device=device,
1715
+ padding_side="left",
1716
+ padding=padding,
1717
+ bos_token_tensor=prev_ids,
1718
+ cut_off_length=cut_off_length,
1719
+ )
1720
+ decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
1721
+
1722
+ kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id
1723
+ elif prompt_ids is not None:
1724
+ prev_tokens = prompt_ids[None].repeat(decoder_input_ids.shape[0], 1)
1725
+ decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
1726
+ # make sure `"decoder_attention_mask"` is not passed to forward
1727
+ kwargs.pop("decoder_attention_mask", None)
1728
+ else:
1729
+ # make sure `"decoder_attention_mask"` is not passed to forward
1730
+ kwargs.pop("decoder_attention_mask", None)
1731
+
1732
+ return decoder_input_ids, kwargs
1733
+
1734
+ def _set_max_new_tokens_and_length(self, config, decoder_input_ids, generation_config):
1735
+ max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
1736
+ if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
1737
+ raise ValueError(
1738
+ f"The length of `decoder_input_ids`, including special start tokens, prompt tokens, and previous tokens, is {decoder_input_ids.shape[-1]}, "
1739
+ f" and `max_new_tokens` is {max_new_tokens}. Thus, the combined length of "
1740
+ f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
1741
+ f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
1742
+ "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
1743
+ f"so that their combined length is less than {self.config.max_target_positions}."
1744
+ )
1745
+
1746
+ num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)
1747
+
1748
+ # Make sure we don't get larger than `max_length`
1749
+ if generation_config.max_length is not None and generation_config.max_new_tokens is None:
1750
+ max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
1751
+ logger.info(
1752
+ f"Increase max_length from {generation_config.max_length} to {max_length} since input is conditioned on previous segment."
1753
+ )
1754
+ elif (
1755
+ generation_config.max_new_tokens is not None
1756
+ and generation_config.max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
1757
+ ):
1758
+ max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
1759
+ generation_config.max_new_tokens = max_new_tokens
1760
+
1761
+ @staticmethod
1762
+ def _retrieve_compression_ratio(tokens, vocab_size):
1763
+ """Compute byte length of zlib compressed token bytes vs. byte length of raw token bytes"""
1764
+ length = int(math.log2(vocab_size) / 8) + 1
1765
+ token_bytes = b"".join([t.to_bytes(length, "little") for t in tokens.tolist()])
1766
+ compression_ratio = len(token_bytes) / len(zlib.compress(token_bytes))
1767
+
1768
+ return compression_ratio
1769
+
1770
+ @staticmethod
1771
+ def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature):
1772
+ rescale_temperature = temperature if temperature > 0.0 else 1
1773
+ scores = torch.stack(scores).to(tokens.device)
1774
+
1775
+ if scores.shape[0] > tokens.shape[0]:
1776
+ scores = scores[: tokens.shape[0]]
1777
+ else:
1778
+ tokens = tokens[-scores.shape[0] :]
1779
+
1780
+ logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype)
1781
+
1782
+ # retrieve logprob of selected tokens and sum
1783
+ sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0]))
1784
+ length = (tokens != eos_token_id).sum(-1) if eos_token_id is not None else tokens.shape[0]
1785
+
1786
+ avg_logprobs = sum_logprobs / (length + 1)
1787
+ return avg_logprobs
1788
+
1789
+ @staticmethod
1790
+ def _retrieve_segment(
1791
+ seek_sequence,
1792
+ seek_outputs,
1793
+ time_offset,
1794
+ timestamp_begin,
1795
+ seek_num_frames,
1796
+ time_precision,
1797
+ time_precision_features,
1798
+ input_stride,
1799
+ prev_idx,
1800
+ idx,
1801
+ return_token_timestamps,
1802
+ ):
1803
+ # find the predicted "end of segment" predictions of Whisper
1804
+ # "end of segment" predictions occur whenever Whisper predicts a timestamp token
1805
+ timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
1806
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
1807
+ timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
1808
+ timestamp_segment_indices.add_(1)
1809
+ token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
1810
+ device = seek_sequence.device
1811
+
1812
+ # If whisper predicted a "end of segment" via a timestep token, let's go ever each
1813
+ # "end of segment" prediction and slice the decoding into segments accordingly
1814
+ if len(timestamp_segment_indices) > 0:
1815
+ # if the output contains two consecutive timestamp tokens
1816
+ slices = timestamp_segment_indices.tolist()
1817
+ segments = []
1818
+ if single_timestamp_ending:
1819
+ slices.append(len(seek_sequence))
1820
+ else:
1821
+ # we want to include the last timestamp token in the last segment to know it was no single ending
1822
+ slices[-1] += 1
1823
+
1824
+ last_slice = 0
1825
+ # Add each segment to list of all segments
1826
+ for i, current_slice in enumerate(slices):
1827
+ is_last_slice = i == len(slices) - 1
1828
+ sliced_tokens = seek_sequence[last_slice:current_slice]
1829
+ start_timestamp_pos = sliced_tokens[0] - timestamp_begin
1830
+ idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2
1831
+ end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
1832
+ segments.append(
1833
+ {
1834
+ "start": time_offset[prev_idx]
1835
+ + start_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
1836
+ * time_precision,
1837
+ "end": time_offset[prev_idx]
1838
+ + end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
1839
+ * time_precision,
1840
+ "tokens": sliced_tokens,
1841
+ "result": seek_outputs[idx],
1842
+ }
1843
+ )
1844
+ if return_token_timestamps:
1845
+ segments[-1]["token_timestamps"] = (
1846
+ token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
1847
+ )
1848
+ last_slice = current_slice
1849
+
1850
+ if single_timestamp_ending:
1851
+ # single timestamp at the end means no speech after the last timestamp.
1852
+ segment_offset = seek_num_frames[prev_idx]
1853
+ else:
1854
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
1855
+ # here we throw away all predictions after the last predicted "end of segment"
1856
+ # since we are cutting right in the middle of an audio
1857
+ last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin
1858
+ segment_offset = last_timestamp_pos * input_stride
1859
+ else:
1860
+ # If whisper does not predict any "end of segment" token, then
1861
+ # the whole decoding is considered a segment and we add it to the list of segments
1862
+ timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
1863
+ last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
1864
+ if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin:
1865
+ # no consecutive timestamps but it has a timestamp; use the last one.
1866
+ last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(
1867
+ torch.float32 if device.type == "mps" else torch.float64
1868
+ )
1869
+ segments = [
1870
+ {
1871
+ "start": time_offset[prev_idx],
1872
+ "end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
1873
+ "tokens": seek_sequence,
1874
+ "result": seek_outputs[idx],
1875
+ }
1876
+ ]
1877
+ if return_token_timestamps:
1878
+ segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
1879
+ segment_offset = seek_num_frames[prev_idx]
1880
+
1881
+ return segments, segment_offset
modeling_flax_whisper (1).py ADDED
@@ -0,0 +1,1696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Flax whisper model."""
16
+
17
+ import math
18
+ import random
19
+ from functools import partial
20
+ from typing import Optional, Tuple
21
+
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
26
+ from flax.linen import combine_masks, make_causal_mask
27
+ from flax.linen import partitioning as nn_partitioning
28
+ from flax.linen.attention import dot_product_attention_weights
29
+ from flax.traverse_util import flatten_dict, unflatten_dict
30
+ from jax import lax
31
+ from jax.random import PRNGKey
32
+
33
+ from ...generation.flax_logits_process import FlaxWhisperTimeStampLogitsProcessor
34
+ from ...modeling_flax_outputs import (
35
+ FlaxBaseModelOutput,
36
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
37
+ FlaxCausalLMOutputWithCrossAttentions,
38
+ FlaxSeq2SeqLMOutput,
39
+ FlaxSeq2SeqModelOutput,
40
+ FlaxSequenceClassifierOutput,
41
+ )
42
+ from ...modeling_flax_utils import (
43
+ ACT2FN,
44
+ FlaxPreTrainedModel,
45
+ append_call_sample_docstring,
46
+ append_replace_return_docstrings,
47
+ overwrite_call_docstring,
48
+ )
49
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
50
+ from .configuration_whisper import WhisperConfig
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+
56
+ _CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
57
+ _CONFIG_FOR_DOC = "WhisperConfig"
58
+
59
+ remat = nn_partitioning.remat
60
+
61
+
62
+ def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array:
63
+ """Returns sinusoids for positional embedding"""
64
+ length, channels = shape
65
+ if channels % 2 != 0:
66
+ raise ValueError(
67
+ f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
68
+ )
69
+ log_timescale_increment = math.log(10000) / (channels // 2 - 1)
70
+ inv_timescales = jnp.exp(-log_timescale_increment * jnp.arange(channels // 2))
71
+ scaled_time = jnp.arange(length).reshape(-1, 1) * inv_timescales.reshape(1, -1)
72
+ return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype)
73
+
74
+
75
+ WHISPER_START_DOCSTRING = r"""
76
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
77
+ library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
78
+ etc.) This model is also a Flax Linen
79
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
80
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
81
+ Finally, this model supports inherent JAX features such as:
82
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
83
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
84
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
85
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
86
+
87
+ Parameters:
88
+ config ([`WhisperConfig`]): Model configuration class with all the parameters of the model.
89
+ Initializing with a config file does not load the weights associated with the model, only the
90
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
91
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
92
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
93
+ `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision
94
+ inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`.
95
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
96
+ parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`]
97
+ and [`~FlaxPreTrainedModel.to_bf16`].
98
+ """
99
+
100
+ WHISPER_INPUTS_DOCSTRING = r"""
101
+ Args:
102
+ input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
103
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
104
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
105
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
106
+ [`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a
107
+ tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]
108
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
109
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
110
+ is not used. By default the silence in the input log mel spectrogram are ignored.
111
+ decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
112
+ Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
113
+ [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
114
+ [What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as
115
+ the starting token for `decoder_input_ids` generation.
116
+ decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
117
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
118
+ be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
119
+ in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
120
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
121
+ Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't
122
+ use masking, but this argument is preserved for compatibility. By default the silence in the input log mel
123
+ spectrogram are ignored.
124
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
125
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
126
+ range `[0, config.max_position_embeddings - 1]`.
127
+ output_attentions (`bool`, *optional*):
128
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
129
+ tensors for more detail.
130
+ output_hidden_states (`bool`, *optional*):
131
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
132
+ more detail.
133
+ return_dict (`bool`, *optional*):
134
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
135
+ """
136
+
137
+ WHISPER_ENCODE_INPUTS_DOCSTRING = r"""
138
+ Args:
139
+ input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
140
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
141
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
142
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
143
+ [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
144
+ tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`].
145
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
146
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
147
+ is not used. By default the silence in the input log mel spectrogram are ignored.
148
+ output_attentions (`bool`, *optional*):
149
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
150
+ tensors for more detail.
151
+ output_hidden_states (`bool`, *optional*):
152
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
153
+ more detail.
154
+ return_dict (`bool`, *optional*):
155
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
156
+ """
157
+
158
+ WHISPER_DECODE_INPUTS_DOCSTRING = r"""
159
+ Args:
160
+ decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`):
161
+ Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
162
+ [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
163
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
164
+ encoder_outputs (`tuple(tuple(numpy.ndarray)`):
165
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
166
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
167
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
168
+ encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
169
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
170
+ but it is not used. By default the silence in the input log mel spectrogram are ignored.
171
+ decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
172
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
173
+ be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
174
+ in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
175
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
176
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
177
+ range `[0, config.max_position_embeddings - 1]`.
178
+ past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
179
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
180
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
181
+ output_attentions (`bool`, *optional*):
182
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
183
+ tensors for more detail.
184
+ output_hidden_states (`bool`, *optional*):
185
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
186
+ more detail.
187
+ return_dict (`bool`, *optional*):
188
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
189
+ """
190
+
191
+
192
+ class FlaxWhisperAttention(nn.Module):
193
+ config: WhisperConfig
194
+ embed_dim: int
195
+ num_heads: int
196
+ dropout: float = 0.0
197
+ causal: bool = False
198
+ bias: bool = True
199
+ dtype: jnp.dtype = jnp.float32
200
+
201
+ def setup(self) -> None:
202
+ self.head_dim = self.embed_dim // self.num_heads
203
+ if self.head_dim * self.num_heads != self.embed_dim:
204
+ raise ValueError(
205
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
206
+ f" and `num_heads`: {self.num_heads})."
207
+ )
208
+
209
+ dense = partial(
210
+ nn.Dense,
211
+ self.embed_dim,
212
+ dtype=self.dtype,
213
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
214
+ )
215
+
216
+ self.q_proj = dense(use_bias=self.bias)
217
+ self.k_proj = dense(use_bias=False)
218
+ self.v_proj = dense(use_bias=self.bias)
219
+ self.out_proj = dense(use_bias=self.bias)
220
+
221
+ if self.causal:
222
+ self.causal_mask = make_causal_mask(
223
+ jnp.ones((1, self.config.max_target_positions), dtype="bool"), dtype="bool"
224
+ )
225
+
226
+ def __call__(
227
+ self,
228
+ hidden_states: jnp.ndarray,
229
+ key_value_states: Optional[jnp.ndarray] = None,
230
+ attention_mask: Optional[jnp.ndarray] = None,
231
+ init_cache: bool = False,
232
+ deterministic: bool = True,
233
+ ) -> Tuple[jnp.ndarray]:
234
+ is_cross_attention = key_value_states is not None
235
+ batch_size = hidden_states.shape[0]
236
+
237
+ query_states = self.q_proj(hidden_states)
238
+
239
+ if is_cross_attention:
240
+ key_states = self.k_proj(key_value_states)
241
+ value_states = self.v_proj(key_value_states)
242
+ else:
243
+ key_states = self.k_proj(hidden_states)
244
+ value_states = self.v_proj(hidden_states)
245
+
246
+ query_states = self._split_heads(query_states)
247
+ key_states = self._split_heads(key_states)
248
+ value_states = self._split_heads(value_states)
249
+
250
+ if self.causal:
251
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
252
+ if self.has_variable("cache", "cached_key"):
253
+ mask_shift = self.variables["cache"]["cache_index"]
254
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
255
+ causal_mask = lax.dynamic_slice(
256
+ self.causal_mask,
257
+ (0, 0, mask_shift, 0),
258
+ (1, 1, query_length, max_decoder_length),
259
+ )
260
+ else:
261
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
262
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
263
+
264
+ # combine masks if needed
265
+ if attention_mask is not None and self.causal:
266
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
267
+ attention_mask = combine_masks(attention_mask, causal_mask)
268
+ elif self.causal:
269
+ attention_mask = causal_mask
270
+ elif attention_mask is not None:
271
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
272
+
273
+ # During fast autoregressive decoding, we feed one position at a time,
274
+ # and cache the keys and values step by step.
275
+
276
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
277
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
278
+ key_states, value_states, query_states, attention_mask
279
+ )
280
+
281
+ # Convert the boolean attention mask to an attention bias.
282
+ if attention_mask is not None:
283
+ # attention mask in the form of attention bias
284
+ attention_bias = lax.select(
285
+ attention_mask > 0,
286
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
287
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
288
+ )
289
+ else:
290
+ attention_bias = None
291
+
292
+ dropout_rng = None
293
+ if not deterministic and self.dropout > 0.0:
294
+ dropout_rng = self.make_rng("dropout")
295
+
296
+ attn_weights = dot_product_attention_weights(
297
+ query_states,
298
+ key_states,
299
+ bias=attention_bias,
300
+ dropout_rng=dropout_rng,
301
+ dropout_rate=self.dropout,
302
+ broadcast_dropout=True,
303
+ deterministic=deterministic,
304
+ dtype=self.dtype,
305
+ precision=None,
306
+ )
307
+
308
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
309
+ attn_output = self._merge_heads(attn_output)
310
+ attn_output = self.out_proj(attn_output)
311
+
312
+ return attn_output, attn_weights
313
+
314
+ def _split_heads(self, hidden_state) -> jnp.ndarray:
315
+ return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim))
316
+
317
+ def _merge_heads(self, hidden_state) -> jnp.ndarray:
318
+ return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,))
319
+
320
+ @nn.compact
321
+ def _concatenate_to_cache(self, key, value, query, attention_mask) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
322
+ # detect if we're initializing by absence of existing cache data.
323
+ is_initialized = self.has_variable("cache", "cached_key")
324
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
325
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
326
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
327
+
328
+ if is_initialized:
329
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
330
+ # update key, value caches with our new 1d spatial slices
331
+ cur_index = cache_index.value
332
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
333
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
334
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
335
+ cached_key.value = key
336
+ cached_value.value = value
337
+ num_updated_cache_vectors = query.shape[1]
338
+ cache_index.value = cache_index.value + num_updated_cache_vectors
339
+ # causal mask for cached decoder self-attention: our single query position should only
340
+ # attend to those key positions that have already been generated and cached, not the
341
+ # remaining zero elements.
342
+ pad_mask = jnp.broadcast_to(
343
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
344
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
345
+ )
346
+ attention_mask = combine_masks(pad_mask, attention_mask)
347
+
348
+ return key, value, attention_mask
349
+
350
+
351
+ # Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Whisper
352
+ class FlaxWhisperEncoderLayer(nn.Module):
353
+ config: WhisperConfig
354
+ dtype: jnp.dtype = jnp.float32
355
+
356
+ def setup(self) -> None:
357
+ self.embed_dim = self.config.d_model
358
+ self.self_attn = FlaxWhisperAttention(
359
+ config=self.config,
360
+ embed_dim=self.embed_dim,
361
+ num_heads=self.config.encoder_attention_heads,
362
+ dropout=self.config.attention_dropout,
363
+ dtype=self.dtype,
364
+ )
365
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
366
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
367
+ self.activation_fn = ACT2FN[self.config.activation_function]
368
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
369
+ self.fc1 = nn.Dense(
370
+ self.config.encoder_ffn_dim,
371
+ dtype=self.dtype,
372
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
373
+ )
374
+ self.fc2 = nn.Dense(
375
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
376
+ )
377
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
378
+
379
+ def __call__(
380
+ self,
381
+ hidden_states: jnp.ndarray,
382
+ attention_mask: jnp.ndarray,
383
+ output_attentions: bool = True,
384
+ deterministic: bool = True,
385
+ ) -> Tuple[jnp.ndarray]:
386
+ residual = hidden_states
387
+ hidden_states = self.self_attn_layer_norm(hidden_states)
388
+ hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
389
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
390
+ hidden_states = residual + hidden_states
391
+
392
+ residual = hidden_states
393
+ hidden_states = self.final_layer_norm(hidden_states)
394
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
395
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
396
+ hidden_states = self.fc2(hidden_states)
397
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
398
+ hidden_states = residual + hidden_states
399
+
400
+ outputs = (hidden_states,)
401
+
402
+ if output_attentions:
403
+ outputs += (attn_weights,)
404
+
405
+ return outputs
406
+
407
+
408
+ class FlaxWhisperEncoderLayerCollection(nn.Module):
409
+ config: WhisperConfig
410
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
411
+ gradient_checkpointing: bool = False
412
+
413
+ def setup(self):
414
+ if self.gradient_checkpointing:
415
+ FlaxWhisperEncoderCheckpointLayer = remat(FlaxWhisperEncoderLayer, static_argnums=(2, 3))
416
+ self.layers = [
417
+ FlaxWhisperEncoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
418
+ for i in range(self.config.encoder_layers)
419
+ ]
420
+ else:
421
+ self.layers = [
422
+ FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
423
+ for i in range(self.config.encoder_layers)
424
+ ]
425
+ self.layerdrop = self.config.encoder_layerdrop
426
+
427
+ def __call__(
428
+ self,
429
+ hidden_states,
430
+ attention_mask,
431
+ deterministic: bool = True,
432
+ output_attentions: bool = False,
433
+ output_hidden_states: bool = False,
434
+ return_dict: bool = True,
435
+ ):
436
+ all_attentions = () if output_attentions else None
437
+ all_hidden_states = () if output_hidden_states else None
438
+
439
+ for encoder_layer in self.layers:
440
+ if output_hidden_states:
441
+ all_hidden_states = all_hidden_states + (hidden_states,)
442
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
443
+ dropout_probability = random.uniform(0, 1)
444
+ if not deterministic and (dropout_probability < self.layerdrop): # skip the layer
445
+ layer_outputs = (None, None)
446
+ else:
447
+ layer_outputs = encoder_layer(
448
+ hidden_states,
449
+ attention_mask,
450
+ output_attentions,
451
+ deterministic,
452
+ )
453
+ hidden_states = layer_outputs[0]
454
+ if output_attentions:
455
+ all_attentions = all_attentions + (layer_outputs[1],)
456
+
457
+ if output_hidden_states:
458
+ all_hidden_states += (hidden_states,)
459
+
460
+ outputs = (hidden_states, all_hidden_states, all_attentions)
461
+
462
+ if not return_dict:
463
+ return tuple(v for v in outputs if v is not None)
464
+
465
+ return FlaxBaseModelOutput(
466
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
467
+ )
468
+
469
+
470
+ # Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Whisper
471
+ class FlaxWhisperDecoderLayer(nn.Module):
472
+ config: WhisperConfig
473
+ dtype: jnp.dtype = jnp.float32
474
+
475
+ def setup(self) -> None:
476
+ self.embed_dim = self.config.d_model
477
+ self.self_attn = FlaxWhisperAttention(
478
+ config=self.config,
479
+ embed_dim=self.embed_dim,
480
+ num_heads=self.config.decoder_attention_heads,
481
+ dropout=self.config.attention_dropout,
482
+ causal=True,
483
+ dtype=self.dtype,
484
+ )
485
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
486
+ self.activation_fn = ACT2FN[self.config.activation_function]
487
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
488
+
489
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
490
+ self.encoder_attn = FlaxWhisperAttention(
491
+ config=self.config,
492
+ embed_dim=self.embed_dim,
493
+ num_heads=self.config.decoder_attention_heads,
494
+ dropout=self.config.attention_dropout,
495
+ dtype=self.dtype,
496
+ )
497
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
498
+ self.fc1 = nn.Dense(
499
+ self.config.decoder_ffn_dim,
500
+ dtype=self.dtype,
501
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
502
+ )
503
+ self.fc2 = nn.Dense(
504
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
505
+ )
506
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
507
+
508
+ def __call__(
509
+ self,
510
+ hidden_states: jnp.ndarray,
511
+ attention_mask: jnp.ndarray,
512
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
513
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
514
+ init_cache: bool = False,
515
+ output_attentions: bool = True,
516
+ deterministic: bool = True,
517
+ ) -> Tuple[jnp.ndarray]:
518
+ residual = hidden_states
519
+ hidden_states = self.self_attn_layer_norm(hidden_states)
520
+
521
+ # Self Attention
522
+ hidden_states, self_attn_weights = self.self_attn(
523
+ hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
524
+ )
525
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
526
+ hidden_states = residual + hidden_states
527
+
528
+ # Cross-Attention Block
529
+ cross_attn_weights = None
530
+ if encoder_hidden_states is not None:
531
+ residual = hidden_states
532
+
533
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
534
+ hidden_states, cross_attn_weights = self.encoder_attn(
535
+ hidden_states=hidden_states,
536
+ key_value_states=encoder_hidden_states,
537
+ attention_mask=encoder_attention_mask,
538
+ )
539
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
540
+ hidden_states = residual + hidden_states
541
+
542
+ # Fully Connected
543
+ residual = hidden_states
544
+ hidden_states = self.final_layer_norm(hidden_states)
545
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
546
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
547
+ hidden_states = self.fc2(hidden_states)
548
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
549
+ hidden_states = residual + hidden_states
550
+
551
+ outputs = (hidden_states,)
552
+
553
+ if output_attentions:
554
+ outputs += (self_attn_weights, cross_attn_weights)
555
+
556
+ return outputs
557
+
558
+
559
+ class FlaxWhisperDecoderLayerCollection(nn.Module):
560
+ config: WhisperConfig
561
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
562
+ gradient_checkpointing: bool = False
563
+
564
+ def setup(self):
565
+ if self.gradient_checkpointing:
566
+ FlaxWhisperDecoderCheckpointLayer = remat(FlaxWhisperDecoderLayer, static_argnums=(4, 5, 6))
567
+ self.layers = [
568
+ FlaxWhisperDecoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
569
+ for i in range(self.config.decoder_layers)
570
+ ]
571
+ else:
572
+ self.layers = [
573
+ FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
574
+ for i in range(self.config.decoder_layers)
575
+ ]
576
+ self.layerdrop = self.config.decoder_layerdrop
577
+
578
+ def __call__(
579
+ self,
580
+ hidden_states,
581
+ attention_mask,
582
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
583
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
584
+ deterministic: bool = True,
585
+ init_cache: bool = False,
586
+ output_attentions: bool = False,
587
+ output_hidden_states: bool = False,
588
+ return_dict: bool = True,
589
+ ):
590
+ # decoder layers
591
+ all_hidden_states = () if output_hidden_states else None
592
+ all_self_attns = () if output_attentions else None
593
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
594
+
595
+ for decoder_layer in self.layers:
596
+ if output_hidden_states:
597
+ all_hidden_states += (hidden_states,)
598
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
599
+ dropout_probability = random.uniform(0, 1)
600
+ if not deterministic and (dropout_probability < self.layerdrop):
601
+ layer_outputs = (None, None, None)
602
+ else:
603
+ layer_outputs = decoder_layer(
604
+ hidden_states,
605
+ attention_mask,
606
+ encoder_hidden_states,
607
+ encoder_attention_mask,
608
+ init_cache,
609
+ output_attentions,
610
+ deterministic,
611
+ )
612
+
613
+ hidden_states = layer_outputs[0]
614
+ if output_attentions:
615
+ all_self_attns += (layer_outputs[1],)
616
+
617
+ if encoder_hidden_states is not None:
618
+ all_cross_attentions += (layer_outputs[2],)
619
+
620
+ # add hidden states from the last decoder layer
621
+ if output_hidden_states:
622
+ all_hidden_states += (hidden_states,)
623
+
624
+ outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
625
+
626
+ if not return_dict:
627
+ return tuple(v for v in outputs if v is not None)
628
+
629
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
630
+ last_hidden_state=hidden_states,
631
+ hidden_states=all_hidden_states,
632
+ attentions=all_self_attns,
633
+ cross_attentions=all_cross_attentions,
634
+ )
635
+
636
+
637
+ class FlaxWhisperEncoder(nn.Module):
638
+ config: WhisperConfig
639
+ dtype: jnp.dtype = jnp.float32
640
+ gradient_checkpointing: bool = False
641
+
642
+ def setup(self) -> None:
643
+ self.conv1 = nn.Conv(
644
+ self.config.d_model,
645
+ kernel_size=(3,),
646
+ padding=1,
647
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
648
+ dtype=self.dtype,
649
+ )
650
+ self.conv2 = nn.Conv(
651
+ self.config.d_model,
652
+ kernel_size=(3,),
653
+ strides=2,
654
+ padding=1,
655
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
656
+ dtype=self.dtype,
657
+ )
658
+
659
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
660
+
661
+ self.layers = FlaxWhisperEncoderLayerCollection(
662
+ self.config,
663
+ dtype=self.dtype,
664
+ gradient_checkpointing=self.gradient_checkpointing,
665
+ )
666
+
667
+ self.embed_positions = nn.Embed(
668
+ self.config.max_source_positions,
669
+ self.config.d_model,
670
+ dtype=self.dtype,
671
+ embedding_init=sinusoidal_embedding_init,
672
+ )
673
+
674
+ self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
675
+
676
+ def __call__(
677
+ self,
678
+ input_features: jnp.ndarray,
679
+ output_attentions: bool = False,
680
+ output_hidden_states: bool = False,
681
+ return_dict: bool = True,
682
+ deterministic: bool = True,
683
+ ) -> Tuple[jnp.ndarray]:
684
+ if input_features.shape[1:] != (self.config.num_mel_bins, self.config.max_source_positions * 2):
685
+ raise ValueError(
686
+ "input_features.shape[1:], must be equal to (self.config.num_mel_bins,"
687
+ f" self.config.max_source_positions * 2) (got {input_features.shape[1:]}, but should be"
688
+ f" ({self.config.num_mel_bins}, {self.config.max_source_positions * 2}))"
689
+ )
690
+
691
+ input_features = input_features.transpose(0, 2, 1)
692
+ hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False)
693
+ hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
694
+
695
+ embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
696
+ # freeze the sinusoidal embeddings by stopping the back-prop
697
+ embed_positions = jax.lax.stop_gradient(embed_positions)
698
+ hidden_states = hidden_states + embed_positions
699
+
700
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
701
+
702
+ outputs = self.layers(
703
+ hidden_states,
704
+ attention_mask=None,
705
+ deterministic=deterministic,
706
+ output_attentions=output_attentions,
707
+ output_hidden_states=output_hidden_states,
708
+ return_dict=return_dict,
709
+ )
710
+
711
+ last_hidden_states = outputs[0]
712
+ last_hidden_states = self.layer_norm(last_hidden_states)
713
+
714
+ # update the last element in `hidden_states` after applying `layernorm` above
715
+ hidden_states = None
716
+ if output_hidden_states:
717
+ hidden_states = outputs[1]
718
+ hidden_states = hidden_states[:-1] + (last_hidden_states,)
719
+
720
+ if not return_dict:
721
+ outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
722
+ return tuple(v for v in outputs if v is not None)
723
+
724
+ return FlaxBaseModelOutput(
725
+ last_hidden_state=last_hidden_states,
726
+ hidden_states=hidden_states,
727
+ attentions=outputs.attentions,
728
+ )
729
+
730
+
731
+ class FlaxWhisperDecoder(nn.Module):
732
+ config: WhisperConfig
733
+ dtype: jnp.dtype = jnp.float32
734
+ gradient_checkpointing: bool = False
735
+
736
+ def setup(self) -> None:
737
+ self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype)
738
+ self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype)
739
+
740
+ self.layers = FlaxWhisperDecoderLayerCollection(
741
+ self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
742
+ )
743
+
744
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
745
+
746
+ self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-5)
747
+
748
+ def __call__(
749
+ self,
750
+ input_ids: jnp.ndarray,
751
+ attention_mask: jnp.ndarray,
752
+ position_ids: jnp.ndarray,
753
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
754
+ init_cache: bool = False,
755
+ output_attentions: bool = False,
756
+ output_hidden_states: bool = False,
757
+ return_dict: bool = True,
758
+ deterministic: bool = True,
759
+ ) -> Tuple[jnp.ndarray]:
760
+ input_embeds = self.embed_tokens(input_ids)
761
+ position_embeds = self.embed_positions(position_ids)
762
+
763
+ hidden_states = input_embeds + position_embeds
764
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
765
+
766
+ outputs = self.layers(
767
+ hidden_states,
768
+ attention_mask=attention_mask,
769
+ encoder_hidden_states=encoder_hidden_states,
770
+ deterministic=deterministic,
771
+ init_cache=init_cache,
772
+ output_attentions=output_attentions,
773
+ output_hidden_states=output_hidden_states,
774
+ return_dict=return_dict,
775
+ )
776
+
777
+ last_hidden_states = outputs[0]
778
+ last_hidden_states = self.layer_norm(last_hidden_states)
779
+
780
+ # update the last element in `hidden_states` after applying `layernorm` above
781
+ hidden_states = None
782
+ if output_hidden_states:
783
+ hidden_states = outputs[1]
784
+ hidden_states = hidden_states[:-1] + (last_hidden_states,)
785
+
786
+ if not return_dict:
787
+ outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
788
+ return tuple(v for v in outputs if v is not None)
789
+
790
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
791
+ last_hidden_state=last_hidden_states,
792
+ hidden_states=hidden_states,
793
+ attentions=outputs.attentions,
794
+ cross_attentions=outputs.cross_attentions,
795
+ )
796
+
797
+
798
+ class FlaxWhisperModule(nn.Module):
799
+ config: WhisperConfig
800
+ dtype: jnp.dtype = jnp.float32
801
+ gradient_checkpointing: bool = False
802
+
803
+ def setup(self) -> None:
804
+ self.encoder = FlaxWhisperEncoder(
805
+ self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
806
+ )
807
+ self.decoder = FlaxWhisperDecoder(
808
+ self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
809
+ )
810
+
811
+ def __call__(
812
+ self,
813
+ input_features: jnp.ndarray,
814
+ decoder_input_ids: jnp.ndarray,
815
+ decoder_attention_mask: jnp.ndarray,
816
+ decoder_position_ids: jnp.ndarray,
817
+ output_attentions: bool = False,
818
+ output_hidden_states: bool = False,
819
+ return_dict: bool = True,
820
+ deterministic: bool = True,
821
+ ):
822
+ encoder_outputs = self.encoder(
823
+ input_features,
824
+ output_attentions=output_attentions,
825
+ output_hidden_states=output_hidden_states,
826
+ return_dict=return_dict,
827
+ deterministic=deterministic,
828
+ )
829
+
830
+ decoder_outputs = self.decoder(
831
+ input_ids=decoder_input_ids,
832
+ attention_mask=decoder_attention_mask,
833
+ position_ids=decoder_position_ids,
834
+ encoder_hidden_states=encoder_outputs[0],
835
+ output_attentions=output_attentions,
836
+ output_hidden_states=output_hidden_states,
837
+ return_dict=return_dict,
838
+ deterministic=deterministic,
839
+ )
840
+
841
+ if not return_dict:
842
+ return decoder_outputs + encoder_outputs
843
+
844
+ return FlaxSeq2SeqModelOutput(
845
+ last_hidden_state=decoder_outputs.last_hidden_state,
846
+ decoder_hidden_states=decoder_outputs.hidden_states,
847
+ decoder_attentions=decoder_outputs.attentions,
848
+ cross_attentions=decoder_outputs.cross_attentions,
849
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
850
+ encoder_hidden_states=encoder_outputs.hidden_states,
851
+ encoder_attentions=encoder_outputs.attentions,
852
+ )
853
+
854
+ def _get_encoder_module(self):
855
+ return self.encoder
856
+
857
+ def _get_decoder_module(self):
858
+ return self.decoder
859
+
860
+
861
+ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
862
+ config_class = WhisperConfig
863
+ base_model_prefix: str = "model"
864
+ main_input_name = "input_features"
865
+ module_class: nn.Module = None
866
+
867
+ def __init__(
868
+ self,
869
+ config: WhisperConfig,
870
+ input_shape: Tuple[int] = None,
871
+ seed: int = 0,
872
+ dtype: jnp.dtype = jnp.float32,
873
+ _do_init: bool = True,
874
+ gradient_checkpointing: bool = False,
875
+ **kwargs,
876
+ ):
877
+ module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
878
+ if input_shape is None:
879
+ input_shape = (1, config.num_mel_bins, 2 * config.max_source_positions)
880
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
881
+
882
+ def enable_gradient_checkpointing(self):
883
+ self._module = self.module_class(
884
+ config=self.config,
885
+ dtype=self.dtype,
886
+ gradient_checkpointing=True,
887
+ )
888
+
889
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
890
+ # init input tensors
891
+ input_features = jnp.zeros(input_shape, dtype="f4")
892
+ input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
893
+
894
+ decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
895
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
896
+
897
+ batch_size, sequence_length = decoder_input_ids.shape
898
+ decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
899
+
900
+ params_rng, dropout_rng = jax.random.split(rng)
901
+ rngs = {"params": params_rng, "dropout": dropout_rng}
902
+
903
+ random_params = self.module.init(
904
+ rngs,
905
+ input_features=input_features,
906
+ decoder_input_ids=decoder_input_ids,
907
+ decoder_attention_mask=decoder_attention_mask,
908
+ decoder_position_ids=decoder_position_ids,
909
+ )["params"]
910
+
911
+ if params is not None:
912
+ random_params = flatten_dict(unfreeze(random_params))
913
+ params = flatten_dict(unfreeze(params))
914
+ for missing_key in self._missing_keys:
915
+ params[missing_key] = random_params[missing_key]
916
+ self._missing_keys = set()
917
+ return freeze(unflatten_dict(params))
918
+ else:
919
+ return random_params
920
+
921
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->Whisper
922
+ def init_cache(self, batch_size, max_length, encoder_outputs):
923
+ r"""
924
+ Args:
925
+ batch_size (`int`):
926
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
927
+ max_length (`int`):
928
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
929
+ cache.
930
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
931
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
932
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
933
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
934
+ cross-attention of the decoder.
935
+ """
936
+ # init input variables to retrieve cache
937
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
938
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
939
+ decoder_position_ids = jnp.broadcast_to(
940
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
941
+ )
942
+
943
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
944
+ decoder_module = module._get_decoder_module()
945
+ return decoder_module(
946
+ decoder_input_ids,
947
+ decoder_attention_mask,
948
+ decoder_position_ids,
949
+ **kwargs,
950
+ )
951
+
952
+ init_variables = self.module.init(
953
+ jax.random.PRNGKey(0),
954
+ decoder_input_ids=decoder_input_ids,
955
+ decoder_attention_mask=decoder_attention_mask,
956
+ decoder_position_ids=decoder_position_ids,
957
+ encoder_hidden_states=encoder_outputs[0],
958
+ init_cache=True,
959
+ method=_decoder_forward, # we only need to call the decoder to init the cache
960
+ )
961
+ return unfreeze(init_variables["cache"])
962
+
963
+ @add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING)
964
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig)
965
+ def encode(
966
+ self,
967
+ input_features: jnp.ndarray,
968
+ attention_mask: Optional[jnp.ndarray] = None,
969
+ output_attentions: Optional[bool] = None,
970
+ output_hidden_states: Optional[bool] = None,
971
+ return_dict: Optional[bool] = None,
972
+ train: bool = False,
973
+ params: dict = None,
974
+ dropout_rng: PRNGKey = None,
975
+ **kwargs,
976
+ ):
977
+ r"""
978
+ Returns:
979
+
980
+ Example:
981
+
982
+ ```python
983
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
984
+ >>> from datasets import load_dataset
985
+
986
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
987
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
988
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
989
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
990
+ >>> input_features = inputs.input_features
991
+ >>> encoder_outputs = model.encode(input_features=input_features)
992
+ ```"""
993
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
994
+ output_hidden_states = (
995
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
996
+ )
997
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
998
+
999
+ # Handle any PRNG if needed
1000
+ rngs = {}
1001
+ if dropout_rng is not None:
1002
+ rngs["dropout"] = dropout_rng
1003
+
1004
+ def _encoder_forward(module, input_features, **kwargs):
1005
+ encode_module = module._get_encoder_module()
1006
+ return encode_module(input_features, **kwargs)
1007
+
1008
+ return self.module.apply(
1009
+ {"params": params or self.params},
1010
+ input_features=jnp.array(input_features, dtype="f4"),
1011
+ output_attentions=output_attentions,
1012
+ output_hidden_states=output_hidden_states,
1013
+ return_dict=return_dict,
1014
+ deterministic=not train,
1015
+ rngs=rngs,
1016
+ method=_encoder_forward,
1017
+ )
1018
+
1019
+ @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
1020
+ @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=WhisperConfig)
1021
+ def decode(
1022
+ self,
1023
+ decoder_input_ids,
1024
+ encoder_outputs,
1025
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1026
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1027
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1028
+ past_key_values: dict = None,
1029
+ output_attentions: Optional[bool] = None,
1030
+ output_hidden_states: Optional[bool] = None,
1031
+ return_dict: Optional[bool] = None,
1032
+ train: bool = False,
1033
+ params: dict = None,
1034
+ dropout_rng: PRNGKey = None,
1035
+ ):
1036
+ r"""
1037
+ Returns:
1038
+
1039
+ Example:
1040
+
1041
+ ```python
1042
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
1043
+ >>> from datasets import load_dataset
1044
+ >>> import jax.numpy as jnp
1045
+
1046
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1047
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
1048
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1049
+ >>> input_features = processor(ds[0]["audio"]["array"], return_tensors="np").input_features
1050
+
1051
+ >>> encoder_outputs = model.encode(input_features=input_features)
1052
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
1053
+
1054
+ >>> decoder_input_ids = jnp.ones((input_features.shape[0], 1), dtype="i4") * decoder_start_token_id
1055
+
1056
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
1057
+ >>> last_decoder_hidden_states = outputs.last_hidden_state
1058
+ ```"""
1059
+
1060
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1061
+ output_hidden_states = (
1062
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1063
+ )
1064
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1065
+
1066
+ encoder_hidden_states = encoder_outputs[0]
1067
+
1068
+ batch_size, sequence_length = decoder_input_ids.shape
1069
+ if decoder_position_ids is None:
1070
+ if past_key_values is not None:
1071
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
1072
+
1073
+ if decoder_attention_mask is not None:
1074
+ decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
1075
+ else:
1076
+ decoder_position_ids = jnp.broadcast_to(
1077
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1078
+ )
1079
+
1080
+ if decoder_attention_mask is None:
1081
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
1082
+
1083
+ # Handle any PRNG if needed
1084
+ rngs = {}
1085
+ if dropout_rng is not None:
1086
+ rngs["dropout"] = dropout_rng
1087
+
1088
+ inputs = {"params": params or self.params}
1089
+
1090
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1091
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1092
+ # it can be changed by FlaxWhisperAttention module
1093
+ if past_key_values:
1094
+ inputs["cache"] = past_key_values
1095
+ mutable = ["cache"]
1096
+ else:
1097
+ mutable = False
1098
+
1099
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
1100
+ decoder_module = module._get_decoder_module()
1101
+ return decoder_module(
1102
+ input_ids=decoder_input_ids,
1103
+ attention_mask=decoder_attention_mask,
1104
+ position_ids=decoder_position_ids,
1105
+ **kwargs,
1106
+ )
1107
+
1108
+ outputs = self.module.apply(
1109
+ inputs,
1110
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1111
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1112
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1113
+ encoder_hidden_states=encoder_hidden_states,
1114
+ output_attentions=output_attentions,
1115
+ output_hidden_states=output_hidden_states,
1116
+ return_dict=return_dict,
1117
+ deterministic=not train,
1118
+ rngs=rngs,
1119
+ mutable=mutable,
1120
+ method=_decoder_forward,
1121
+ )
1122
+
1123
+ # add updated cache to model output
1124
+ if past_key_values is not None and return_dict:
1125
+ outputs, past = outputs
1126
+ outputs["past_key_values"] = unfreeze(past["cache"])
1127
+ return outputs
1128
+ elif past_key_values is not None and not return_dict:
1129
+ outputs, past = outputs
1130
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1131
+
1132
+ return outputs
1133
+
1134
+ @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
1135
+ def __call__(
1136
+ self,
1137
+ input_features: jnp.ndarray,
1138
+ decoder_input_ids: jnp.ndarray,
1139
+ attention_mask: Optional[jnp.ndarray] = None,
1140
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1141
+ position_ids: Optional[jnp.ndarray] = None,
1142
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1143
+ output_attentions: Optional[bool] = None,
1144
+ output_hidden_states: Optional[bool] = None,
1145
+ return_dict: Optional[bool] = None,
1146
+ train: bool = False,
1147
+ params: dict = None,
1148
+ dropout_rng: PRNGKey = None,
1149
+ ):
1150
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1151
+ output_hidden_states = (
1152
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1153
+ )
1154
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1155
+
1156
+ # prepare decoder inputs
1157
+ if decoder_position_ids is None:
1158
+ if decoder_attention_mask is not None:
1159
+ decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
1160
+ else:
1161
+ batch_size, sequence_length = decoder_input_ids.shape
1162
+ decoder_position_ids = jnp.broadcast_to(
1163
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1164
+ )
1165
+ if decoder_attention_mask is None:
1166
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
1167
+
1168
+ # Handle any PRNG if needed
1169
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
1170
+
1171
+ return self.module.apply(
1172
+ {"params": params or self.params},
1173
+ input_features=jnp.array(input_features, dtype="f4"),
1174
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1175
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1176
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1177
+ output_attentions=output_attentions,
1178
+ output_hidden_states=output_hidden_states,
1179
+ return_dict=return_dict,
1180
+ deterministic=not train,
1181
+ rngs=rngs,
1182
+ )
1183
+
1184
+
1185
+ @add_start_docstrings(
1186
+ "The bare Whisper Model transformer outputting raw hidden-states without any specific head on top.",
1187
+ WHISPER_START_DOCSTRING,
1188
+ )
1189
+ class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
1190
+ config: WhisperConfig
1191
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1192
+ module_class = FlaxWhisperModule
1193
+
1194
+
1195
+ append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
1196
+
1197
+
1198
+ class FlaxWhisperForConditionalGenerationModule(nn.Module):
1199
+ config: WhisperConfig
1200
+ dtype: jnp.dtype = jnp.float32
1201
+ gradient_checkpointing: bool = False
1202
+
1203
+ def setup(self) -> None:
1204
+ self.model = FlaxWhisperModule(
1205
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
1206
+ )
1207
+ self.lm_head = nn.Dense(
1208
+ self.config.vocab_size,
1209
+ use_bias=False,
1210
+ dtype=self.dtype,
1211
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
1212
+ )
1213
+
1214
+ def _get_encoder_module(self):
1215
+ return self.model.encoder
1216
+
1217
+ def _get_decoder_module(self):
1218
+ return self.model.decoder
1219
+
1220
+ def __call__(
1221
+ self,
1222
+ input_features,
1223
+ decoder_input_ids,
1224
+ decoder_attention_mask: jnp.ndarray = None,
1225
+ decoder_position_ids: jnp.ndarray = None,
1226
+ position_ids: jnp.ndarray = None,
1227
+ attention_mask: jnp.ndarray = None,
1228
+ output_attentions: bool = False,
1229
+ output_hidden_states: bool = False,
1230
+ return_dict: bool = True,
1231
+ deterministic: bool = True,
1232
+ ):
1233
+ outputs = self.model(
1234
+ input_features=input_features,
1235
+ decoder_input_ids=decoder_input_ids,
1236
+ decoder_attention_mask=decoder_attention_mask,
1237
+ decoder_position_ids=decoder_position_ids,
1238
+ output_attentions=output_attentions,
1239
+ output_hidden_states=output_hidden_states,
1240
+ return_dict=return_dict,
1241
+ deterministic=deterministic,
1242
+ )
1243
+
1244
+ hidden_states = outputs[0]
1245
+
1246
+ if self.config.tie_word_embeddings:
1247
+ shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"]
1248
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
1249
+ else:
1250
+ lm_logits = self.lm_head(hidden_states)
1251
+
1252
+ if not return_dict:
1253
+ output = (lm_logits,) + outputs[1:]
1254
+ return output
1255
+
1256
+ return FlaxSeq2SeqLMOutput(
1257
+ logits=lm_logits,
1258
+ decoder_hidden_states=outputs.decoder_hidden_states,
1259
+ decoder_attentions=outputs.decoder_attentions,
1260
+ cross_attentions=outputs.cross_attentions,
1261
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1262
+ encoder_hidden_states=outputs.encoder_hidden_states,
1263
+ encoder_attentions=outputs.encoder_attentions,
1264
+ )
1265
+
1266
+
1267
+ @add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING)
1268
+ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
1269
+ module_class = FlaxWhisperForConditionalGenerationModule
1270
+ dtype: jnp.dtype = jnp.float32
1271
+
1272
+ @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
1273
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig)
1274
+ def decode(
1275
+ self,
1276
+ decoder_input_ids,
1277
+ encoder_outputs,
1278
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1279
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1280
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1281
+ past_key_values: dict = None,
1282
+ output_attentions: Optional[bool] = None,
1283
+ output_hidden_states: Optional[bool] = None,
1284
+ return_dict: Optional[bool] = None,
1285
+ train: bool = False,
1286
+ params: dict = None,
1287
+ dropout_rng: PRNGKey = None,
1288
+ ):
1289
+ r"""
1290
+ Returns:
1291
+
1292
+ Example:
1293
+
1294
+ ```python
1295
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
1296
+ >>> from datasets import load_dataset
1297
+
1298
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1299
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
1300
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1301
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
1302
+ >>> input_features = inputs.input_features
1303
+ >>> encoder_outputs = model.encode(input_features=input_features)
1304
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
1305
+
1306
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
1307
+
1308
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
1309
+ >>> last_decoder_hidden_states = outputs.last_hidden_state
1310
+ ```"""
1311
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1312
+ output_hidden_states = (
1313
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1314
+ )
1315
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1316
+
1317
+ encoder_hidden_states = encoder_outputs[0]
1318
+
1319
+ batch_size, sequence_length = decoder_input_ids.shape
1320
+ if decoder_position_ids is None:
1321
+ if past_key_values is not None:
1322
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
1323
+
1324
+ if decoder_attention_mask is not None:
1325
+ decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
1326
+ else:
1327
+ decoder_position_ids = jnp.broadcast_to(
1328
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1329
+ )
1330
+ if decoder_attention_mask is None:
1331
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4")
1332
+
1333
+ # Handle any PRNG if needed
1334
+ rngs = {}
1335
+ if dropout_rng is not None:
1336
+ rngs["dropout"] = dropout_rng
1337
+
1338
+ inputs = {"params": params or self.params}
1339
+
1340
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1341
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1342
+ # it can be changed by FlaxWhisperAttention module
1343
+ if past_key_values:
1344
+ inputs["cache"] = past_key_values
1345
+ mutable = ["cache"]
1346
+ else:
1347
+ mutable = False
1348
+
1349
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
1350
+ decoder_module = module._get_decoder_module()
1351
+ outputs = decoder_module(
1352
+ input_ids=decoder_input_ids,
1353
+ attention_mask=decoder_attention_mask,
1354
+ position_ids=decoder_position_ids,
1355
+ **kwargs,
1356
+ )
1357
+ hidden_states = outputs[0]
1358
+
1359
+ if self.config.tie_word_embeddings:
1360
+ shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"]
1361
+ lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
1362
+ else:
1363
+ lm_logits = module.lm_head(hidden_states)
1364
+
1365
+ return lm_logits, outputs
1366
+
1367
+ outputs = self.module.apply(
1368
+ inputs,
1369
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1370
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1371
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1372
+ encoder_hidden_states=encoder_hidden_states,
1373
+ output_attentions=output_attentions,
1374
+ output_hidden_states=output_hidden_states,
1375
+ return_dict=return_dict,
1376
+ deterministic=not train,
1377
+ rngs=rngs,
1378
+ mutable=mutable,
1379
+ method=_decoder_forward,
1380
+ )
1381
+
1382
+ if past_key_values is None:
1383
+ lm_logits, decoder_outputs = outputs
1384
+ else:
1385
+ (lm_logits, decoder_outputs), past = outputs
1386
+
1387
+ if return_dict:
1388
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
1389
+ logits=lm_logits,
1390
+ hidden_states=decoder_outputs.hidden_states,
1391
+ attentions=decoder_outputs.attentions,
1392
+ cross_attentions=decoder_outputs.cross_attentions,
1393
+ )
1394
+ else:
1395
+ outputs = (lm_logits,) + decoder_outputs[1:]
1396
+
1397
+ # add updated cache to model output
1398
+ if past_key_values is not None and return_dict:
1399
+ outputs["past_key_values"] = unfreeze(past["cache"])
1400
+ return outputs
1401
+ elif past_key_values is not None and not return_dict:
1402
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1403
+
1404
+ return outputs
1405
+
1406
+ def generate(
1407
+ self,
1408
+ input_features,
1409
+ generation_config=None,
1410
+ logits_processor=None,
1411
+ return_timestamps=None,
1412
+ task=None,
1413
+ language=None,
1414
+ is_multilingual=None,
1415
+ **kwargs,
1416
+ ):
1417
+ if generation_config is None:
1418
+ generation_config = self.generation_config
1419
+
1420
+ if return_timestamps is not None:
1421
+ generation_config.return_timestamps = return_timestamps
1422
+
1423
+ if task is not None:
1424
+ generation_config.task = task
1425
+
1426
+ if is_multilingual is not None:
1427
+ generation_config.is_multilingual = is_multilingual
1428
+
1429
+ if language is not None:
1430
+ generation_config.language = language
1431
+
1432
+ if kwargs is not None and "decoder_input_ids" in kwargs:
1433
+ decoder_input_length = len(kwargs["decoder_input_ids"])
1434
+ else:
1435
+ decoder_input_length = 1
1436
+
1437
+ forced_decoder_ids = []
1438
+
1439
+ if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
1440
+ if hasattr(generation_config, "language"):
1441
+ forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
1442
+ else:
1443
+ forced_decoder_ids.append((1, None))
1444
+
1445
+ if hasattr(generation_config, "task"):
1446
+ forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
1447
+ else:
1448
+ forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
1449
+
1450
+ if (
1451
+ hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
1452
+ ) or return_timestamps:
1453
+ logits_processor = [
1454
+ FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length)
1455
+ ]
1456
+ else:
1457
+ if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
1458
+ idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
1459
+ forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
1460
+
1461
+ if len(forced_decoder_ids) > 0:
1462
+ generation_config.forced_decoder_ids = forced_decoder_ids
1463
+
1464
+ return super().generate(
1465
+ input_features,
1466
+ generation_config,
1467
+ logits_processor=logits_processor,
1468
+ **kwargs,
1469
+ )
1470
+
1471
+ def prepare_inputs_for_generation(
1472
+ self,
1473
+ decoder_input_ids,
1474
+ max_length,
1475
+ attention_mask: Optional[jax.Array] = None,
1476
+ decoder_attention_mask: Optional[jax.Array] = None,
1477
+ encoder_outputs=None,
1478
+ **kwargs,
1479
+ ):
1480
+ # initializing the cache
1481
+ batch_size, seq_length = decoder_input_ids.shape
1482
+
1483
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
1484
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1485
+ # But since the decoder uses a causal mask, those positions are masked anyways.
1486
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
1487
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1488
+ if decoder_attention_mask is not None:
1489
+ position_ids = decoder_attention_mask.cumsum(-1) - 1
1490
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
1491
+ else:
1492
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1493
+
1494
+ return {
1495
+ "past_key_values": past_key_values,
1496
+ "encoder_outputs": encoder_outputs,
1497
+ "encoder_attention_mask": attention_mask,
1498
+ "decoder_attention_mask": extended_attention_mask,
1499
+ "decoder_position_ids": position_ids,
1500
+ }
1501
+
1502
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1503
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1504
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
1505
+ return model_kwargs
1506
+
1507
+
1508
+ FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r"""
1509
+ Returns:
1510
+
1511
+ Transcription example:
1512
+
1513
+ ```python
1514
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
1515
+ >>> from datasets import load_dataset
1516
+
1517
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1518
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
1519
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1520
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
1521
+ >>> input_features = inputs.input_features
1522
+ >>> generated_ids = model.generate(input_ids=input_features)
1523
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1524
+ >>> transcription
1525
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
1526
+ ```
1527
+ """
1528
+
1529
+ overwrite_call_docstring(
1530
+ FlaxWhisperForConditionalGeneration, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING
1531
+ )
1532
+ append_replace_return_docstrings(
1533
+ FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
1534
+ )
1535
+
1536
+
1537
+ class FlaxWhisperForAudioClassificationModule(nn.Module):
1538
+ config: WhisperConfig
1539
+ dtype: jnp.dtype = jnp.float32
1540
+ gradient_checkpointing: bool = False
1541
+
1542
+ def setup(self) -> None:
1543
+ self.encoder = FlaxWhisperEncoder(
1544
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
1545
+ )
1546
+ self.config.is_encoder_decoder = False
1547
+ num_layers = self.config.num_hidden_layers + 1
1548
+ if self.config.use_weighted_layer_sum:
1549
+ self.layer_weights = jnp.repeat(1 / num_layers, num_layers)
1550
+ self.projector = nn.Dense(self.config.classifier_proj_size, dtype=self.dtype)
1551
+ self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
1552
+
1553
+ def __call__(
1554
+ self,
1555
+ input_features,
1556
+ encoder_outputs=None,
1557
+ output_attentions=None,
1558
+ output_hidden_states: bool = True,
1559
+ return_dict: bool = True,
1560
+ ):
1561
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1562
+ output_hidden_states = (
1563
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1564
+ )
1565
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1566
+
1567
+ if encoder_outputs is None:
1568
+ encoder_outputs = self.encoder(
1569
+ input_features,
1570
+ output_attentions=output_attentions,
1571
+ output_hidden_states=output_hidden_states,
1572
+ return_dict=return_dict,
1573
+ )
1574
+
1575
+ if self.config.use_weighted_layer_sum:
1576
+ hidden_states = jnp.stack(encoder_outputs, axis=1)
1577
+ norm_weights = jax.nn.softmax(self.layer_weights, axis=-1)
1578
+ hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1)
1579
+ else:
1580
+ hidden_states = encoder_outputs[0]
1581
+
1582
+ hidden_states = self.projector(hidden_states)
1583
+ pooled_output = jnp.mean(hidden_states, axis=1)
1584
+
1585
+ logits = self.classifier(pooled_output)
1586
+
1587
+ if not return_dict:
1588
+ return (logits,) + encoder_outputs[1:]
1589
+
1590
+ return FlaxSequenceClassifierOutput(
1591
+ logits=logits,
1592
+ hidden_states=encoder_outputs.hidden_states,
1593
+ attentions=encoder_outputs.attentions,
1594
+ )
1595
+
1596
+
1597
+ @add_start_docstrings("The Whisper Model with an audio classification head on top.", WHISPER_START_DOCSTRING)
1598
+ class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel):
1599
+ module_class = FlaxWhisperForAudioClassificationModule
1600
+ dtype: jnp.dtype = jnp.float32
1601
+
1602
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
1603
+ # init input tensors
1604
+ input_features = jnp.zeros(input_shape, dtype="f4")
1605
+ input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
1606
+
1607
+ params_rng, dropout_rng = jax.random.split(rng)
1608
+ rngs = {"params": params_rng, "dropout": dropout_rng}
1609
+
1610
+ random_params = self.module.init(
1611
+ rngs,
1612
+ input_features=input_features,
1613
+ )["params"]
1614
+
1615
+ if params is not None:
1616
+ random_params = flatten_dict(unfreeze(random_params))
1617
+ params = flatten_dict(unfreeze(params))
1618
+ for missing_key in self._missing_keys:
1619
+ params[missing_key] = random_params[missing_key]
1620
+ self._missing_keys = set()
1621
+ return freeze(unflatten_dict(params))
1622
+ else:
1623
+ return random_params
1624
+
1625
+ @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
1626
+ def __call__(
1627
+ self,
1628
+ input_features: jnp.ndarray,
1629
+ attention_mask: Optional[jnp.ndarray] = None,
1630
+ output_attentions: Optional[bool] = None,
1631
+ output_hidden_states: Optional[bool] = None,
1632
+ return_dict: Optional[bool] = None,
1633
+ train: bool = False,
1634
+ params: dict = None,
1635
+ dropout_rng: PRNGKey = None,
1636
+ **kwargs,
1637
+ ):
1638
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1639
+ output_hidden_states = (
1640
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1641
+ )
1642
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1643
+
1644
+ # Handle any PRNG if needed
1645
+ rngs = {}
1646
+ if dropout_rng is not None:
1647
+ rngs["dropout"] = dropout_rng
1648
+
1649
+ return self.module.apply(
1650
+ {"params": params or self.params},
1651
+ input_features=jnp.array(input_features, dtype="f4"),
1652
+ output_attentions=output_attentions,
1653
+ output_hidden_states=output_hidden_states,
1654
+ return_dict=return_dict,
1655
+ rngs=rngs,
1656
+ )
1657
+
1658
+
1659
+ FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING = r"""
1660
+ Returns:
1661
+
1662
+ Transcription example:
1663
+
1664
+ ```python
1665
+ >>> import jax.numpy as jnp
1666
+ >>> from transformers import AutoFeatureExtractor, FlaxWhisperForAudioClassification
1667
+ >>> from datasets import load_dataset
1668
+
1669
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
1670
+ >>> model = FlaxWhisperForAudioClassification.from_pretrained(
1671
+ ... "sanchit-gandhi/whisper-medium-fleurs-lang-id", from_pt=True
1672
+ ... )
1673
+ >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True, trust_remote_code=True)
1674
+
1675
+ >>> sample = next(iter(ds))
1676
+
1677
+ >>> inputs = feature_extractor(
1678
+ ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="np"
1679
+ ... )
1680
+ >>> input_features = inputs.input_features
1681
+
1682
+ >>> logits = model(input_features).logits
1683
+
1684
+ >>> predicted_class_ids = jnp.argmax(logits).item()
1685
+ >>> predicted_label = model.config.id2label[predicted_class_ids]
1686
+ >>> predicted_label
1687
+ 'af_za'
1688
+ ```
1689
+ """
1690
+
1691
+ overwrite_call_docstring(
1692
+ FlaxWhisperForAudioClassification, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING
1693
+ )
1694
+ append_replace_return_docstrings(
1695
+ FlaxWhisperForAudioClassification, output_type=FlaxSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC
1696
+ )
modeling_flax_whisper.cpython-312 (1).pyc ADDED
Binary file (75.9 kB). View file
 
modeling_flax_whisper.cpython-312.pyc ADDED
Binary file (75.9 kB). View file
 
modeling_flax_whisper.py ADDED
@@ -0,0 +1,1696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Flax whisper model."""
16
+
17
+ import math
18
+ import random
19
+ from functools import partial
20
+ from typing import Optional, Tuple
21
+
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
26
+ from flax.linen import combine_masks, make_causal_mask
27
+ from flax.linen import partitioning as nn_partitioning
28
+ from flax.linen.attention import dot_product_attention_weights
29
+ from flax.traverse_util import flatten_dict, unflatten_dict
30
+ from jax import lax
31
+ from jax.random import PRNGKey
32
+
33
+ from ...generation.flax_logits_process import FlaxWhisperTimeStampLogitsProcessor
34
+ from ...modeling_flax_outputs import (
35
+ FlaxBaseModelOutput,
36
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
37
+ FlaxCausalLMOutputWithCrossAttentions,
38
+ FlaxSeq2SeqLMOutput,
39
+ FlaxSeq2SeqModelOutput,
40
+ FlaxSequenceClassifierOutput,
41
+ )
42
+ from ...modeling_flax_utils import (
43
+ ACT2FN,
44
+ FlaxPreTrainedModel,
45
+ append_call_sample_docstring,
46
+ append_replace_return_docstrings,
47
+ overwrite_call_docstring,
48
+ )
49
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
50
+ from .configuration_whisper import WhisperConfig
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+
56
+ _CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
57
+ _CONFIG_FOR_DOC = "WhisperConfig"
58
+
59
+ remat = nn_partitioning.remat
60
+
61
+
62
+ def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array:
63
+ """Returns sinusoids for positional embedding"""
64
+ length, channels = shape
65
+ if channels % 2 != 0:
66
+ raise ValueError(
67
+ f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
68
+ )
69
+ log_timescale_increment = math.log(10000) / (channels // 2 - 1)
70
+ inv_timescales = jnp.exp(-log_timescale_increment * jnp.arange(channels // 2))
71
+ scaled_time = jnp.arange(length).reshape(-1, 1) * inv_timescales.reshape(1, -1)
72
+ return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype)
73
+
74
+
75
+ WHISPER_START_DOCSTRING = r"""
76
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
77
+ library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
78
+ etc.) This model is also a Flax Linen
79
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
80
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
81
+ Finally, this model supports inherent JAX features such as:
82
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
83
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
84
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
85
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
86
+
87
+ Parameters:
88
+ config ([`WhisperConfig`]): Model configuration class with all the parameters of the model.
89
+ Initializing with a config file does not load the weights associated with the model, only the
90
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
91
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
92
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
93
+ `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision
94
+ inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`.
95
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
96
+ parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`]
97
+ and [`~FlaxPreTrainedModel.to_bf16`].
98
+ """
99
+
100
+ WHISPER_INPUTS_DOCSTRING = r"""
101
+ Args:
102
+ input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
103
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
104
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
105
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
106
+ [`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a
107
+ tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]
108
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
109
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
110
+ is not used. By default the silence in the input log mel spectrogram are ignored.
111
+ decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
112
+ Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
113
+ [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
114
+ [What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as
115
+ the starting token for `decoder_input_ids` generation.
116
+ decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
117
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
118
+ be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
119
+ in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
120
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
121
+ Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't
122
+ use masking, but this argument is preserved for compatibility. By default the silence in the input log mel
123
+ spectrogram are ignored.
124
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
125
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
126
+ range `[0, config.max_position_embeddings - 1]`.
127
+ output_attentions (`bool`, *optional*):
128
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
129
+ tensors for more detail.
130
+ output_hidden_states (`bool`, *optional*):
131
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
132
+ more detail.
133
+ return_dict (`bool`, *optional*):
134
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
135
+ """
136
+
137
+ WHISPER_ENCODE_INPUTS_DOCSTRING = r"""
138
+ Args:
139
+ input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
140
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
141
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
142
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
143
+ [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
144
+ tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`].
145
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
146
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
147
+ is not used. By default the silence in the input log mel spectrogram are ignored.
148
+ output_attentions (`bool`, *optional*):
149
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
150
+ tensors for more detail.
151
+ output_hidden_states (`bool`, *optional*):
152
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
153
+ more detail.
154
+ return_dict (`bool`, *optional*):
155
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
156
+ """
157
+
158
+ WHISPER_DECODE_INPUTS_DOCSTRING = r"""
159
+ Args:
160
+ decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`):
161
+ Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
162
+ [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
163
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
164
+ encoder_outputs (`tuple(tuple(numpy.ndarray)`):
165
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
166
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
167
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
168
+ encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
169
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
170
+ but it is not used. By default the silence in the input log mel spectrogram are ignored.
171
+ decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
172
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
173
+ be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
174
+ in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
175
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
176
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
177
+ range `[0, config.max_position_embeddings - 1]`.
178
+ past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
179
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
180
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
181
+ output_attentions (`bool`, *optional*):
182
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
183
+ tensors for more detail.
184
+ output_hidden_states (`bool`, *optional*):
185
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
186
+ more detail.
187
+ return_dict (`bool`, *optional*):
188
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
189
+ """
190
+
191
+
192
+ class FlaxWhisperAttention(nn.Module):
193
+ config: WhisperConfig
194
+ embed_dim: int
195
+ num_heads: int
196
+ dropout: float = 0.0
197
+ causal: bool = False
198
+ bias: bool = True
199
+ dtype: jnp.dtype = jnp.float32
200
+
201
+ def setup(self) -> None:
202
+ self.head_dim = self.embed_dim // self.num_heads
203
+ if self.head_dim * self.num_heads != self.embed_dim:
204
+ raise ValueError(
205
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
206
+ f" and `num_heads`: {self.num_heads})."
207
+ )
208
+
209
+ dense = partial(
210
+ nn.Dense,
211
+ self.embed_dim,
212
+ dtype=self.dtype,
213
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
214
+ )
215
+
216
+ self.q_proj = dense(use_bias=self.bias)
217
+ self.k_proj = dense(use_bias=False)
218
+ self.v_proj = dense(use_bias=self.bias)
219
+ self.out_proj = dense(use_bias=self.bias)
220
+
221
+ if self.causal:
222
+ self.causal_mask = make_causal_mask(
223
+ jnp.ones((1, self.config.max_target_positions), dtype="bool"), dtype="bool"
224
+ )
225
+
226
+ def __call__(
227
+ self,
228
+ hidden_states: jnp.ndarray,
229
+ key_value_states: Optional[jnp.ndarray] = None,
230
+ attention_mask: Optional[jnp.ndarray] = None,
231
+ init_cache: bool = False,
232
+ deterministic: bool = True,
233
+ ) -> Tuple[jnp.ndarray]:
234
+ is_cross_attention = key_value_states is not None
235
+ batch_size = hidden_states.shape[0]
236
+
237
+ query_states = self.q_proj(hidden_states)
238
+
239
+ if is_cross_attention:
240
+ key_states = self.k_proj(key_value_states)
241
+ value_states = self.v_proj(key_value_states)
242
+ else:
243
+ key_states = self.k_proj(hidden_states)
244
+ value_states = self.v_proj(hidden_states)
245
+
246
+ query_states = self._split_heads(query_states)
247
+ key_states = self._split_heads(key_states)
248
+ value_states = self._split_heads(value_states)
249
+
250
+ if self.causal:
251
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
252
+ if self.has_variable("cache", "cached_key"):
253
+ mask_shift = self.variables["cache"]["cache_index"]
254
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
255
+ causal_mask = lax.dynamic_slice(
256
+ self.causal_mask,
257
+ (0, 0, mask_shift, 0),
258
+ (1, 1, query_length, max_decoder_length),
259
+ )
260
+ else:
261
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
262
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
263
+
264
+ # combine masks if needed
265
+ if attention_mask is not None and self.causal:
266
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
267
+ attention_mask = combine_masks(attention_mask, causal_mask)
268
+ elif self.causal:
269
+ attention_mask = causal_mask
270
+ elif attention_mask is not None:
271
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
272
+
273
+ # During fast autoregressive decoding, we feed one position at a time,
274
+ # and cache the keys and values step by step.
275
+
276
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
277
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
278
+ key_states, value_states, query_states, attention_mask
279
+ )
280
+
281
+ # Convert the boolean attention mask to an attention bias.
282
+ if attention_mask is not None:
283
+ # attention mask in the form of attention bias
284
+ attention_bias = lax.select(
285
+ attention_mask > 0,
286
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
287
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
288
+ )
289
+ else:
290
+ attention_bias = None
291
+
292
+ dropout_rng = None
293
+ if not deterministic and self.dropout > 0.0:
294
+ dropout_rng = self.make_rng("dropout")
295
+
296
+ attn_weights = dot_product_attention_weights(
297
+ query_states,
298
+ key_states,
299
+ bias=attention_bias,
300
+ dropout_rng=dropout_rng,
301
+ dropout_rate=self.dropout,
302
+ broadcast_dropout=True,
303
+ deterministic=deterministic,
304
+ dtype=self.dtype,
305
+ precision=None,
306
+ )
307
+
308
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
309
+ attn_output = self._merge_heads(attn_output)
310
+ attn_output = self.out_proj(attn_output)
311
+
312
+ return attn_output, attn_weights
313
+
314
+ def _split_heads(self, hidden_state) -> jnp.ndarray:
315
+ return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim))
316
+
317
+ def _merge_heads(self, hidden_state) -> jnp.ndarray:
318
+ return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,))
319
+
320
+ @nn.compact
321
+ def _concatenate_to_cache(self, key, value, query, attention_mask) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
322
+ # detect if we're initializing by absence of existing cache data.
323
+ is_initialized = self.has_variable("cache", "cached_key")
324
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
325
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
326
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
327
+
328
+ if is_initialized:
329
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
330
+ # update key, value caches with our new 1d spatial slices
331
+ cur_index = cache_index.value
332
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
333
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
334
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
335
+ cached_key.value = key
336
+ cached_value.value = value
337
+ num_updated_cache_vectors = query.shape[1]
338
+ cache_index.value = cache_index.value + num_updated_cache_vectors
339
+ # causal mask for cached decoder self-attention: our single query position should only
340
+ # attend to those key positions that have already been generated and cached, not the
341
+ # remaining zero elements.
342
+ pad_mask = jnp.broadcast_to(
343
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
344
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
345
+ )
346
+ attention_mask = combine_masks(pad_mask, attention_mask)
347
+
348
+ return key, value, attention_mask
349
+
350
+
351
+ # Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Whisper
352
+ class FlaxWhisperEncoderLayer(nn.Module):
353
+ config: WhisperConfig
354
+ dtype: jnp.dtype = jnp.float32
355
+
356
+ def setup(self) -> None:
357
+ self.embed_dim = self.config.d_model
358
+ self.self_attn = FlaxWhisperAttention(
359
+ config=self.config,
360
+ embed_dim=self.embed_dim,
361
+ num_heads=self.config.encoder_attention_heads,
362
+ dropout=self.config.attention_dropout,
363
+ dtype=self.dtype,
364
+ )
365
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
366
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
367
+ self.activation_fn = ACT2FN[self.config.activation_function]
368
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
369
+ self.fc1 = nn.Dense(
370
+ self.config.encoder_ffn_dim,
371
+ dtype=self.dtype,
372
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
373
+ )
374
+ self.fc2 = nn.Dense(
375
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
376
+ )
377
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
378
+
379
+ def __call__(
380
+ self,
381
+ hidden_states: jnp.ndarray,
382
+ attention_mask: jnp.ndarray,
383
+ output_attentions: bool = True,
384
+ deterministic: bool = True,
385
+ ) -> Tuple[jnp.ndarray]:
386
+ residual = hidden_states
387
+ hidden_states = self.self_attn_layer_norm(hidden_states)
388
+ hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
389
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
390
+ hidden_states = residual + hidden_states
391
+
392
+ residual = hidden_states
393
+ hidden_states = self.final_layer_norm(hidden_states)
394
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
395
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
396
+ hidden_states = self.fc2(hidden_states)
397
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
398
+ hidden_states = residual + hidden_states
399
+
400
+ outputs = (hidden_states,)
401
+
402
+ if output_attentions:
403
+ outputs += (attn_weights,)
404
+
405
+ return outputs
406
+
407
+
408
+ class FlaxWhisperEncoderLayerCollection(nn.Module):
409
+ config: WhisperConfig
410
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
411
+ gradient_checkpointing: bool = False
412
+
413
+ def setup(self):
414
+ if self.gradient_checkpointing:
415
+ FlaxWhisperEncoderCheckpointLayer = remat(FlaxWhisperEncoderLayer, static_argnums=(2, 3))
416
+ self.layers = [
417
+ FlaxWhisperEncoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
418
+ for i in range(self.config.encoder_layers)
419
+ ]
420
+ else:
421
+ self.layers = [
422
+ FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
423
+ for i in range(self.config.encoder_layers)
424
+ ]
425
+ self.layerdrop = self.config.encoder_layerdrop
426
+
427
+ def __call__(
428
+ self,
429
+ hidden_states,
430
+ attention_mask,
431
+ deterministic: bool = True,
432
+ output_attentions: bool = False,
433
+ output_hidden_states: bool = False,
434
+ return_dict: bool = True,
435
+ ):
436
+ all_attentions = () if output_attentions else None
437
+ all_hidden_states = () if output_hidden_states else None
438
+
439
+ for encoder_layer in self.layers:
440
+ if output_hidden_states:
441
+ all_hidden_states = all_hidden_states + (hidden_states,)
442
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
443
+ dropout_probability = random.uniform(0, 1)
444
+ if not deterministic and (dropout_probability < self.layerdrop): # skip the layer
445
+ layer_outputs = (None, None)
446
+ else:
447
+ layer_outputs = encoder_layer(
448
+ hidden_states,
449
+ attention_mask,
450
+ output_attentions,
451
+ deterministic,
452
+ )
453
+ hidden_states = layer_outputs[0]
454
+ if output_attentions:
455
+ all_attentions = all_attentions + (layer_outputs[1],)
456
+
457
+ if output_hidden_states:
458
+ all_hidden_states += (hidden_states,)
459
+
460
+ outputs = (hidden_states, all_hidden_states, all_attentions)
461
+
462
+ if not return_dict:
463
+ return tuple(v for v in outputs if v is not None)
464
+
465
+ return FlaxBaseModelOutput(
466
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
467
+ )
468
+
469
+
470
+ # Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Whisper
471
+ class FlaxWhisperDecoderLayer(nn.Module):
472
+ config: WhisperConfig
473
+ dtype: jnp.dtype = jnp.float32
474
+
475
+ def setup(self) -> None:
476
+ self.embed_dim = self.config.d_model
477
+ self.self_attn = FlaxWhisperAttention(
478
+ config=self.config,
479
+ embed_dim=self.embed_dim,
480
+ num_heads=self.config.decoder_attention_heads,
481
+ dropout=self.config.attention_dropout,
482
+ causal=True,
483
+ dtype=self.dtype,
484
+ )
485
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
486
+ self.activation_fn = ACT2FN[self.config.activation_function]
487
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
488
+
489
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
490
+ self.encoder_attn = FlaxWhisperAttention(
491
+ config=self.config,
492
+ embed_dim=self.embed_dim,
493
+ num_heads=self.config.decoder_attention_heads,
494
+ dropout=self.config.attention_dropout,
495
+ dtype=self.dtype,
496
+ )
497
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
498
+ self.fc1 = nn.Dense(
499
+ self.config.decoder_ffn_dim,
500
+ dtype=self.dtype,
501
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
502
+ )
503
+ self.fc2 = nn.Dense(
504
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
505
+ )
506
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
507
+
508
+ def __call__(
509
+ self,
510
+ hidden_states: jnp.ndarray,
511
+ attention_mask: jnp.ndarray,
512
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
513
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
514
+ init_cache: bool = False,
515
+ output_attentions: bool = True,
516
+ deterministic: bool = True,
517
+ ) -> Tuple[jnp.ndarray]:
518
+ residual = hidden_states
519
+ hidden_states = self.self_attn_layer_norm(hidden_states)
520
+
521
+ # Self Attention
522
+ hidden_states, self_attn_weights = self.self_attn(
523
+ hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
524
+ )
525
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
526
+ hidden_states = residual + hidden_states
527
+
528
+ # Cross-Attention Block
529
+ cross_attn_weights = None
530
+ if encoder_hidden_states is not None:
531
+ residual = hidden_states
532
+
533
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
534
+ hidden_states, cross_attn_weights = self.encoder_attn(
535
+ hidden_states=hidden_states,
536
+ key_value_states=encoder_hidden_states,
537
+ attention_mask=encoder_attention_mask,
538
+ )
539
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
540
+ hidden_states = residual + hidden_states
541
+
542
+ # Fully Connected
543
+ residual = hidden_states
544
+ hidden_states = self.final_layer_norm(hidden_states)
545
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
546
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
547
+ hidden_states = self.fc2(hidden_states)
548
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
549
+ hidden_states = residual + hidden_states
550
+
551
+ outputs = (hidden_states,)
552
+
553
+ if output_attentions:
554
+ outputs += (self_attn_weights, cross_attn_weights)
555
+
556
+ return outputs
557
+
558
+
559
+ class FlaxWhisperDecoderLayerCollection(nn.Module):
560
+ config: WhisperConfig
561
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
562
+ gradient_checkpointing: bool = False
563
+
564
+ def setup(self):
565
+ if self.gradient_checkpointing:
566
+ FlaxWhisperDecoderCheckpointLayer = remat(FlaxWhisperDecoderLayer, static_argnums=(4, 5, 6))
567
+ self.layers = [
568
+ FlaxWhisperDecoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
569
+ for i in range(self.config.decoder_layers)
570
+ ]
571
+ else:
572
+ self.layers = [
573
+ FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
574
+ for i in range(self.config.decoder_layers)
575
+ ]
576
+ self.layerdrop = self.config.decoder_layerdrop
577
+
578
+ def __call__(
579
+ self,
580
+ hidden_states,
581
+ attention_mask,
582
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
583
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
584
+ deterministic: bool = True,
585
+ init_cache: bool = False,
586
+ output_attentions: bool = False,
587
+ output_hidden_states: bool = False,
588
+ return_dict: bool = True,
589
+ ):
590
+ # decoder layers
591
+ all_hidden_states = () if output_hidden_states else None
592
+ all_self_attns = () if output_attentions else None
593
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
594
+
595
+ for decoder_layer in self.layers:
596
+ if output_hidden_states:
597
+ all_hidden_states += (hidden_states,)
598
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
599
+ dropout_probability = random.uniform(0, 1)
600
+ if not deterministic and (dropout_probability < self.layerdrop):
601
+ layer_outputs = (None, None, None)
602
+ else:
603
+ layer_outputs = decoder_layer(
604
+ hidden_states,
605
+ attention_mask,
606
+ encoder_hidden_states,
607
+ encoder_attention_mask,
608
+ init_cache,
609
+ output_attentions,
610
+ deterministic,
611
+ )
612
+
613
+ hidden_states = layer_outputs[0]
614
+ if output_attentions:
615
+ all_self_attns += (layer_outputs[1],)
616
+
617
+ if encoder_hidden_states is not None:
618
+ all_cross_attentions += (layer_outputs[2],)
619
+
620
+ # add hidden states from the last decoder layer
621
+ if output_hidden_states:
622
+ all_hidden_states += (hidden_states,)
623
+
624
+ outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
625
+
626
+ if not return_dict:
627
+ return tuple(v for v in outputs if v is not None)
628
+
629
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
630
+ last_hidden_state=hidden_states,
631
+ hidden_states=all_hidden_states,
632
+ attentions=all_self_attns,
633
+ cross_attentions=all_cross_attentions,
634
+ )
635
+
636
+
637
+ class FlaxWhisperEncoder(nn.Module):
638
+ config: WhisperConfig
639
+ dtype: jnp.dtype = jnp.float32
640
+ gradient_checkpointing: bool = False
641
+
642
+ def setup(self) -> None:
643
+ self.conv1 = nn.Conv(
644
+ self.config.d_model,
645
+ kernel_size=(3,),
646
+ padding=1,
647
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
648
+ dtype=self.dtype,
649
+ )
650
+ self.conv2 = nn.Conv(
651
+ self.config.d_model,
652
+ kernel_size=(3,),
653
+ strides=2,
654
+ padding=1,
655
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
656
+ dtype=self.dtype,
657
+ )
658
+
659
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
660
+
661
+ self.layers = FlaxWhisperEncoderLayerCollection(
662
+ self.config,
663
+ dtype=self.dtype,
664
+ gradient_checkpointing=self.gradient_checkpointing,
665
+ )
666
+
667
+ self.embed_positions = nn.Embed(
668
+ self.config.max_source_positions,
669
+ self.config.d_model,
670
+ dtype=self.dtype,
671
+ embedding_init=sinusoidal_embedding_init,
672
+ )
673
+
674
+ self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
675
+
676
+ def __call__(
677
+ self,
678
+ input_features: jnp.ndarray,
679
+ output_attentions: bool = False,
680
+ output_hidden_states: bool = False,
681
+ return_dict: bool = True,
682
+ deterministic: bool = True,
683
+ ) -> Tuple[jnp.ndarray]:
684
+ if input_features.shape[1:] != (self.config.num_mel_bins, self.config.max_source_positions * 2):
685
+ raise ValueError(
686
+ "input_features.shape[1:], must be equal to (self.config.num_mel_bins,"
687
+ f" self.config.max_source_positions * 2) (got {input_features.shape[1:]}, but should be"
688
+ f" ({self.config.num_mel_bins}, {self.config.max_source_positions * 2}))"
689
+ )
690
+
691
+ input_features = input_features.transpose(0, 2, 1)
692
+ hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False)
693
+ hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
694
+
695
+ embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
696
+ # freeze the sinusoidal embeddings by stopping the back-prop
697
+ embed_positions = jax.lax.stop_gradient(embed_positions)
698
+ hidden_states = hidden_states + embed_positions
699
+
700
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
701
+
702
+ outputs = self.layers(
703
+ hidden_states,
704
+ attention_mask=None,
705
+ deterministic=deterministic,
706
+ output_attentions=output_attentions,
707
+ output_hidden_states=output_hidden_states,
708
+ return_dict=return_dict,
709
+ )
710
+
711
+ last_hidden_states = outputs[0]
712
+ last_hidden_states = self.layer_norm(last_hidden_states)
713
+
714
+ # update the last element in `hidden_states` after applying `layernorm` above
715
+ hidden_states = None
716
+ if output_hidden_states:
717
+ hidden_states = outputs[1]
718
+ hidden_states = hidden_states[:-1] + (last_hidden_states,)
719
+
720
+ if not return_dict:
721
+ outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
722
+ return tuple(v for v in outputs if v is not None)
723
+
724
+ return FlaxBaseModelOutput(
725
+ last_hidden_state=last_hidden_states,
726
+ hidden_states=hidden_states,
727
+ attentions=outputs.attentions,
728
+ )
729
+
730
+
731
+ class FlaxWhisperDecoder(nn.Module):
732
+ config: WhisperConfig
733
+ dtype: jnp.dtype = jnp.float32
734
+ gradient_checkpointing: bool = False
735
+
736
+ def setup(self) -> None:
737
+ self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype)
738
+ self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype)
739
+
740
+ self.layers = FlaxWhisperDecoderLayerCollection(
741
+ self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
742
+ )
743
+
744
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
745
+
746
+ self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-5)
747
+
748
+ def __call__(
749
+ self,
750
+ input_ids: jnp.ndarray,
751
+ attention_mask: jnp.ndarray,
752
+ position_ids: jnp.ndarray,
753
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
754
+ init_cache: bool = False,
755
+ output_attentions: bool = False,
756
+ output_hidden_states: bool = False,
757
+ return_dict: bool = True,
758
+ deterministic: bool = True,
759
+ ) -> Tuple[jnp.ndarray]:
760
+ input_embeds = self.embed_tokens(input_ids)
761
+ position_embeds = self.embed_positions(position_ids)
762
+
763
+ hidden_states = input_embeds + position_embeds
764
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
765
+
766
+ outputs = self.layers(
767
+ hidden_states,
768
+ attention_mask=attention_mask,
769
+ encoder_hidden_states=encoder_hidden_states,
770
+ deterministic=deterministic,
771
+ init_cache=init_cache,
772
+ output_attentions=output_attentions,
773
+ output_hidden_states=output_hidden_states,
774
+ return_dict=return_dict,
775
+ )
776
+
777
+ last_hidden_states = outputs[0]
778
+ last_hidden_states = self.layer_norm(last_hidden_states)
779
+
780
+ # update the last element in `hidden_states` after applying `layernorm` above
781
+ hidden_states = None
782
+ if output_hidden_states:
783
+ hidden_states = outputs[1]
784
+ hidden_states = hidden_states[:-1] + (last_hidden_states,)
785
+
786
+ if not return_dict:
787
+ outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
788
+ return tuple(v for v in outputs if v is not None)
789
+
790
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
791
+ last_hidden_state=last_hidden_states,
792
+ hidden_states=hidden_states,
793
+ attentions=outputs.attentions,
794
+ cross_attentions=outputs.cross_attentions,
795
+ )
796
+
797
+
798
+ class FlaxWhisperModule(nn.Module):
799
+ config: WhisperConfig
800
+ dtype: jnp.dtype = jnp.float32
801
+ gradient_checkpointing: bool = False
802
+
803
+ def setup(self) -> None:
804
+ self.encoder = FlaxWhisperEncoder(
805
+ self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
806
+ )
807
+ self.decoder = FlaxWhisperDecoder(
808
+ self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
809
+ )
810
+
811
+ def __call__(
812
+ self,
813
+ input_features: jnp.ndarray,
814
+ decoder_input_ids: jnp.ndarray,
815
+ decoder_attention_mask: jnp.ndarray,
816
+ decoder_position_ids: jnp.ndarray,
817
+ output_attentions: bool = False,
818
+ output_hidden_states: bool = False,
819
+ return_dict: bool = True,
820
+ deterministic: bool = True,
821
+ ):
822
+ encoder_outputs = self.encoder(
823
+ input_features,
824
+ output_attentions=output_attentions,
825
+ output_hidden_states=output_hidden_states,
826
+ return_dict=return_dict,
827
+ deterministic=deterministic,
828
+ )
829
+
830
+ decoder_outputs = self.decoder(
831
+ input_ids=decoder_input_ids,
832
+ attention_mask=decoder_attention_mask,
833
+ position_ids=decoder_position_ids,
834
+ encoder_hidden_states=encoder_outputs[0],
835
+ output_attentions=output_attentions,
836
+ output_hidden_states=output_hidden_states,
837
+ return_dict=return_dict,
838
+ deterministic=deterministic,
839
+ )
840
+
841
+ if not return_dict:
842
+ return decoder_outputs + encoder_outputs
843
+
844
+ return FlaxSeq2SeqModelOutput(
845
+ last_hidden_state=decoder_outputs.last_hidden_state,
846
+ decoder_hidden_states=decoder_outputs.hidden_states,
847
+ decoder_attentions=decoder_outputs.attentions,
848
+ cross_attentions=decoder_outputs.cross_attentions,
849
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
850
+ encoder_hidden_states=encoder_outputs.hidden_states,
851
+ encoder_attentions=encoder_outputs.attentions,
852
+ )
853
+
854
+ def _get_encoder_module(self):
855
+ return self.encoder
856
+
857
+ def _get_decoder_module(self):
858
+ return self.decoder
859
+
860
+
861
+ class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
862
+ config_class = WhisperConfig
863
+ base_model_prefix: str = "model"
864
+ main_input_name = "input_features"
865
+ module_class: nn.Module = None
866
+
867
+ def __init__(
868
+ self,
869
+ config: WhisperConfig,
870
+ input_shape: Tuple[int] = None,
871
+ seed: int = 0,
872
+ dtype: jnp.dtype = jnp.float32,
873
+ _do_init: bool = True,
874
+ gradient_checkpointing: bool = False,
875
+ **kwargs,
876
+ ):
877
+ module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
878
+ if input_shape is None:
879
+ input_shape = (1, config.num_mel_bins, 2 * config.max_source_positions)
880
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
881
+
882
+ def enable_gradient_checkpointing(self):
883
+ self._module = self.module_class(
884
+ config=self.config,
885
+ dtype=self.dtype,
886
+ gradient_checkpointing=True,
887
+ )
888
+
889
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
890
+ # init input tensors
891
+ input_features = jnp.zeros(input_shape, dtype="f4")
892
+ input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
893
+
894
+ decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
895
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
896
+
897
+ batch_size, sequence_length = decoder_input_ids.shape
898
+ decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
899
+
900
+ params_rng, dropout_rng = jax.random.split(rng)
901
+ rngs = {"params": params_rng, "dropout": dropout_rng}
902
+
903
+ random_params = self.module.init(
904
+ rngs,
905
+ input_features=input_features,
906
+ decoder_input_ids=decoder_input_ids,
907
+ decoder_attention_mask=decoder_attention_mask,
908
+ decoder_position_ids=decoder_position_ids,
909
+ )["params"]
910
+
911
+ if params is not None:
912
+ random_params = flatten_dict(unfreeze(random_params))
913
+ params = flatten_dict(unfreeze(params))
914
+ for missing_key in self._missing_keys:
915
+ params[missing_key] = random_params[missing_key]
916
+ self._missing_keys = set()
917
+ return freeze(unflatten_dict(params))
918
+ else:
919
+ return random_params
920
+
921
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->Whisper
922
+ def init_cache(self, batch_size, max_length, encoder_outputs):
923
+ r"""
924
+ Args:
925
+ batch_size (`int`):
926
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
927
+ max_length (`int`):
928
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
929
+ cache.
930
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
931
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
932
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
933
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
934
+ cross-attention of the decoder.
935
+ """
936
+ # init input variables to retrieve cache
937
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
938
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
939
+ decoder_position_ids = jnp.broadcast_to(
940
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
941
+ )
942
+
943
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
944
+ decoder_module = module._get_decoder_module()
945
+ return decoder_module(
946
+ decoder_input_ids,
947
+ decoder_attention_mask,
948
+ decoder_position_ids,
949
+ **kwargs,
950
+ )
951
+
952
+ init_variables = self.module.init(
953
+ jax.random.PRNGKey(0),
954
+ decoder_input_ids=decoder_input_ids,
955
+ decoder_attention_mask=decoder_attention_mask,
956
+ decoder_position_ids=decoder_position_ids,
957
+ encoder_hidden_states=encoder_outputs[0],
958
+ init_cache=True,
959
+ method=_decoder_forward, # we only need to call the decoder to init the cache
960
+ )
961
+ return unfreeze(init_variables["cache"])
962
+
963
+ @add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING)
964
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig)
965
+ def encode(
966
+ self,
967
+ input_features: jnp.ndarray,
968
+ attention_mask: Optional[jnp.ndarray] = None,
969
+ output_attentions: Optional[bool] = None,
970
+ output_hidden_states: Optional[bool] = None,
971
+ return_dict: Optional[bool] = None,
972
+ train: bool = False,
973
+ params: dict = None,
974
+ dropout_rng: PRNGKey = None,
975
+ **kwargs,
976
+ ):
977
+ r"""
978
+ Returns:
979
+
980
+ Example:
981
+
982
+ ```python
983
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
984
+ >>> from datasets import load_dataset
985
+
986
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
987
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
988
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
989
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
990
+ >>> input_features = inputs.input_features
991
+ >>> encoder_outputs = model.encode(input_features=input_features)
992
+ ```"""
993
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
994
+ output_hidden_states = (
995
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
996
+ )
997
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
998
+
999
+ # Handle any PRNG if needed
1000
+ rngs = {}
1001
+ if dropout_rng is not None:
1002
+ rngs["dropout"] = dropout_rng
1003
+
1004
+ def _encoder_forward(module, input_features, **kwargs):
1005
+ encode_module = module._get_encoder_module()
1006
+ return encode_module(input_features, **kwargs)
1007
+
1008
+ return self.module.apply(
1009
+ {"params": params or self.params},
1010
+ input_features=jnp.array(input_features, dtype="f4"),
1011
+ output_attentions=output_attentions,
1012
+ output_hidden_states=output_hidden_states,
1013
+ return_dict=return_dict,
1014
+ deterministic=not train,
1015
+ rngs=rngs,
1016
+ method=_encoder_forward,
1017
+ )
1018
+
1019
+ @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
1020
+ @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=WhisperConfig)
1021
+ def decode(
1022
+ self,
1023
+ decoder_input_ids,
1024
+ encoder_outputs,
1025
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1026
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1027
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1028
+ past_key_values: dict = None,
1029
+ output_attentions: Optional[bool] = None,
1030
+ output_hidden_states: Optional[bool] = None,
1031
+ return_dict: Optional[bool] = None,
1032
+ train: bool = False,
1033
+ params: dict = None,
1034
+ dropout_rng: PRNGKey = None,
1035
+ ):
1036
+ r"""
1037
+ Returns:
1038
+
1039
+ Example:
1040
+
1041
+ ```python
1042
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
1043
+ >>> from datasets import load_dataset
1044
+ >>> import jax.numpy as jnp
1045
+
1046
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1047
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
1048
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1049
+ >>> input_features = processor(ds[0]["audio"]["array"], return_tensors="np").input_features
1050
+
1051
+ >>> encoder_outputs = model.encode(input_features=input_features)
1052
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
1053
+
1054
+ >>> decoder_input_ids = jnp.ones((input_features.shape[0], 1), dtype="i4") * decoder_start_token_id
1055
+
1056
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
1057
+ >>> last_decoder_hidden_states = outputs.last_hidden_state
1058
+ ```"""
1059
+
1060
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1061
+ output_hidden_states = (
1062
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1063
+ )
1064
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1065
+
1066
+ encoder_hidden_states = encoder_outputs[0]
1067
+
1068
+ batch_size, sequence_length = decoder_input_ids.shape
1069
+ if decoder_position_ids is None:
1070
+ if past_key_values is not None:
1071
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
1072
+
1073
+ if decoder_attention_mask is not None:
1074
+ decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
1075
+ else:
1076
+ decoder_position_ids = jnp.broadcast_to(
1077
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1078
+ )
1079
+
1080
+ if decoder_attention_mask is None:
1081
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
1082
+
1083
+ # Handle any PRNG if needed
1084
+ rngs = {}
1085
+ if dropout_rng is not None:
1086
+ rngs["dropout"] = dropout_rng
1087
+
1088
+ inputs = {"params": params or self.params}
1089
+
1090
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1091
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1092
+ # it can be changed by FlaxWhisperAttention module
1093
+ if past_key_values:
1094
+ inputs["cache"] = past_key_values
1095
+ mutable = ["cache"]
1096
+ else:
1097
+ mutable = False
1098
+
1099
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
1100
+ decoder_module = module._get_decoder_module()
1101
+ return decoder_module(
1102
+ input_ids=decoder_input_ids,
1103
+ attention_mask=decoder_attention_mask,
1104
+ position_ids=decoder_position_ids,
1105
+ **kwargs,
1106
+ )
1107
+
1108
+ outputs = self.module.apply(
1109
+ inputs,
1110
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1111
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1112
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1113
+ encoder_hidden_states=encoder_hidden_states,
1114
+ output_attentions=output_attentions,
1115
+ output_hidden_states=output_hidden_states,
1116
+ return_dict=return_dict,
1117
+ deterministic=not train,
1118
+ rngs=rngs,
1119
+ mutable=mutable,
1120
+ method=_decoder_forward,
1121
+ )
1122
+
1123
+ # add updated cache to model output
1124
+ if past_key_values is not None and return_dict:
1125
+ outputs, past = outputs
1126
+ outputs["past_key_values"] = unfreeze(past["cache"])
1127
+ return outputs
1128
+ elif past_key_values is not None and not return_dict:
1129
+ outputs, past = outputs
1130
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1131
+
1132
+ return outputs
1133
+
1134
+ @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
1135
+ def __call__(
1136
+ self,
1137
+ input_features: jnp.ndarray,
1138
+ decoder_input_ids: jnp.ndarray,
1139
+ attention_mask: Optional[jnp.ndarray] = None,
1140
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1141
+ position_ids: Optional[jnp.ndarray] = None,
1142
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1143
+ output_attentions: Optional[bool] = None,
1144
+ output_hidden_states: Optional[bool] = None,
1145
+ return_dict: Optional[bool] = None,
1146
+ train: bool = False,
1147
+ params: dict = None,
1148
+ dropout_rng: PRNGKey = None,
1149
+ ):
1150
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1151
+ output_hidden_states = (
1152
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1153
+ )
1154
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1155
+
1156
+ # prepare decoder inputs
1157
+ if decoder_position_ids is None:
1158
+ if decoder_attention_mask is not None:
1159
+ decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
1160
+ else:
1161
+ batch_size, sequence_length = decoder_input_ids.shape
1162
+ decoder_position_ids = jnp.broadcast_to(
1163
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1164
+ )
1165
+ if decoder_attention_mask is None:
1166
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
1167
+
1168
+ # Handle any PRNG if needed
1169
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
1170
+
1171
+ return self.module.apply(
1172
+ {"params": params or self.params},
1173
+ input_features=jnp.array(input_features, dtype="f4"),
1174
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1175
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1176
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1177
+ output_attentions=output_attentions,
1178
+ output_hidden_states=output_hidden_states,
1179
+ return_dict=return_dict,
1180
+ deterministic=not train,
1181
+ rngs=rngs,
1182
+ )
1183
+
1184
+
1185
+ @add_start_docstrings(
1186
+ "The bare Whisper Model transformer outputting raw hidden-states without any specific head on top.",
1187
+ WHISPER_START_DOCSTRING,
1188
+ )
1189
+ class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
1190
+ config: WhisperConfig
1191
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1192
+ module_class = FlaxWhisperModule
1193
+
1194
+
1195
+ append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
1196
+
1197
+
1198
+ class FlaxWhisperForConditionalGenerationModule(nn.Module):
1199
+ config: WhisperConfig
1200
+ dtype: jnp.dtype = jnp.float32
1201
+ gradient_checkpointing: bool = False
1202
+
1203
+ def setup(self) -> None:
1204
+ self.model = FlaxWhisperModule(
1205
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
1206
+ )
1207
+ self.lm_head = nn.Dense(
1208
+ self.config.vocab_size,
1209
+ use_bias=False,
1210
+ dtype=self.dtype,
1211
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
1212
+ )
1213
+
1214
+ def _get_encoder_module(self):
1215
+ return self.model.encoder
1216
+
1217
+ def _get_decoder_module(self):
1218
+ return self.model.decoder
1219
+
1220
+ def __call__(
1221
+ self,
1222
+ input_features,
1223
+ decoder_input_ids,
1224
+ decoder_attention_mask: jnp.ndarray = None,
1225
+ decoder_position_ids: jnp.ndarray = None,
1226
+ position_ids: jnp.ndarray = None,
1227
+ attention_mask: jnp.ndarray = None,
1228
+ output_attentions: bool = False,
1229
+ output_hidden_states: bool = False,
1230
+ return_dict: bool = True,
1231
+ deterministic: bool = True,
1232
+ ):
1233
+ outputs = self.model(
1234
+ input_features=input_features,
1235
+ decoder_input_ids=decoder_input_ids,
1236
+ decoder_attention_mask=decoder_attention_mask,
1237
+ decoder_position_ids=decoder_position_ids,
1238
+ output_attentions=output_attentions,
1239
+ output_hidden_states=output_hidden_states,
1240
+ return_dict=return_dict,
1241
+ deterministic=deterministic,
1242
+ )
1243
+
1244
+ hidden_states = outputs[0]
1245
+
1246
+ if self.config.tie_word_embeddings:
1247
+ shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"]
1248
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
1249
+ else:
1250
+ lm_logits = self.lm_head(hidden_states)
1251
+
1252
+ if not return_dict:
1253
+ output = (lm_logits,) + outputs[1:]
1254
+ return output
1255
+
1256
+ return FlaxSeq2SeqLMOutput(
1257
+ logits=lm_logits,
1258
+ decoder_hidden_states=outputs.decoder_hidden_states,
1259
+ decoder_attentions=outputs.decoder_attentions,
1260
+ cross_attentions=outputs.cross_attentions,
1261
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1262
+ encoder_hidden_states=outputs.encoder_hidden_states,
1263
+ encoder_attentions=outputs.encoder_attentions,
1264
+ )
1265
+
1266
+
1267
+ @add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING)
1268
+ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
1269
+ module_class = FlaxWhisperForConditionalGenerationModule
1270
+ dtype: jnp.dtype = jnp.float32
1271
+
1272
+ @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
1273
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig)
1274
+ def decode(
1275
+ self,
1276
+ decoder_input_ids,
1277
+ encoder_outputs,
1278
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1279
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1280
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1281
+ past_key_values: dict = None,
1282
+ output_attentions: Optional[bool] = None,
1283
+ output_hidden_states: Optional[bool] = None,
1284
+ return_dict: Optional[bool] = None,
1285
+ train: bool = False,
1286
+ params: dict = None,
1287
+ dropout_rng: PRNGKey = None,
1288
+ ):
1289
+ r"""
1290
+ Returns:
1291
+
1292
+ Example:
1293
+
1294
+ ```python
1295
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
1296
+ >>> from datasets import load_dataset
1297
+
1298
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1299
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
1300
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1301
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
1302
+ >>> input_features = inputs.input_features
1303
+ >>> encoder_outputs = model.encode(input_features=input_features)
1304
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
1305
+
1306
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
1307
+
1308
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
1309
+ >>> last_decoder_hidden_states = outputs.last_hidden_state
1310
+ ```"""
1311
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1312
+ output_hidden_states = (
1313
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1314
+ )
1315
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1316
+
1317
+ encoder_hidden_states = encoder_outputs[0]
1318
+
1319
+ batch_size, sequence_length = decoder_input_ids.shape
1320
+ if decoder_position_ids is None:
1321
+ if past_key_values is not None:
1322
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
1323
+
1324
+ if decoder_attention_mask is not None:
1325
+ decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
1326
+ else:
1327
+ decoder_position_ids = jnp.broadcast_to(
1328
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1329
+ )
1330
+ if decoder_attention_mask is None:
1331
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4")
1332
+
1333
+ # Handle any PRNG if needed
1334
+ rngs = {}
1335
+ if dropout_rng is not None:
1336
+ rngs["dropout"] = dropout_rng
1337
+
1338
+ inputs = {"params": params or self.params}
1339
+
1340
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1341
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1342
+ # it can be changed by FlaxWhisperAttention module
1343
+ if past_key_values:
1344
+ inputs["cache"] = past_key_values
1345
+ mutable = ["cache"]
1346
+ else:
1347
+ mutable = False
1348
+
1349
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
1350
+ decoder_module = module._get_decoder_module()
1351
+ outputs = decoder_module(
1352
+ input_ids=decoder_input_ids,
1353
+ attention_mask=decoder_attention_mask,
1354
+ position_ids=decoder_position_ids,
1355
+ **kwargs,
1356
+ )
1357
+ hidden_states = outputs[0]
1358
+
1359
+ if self.config.tie_word_embeddings:
1360
+ shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"]
1361
+ lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
1362
+ else:
1363
+ lm_logits = module.lm_head(hidden_states)
1364
+
1365
+ return lm_logits, outputs
1366
+
1367
+ outputs = self.module.apply(
1368
+ inputs,
1369
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1370
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1371
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1372
+ encoder_hidden_states=encoder_hidden_states,
1373
+ output_attentions=output_attentions,
1374
+ output_hidden_states=output_hidden_states,
1375
+ return_dict=return_dict,
1376
+ deterministic=not train,
1377
+ rngs=rngs,
1378
+ mutable=mutable,
1379
+ method=_decoder_forward,
1380
+ )
1381
+
1382
+ if past_key_values is None:
1383
+ lm_logits, decoder_outputs = outputs
1384
+ else:
1385
+ (lm_logits, decoder_outputs), past = outputs
1386
+
1387
+ if return_dict:
1388
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
1389
+ logits=lm_logits,
1390
+ hidden_states=decoder_outputs.hidden_states,
1391
+ attentions=decoder_outputs.attentions,
1392
+ cross_attentions=decoder_outputs.cross_attentions,
1393
+ )
1394
+ else:
1395
+ outputs = (lm_logits,) + decoder_outputs[1:]
1396
+
1397
+ # add updated cache to model output
1398
+ if past_key_values is not None and return_dict:
1399
+ outputs["past_key_values"] = unfreeze(past["cache"])
1400
+ return outputs
1401
+ elif past_key_values is not None and not return_dict:
1402
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1403
+
1404
+ return outputs
1405
+
1406
+ def generate(
1407
+ self,
1408
+ input_features,
1409
+ generation_config=None,
1410
+ logits_processor=None,
1411
+ return_timestamps=None,
1412
+ task=None,
1413
+ language=None,
1414
+ is_multilingual=None,
1415
+ **kwargs,
1416
+ ):
1417
+ if generation_config is None:
1418
+ generation_config = self.generation_config
1419
+
1420
+ if return_timestamps is not None:
1421
+ generation_config.return_timestamps = return_timestamps
1422
+
1423
+ if task is not None:
1424
+ generation_config.task = task
1425
+
1426
+ if is_multilingual is not None:
1427
+ generation_config.is_multilingual = is_multilingual
1428
+
1429
+ if language is not None:
1430
+ generation_config.language = language
1431
+
1432
+ if kwargs is not None and "decoder_input_ids" in kwargs:
1433
+ decoder_input_length = len(kwargs["decoder_input_ids"])
1434
+ else:
1435
+ decoder_input_length = 1
1436
+
1437
+ forced_decoder_ids = []
1438
+
1439
+ if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
1440
+ if hasattr(generation_config, "language"):
1441
+ forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
1442
+ else:
1443
+ forced_decoder_ids.append((1, None))
1444
+
1445
+ if hasattr(generation_config, "task"):
1446
+ forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
1447
+ else:
1448
+ forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
1449
+
1450
+ if (
1451
+ hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
1452
+ ) or return_timestamps:
1453
+ logits_processor = [
1454
+ FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length)
1455
+ ]
1456
+ else:
1457
+ if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
1458
+ idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
1459
+ forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
1460
+
1461
+ if len(forced_decoder_ids) > 0:
1462
+ generation_config.forced_decoder_ids = forced_decoder_ids
1463
+
1464
+ return super().generate(
1465
+ input_features,
1466
+ generation_config,
1467
+ logits_processor=logits_processor,
1468
+ **kwargs,
1469
+ )
1470
+
1471
+ def prepare_inputs_for_generation(
1472
+ self,
1473
+ decoder_input_ids,
1474
+ max_length,
1475
+ attention_mask: Optional[jax.Array] = None,
1476
+ decoder_attention_mask: Optional[jax.Array] = None,
1477
+ encoder_outputs=None,
1478
+ **kwargs,
1479
+ ):
1480
+ # initializing the cache
1481
+ batch_size, seq_length = decoder_input_ids.shape
1482
+
1483
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
1484
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1485
+ # But since the decoder uses a causal mask, those positions are masked anyways.
1486
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
1487
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1488
+ if decoder_attention_mask is not None:
1489
+ position_ids = decoder_attention_mask.cumsum(-1) - 1
1490
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
1491
+ else:
1492
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1493
+
1494
+ return {
1495
+ "past_key_values": past_key_values,
1496
+ "encoder_outputs": encoder_outputs,
1497
+ "encoder_attention_mask": attention_mask,
1498
+ "decoder_attention_mask": extended_attention_mask,
1499
+ "decoder_position_ids": position_ids,
1500
+ }
1501
+
1502
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1503
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1504
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
1505
+ return model_kwargs
1506
+
1507
+
1508
+ FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r"""
1509
+ Returns:
1510
+
1511
+ Transcription example:
1512
+
1513
+ ```python
1514
+ >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
1515
+ >>> from datasets import load_dataset
1516
+
1517
+ >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1518
+ >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
1519
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1520
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
1521
+ >>> input_features = inputs.input_features
1522
+ >>> generated_ids = model.generate(input_ids=input_features)
1523
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1524
+ >>> transcription
1525
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
1526
+ ```
1527
+ """
1528
+
1529
+ overwrite_call_docstring(
1530
+ FlaxWhisperForConditionalGeneration, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING
1531
+ )
1532
+ append_replace_return_docstrings(
1533
+ FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
1534
+ )
1535
+
1536
+
1537
+ class FlaxWhisperForAudioClassificationModule(nn.Module):
1538
+ config: WhisperConfig
1539
+ dtype: jnp.dtype = jnp.float32
1540
+ gradient_checkpointing: bool = False
1541
+
1542
+ def setup(self) -> None:
1543
+ self.encoder = FlaxWhisperEncoder(
1544
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
1545
+ )
1546
+ self.config.is_encoder_decoder = False
1547
+ num_layers = self.config.num_hidden_layers + 1
1548
+ if self.config.use_weighted_layer_sum:
1549
+ self.layer_weights = jnp.repeat(1 / num_layers, num_layers)
1550
+ self.projector = nn.Dense(self.config.classifier_proj_size, dtype=self.dtype)
1551
+ self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
1552
+
1553
+ def __call__(
1554
+ self,
1555
+ input_features,
1556
+ encoder_outputs=None,
1557
+ output_attentions=None,
1558
+ output_hidden_states: bool = True,
1559
+ return_dict: bool = True,
1560
+ ):
1561
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1562
+ output_hidden_states = (
1563
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1564
+ )
1565
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1566
+
1567
+ if encoder_outputs is None:
1568
+ encoder_outputs = self.encoder(
1569
+ input_features,
1570
+ output_attentions=output_attentions,
1571
+ output_hidden_states=output_hidden_states,
1572
+ return_dict=return_dict,
1573
+ )
1574
+
1575
+ if self.config.use_weighted_layer_sum:
1576
+ hidden_states = jnp.stack(encoder_outputs, axis=1)
1577
+ norm_weights = jax.nn.softmax(self.layer_weights, axis=-1)
1578
+ hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1)
1579
+ else:
1580
+ hidden_states = encoder_outputs[0]
1581
+
1582
+ hidden_states = self.projector(hidden_states)
1583
+ pooled_output = jnp.mean(hidden_states, axis=1)
1584
+
1585
+ logits = self.classifier(pooled_output)
1586
+
1587
+ if not return_dict:
1588
+ return (logits,) + encoder_outputs[1:]
1589
+
1590
+ return FlaxSequenceClassifierOutput(
1591
+ logits=logits,
1592
+ hidden_states=encoder_outputs.hidden_states,
1593
+ attentions=encoder_outputs.attentions,
1594
+ )
1595
+
1596
+
1597
+ @add_start_docstrings("The Whisper Model with an audio classification head on top.", WHISPER_START_DOCSTRING)
1598
+ class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel):
1599
+ module_class = FlaxWhisperForAudioClassificationModule
1600
+ dtype: jnp.dtype = jnp.float32
1601
+
1602
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
1603
+ # init input tensors
1604
+ input_features = jnp.zeros(input_shape, dtype="f4")
1605
+ input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
1606
+
1607
+ params_rng, dropout_rng = jax.random.split(rng)
1608
+ rngs = {"params": params_rng, "dropout": dropout_rng}
1609
+
1610
+ random_params = self.module.init(
1611
+ rngs,
1612
+ input_features=input_features,
1613
+ )["params"]
1614
+
1615
+ if params is not None:
1616
+ random_params = flatten_dict(unfreeze(random_params))
1617
+ params = flatten_dict(unfreeze(params))
1618
+ for missing_key in self._missing_keys:
1619
+ params[missing_key] = random_params[missing_key]
1620
+ self._missing_keys = set()
1621
+ return freeze(unflatten_dict(params))
1622
+ else:
1623
+ return random_params
1624
+
1625
+ @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
1626
+ def __call__(
1627
+ self,
1628
+ input_features: jnp.ndarray,
1629
+ attention_mask: Optional[jnp.ndarray] = None,
1630
+ output_attentions: Optional[bool] = None,
1631
+ output_hidden_states: Optional[bool] = None,
1632
+ return_dict: Optional[bool] = None,
1633
+ train: bool = False,
1634
+ params: dict = None,
1635
+ dropout_rng: PRNGKey = None,
1636
+ **kwargs,
1637
+ ):
1638
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1639
+ output_hidden_states = (
1640
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1641
+ )
1642
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1643
+
1644
+ # Handle any PRNG if needed
1645
+ rngs = {}
1646
+ if dropout_rng is not None:
1647
+ rngs["dropout"] = dropout_rng
1648
+
1649
+ return self.module.apply(
1650
+ {"params": params or self.params},
1651
+ input_features=jnp.array(input_features, dtype="f4"),
1652
+ output_attentions=output_attentions,
1653
+ output_hidden_states=output_hidden_states,
1654
+ return_dict=return_dict,
1655
+ rngs=rngs,
1656
+ )
1657
+
1658
+
1659
+ FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING = r"""
1660
+ Returns:
1661
+
1662
+ Transcription example:
1663
+
1664
+ ```python
1665
+ >>> import jax.numpy as jnp
1666
+ >>> from transformers import AutoFeatureExtractor, FlaxWhisperForAudioClassification
1667
+ >>> from datasets import load_dataset
1668
+
1669
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
1670
+ >>> model = FlaxWhisperForAudioClassification.from_pretrained(
1671
+ ... "sanchit-gandhi/whisper-medium-fleurs-lang-id", from_pt=True
1672
+ ... )
1673
+ >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True, trust_remote_code=True)
1674
+
1675
+ >>> sample = next(iter(ds))
1676
+
1677
+ >>> inputs = feature_extractor(
1678
+ ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="np"
1679
+ ... )
1680
+ >>> input_features = inputs.input_features
1681
+
1682
+ >>> logits = model(input_features).logits
1683
+
1684
+ >>> predicted_class_ids = jnp.argmax(logits).item()
1685
+ >>> predicted_label = model.config.id2label[predicted_class_ids]
1686
+ >>> predicted_label
1687
+ 'af_za'
1688
+ ```
1689
+ """
1690
+
1691
+ overwrite_call_docstring(
1692
+ FlaxWhisperForAudioClassification, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING
1693
+ )
1694
+ append_replace_return_docstrings(
1695
+ FlaxWhisperForAudioClassification, output_type=FlaxSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC
1696
+ )
modeling_tf_whisper (1).py ADDED
@@ -0,0 +1,1758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """TensorFlow Whisper model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ import random
21
+ from typing import Dict, List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import tensorflow as tf
25
+
26
+ from ...activations_tf import get_tf_activation
27
+ from ...generation.configuration_utils import GenerationConfig
28
+ from ...generation.tf_logits_process import TFLogitsProcessorList
29
+ from ...modeling_tf_outputs import (
30
+ TFBaseModelOutput,
31
+ TFBaseModelOutputWithPastAndCrossAttentions,
32
+ TFSeq2SeqLMOutput,
33
+ TFSeq2SeqModelOutput,
34
+ )
35
+ from ...modeling_tf_utils import (
36
+ TFCausalLanguageModelingLoss,
37
+ TFModelInputType,
38
+ TFPreTrainedModel,
39
+ keras,
40
+ keras_serializable,
41
+ unpack_inputs,
42
+ )
43
+ from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
44
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
45
+ from .configuration_whisper import WhisperConfig
46
+ from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CONFIG_FOR_DOC = "WhisperConfig"
52
+
53
+
54
+ LARGE_NEGATIVE = -1e8
55
+
56
+
57
+ def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor:
58
+ """Returns sinusoids for positional embedding"""
59
+ length, channels = shape
60
+ if channels % 2 != 0:
61
+ raise ValueError(
62
+ f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
63
+ )
64
+ log_timescale_increment = math.log(10000) / (channels // 2 - 1)
65
+ inv_timescales = tf.exp(-log_timescale_increment * tf.range(channels // 2, dtype=tf.float32))
66
+ scaled_time = tf.reshape(tf.range(length, dtype=tf.float32), (-1, 1)) * tf.reshape(inv_timescales, (1, -1))
67
+ return tf.cast(tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1), dtype)
68
+
69
+
70
+ # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
71
+ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
72
+ pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
73
+ decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
74
+ start_tokens = tf.fill(
75
+ (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
76
+ )
77
+ shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
78
+ # replace possible -100 values in labels by `pad_token_id`
79
+ shifted_input_ids = tf.where(
80
+ shifted_input_ids == -100,
81
+ tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
82
+ shifted_input_ids,
83
+ )
84
+
85
+ # "Verify that `labels` has only positive values and -100"
86
+ assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
87
+
88
+ # Make sure the assertion op is called by wrapping the result in an identity no-op
89
+ with tf.control_dependencies([assert_gte0]):
90
+ shifted_input_ids = tf.identity(shifted_input_ids)
91
+
92
+ return shifted_input_ids
93
+
94
+
95
+ # Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
96
+ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
97
+ """
98
+ Make causal mask used for bi-directional self-attention.
99
+ """
100
+ bsz = input_ids_shape[0]
101
+ tgt_len = input_ids_shape[1]
102
+ mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
103
+ mask_cond = tf.range(shape_list(mask)[-1])
104
+
105
+ mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
106
+
107
+ if past_key_values_length > 0:
108
+ mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
109
+
110
+ return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
111
+
112
+
113
+ # Copied from transformers.models.bart.modeling_tf_bart._expand_mask
114
+ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
115
+ """
116
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
117
+ """
118
+ src_len = shape_list(mask)[1]
119
+ tgt_len = tgt_len if tgt_len is not None else src_len
120
+ one_cst = tf.constant(1.0)
121
+ mask = tf.cast(mask, dtype=one_cst.dtype)
122
+ expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
123
+
124
+ return (one_cst - expanded_mask) * LARGE_NEGATIVE
125
+
126
+
127
+ class TFWhisperPositionalEmbedding(keras.layers.Layer):
128
+ def __init__(
129
+ self,
130
+ num_positions: int,
131
+ embedding_dim: int,
132
+ padding_idx: Optional[int] = None,
133
+ embedding_initializer=None,
134
+ **kwargs,
135
+ ):
136
+ super().__init__(**kwargs)
137
+ self.num_positions = num_positions
138
+ self.embedding_dim = embedding_dim
139
+ self.padding_idx = padding_idx
140
+ self.embedding_initializer = keras.initializers.get(embedding_initializer)
141
+
142
+ def build(self, input_shape):
143
+ self.weight = self.add_weight(
144
+ name="weight",
145
+ shape=[self.num_positions, self.embedding_dim],
146
+ initializer=self.embedding_initializer,
147
+ trainable=True,
148
+ )
149
+ super().build(input_shape)
150
+
151
+ def call(self, input_ids, past_key_values_length=0):
152
+ past_key_values_length = tf.cast(past_key_values_length, tf.int32)
153
+ gather_indices = tf.range(tf.shape(input_ids)[1], delta=1) + past_key_values_length
154
+ return tf.gather(self.weight, gather_indices)
155
+
156
+
157
+ class TFWhisperAttention(keras.layers.Layer):
158
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
159
+
160
+ def __init__(
161
+ self,
162
+ embed_dim: int,
163
+ num_heads: int,
164
+ dropout: float = 0.0,
165
+ is_decoder: bool = False,
166
+ bias: bool = True,
167
+ **kwargs,
168
+ ):
169
+ super().__init__(**kwargs)
170
+ self.embed_dim = embed_dim
171
+ self.num_heads = num_heads
172
+ self.dropout = keras.layers.Dropout(dropout)
173
+ self.head_dim = embed_dim // num_heads
174
+
175
+ if (self.head_dim * num_heads) != self.embed_dim:
176
+ raise ValueError(
177
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
178
+ f" and `num_heads`: {num_heads})."
179
+ )
180
+ self.scaling = self.head_dim**-0.5
181
+ self.is_decoder = is_decoder
182
+
183
+ self.k_proj = keras.layers.Dense(embed_dim, use_bias=False, name="k_proj")
184
+ self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
185
+ self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
186
+ self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
187
+
188
+ # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention._shape with BART->whisper
189
+ def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
190
+ return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
191
+
192
+ # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention.call with BART->whisper
193
+ def call(
194
+ self,
195
+ hidden_states: tf.Tensor,
196
+ key_value_states: tf.Tensor | None = None,
197
+ past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
198
+ attention_mask: tf.Tensor | None = None,
199
+ layer_head_mask: tf.Tensor | None = None,
200
+ training: Optional[bool] = False,
201
+ ) -> Tuple[tf.Tensor, tf.Tensor | None]:
202
+ """Input shape: Batch x Time x Channel"""
203
+
204
+ # if key_value_states are provided this layer is used as a cross-attention layer
205
+ # for the decoder
206
+ is_cross_attention = key_value_states is not None
207
+ bsz, tgt_len, embed_dim = shape_list(hidden_states)
208
+
209
+ # get query proj
210
+ query_states = self.q_proj(hidden_states) * self.scaling
211
+ # get key, value proj
212
+ if is_cross_attention and past_key_value is not None:
213
+ # reuse k,v, cross_attentions
214
+ key_states = past_key_value[0]
215
+ value_states = past_key_value[1]
216
+ elif is_cross_attention:
217
+ # cross_attentions
218
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
219
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
220
+ elif past_key_value is not None:
221
+ # reuse k, v, self_attention
222
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
223
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
224
+ key_states = tf.concat([past_key_value[0], key_states], axis=2)
225
+ value_states = tf.concat([past_key_value[1], value_states], axis=2)
226
+ else:
227
+ # self_attention
228
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
229
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
230
+
231
+ if self.is_decoder:
232
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
233
+ # Further calls to cross_attention layer can then reuse all cross-attention
234
+ # key/value_states (first "if" case)
235
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
236
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
237
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
238
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
239
+ past_key_value = (key_states, value_states)
240
+
241
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
242
+ query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
243
+ key_states = tf.reshape(key_states, proj_shape)
244
+ value_states = tf.reshape(value_states, proj_shape)
245
+
246
+ src_len = shape_list(key_states)[1]
247
+ attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
248
+
249
+ tf.debugging.assert_equal(
250
+ shape_list(attn_weights),
251
+ [bsz * self.num_heads, tgt_len, src_len],
252
+ message=(
253
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
254
+ f" {shape_list(attn_weights)}"
255
+ ),
256
+ )
257
+
258
+ if attention_mask is not None:
259
+ tf.debugging.assert_equal(
260
+ shape_list(attention_mask),
261
+ [bsz, 1, tgt_len, src_len],
262
+ message=(
263
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
264
+ f" {shape_list(attention_mask)}"
265
+ ),
266
+ )
267
+
268
+ attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
269
+ attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
270
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
271
+
272
+ attn_weights = stable_softmax(attn_weights, axis=-1)
273
+
274
+ if layer_head_mask is not None:
275
+ tf.debugging.assert_equal(
276
+ shape_list(layer_head_mask),
277
+ [self.num_heads],
278
+ message=(
279
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
280
+ f" {shape_list(layer_head_mask)}"
281
+ ),
282
+ )
283
+
284
+ attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
285
+ attn_weights, (bsz, self.num_heads, tgt_len, src_len)
286
+ )
287
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
288
+
289
+ attn_probs = self.dropout(attn_weights, training=training)
290
+ attn_output = tf.matmul(attn_probs, value_states)
291
+
292
+ tf.debugging.assert_equal(
293
+ shape_list(attn_output),
294
+ [bsz * self.num_heads, tgt_len, self.head_dim],
295
+ message=(
296
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
297
+ f" {shape_list(attn_output)}"
298
+ ),
299
+ )
300
+
301
+ attn_output = tf.transpose(
302
+ tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
303
+ )
304
+ attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
305
+
306
+ attn_output = self.out_proj(attn_output)
307
+ attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
308
+
309
+ return attn_output, attn_weights, past_key_value
310
+
311
+ def build(self, input_shape=None):
312
+ if self.built:
313
+ return
314
+ self.built = True
315
+ if getattr(self, "k_proj", None) is not None:
316
+ with tf.name_scope(self.k_proj.name):
317
+ self.k_proj.build([None, None, self.embed_dim])
318
+ if getattr(self, "v_proj", None) is not None:
319
+ with tf.name_scope(self.v_proj.name):
320
+ self.v_proj.build([None, None, self.embed_dim])
321
+ if getattr(self, "q_proj", None) is not None:
322
+ with tf.name_scope(self.q_proj.name):
323
+ self.q_proj.build([None, None, self.embed_dim])
324
+ if getattr(self, "out_proj", None) is not None:
325
+ with tf.name_scope(self.out_proj.name):
326
+ self.out_proj.build([None, None, self.embed_dim])
327
+
328
+
329
+ # Copied from transformers.models.speech_to_text.modeling_tf_speech_to_text.TFSpeech2TextEncoderLayer with Speech2Text->Whisper
330
+ class TFWhisperEncoderLayer(keras.layers.Layer):
331
+ def __init__(self, config: WhisperConfig, **kwargs):
332
+ super().__init__(**kwargs)
333
+ self.embed_dim = config.d_model
334
+ self.self_attn = TFWhisperAttention(
335
+ self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
336
+ )
337
+ self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
338
+ self.dropout = keras.layers.Dropout(config.dropout)
339
+ self.activation_fn = get_tf_activation(config.activation_function)
340
+ self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
341
+ self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
342
+ self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
343
+ self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
344
+ self.config = config
345
+
346
+ def call(
347
+ self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training: bool = False
348
+ ):
349
+ """
350
+ Args:
351
+ hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
352
+ attention_mask (`tf.Tensor`): attention mask of size
353
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
354
+ layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
355
+ `(encoder_attention_heads,)`
356
+ """
357
+ residual = hidden_states
358
+ hidden_states = self.self_attn_layer_norm(hidden_states)
359
+ hidden_states, self_attn_weights, _ = self.self_attn(
360
+ hidden_states=hidden_states,
361
+ attention_mask=attention_mask,
362
+ layer_head_mask=layer_head_mask,
363
+ training=training,
364
+ )
365
+
366
+ tf.debugging.assert_equal(
367
+ shape_list(hidden_states),
368
+ shape_list(residual),
369
+ message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
370
+ )
371
+
372
+ hidden_states = self.dropout(hidden_states, training=training)
373
+ hidden_states = residual + hidden_states
374
+
375
+ residual = hidden_states
376
+ hidden_states = self.final_layer_norm(hidden_states)
377
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
378
+ hidden_states = self.activation_dropout(hidden_states, training=training)
379
+ hidden_states = self.fc2(hidden_states)
380
+ hidden_states = self.dropout(hidden_states, training=training)
381
+ hidden_states = residual + hidden_states
382
+
383
+ return hidden_states, self_attn_weights
384
+
385
+ def build(self, input_shape=None):
386
+ if self.built:
387
+ return
388
+ self.built = True
389
+ if getattr(self, "self_attn", None) is not None:
390
+ with tf.name_scope(self.self_attn.name):
391
+ self.self_attn.build(None)
392
+ if getattr(self, "self_attn_layer_norm", None) is not None:
393
+ with tf.name_scope(self.self_attn_layer_norm.name):
394
+ self.self_attn_layer_norm.build([None, None, self.embed_dim])
395
+ if getattr(self, "fc1", None) is not None:
396
+ with tf.name_scope(self.fc1.name):
397
+ self.fc1.build([None, None, self.embed_dim])
398
+ if getattr(self, "fc2", None) is not None:
399
+ with tf.name_scope(self.fc2.name):
400
+ self.fc2.build([None, None, self.config.encoder_ffn_dim])
401
+ if getattr(self, "final_layer_norm", None) is not None:
402
+ with tf.name_scope(self.final_layer_norm.name):
403
+ self.final_layer_norm.build([None, None, self.embed_dim])
404
+
405
+
406
+ # Copied from transformers.models.speech_to_text.modeling_tf_speech_to_text.TFSpeech2TextDecoderLayer with Speech2Text->Whisper
407
+ class TFWhisperDecoderLayer(keras.layers.Layer):
408
+ def __init__(self, config: WhisperConfig, **kwargs):
409
+ super().__init__(**kwargs)
410
+ self.embed_dim = config.d_model
411
+
412
+ self.self_attn = TFWhisperAttention(
413
+ embed_dim=self.embed_dim,
414
+ num_heads=config.decoder_attention_heads,
415
+ dropout=config.attention_dropout,
416
+ name="self_attn",
417
+ is_decoder=True,
418
+ )
419
+ self.dropout = keras.layers.Dropout(config.dropout)
420
+ self.activation_fn = get_tf_activation(config.activation_function)
421
+ self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
422
+
423
+ self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
424
+ self.encoder_attn = TFWhisperAttention(
425
+ self.embed_dim,
426
+ config.decoder_attention_heads,
427
+ dropout=config.attention_dropout,
428
+ name="encoder_attn",
429
+ is_decoder=True,
430
+ )
431
+ self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
432
+ self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
433
+ self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
434
+ self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
435
+ self.config = config
436
+
437
+ def call(
438
+ self,
439
+ hidden_states,
440
+ attention_mask: tf.Tensor | None = None,
441
+ encoder_hidden_states: tf.Tensor | None = None,
442
+ encoder_attention_mask: tf.Tensor | None = None,
443
+ layer_head_mask: tf.Tensor | None = None,
444
+ cross_attn_layer_head_mask: tf.Tensor | None = None,
445
+ past_key_value: Tuple[tf.Tensor] | None = None,
446
+ training=False,
447
+ ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
448
+ """
449
+ Args:
450
+ hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
451
+ attention_mask (`tf.Tensor`): attention mask of size
452
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
453
+ encoder_hidden_states (`tf.Tensor`):
454
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
455
+ encoder_attention_mask (`tf.Tensor`): encoder attention mask of size
456
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
457
+ layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
458
+ `(decoder_attention_heads,)`
459
+ cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module.
460
+ `(decoder_attention_heads,)`
461
+ past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states
462
+ """
463
+ residual = hidden_states
464
+ hidden_states = self.self_attn_layer_norm(hidden_states)
465
+
466
+ # Self Attention
467
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
468
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
469
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
470
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
471
+ hidden_states=hidden_states,
472
+ past_key_value=self_attn_past_key_value,
473
+ attention_mask=attention_mask,
474
+ layer_head_mask=layer_head_mask,
475
+ training=training,
476
+ )
477
+ hidden_states = self.dropout(hidden_states, training=training)
478
+ hidden_states = residual + hidden_states
479
+
480
+ # Cross-Attention Block
481
+ cross_attn_present_key_value = None
482
+ cross_attn_weights = None
483
+ if encoder_hidden_states is not None:
484
+ residual = hidden_states
485
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
486
+
487
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
488
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
489
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
490
+ hidden_states=hidden_states,
491
+ key_value_states=encoder_hidden_states,
492
+ attention_mask=encoder_attention_mask,
493
+ layer_head_mask=cross_attn_layer_head_mask,
494
+ past_key_value=cross_attn_past_key_value,
495
+ training=training,
496
+ )
497
+ hidden_states = self.dropout(hidden_states, training=training)
498
+ hidden_states = residual + hidden_states
499
+
500
+ # add cross-attn to positions 3,4 of present_key_value tuple
501
+ present_key_value = present_key_value + cross_attn_present_key_value
502
+
503
+ # Fully Connected
504
+ residual = hidden_states
505
+ hidden_states = self.final_layer_norm(hidden_states)
506
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
507
+ hidden_states = self.activation_dropout(hidden_states, training=training)
508
+ hidden_states = self.fc2(hidden_states)
509
+ hidden_states = self.dropout(hidden_states, training=training)
510
+ hidden_states = residual + hidden_states
511
+
512
+ return (
513
+ hidden_states,
514
+ self_attn_weights,
515
+ cross_attn_weights,
516
+ present_key_value,
517
+ )
518
+
519
+ def build(self, input_shape=None):
520
+ if self.built:
521
+ return
522
+ self.built = True
523
+ if getattr(self, "self_attn", None) is not None:
524
+ with tf.name_scope(self.self_attn.name):
525
+ self.self_attn.build(None)
526
+ if getattr(self, "self_attn_layer_norm", None) is not None:
527
+ with tf.name_scope(self.self_attn_layer_norm.name):
528
+ self.self_attn_layer_norm.build([None, None, self.embed_dim])
529
+ if getattr(self, "encoder_attn", None) is not None:
530
+ with tf.name_scope(self.encoder_attn.name):
531
+ self.encoder_attn.build(None)
532
+ if getattr(self, "encoder_attn_layer_norm", None) is not None:
533
+ with tf.name_scope(self.encoder_attn_layer_norm.name):
534
+ self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
535
+ if getattr(self, "fc1", None) is not None:
536
+ with tf.name_scope(self.fc1.name):
537
+ self.fc1.build([None, None, self.embed_dim])
538
+ if getattr(self, "fc2", None) is not None:
539
+ with tf.name_scope(self.fc2.name):
540
+ self.fc2.build([None, None, self.config.decoder_ffn_dim])
541
+ if getattr(self, "final_layer_norm", None) is not None:
542
+ with tf.name_scope(self.final_layer_norm.name):
543
+ self.final_layer_norm.build([None, None, self.embed_dim])
544
+
545
+
546
+ class TFWhisperPreTrainedModel(TFPreTrainedModel):
547
+ config_class = WhisperConfig
548
+ base_model_prefix = "model"
549
+ main_input_name = "input_features"
550
+
551
+ def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor) -> int:
552
+ """
553
+ Computes the output length of the convolutional layers
554
+ """
555
+ input_lengths = (input_lengths - 1) // 2 + 1
556
+
557
+ return input_lengths
558
+
559
+ @property
560
+ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
561
+ """
562
+ Dummy inputs to build the network.
563
+
564
+ Returns:
565
+ `Dict[str, tf.Tensor]`: The dummy inputs.
566
+ """
567
+ return {
568
+ self.main_input_name: tf.random.uniform(
569
+ [1, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32
570
+ ),
571
+ "decoder_input_ids": tf.constant([[1, 3]], dtype=tf.int32),
572
+ }
573
+
574
+ @property
575
+ def input_signature(self):
576
+ return {
577
+ "input_features": tf.TensorSpec((None, self.config.num_mel_bins, None), tf.float32, name="input_features"),
578
+ "decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
579
+ "decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
580
+ }
581
+
582
+
583
+ WHISPER_START_DOCSTRING = r"""
584
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
585
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
586
+ etc.)
587
+
588
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
589
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
590
+ behavior.
591
+
592
+ Parameters:
593
+ config ([`WhisperConfig`]):
594
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
595
+ load the weights associated with the model, only the configuration. Check out the
596
+ [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
597
+ """
598
+
599
+ WHISPER_INPUTS_DOCSTRING = r"""
600
+ Args:
601
+ input_features (`tf.Tensor` of shape `(batch_size, feature_size, sequence_length)`):
602
+ Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
603
+ by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.*
604
+ via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
605
+ [`AutoFeatureExtractor`] should be used for extracting the fbank features, padding and conversion into a
606
+ tensor of type `tf.Tensor`. See [`~WhisperFeatureExtractor.__call__`]
607
+ decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
608
+ Indices of decoder input sequence tokens in the vocabulary.
609
+
610
+ Indices can be obtained using [`SpeechToTextTokenizer`]. See [`PreTrainedTokenizer.encode`] and
611
+ [`PreTrainedTokenizer.__call__`] for details.
612
+
613
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
614
+
615
+ SpeechToText uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
616
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
617
+ `past_key_values`).
618
+ decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
619
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
620
+ be used by default.
621
+
622
+ If you want to change padding behavior, you should read
623
+ [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the
624
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
625
+ head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
626
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
627
+
628
+ - 1 indicates the head is **not masked**,
629
+ - 0 indicates the head is **masked**.
630
+
631
+ decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
632
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
633
+
634
+ - 1 indicates the head is **not masked**,
635
+ - 0 indicates the head is **masked**.
636
+
637
+ cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
638
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
639
+
640
+ - 1 indicates the head is **not masked**,
641
+ - 0 indicates the head is **masked**.
642
+
643
+ encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*):
644
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
645
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
646
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
647
+ past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
648
+ Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
649
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
650
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
651
+
652
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
653
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
654
+
655
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
656
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
657
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
658
+ decoder_inputs_embeds (`tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
659
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
660
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
661
+ input (see `past_key_values`). This is useful if you want more control over how to convert
662
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
663
+ use_cache (`bool`, *optional*):
664
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
665
+ `past_key_values`).
666
+ output_attentions (`bool`, *optional*):
667
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
668
+ tensors for more detail.
669
+ output_hidden_states (`bool`, *optional*):
670
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
671
+ more detail.
672
+ return_dict (`bool`, *optional*):
673
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
674
+ """
675
+
676
+
677
+ @keras_serializable
678
+ class TFWhisperEncoder(keras.layers.Layer):
679
+ config_class = WhisperConfig
680
+ """
681
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
682
+ [`TFWhisperEncoderLayer`].
683
+
684
+ Args:
685
+ config: WhisperConfig
686
+ embed_tokens (TFWhisperEmbedding): output embedding
687
+ """
688
+
689
+ def __init__(self, config: WhisperConfig, **kwargs):
690
+ super().__init__(**kwargs)
691
+ self.config = config
692
+ self.layerdrop = config.encoder_layerdrop
693
+
694
+ self.embed_dim = config.d_model
695
+ self.num_mel_bins = config.num_mel_bins
696
+ self.padding_idx = config.pad_token_id
697
+ self.max_source_positions = config.max_source_positions
698
+ self.embed_scale = math.sqrt(self.embed_dim) if config.scale_embedding else 1.0
699
+
700
+ # Padding is added in call() to match the PyTorch implementation
701
+ self.conv1 = keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=1, padding="valid", name="conv1")
702
+ self.conv2 = keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2")
703
+
704
+ self.embed_positions = TFWhisperPositionalEmbedding(
705
+ num_positions=self.max_source_positions,
706
+ embedding_dim=self.embed_dim,
707
+ embedding_initializer=sinusoidal_embedding_init,
708
+ name="embed_positions",
709
+ )
710
+ self.embed_positions.trainable = False
711
+
712
+ self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
713
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
714
+
715
+ self.dropout = keras.layers.Dropout(config.dropout)
716
+
717
+ @unpack_inputs
718
+ def call(
719
+ self,
720
+ input_features=None,
721
+ head_mask=None,
722
+ output_attentions=None,
723
+ output_hidden_states=None,
724
+ return_dict=None,
725
+ training=False,
726
+ ):
727
+ r"""
728
+ Args:
729
+ input_features (`tf.Tensor` of shape `(batch_size, feature_size, sequence_length)`):
730
+ Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be
731
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
732
+ `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
733
+ `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the fbank features,
734
+ padding and conversion into a tensor of type `tf.Tensor`. See [`~WhisperFeatureExtractor.__call__`]
735
+ head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
736
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
737
+
738
+ - 1 indicates the head is **not masked**,
739
+ - 0 indicates the head is **masked**.
740
+
741
+ output_attentions (`bool`, *optional*):
742
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
743
+ returned tensors for more detail.
744
+ output_hidden_states (`bool`, *optional*):
745
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
746
+ for more detail.
747
+ return_dict (`bool`, *optional*):
748
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
749
+ """
750
+
751
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
752
+ output_hidden_states = (
753
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
754
+ )
755
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
756
+
757
+ # TF 2.0 layers can't use channels first format when running on CPU.
758
+ input_features = tf.transpose(input_features, perm=(0, 2, 1))
759
+ input_features = tf.pad(input_features, [[0, 0], [1, 1], [0, 0]])
760
+ inputs_embeds = keras.activations.gelu(self.conv1(input_features))
761
+ inputs_embeds = tf.pad(inputs_embeds, [[0, 0], [1, 1], [0, 0]])
762
+ inputs_embeds = keras.activations.gelu(self.conv2(inputs_embeds))
763
+ inputs_embeds = tf.transpose(inputs_embeds, perm=(0, 1, 2))
764
+
765
+ embed_pos = self.embed_positions(input_ids=tf.zeros((1, self.max_source_positions), dtype=tf.int32))
766
+
767
+ hidden_states = inputs_embeds + embed_pos
768
+ hidden_states = self.dropout(hidden_states, training=training)
769
+
770
+ encoder_states = () if output_hidden_states else None
771
+ all_attentions = () if output_attentions else None
772
+
773
+ # check if head_mask has a correct number of layers specified if desired
774
+ if head_mask is not None:
775
+ tf.debugging.assert_equal(
776
+ shape_list(head_mask)[0],
777
+ len(self.encoder_layers),
778
+ message=(
779
+ f"The head_mask should be specified for {len(self.encoder_layers)} layers, but it is for"
780
+ f" {shape_list(head_mask)[0]}."
781
+ ),
782
+ )
783
+
784
+ for idx, encoder_layer in enumerate(self.encoder_layers):
785
+ if output_hidden_states:
786
+ encoder_states = encoder_states + (hidden_states,)
787
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
788
+ dropout_probability = random.uniform(0, 1)
789
+ if training and (dropout_probability < self.layerdrop): # skip the layer
790
+ continue
791
+
792
+ hidden_states, attn = encoder_layer(
793
+ hidden_states,
794
+ None,
795
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
796
+ training=training,
797
+ )
798
+
799
+ if output_attentions:
800
+ all_attentions += (attn,)
801
+
802
+ hidden_states = self.layer_norm(hidden_states)
803
+ if output_hidden_states:
804
+ encoder_states = encoder_states + (hidden_states,)
805
+
806
+ if not return_dict:
807
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
808
+ return TFBaseModelOutput(
809
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
810
+ )
811
+
812
+ def build(self, input_shape=None):
813
+ if self.built:
814
+ return
815
+ self.built = True
816
+ if getattr(self, "conv1", None) is not None:
817
+ with tf.name_scope(self.conv1.name):
818
+ self.conv1.build([None, None, self.num_mel_bins])
819
+ if getattr(self, "conv2", None) is not None:
820
+ with tf.name_scope(self.conv2.name):
821
+ self.conv2.build([None, None, self.embed_dim])
822
+ if getattr(self, "embed_positions", None) is not None:
823
+ with tf.name_scope(self.embed_positions.name):
824
+ self.embed_positions.build(None)
825
+ if getattr(self, "layer_norm", None) is not None:
826
+ with tf.name_scope(self.layer_norm.name):
827
+ self.layer_norm.build([None, None, self.config.d_model])
828
+ if getattr(self, "encoder_layers", None) is not None:
829
+ for layer in self.encoder_layers:
830
+ with tf.name_scope(layer.name):
831
+ layer.build(None)
832
+
833
+
834
+ @keras_serializable
835
+ class TFWhisperDecoder(keras.layers.Layer):
836
+ config_class = WhisperConfig
837
+ """
838
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFWhisperDecoderLayer`]
839
+
840
+ Args:
841
+ config: WhisperConfig
842
+ """
843
+
844
+ def __init__(self, config: WhisperConfig, **kwargs):
845
+ super().__init__(**kwargs)
846
+ self.config = config
847
+ self.dropout = keras.layers.Dropout(config.dropout)
848
+ self.layerdrop = config.decoder_layerdrop
849
+ self.padding_idx = config.pad_token_id
850
+ self.max_target_positions = config.max_target_positions
851
+ self.max_source_positions = config.max_source_positions
852
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
853
+
854
+ self.embed_tokens = keras.layers.Embedding(
855
+ input_dim=config.vocab_size,
856
+ output_dim=config.d_model,
857
+ embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),
858
+ name="embed_tokens",
859
+ )
860
+ self.embed_positions = TFWhisperPositionalEmbedding(
861
+ self.max_target_positions, config.d_model, name="embed_positions"
862
+ )
863
+
864
+ self.decoder_layers = [TFWhisperDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
865
+
866
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
867
+
868
+ def get_input_embeddings(self):
869
+ return self.embed_tokens
870
+
871
+ def set_input_embeddings(self, value):
872
+ self.embed_tokens = value
873
+
874
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
875
+ # create causal mask
876
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
877
+ batch_size, seq_len = input_shape[0], input_shape[1]
878
+
879
+ combined_attention_mask = tf.cond(
880
+ tf.math.greater(seq_len, 1),
881
+ lambda: _make_causal_mask(input_shape, past_key_values_length=past_key_values_length),
882
+ lambda: _expand_mask(tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len),
883
+ )
884
+
885
+ if attention_mask is not None:
886
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
887
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
888
+ combined_attention_mask = (
889
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
890
+ )
891
+ return combined_attention_mask
892
+
893
+ @unpack_inputs
894
+ def call(
895
+ self,
896
+ input_ids=None,
897
+ attention_mask=None,
898
+ position_ids=None,
899
+ encoder_hidden_states=None,
900
+ head_mask=None,
901
+ cross_attn_head_mask=None,
902
+ past_key_values=None,
903
+ inputs_embeds=None,
904
+ use_cache=None,
905
+ output_attentions=None,
906
+ output_hidden_states=None,
907
+ return_dict=None,
908
+ training=False,
909
+ ):
910
+ r"""
911
+ Args:
912
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
913
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
914
+ provide it.
915
+
916
+ Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
917
+ [`PreTrainedTokenizer.__call__`] for details.
918
+
919
+ [What are input IDs?](../glossary#input-ids)
920
+ attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
921
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
922
+
923
+ - 1 for tokens that are **not masked**,
924
+ - 0 for tokens that are **masked**.
925
+
926
+ [What are attention masks?](../glossary#attention-mask)
927
+ position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
928
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
929
+ range `[0, config.max_position_embeddings - 1]`.
930
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
931
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
932
+ of the decoder.
933
+ head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
934
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
935
+
936
+ - 1 indicates the head is **not masked**,
937
+ - 0 indicates the head is **masked**.
938
+
939
+ cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
940
+ Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
941
+ on hidden heads. Mask values selected in `[0, 1]`:
942
+
943
+ - 1 indicates the head is **not masked**,
944
+ - 0 indicates the head is **masked**.
945
+
946
+ past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
947
+ Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
948
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
949
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
950
+
951
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
952
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
953
+
954
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
955
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
956
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
957
+ inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
958
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
959
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
960
+ than the model's internal embedding lookup matrix.
961
+ output_attentions (`bool`, *optional*):
962
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
963
+ returned tensors for more detail.
964
+ output_hidden_states (`bool`, *optional*):
965
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
966
+ for more detail.
967
+ return_dict (`bool`, *optional*):
968
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
969
+ """
970
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
971
+ output_hidden_states = (
972
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
973
+ )
974
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
975
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
976
+
977
+ # retrieve input_ids and inputs_embeds
978
+ if input_ids is not None and inputs_embeds is not None:
979
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
980
+ elif input_ids is not None:
981
+ input_shape = tf.shape(input_ids)
982
+ input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
983
+ elif inputs_embeds is not None:
984
+ input_shape = tf.shape(inputs_embeds)[:-1]
985
+ else:
986
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
987
+
988
+ # past_key_values_length
989
+ past_key_values_length = tf.shape(past_key_values[0][0])[2] if past_key_values is not None else 0
990
+
991
+ if inputs_embeds is None:
992
+ check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
993
+ inputs_embeds = self.embed_tokens(input_ids)
994
+
995
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
996
+
997
+ # embed positions
998
+ filled_past_positions = past_key_values_length if position_ids is None else position_ids[0, -1]
999
+ positions = self.embed_positions(input_ids, past_key_values_length=filled_past_positions)
1000
+
1001
+ hidden_states = inputs_embeds + positions
1002
+ hidden_states = self.dropout(hidden_states, training=training)
1003
+
1004
+ # decoder layers
1005
+ all_hidden_states = () if output_hidden_states else None
1006
+ all_self_attns = () if output_attentions else None
1007
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1008
+ next_decoder_cache = () if use_cache else None
1009
+
1010
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1011
+ for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
1012
+ if attn_mask is not None:
1013
+ tf.debugging.assert_equal(
1014
+ shape_list(attn_mask)[0],
1015
+ len(self.decoder_layers),
1016
+ message=(
1017
+ f"The {attn_mask_name} should be specified for {len(self.decoder_layers)} layers, but it is"
1018
+ f" for {shape_list(attn_mask)[0]}."
1019
+ ),
1020
+ )
1021
+
1022
+ for idx, decoder_layer in enumerate(self.decoder_layers):
1023
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1024
+ if output_hidden_states:
1025
+ all_hidden_states += (hidden_states,)
1026
+ dropout_probability = random.uniform(0, 1)
1027
+ if training and (dropout_probability < self.layerdrop):
1028
+ continue
1029
+
1030
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1031
+
1032
+ layer_outputs = decoder_layer(
1033
+ hidden_states,
1034
+ attention_mask=attention_mask,
1035
+ encoder_hidden_states=encoder_hidden_states,
1036
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1037
+ cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
1038
+ past_key_value=past_key_value,
1039
+ training=training,
1040
+ )
1041
+ hidden_states = layer_outputs[0]
1042
+
1043
+ if use_cache:
1044
+ next_decoder_cache += (layer_outputs[3],)
1045
+
1046
+ if output_attentions:
1047
+ all_self_attns += (layer_outputs[1],)
1048
+
1049
+ if encoder_hidden_states is not None:
1050
+ all_cross_attentions += (layer_outputs[2],)
1051
+
1052
+ hidden_states = self.layer_norm(hidden_states)
1053
+ # add hidden states from the last decoder layer
1054
+ if output_hidden_states:
1055
+ all_hidden_states += (hidden_states,)
1056
+
1057
+ next_cache = next_decoder_cache if use_cache else None
1058
+ if not return_dict:
1059
+ return tuple(
1060
+ v
1061
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
1062
+ if v is not None
1063
+ )
1064
+ return TFBaseModelOutputWithPastAndCrossAttentions(
1065
+ last_hidden_state=hidden_states,
1066
+ past_key_values=next_cache,
1067
+ hidden_states=all_hidden_states,
1068
+ attentions=all_self_attns,
1069
+ cross_attentions=all_cross_attentions,
1070
+ )
1071
+
1072
+ def build(self, input_shape=None):
1073
+ if self.built:
1074
+ return
1075
+ self.built = True
1076
+ if getattr(self, "embed_tokens", None) is not None:
1077
+ with tf.name_scope(self.embed_tokens.name):
1078
+ self.embed_tokens.build(None)
1079
+ if getattr(self, "embed_positions", None) is not None:
1080
+ with tf.name_scope(self.embed_positions.name):
1081
+ self.embed_positions.build(None)
1082
+ if getattr(self, "layer_norm", None) is not None:
1083
+ with tf.name_scope(self.layer_norm.name):
1084
+ self.layer_norm.build([None, None, self.config.d_model])
1085
+ if getattr(self, "decoder_layers", None) is not None:
1086
+ for layer in self.decoder_layers:
1087
+ with tf.name_scope(layer.name):
1088
+ layer.build(None)
1089
+
1090
+
1091
+ @add_start_docstrings(
1092
+ "The bare Whisper Model outputting raw hidden-states without any specific head on top.",
1093
+ WHISPER_START_DOCSTRING,
1094
+ )
1095
+ @keras_serializable
1096
+ class TFWhisperMainLayer(keras.layers.Layer):
1097
+ config_class = WhisperConfig
1098
+
1099
+ def __init__(self, config: WhisperConfig, **kwargs):
1100
+ super().__init__(**kwargs)
1101
+ self.config = config
1102
+ self.encoder = TFWhisperEncoder(config, name="encoder")
1103
+ self.decoder = TFWhisperDecoder(config, name="decoder")
1104
+
1105
+ def get_input_embeddings(self):
1106
+ return self.decoder.embed_tokens
1107
+
1108
+ def set_input_embeddings(self, value):
1109
+ self.decoder.embed_tokens = value
1110
+
1111
+ def get_encoder(self):
1112
+ return self.encoder
1113
+
1114
+ def get_decoder(self):
1115
+ return self.decoder
1116
+
1117
+ @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
1118
+ @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1119
+ @unpack_inputs
1120
+ def call(
1121
+ self,
1122
+ input_features=None,
1123
+ decoder_input_ids=None,
1124
+ decoder_attention_mask=None,
1125
+ decoder_position_ids=None,
1126
+ head_mask=None,
1127
+ decoder_head_mask=None,
1128
+ cross_attn_head_mask=None,
1129
+ encoder_outputs=None,
1130
+ past_key_values=None,
1131
+ decoder_inputs_embeds=None,
1132
+ use_cache=None,
1133
+ output_attentions=None,
1134
+ output_hidden_states=None,
1135
+ return_dict=None,
1136
+ training=False,
1137
+ ):
1138
+ r"""
1139
+ Returns:
1140
+
1141
+ Example:
1142
+
1143
+ ```python
1144
+ >>> import tensorflow as tf
1145
+ >>> from transformers import TFWhisperModel, AutoFeatureExtractor
1146
+ >>> from datasets import load_dataset
1147
+
1148
+ >>> model = TFWhisperModel.from_pretrained("openai/whisper-base")
1149
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
1150
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1151
+ >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="tf")
1152
+ >>> input_features = inputs.input_features
1153
+ >>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id
1154
+ >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
1155
+ >>> list(last_hidden_state.shape)
1156
+ [1, 2, 512]
1157
+ ```"""
1158
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1159
+ output_hidden_states = (
1160
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1161
+ )
1162
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1163
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1164
+
1165
+ if encoder_outputs is None:
1166
+ encoder_outputs = self.encoder(
1167
+ input_features,
1168
+ head_mask=head_mask,
1169
+ output_attentions=output_attentions,
1170
+ output_hidden_states=output_hidden_states,
1171
+ return_dict=return_dict,
1172
+ training=training,
1173
+ )
1174
+ # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
1175
+ elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
1176
+ encoder_outputs = TFBaseModelOutput(
1177
+ last_hidden_state=encoder_outputs[0],
1178
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1179
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1180
+ )
1181
+
1182
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1183
+ decoder_outputs = self.decoder(
1184
+ input_ids=decoder_input_ids,
1185
+ attention_mask=decoder_attention_mask,
1186
+ position_ids=decoder_position_ids,
1187
+ encoder_hidden_states=encoder_outputs[0],
1188
+ head_mask=decoder_head_mask,
1189
+ cross_attn_head_mask=cross_attn_head_mask,
1190
+ past_key_values=past_key_values,
1191
+ inputs_embeds=decoder_inputs_embeds,
1192
+ use_cache=use_cache,
1193
+ output_attentions=output_attentions,
1194
+ output_hidden_states=output_hidden_states,
1195
+ return_dict=return_dict,
1196
+ training=training,
1197
+ )
1198
+
1199
+ if not return_dict:
1200
+ return decoder_outputs + encoder_outputs
1201
+
1202
+ return TFSeq2SeqModelOutput(
1203
+ last_hidden_state=decoder_outputs.last_hidden_state,
1204
+ past_key_values=decoder_outputs.past_key_values,
1205
+ decoder_hidden_states=decoder_outputs.hidden_states,
1206
+ decoder_attentions=decoder_outputs.attentions,
1207
+ cross_attentions=decoder_outputs.cross_attentions,
1208
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1209
+ encoder_hidden_states=encoder_outputs.hidden_states,
1210
+ encoder_attentions=encoder_outputs.attentions,
1211
+ )
1212
+
1213
+ def build(self, input_shape=None):
1214
+ if self.built:
1215
+ return
1216
+ self.built = True
1217
+ if getattr(self, "encoder", None) is not None:
1218
+ with tf.name_scope(self.encoder.name):
1219
+ self.encoder.build(None)
1220
+ if getattr(self, "decoder", None) is not None:
1221
+ with tf.name_scope(self.decoder.name):
1222
+ self.decoder.build(None)
1223
+
1224
+
1225
+ @add_start_docstrings(
1226
+ "The bare Whisper Model outputting raw hidden-states without any specific head on top.",
1227
+ WHISPER_START_DOCSTRING,
1228
+ )
1229
+ class TFWhisperModel(TFWhisperPreTrainedModel):
1230
+ def __init__(self, config: WhisperConfig, **kwargs):
1231
+ super().__init__(config, **kwargs)
1232
+
1233
+ self.model = TFWhisperMainLayer(config, name="model")
1234
+
1235
+ def get_input_embeddings(self):
1236
+ return self.model.decoder.embed_tokens
1237
+
1238
+ def set_input_embeddings(self, value):
1239
+ self.model.decoder.embed_tokens = value
1240
+
1241
+ def get_encoder(self):
1242
+ return self.model.encoder
1243
+
1244
+ def get_decoder(self):
1245
+ return self.model.decoder
1246
+
1247
+ def decoder(self):
1248
+ return self.model.decoder
1249
+
1250
+ def encoder(self):
1251
+ return self.model.encoder
1252
+
1253
+ @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
1254
+ @replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1255
+ @unpack_inputs
1256
+ def call(
1257
+ self,
1258
+ input_features: TFModelInputType | None = None,
1259
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
1260
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
1261
+ decoder_position_ids: np.ndarray | tf.Tensor | None = None,
1262
+ head_mask: np.ndarray | tf.Tensor | None = None,
1263
+ decoder_head_mask: np.ndarray | tf.Tensor | None = None,
1264
+ cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
1265
+ encoder_outputs: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
1266
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
1267
+ decoder_inputs_embeds: Optional[Tuple[Union[np.ndarray, tf.Tensor]]] = None,
1268
+ use_cache: Optional[bool] = None,
1269
+ output_attentions: Optional[bool] = None,
1270
+ output_hidden_states: Optional[bool] = None,
1271
+ return_dict: Optional[bool] = None,
1272
+ training: bool = False,
1273
+ ) -> Union[Tuple[tf.Tensor], TFSeq2SeqModelOutput]:
1274
+ r"""
1275
+ Returns:
1276
+
1277
+ Example:
1278
+
1279
+ ```python
1280
+ >>> import tensorflow as tf
1281
+ >>> from transformers import TFWhisperModel, AutoFeatureExtractor
1282
+ >>> from datasets import load_dataset
1283
+
1284
+ >>> model = TFWhisperModel.from_pretrained("openai/whisper-base")
1285
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
1286
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1287
+ >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="tf")
1288
+ >>> input_features = inputs.input_features
1289
+ >>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id
1290
+ >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
1291
+ >>> list(last_hidden_state.shape)
1292
+ [1, 2, 512]
1293
+ ```"""
1294
+ outputs = self.model(
1295
+ input_features=input_features,
1296
+ decoder_input_ids=decoder_input_ids,
1297
+ decoder_attention_mask=decoder_attention_mask,
1298
+ decoder_position_ids=decoder_position_ids,
1299
+ head_mask=head_mask,
1300
+ decoder_head_mask=decoder_head_mask,
1301
+ cross_attn_head_mask=cross_attn_head_mask,
1302
+ encoder_outputs=encoder_outputs,
1303
+ past_key_values=past_key_values,
1304
+ decoder_inputs_embeds=decoder_inputs_embeds,
1305
+ use_cache=use_cache,
1306
+ output_attentions=output_attentions,
1307
+ output_hidden_states=output_hidden_states,
1308
+ return_dict=return_dict,
1309
+ training=training,
1310
+ )
1311
+ return outputs
1312
+
1313
+ def serving_output(self, output):
1314
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
1315
+ dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
1316
+ dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
1317
+ cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
1318
+ enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
1319
+ enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
1320
+
1321
+ return TFSeq2SeqModelOutput(
1322
+ last_hidden_state=output.last_hidden_state,
1323
+ past_key_values=pkv,
1324
+ decoder_hidden_states=dec_hs,
1325
+ decoder_attentions=dec_attns,
1326
+ cross_attentions=cross_attns,
1327
+ encoder_last_hidden_state=output.encoder_last_hidden_state,
1328
+ encoder_hidden_states=enc_hs,
1329
+ encoder_attentions=enc_attns,
1330
+ )
1331
+
1332
+ def build(self, input_shape=None):
1333
+ if self.built:
1334
+ return
1335
+ self.built = True
1336
+ if getattr(self, "model", None) is not None:
1337
+ with tf.name_scope(self.model.name):
1338
+ self.model.build(None)
1339
+
1340
+
1341
+ @add_start_docstrings(
1342
+ "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
1343
+ WHISPER_START_DOCSTRING,
1344
+ )
1345
+ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLanguageModelingLoss):
1346
+ base_model_prefix = "model"
1347
+ _keys_to_ignore_on_load_missing = [
1348
+ r"encoder.version",
1349
+ r"decoder.version",
1350
+ r"proj_out.weight",
1351
+ ]
1352
+ _keys_to_ignore_on_save = [
1353
+ r"proj_out.weight",
1354
+ ]
1355
+
1356
+ def __init__(self, config: WhisperConfig, **kwargs):
1357
+ super().__init__(config, **kwargs)
1358
+ self.model = TFWhisperMainLayer(config, name="model")
1359
+
1360
+ def get_encoder(self):
1361
+ return self.model.get_encoder()
1362
+
1363
+ def get_decoder(self):
1364
+ return self.model.get_decoder()
1365
+
1366
+ def get_output_embeddings(self):
1367
+ return self.get_input_embeddings()
1368
+
1369
+ def set_output_embeddings(self, value):
1370
+ self.set_input_embeddings(value)
1371
+
1372
+ def resize_token_embeddings(self, new_num_tokens: int) -> keras.layers.Embedding:
1373
+ new_embeddings = super().resize_token_embeddings(new_num_tokens)
1374
+ return new_embeddings
1375
+
1376
+ @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
1377
+ @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1378
+ @unpack_inputs
1379
+ def call(
1380
+ self,
1381
+ input_features: TFModelInputType | None = None,
1382
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
1383
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
1384
+ decoder_position_ids: np.ndarray | tf.Tensor | None = None,
1385
+ head_mask: np.ndarray | tf.Tensor | None = None,
1386
+ decoder_head_mask: np.ndarray | tf.Tensor | None = None,
1387
+ cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
1388
+ encoder_outputs: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
1389
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
1390
+ decoder_inputs_embeds: Optional[Tuple[Union[np.ndarray, tf.Tensor]]] = None,
1391
+ labels: np.ndarray | tf.Tensor | None = None,
1392
+ use_cache: Optional[bool] = None,
1393
+ output_attentions: Optional[bool] = None,
1394
+ output_hidden_states: Optional[bool] = None,
1395
+ return_dict: Optional[bool] = None,
1396
+ training: bool = False,
1397
+ ) -> Union[Tuple[tf.Tensor], TFSeq2SeqLMOutput]:
1398
+ r"""
1399
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1400
+ Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
1401
+ or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
1402
+ only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1403
+
1404
+ Returns:
1405
+
1406
+ Example:
1407
+
1408
+ ```python
1409
+ >>> import tensorflow as tf
1410
+ >>> from transformers import AutoProcessor, TFWhisperForConditionalGeneration
1411
+ >>> from datasets import load_dataset
1412
+
1413
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
1414
+ >>> model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
1415
+
1416
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1417
+
1418
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf")
1419
+ >>> input_features = inputs.input_features
1420
+
1421
+ >>> generated_ids = model.generate(input_features=input_features)
1422
+
1423
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1424
+ >>> transcription
1425
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
1426
+ ```"""
1427
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1428
+
1429
+ if labels is not None:
1430
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1431
+ decoder_input_ids = shift_tokens_right(
1432
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1433
+ )
1434
+
1435
+ outputs = self.model(
1436
+ input_features,
1437
+ decoder_input_ids=decoder_input_ids,
1438
+ encoder_outputs=encoder_outputs,
1439
+ decoder_attention_mask=decoder_attention_mask,
1440
+ decoder_position_ids=decoder_position_ids,
1441
+ head_mask=head_mask,
1442
+ decoder_head_mask=decoder_head_mask,
1443
+ cross_attn_head_mask=cross_attn_head_mask,
1444
+ past_key_values=past_key_values,
1445
+ decoder_inputs_embeds=decoder_inputs_embeds,
1446
+ use_cache=use_cache,
1447
+ output_attentions=output_attentions,
1448
+ output_hidden_states=output_hidden_states,
1449
+ return_dict=return_dict,
1450
+ training=training,
1451
+ )
1452
+ decoder_last_hidden_state = outputs[0]
1453
+ # Decoder and encoder embeddings are tied
1454
+ lm_logits = tf.matmul(decoder_last_hidden_state, self.get_output_embeddings().weights, transpose_b=True)
1455
+
1456
+ loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
1457
+
1458
+ if not return_dict:
1459
+ output = (lm_logits,) + outputs[1:]
1460
+ return ((loss,) + output) if loss is not None else output
1461
+
1462
+ return TFSeq2SeqLMOutput(
1463
+ loss=loss,
1464
+ logits=lm_logits,
1465
+ past_key_values=outputs.past_key_values,
1466
+ decoder_hidden_states=outputs.decoder_hidden_states,
1467
+ decoder_attentions=outputs.decoder_attentions,
1468
+ cross_attentions=outputs.cross_attentions,
1469
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1470
+ encoder_hidden_states=outputs.encoder_hidden_states,
1471
+ encoder_attentions=outputs.encoder_attentions,
1472
+ )
1473
+
1474
+ def generate(
1475
+ self,
1476
+ inputs: Optional[tf.Tensor] = None,
1477
+ generation_config: Optional[GenerationConfig] = None,
1478
+ logits_processor: Optional[TFLogitsProcessorList] = None,
1479
+ seed: Optional[List[int]] = None,
1480
+ return_timestamps: Optional[bool] = None,
1481
+ task: Optional[str] = None,
1482
+ language: Optional[str] = None,
1483
+ is_multilingual: Optional[bool] = None,
1484
+ prompt_ids: Optional[tf.Tensor] = None,
1485
+ return_token_timestamps=None,
1486
+ **kwargs,
1487
+ ):
1488
+ r"""
1489
+ Generates sequences of token ids for models with a language modeling head.
1490
+
1491
+ <Tip warning={true}>
1492
+
1493
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
1494
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
1495
+ parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
1496
+
1497
+ For an overview of generation strategies and code examples, check out the [following
1498
+ guide](../generation_strategies).
1499
+
1500
+ </Tip>
1501
+
1502
+ Parameters:
1503
+ inputs (`tf.Tensor` of varying shape depending on the modality, *optional*):
1504
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If unset the method
1505
+ initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in
1506
+ the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`,
1507
+ `input_values`, `input_features`, or `pixel_values`.
1508
+ generation_config (`~generation.GenerationConfig`, *optional*):
1509
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
1510
+ passed to generate matching the attributes of `generation_config` will override them. If
1511
+ `generation_config` is not provided, the default will be used, which had the following loading
1512
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
1513
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
1514
+ default values, whose documentation should be checked to parameterize generation.
1515
+ logits_processor (`LogitsProcessorList`, *optional*):
1516
+ Custom logits processors that complement the default logits processors built from arguments and
1517
+ generation config. If a logit processor is passed that is already created with the arguments or a
1518
+ generation config an error is thrown. This feature is intended for advanced users.
1519
+ seed (`List[int]`, *optional*):
1520
+ Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
1521
+ `seed` argument from stateless functions in `tf.random`.
1522
+ return_timestamps (`bool`, *optional*):
1523
+ Whether to return the timestamps with the text. This enables the `TFWhisperTimestampsLogitsProcessor`.
1524
+ task (`str`, *optional*):
1525
+ Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
1526
+ will be updated accordingly.
1527
+ language (`str`, *optional*):
1528
+ Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can
1529
+ find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
1530
+ is_multilingual (`bool`, *optional*):
1531
+ Whether or not the model is multilingual.
1532
+ prompt_ids (`tf.Tensor`, *optional*):
1533
+ Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
1534
+ provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
1535
+ transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
1536
+ correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
1537
+ return_token_timestamps (`bool`, *optional*):
1538
+ Whether to return token-level timestamps with the text. This can be used with or without the
1539
+ `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
1540
+ words.
1541
+ kwargs (`Dict[str, Any]`, *optional*):
1542
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
1543
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
1544
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
1545
+
1546
+ Return:
1547
+ [`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when
1548
+ `config.return_dict_in_generate=True`) or a `tf.Tensor`.
1549
+
1550
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
1551
+ [`~utils.ModelOutput`] types are:
1552
+
1553
+ - [`~generation.TFGreedySearchDecoderOnlyOutput`],
1554
+ - [`~generation.TFSampleDecoderOnlyOutput`],
1555
+ - [`~generation.TFBeamSearchDecoderOnlyOutput`],
1556
+ - [`~generation.TFBeamSampleDecoderOnlyOutput`]
1557
+
1558
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
1559
+ [`~utils.ModelOutput`] types are:
1560
+
1561
+ - [`~generation.TFGreedySearchEncoderDecoderOutput`],
1562
+ - [`~generation.TFSampleEncoderDecoderOutput`],
1563
+ - [`~generation.TFBeamSearchEncoderDecoderOutput`],
1564
+ - [`~generation.TFBeamSampleEncoderDecoderOutput`]
1565
+
1566
+ """
1567
+ if generation_config is None:
1568
+ generation_config = self.generation_config
1569
+
1570
+ if return_timestamps is not None:
1571
+ if not hasattr(generation_config, "no_timestamps_token_id"):
1572
+ raise ValueError(
1573
+ "You are trying to return timestamps, but the generation config is not properly set. "
1574
+ "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
1575
+ "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
1576
+ )
1577
+
1578
+ generation_config.return_timestamps = return_timestamps
1579
+ else:
1580
+ generation_config.return_timestamps = False
1581
+
1582
+ if language is not None:
1583
+ language = language.lower()
1584
+ generation_config.language = language
1585
+ if task is not None:
1586
+ generation_config.task = task
1587
+
1588
+ forced_decoder_ids = None
1589
+
1590
+ # Legacy code for backward compatibility
1591
+ if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
1592
+ forced_decoder_ids = self.config.forced_decoder_ids
1593
+ elif (
1594
+ hasattr(self.generation_config, "forced_decoder_ids")
1595
+ and self.generation_config.forced_decoder_ids is not None
1596
+ ):
1597
+ forced_decoder_ids = self.generation_config.forced_decoder_ids
1598
+ else:
1599
+ forced_decoder_ids = kwargs.get("forced_decoder_ids", None)
1600
+
1601
+ if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
1602
+ forced_decoder_ids = []
1603
+ if hasattr(generation_config, "language"):
1604
+ if generation_config.language in generation_config.lang_to_id.keys():
1605
+ language_token = generation_config.language
1606
+ elif generation_config.language in TO_LANGUAGE_CODE.keys():
1607
+ language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
1608
+ elif generation_config.language in TO_LANGUAGE_CODE.values():
1609
+ language_token = f"<|{generation_config.language}|>"
1610
+ else:
1611
+ is_language_code = len(generation_config.language) == 2
1612
+ raise ValueError(
1613
+ f"Unsupported language: {generation_config.language}. Language should be one of:"
1614
+ f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
1615
+ )
1616
+ if language_token not in generation_config.lang_to_id:
1617
+ raise ValueError(
1618
+ f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
1619
+ "(You should just add it to the generation config)"
1620
+ )
1621
+ forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
1622
+ else:
1623
+ forced_decoder_ids.append((1, None)) # automatically detect the language
1624
+
1625
+ if hasattr(generation_config, "task"):
1626
+ if generation_config.task in TASK_IDS:
1627
+ forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
1628
+ else:
1629
+ raise ValueError(
1630
+ f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
1631
+ )
1632
+ elif hasattr(generation_config, "task_to_id"):
1633
+ forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
1634
+ if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
1635
+ idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
1636
+ forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
1637
+
1638
+ if forced_decoder_ids is not None:
1639
+ generation_config.forced_decoder_ids = forced_decoder_ids
1640
+
1641
+ if prompt_ids is not None:
1642
+ if kwargs.get("decoder_start_token_id") is not None:
1643
+ raise ValueError(
1644
+ "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
1645
+ )
1646
+ prompt_ids = prompt_ids.tolist()
1647
+ decoder_start_token_id, *text_prompt_ids = prompt_ids
1648
+ # Slicing the text prompt ids in a manner consistent with the OpenAI implementation
1649
+ # to accommodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
1650
+ text_prompt_ids = text_prompt_ids[-self.config.max_length // 2 - 1 :]
1651
+ # Set the decoder_start_token_id to <|startofprev|>
1652
+ kwargs.update({"decoder_start_token_id": decoder_start_token_id})
1653
+
1654
+ # Update the max generation length to include the prompt
1655
+ specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None)
1656
+ default_max_length = generation_config.max_new_tokens or generation_config.max_length
1657
+ non_prompt_max_length = specified_max_length or default_max_length
1658
+ kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids)
1659
+
1660
+ # Reformat the forced_decoder_ids to incorporate the prompt
1661
+ non_prompt_forced_decoder_ids = (
1662
+ kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
1663
+ )
1664
+ forced_decoder_ids = [
1665
+ *text_prompt_ids,
1666
+ generation_config.decoder_start_token_id,
1667
+ *[token for _rank, token in non_prompt_forced_decoder_ids],
1668
+ ]
1669
+ forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
1670
+ generation_config.forced_decoder_ids = forced_decoder_ids
1671
+
1672
+ # TODO: Implement `WhisperTimeStampLogitsProcessor`.
1673
+ if generation_config.return_timestamps:
1674
+ # logits_processor = [TFWhisperTimeStampLogitsProcessor(generation_config)]
1675
+ raise ValueError("`TFWhisperForConditionalGeneration` doesn't support returning the timestamps yet.")
1676
+
1677
+ if return_token_timestamps:
1678
+ kwargs["output_attentions"] = True
1679
+ kwargs["return_dict_in_generate"] = True
1680
+
1681
+ if getattr(generation_config, "task", None) == "translate":
1682
+ logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
1683
+ if not hasattr(generation_config, "alignment_heads"):
1684
+ raise ValueError(
1685
+ "Model generation config has no `alignment_heads`, token-level timestamps not available. "
1686
+ "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
1687
+ )
1688
+
1689
+ outputs = super().generate(
1690
+ inputs,
1691
+ generation_config,
1692
+ logits_processor,
1693
+ **kwargs,
1694
+ )
1695
+
1696
+ if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
1697
+ outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads)
1698
+
1699
+ return outputs
1700
+
1701
+ def serving_output(self, output):
1702
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
1703
+ dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
1704
+ dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
1705
+ cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
1706
+ enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
1707
+ enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
1708
+
1709
+ return TFSeq2SeqLMOutput(
1710
+ logits=output.logits,
1711
+ past_key_values=pkv,
1712
+ decoder_hidden_states=dec_hs,
1713
+ decoder_attentions=dec_attns,
1714
+ cross_attentions=cross_attns,
1715
+ encoder_last_hidden_state=output.encoder_last_hidden_state,
1716
+ encoder_hidden_states=enc_hs,
1717
+ encoder_attentions=enc_attns,
1718
+ )
1719
+
1720
+ def prepare_inputs_for_generation(
1721
+ self,
1722
+ decoder_input_ids,
1723
+ past_key_values=None,
1724
+ use_cache=None,
1725
+ encoder_outputs=None,
1726
+ attention_mask=None,
1727
+ decoder_attention_mask=None,
1728
+ **kwargs,
1729
+ ):
1730
+ # cut decoder_input_ids if past is used
1731
+ if past_key_values is not None:
1732
+ decoder_input_ids = decoder_input_ids[:, -1:]
1733
+
1734
+ if decoder_attention_mask is not None: # xla
1735
+ decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
1736
+ elif past_key_values is not None: # no xla + past
1737
+ decoder_position_ids = past_key_values[0][0].shape[2]
1738
+ else: # no xla + no past
1739
+ decoder_position_ids = tf.range(decoder_input_ids.shape[1])
1740
+ decoder_position_ids = tf.broadcast_to(decoder_position_ids, decoder_input_ids.shape)
1741
+
1742
+ return {
1743
+ "input_features": None, # Needs to be passed to make Keras.layer.__call__ happy
1744
+ "encoder_outputs": encoder_outputs,
1745
+ "past_key_values": past_key_values,
1746
+ "decoder_input_ids": decoder_input_ids,
1747
+ "use_cache": use_cache,
1748
+ "decoder_attention_mask": decoder_attention_mask,
1749
+ "decoder_position_ids": decoder_position_ids,
1750
+ }
1751
+
1752
+ def build(self, input_shape=None):
1753
+ if self.built:
1754
+ return
1755
+ self.built = True
1756
+ if getattr(self, "model", None) is not None:
1757
+ with tf.name_scope(self.model.name):
1758
+ self.model.build(None)
modeling_tf_whisper.cpython-312 (1).pyc ADDED
Binary file (88.7 kB). View file
 
modeling_tf_whisper.cpython-312.pyc ADDED
Binary file (88.7 kB). View file
 
modeling_tf_whisper.py ADDED
@@ -0,0 +1,1758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """TensorFlow Whisper model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ import random
21
+ from typing import Dict, List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import tensorflow as tf
25
+
26
+ from ...activations_tf import get_tf_activation
27
+ from ...generation.configuration_utils import GenerationConfig
28
+ from ...generation.tf_logits_process import TFLogitsProcessorList
29
+ from ...modeling_tf_outputs import (
30
+ TFBaseModelOutput,
31
+ TFBaseModelOutputWithPastAndCrossAttentions,
32
+ TFSeq2SeqLMOutput,
33
+ TFSeq2SeqModelOutput,
34
+ )
35
+ from ...modeling_tf_utils import (
36
+ TFCausalLanguageModelingLoss,
37
+ TFModelInputType,
38
+ TFPreTrainedModel,
39
+ keras,
40
+ keras_serializable,
41
+ unpack_inputs,
42
+ )
43
+ from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
44
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
45
+ from .configuration_whisper import WhisperConfig
46
+ from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CONFIG_FOR_DOC = "WhisperConfig"
52
+
53
+
54
+ LARGE_NEGATIVE = -1e8
55
+
56
+
57
+ def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor:
58
+ """Returns sinusoids for positional embedding"""
59
+ length, channels = shape
60
+ if channels % 2 != 0:
61
+ raise ValueError(
62
+ f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
63
+ )
64
+ log_timescale_increment = math.log(10000) / (channels // 2 - 1)
65
+ inv_timescales = tf.exp(-log_timescale_increment * tf.range(channels // 2, dtype=tf.float32))
66
+ scaled_time = tf.reshape(tf.range(length, dtype=tf.float32), (-1, 1)) * tf.reshape(inv_timescales, (1, -1))
67
+ return tf.cast(tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1), dtype)
68
+
69
+
70
+ # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
71
+ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
72
+ pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
73
+ decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
74
+ start_tokens = tf.fill(
75
+ (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
76
+ )
77
+ shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
78
+ # replace possible -100 values in labels by `pad_token_id`
79
+ shifted_input_ids = tf.where(
80
+ shifted_input_ids == -100,
81
+ tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
82
+ shifted_input_ids,
83
+ )
84
+
85
+ # "Verify that `labels` has only positive values and -100"
86
+ assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
87
+
88
+ # Make sure the assertion op is called by wrapping the result in an identity no-op
89
+ with tf.control_dependencies([assert_gte0]):
90
+ shifted_input_ids = tf.identity(shifted_input_ids)
91
+
92
+ return shifted_input_ids
93
+
94
+
95
+ # Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
96
+ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
97
+ """
98
+ Make causal mask used for bi-directional self-attention.
99
+ """
100
+ bsz = input_ids_shape[0]
101
+ tgt_len = input_ids_shape[1]
102
+ mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
103
+ mask_cond = tf.range(shape_list(mask)[-1])
104
+
105
+ mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
106
+
107
+ if past_key_values_length > 0:
108
+ mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
109
+
110
+ return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
111
+
112
+
113
+ # Copied from transformers.models.bart.modeling_tf_bart._expand_mask
114
+ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
115
+ """
116
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
117
+ """
118
+ src_len = shape_list(mask)[1]
119
+ tgt_len = tgt_len if tgt_len is not None else src_len
120
+ one_cst = tf.constant(1.0)
121
+ mask = tf.cast(mask, dtype=one_cst.dtype)
122
+ expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
123
+
124
+ return (one_cst - expanded_mask) * LARGE_NEGATIVE
125
+
126
+
127
+ class TFWhisperPositionalEmbedding(keras.layers.Layer):
128
+ def __init__(
129
+ self,
130
+ num_positions: int,
131
+ embedding_dim: int,
132
+ padding_idx: Optional[int] = None,
133
+ embedding_initializer=None,
134
+ **kwargs,
135
+ ):
136
+ super().__init__(**kwargs)
137
+ self.num_positions = num_positions
138
+ self.embedding_dim = embedding_dim
139
+ self.padding_idx = padding_idx
140
+ self.embedding_initializer = keras.initializers.get(embedding_initializer)
141
+
142
+ def build(self, input_shape):
143
+ self.weight = self.add_weight(
144
+ name="weight",
145
+ shape=[self.num_positions, self.embedding_dim],
146
+ initializer=self.embedding_initializer,
147
+ trainable=True,
148
+ )
149
+ super().build(input_shape)
150
+
151
+ def call(self, input_ids, past_key_values_length=0):
152
+ past_key_values_length = tf.cast(past_key_values_length, tf.int32)
153
+ gather_indices = tf.range(tf.shape(input_ids)[1], delta=1) + past_key_values_length
154
+ return tf.gather(self.weight, gather_indices)
155
+
156
+
157
+ class TFWhisperAttention(keras.layers.Layer):
158
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
159
+
160
+ def __init__(
161
+ self,
162
+ embed_dim: int,
163
+ num_heads: int,
164
+ dropout: float = 0.0,
165
+ is_decoder: bool = False,
166
+ bias: bool = True,
167
+ **kwargs,
168
+ ):
169
+ super().__init__(**kwargs)
170
+ self.embed_dim = embed_dim
171
+ self.num_heads = num_heads
172
+ self.dropout = keras.layers.Dropout(dropout)
173
+ self.head_dim = embed_dim // num_heads
174
+
175
+ if (self.head_dim * num_heads) != self.embed_dim:
176
+ raise ValueError(
177
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
178
+ f" and `num_heads`: {num_heads})."
179
+ )
180
+ self.scaling = self.head_dim**-0.5
181
+ self.is_decoder = is_decoder
182
+
183
+ self.k_proj = keras.layers.Dense(embed_dim, use_bias=False, name="k_proj")
184
+ self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
185
+ self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
186
+ self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
187
+
188
+ # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention._shape with BART->whisper
189
+ def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
190
+ return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
191
+
192
+ # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention.call with BART->whisper
193
+ def call(
194
+ self,
195
+ hidden_states: tf.Tensor,
196
+ key_value_states: tf.Tensor | None = None,
197
+ past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
198
+ attention_mask: tf.Tensor | None = None,
199
+ layer_head_mask: tf.Tensor | None = None,
200
+ training: Optional[bool] = False,
201
+ ) -> Tuple[tf.Tensor, tf.Tensor | None]:
202
+ """Input shape: Batch x Time x Channel"""
203
+
204
+ # if key_value_states are provided this layer is used as a cross-attention layer
205
+ # for the decoder
206
+ is_cross_attention = key_value_states is not None
207
+ bsz, tgt_len, embed_dim = shape_list(hidden_states)
208
+
209
+ # get query proj
210
+ query_states = self.q_proj(hidden_states) * self.scaling
211
+ # get key, value proj
212
+ if is_cross_attention and past_key_value is not None:
213
+ # reuse k,v, cross_attentions
214
+ key_states = past_key_value[0]
215
+ value_states = past_key_value[1]
216
+ elif is_cross_attention:
217
+ # cross_attentions
218
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
219
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
220
+ elif past_key_value is not None:
221
+ # reuse k, v, self_attention
222
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
223
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
224
+ key_states = tf.concat([past_key_value[0], key_states], axis=2)
225
+ value_states = tf.concat([past_key_value[1], value_states], axis=2)
226
+ else:
227
+ # self_attention
228
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
229
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
230
+
231
+ if self.is_decoder:
232
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
233
+ # Further calls to cross_attention layer can then reuse all cross-attention
234
+ # key/value_states (first "if" case)
235
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
236
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
237
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
238
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
239
+ past_key_value = (key_states, value_states)
240
+
241
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
242
+ query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
243
+ key_states = tf.reshape(key_states, proj_shape)
244
+ value_states = tf.reshape(value_states, proj_shape)
245
+
246
+ src_len = shape_list(key_states)[1]
247
+ attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
248
+
249
+ tf.debugging.assert_equal(
250
+ shape_list(attn_weights),
251
+ [bsz * self.num_heads, tgt_len, src_len],
252
+ message=(
253
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
254
+ f" {shape_list(attn_weights)}"
255
+ ),
256
+ )
257
+
258
+ if attention_mask is not None:
259
+ tf.debugging.assert_equal(
260
+ shape_list(attention_mask),
261
+ [bsz, 1, tgt_len, src_len],
262
+ message=(
263
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
264
+ f" {shape_list(attention_mask)}"
265
+ ),
266
+ )
267
+
268
+ attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
269
+ attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
270
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
271
+
272
+ attn_weights = stable_softmax(attn_weights, axis=-1)
273
+
274
+ if layer_head_mask is not None:
275
+ tf.debugging.assert_equal(
276
+ shape_list(layer_head_mask),
277
+ [self.num_heads],
278
+ message=(
279
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
280
+ f" {shape_list(layer_head_mask)}"
281
+ ),
282
+ )
283
+
284
+ attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
285
+ attn_weights, (bsz, self.num_heads, tgt_len, src_len)
286
+ )
287
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
288
+
289
+ attn_probs = self.dropout(attn_weights, training=training)
290
+ attn_output = tf.matmul(attn_probs, value_states)
291
+
292
+ tf.debugging.assert_equal(
293
+ shape_list(attn_output),
294
+ [bsz * self.num_heads, tgt_len, self.head_dim],
295
+ message=(
296
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
297
+ f" {shape_list(attn_output)}"
298
+ ),
299
+ )
300
+
301
+ attn_output = tf.transpose(
302
+ tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
303
+ )
304
+ attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
305
+
306
+ attn_output = self.out_proj(attn_output)
307
+ attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
308
+
309
+ return attn_output, attn_weights, past_key_value
310
+
311
+ def build(self, input_shape=None):
312
+ if self.built:
313
+ return
314
+ self.built = True
315
+ if getattr(self, "k_proj", None) is not None:
316
+ with tf.name_scope(self.k_proj.name):
317
+ self.k_proj.build([None, None, self.embed_dim])
318
+ if getattr(self, "v_proj", None) is not None:
319
+ with tf.name_scope(self.v_proj.name):
320
+ self.v_proj.build([None, None, self.embed_dim])
321
+ if getattr(self, "q_proj", None) is not None:
322
+ with tf.name_scope(self.q_proj.name):
323
+ self.q_proj.build([None, None, self.embed_dim])
324
+ if getattr(self, "out_proj", None) is not None:
325
+ with tf.name_scope(self.out_proj.name):
326
+ self.out_proj.build([None, None, self.embed_dim])
327
+
328
+
329
+ # Copied from transformers.models.speech_to_text.modeling_tf_speech_to_text.TFSpeech2TextEncoderLayer with Speech2Text->Whisper
330
+ class TFWhisperEncoderLayer(keras.layers.Layer):
331
+ def __init__(self, config: WhisperConfig, **kwargs):
332
+ super().__init__(**kwargs)
333
+ self.embed_dim = config.d_model
334
+ self.self_attn = TFWhisperAttention(
335
+ self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
336
+ )
337
+ self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
338
+ self.dropout = keras.layers.Dropout(config.dropout)
339
+ self.activation_fn = get_tf_activation(config.activation_function)
340
+ self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
341
+ self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
342
+ self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
343
+ self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
344
+ self.config = config
345
+
346
+ def call(
347
+ self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training: bool = False
348
+ ):
349
+ """
350
+ Args:
351
+ hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
352
+ attention_mask (`tf.Tensor`): attention mask of size
353
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
354
+ layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
355
+ `(encoder_attention_heads,)`
356
+ """
357
+ residual = hidden_states
358
+ hidden_states = self.self_attn_layer_norm(hidden_states)
359
+ hidden_states, self_attn_weights, _ = self.self_attn(
360
+ hidden_states=hidden_states,
361
+ attention_mask=attention_mask,
362
+ layer_head_mask=layer_head_mask,
363
+ training=training,
364
+ )
365
+
366
+ tf.debugging.assert_equal(
367
+ shape_list(hidden_states),
368
+ shape_list(residual),
369
+ message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
370
+ )
371
+
372
+ hidden_states = self.dropout(hidden_states, training=training)
373
+ hidden_states = residual + hidden_states
374
+
375
+ residual = hidden_states
376
+ hidden_states = self.final_layer_norm(hidden_states)
377
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
378
+ hidden_states = self.activation_dropout(hidden_states, training=training)
379
+ hidden_states = self.fc2(hidden_states)
380
+ hidden_states = self.dropout(hidden_states, training=training)
381
+ hidden_states = residual + hidden_states
382
+
383
+ return hidden_states, self_attn_weights
384
+
385
+ def build(self, input_shape=None):
386
+ if self.built:
387
+ return
388
+ self.built = True
389
+ if getattr(self, "self_attn", None) is not None:
390
+ with tf.name_scope(self.self_attn.name):
391
+ self.self_attn.build(None)
392
+ if getattr(self, "self_attn_layer_norm", None) is not None:
393
+ with tf.name_scope(self.self_attn_layer_norm.name):
394
+ self.self_attn_layer_norm.build([None, None, self.embed_dim])
395
+ if getattr(self, "fc1", None) is not None:
396
+ with tf.name_scope(self.fc1.name):
397
+ self.fc1.build([None, None, self.embed_dim])
398
+ if getattr(self, "fc2", None) is not None:
399
+ with tf.name_scope(self.fc2.name):
400
+ self.fc2.build([None, None, self.config.encoder_ffn_dim])
401
+ if getattr(self, "final_layer_norm", None) is not None:
402
+ with tf.name_scope(self.final_layer_norm.name):
403
+ self.final_layer_norm.build([None, None, self.embed_dim])
404
+
405
+
406
+ # Copied from transformers.models.speech_to_text.modeling_tf_speech_to_text.TFSpeech2TextDecoderLayer with Speech2Text->Whisper
407
+ class TFWhisperDecoderLayer(keras.layers.Layer):
408
+ def __init__(self, config: WhisperConfig, **kwargs):
409
+ super().__init__(**kwargs)
410
+ self.embed_dim = config.d_model
411
+
412
+ self.self_attn = TFWhisperAttention(
413
+ embed_dim=self.embed_dim,
414
+ num_heads=config.decoder_attention_heads,
415
+ dropout=config.attention_dropout,
416
+ name="self_attn",
417
+ is_decoder=True,
418
+ )
419
+ self.dropout = keras.layers.Dropout(config.dropout)
420
+ self.activation_fn = get_tf_activation(config.activation_function)
421
+ self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
422
+
423
+ self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
424
+ self.encoder_attn = TFWhisperAttention(
425
+ self.embed_dim,
426
+ config.decoder_attention_heads,
427
+ dropout=config.attention_dropout,
428
+ name="encoder_attn",
429
+ is_decoder=True,
430
+ )
431
+ self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
432
+ self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
433
+ self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
434
+ self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
435
+ self.config = config
436
+
437
+ def call(
438
+ self,
439
+ hidden_states,
440
+ attention_mask: tf.Tensor | None = None,
441
+ encoder_hidden_states: tf.Tensor | None = None,
442
+ encoder_attention_mask: tf.Tensor | None = None,
443
+ layer_head_mask: tf.Tensor | None = None,
444
+ cross_attn_layer_head_mask: tf.Tensor | None = None,
445
+ past_key_value: Tuple[tf.Tensor] | None = None,
446
+ training=False,
447
+ ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
448
+ """
449
+ Args:
450
+ hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
451
+ attention_mask (`tf.Tensor`): attention mask of size
452
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
453
+ encoder_hidden_states (`tf.Tensor`):
454
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
455
+ encoder_attention_mask (`tf.Tensor`): encoder attention mask of size
456
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
457
+ layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
458
+ `(decoder_attention_heads,)`
459
+ cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module.
460
+ `(decoder_attention_heads,)`
461
+ past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states
462
+ """
463
+ residual = hidden_states
464
+ hidden_states = self.self_attn_layer_norm(hidden_states)
465
+
466
+ # Self Attention
467
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
468
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
469
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
470
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
471
+ hidden_states=hidden_states,
472
+ past_key_value=self_attn_past_key_value,
473
+ attention_mask=attention_mask,
474
+ layer_head_mask=layer_head_mask,
475
+ training=training,
476
+ )
477
+ hidden_states = self.dropout(hidden_states, training=training)
478
+ hidden_states = residual + hidden_states
479
+
480
+ # Cross-Attention Block
481
+ cross_attn_present_key_value = None
482
+ cross_attn_weights = None
483
+ if encoder_hidden_states is not None:
484
+ residual = hidden_states
485
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
486
+
487
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
488
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
489
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
490
+ hidden_states=hidden_states,
491
+ key_value_states=encoder_hidden_states,
492
+ attention_mask=encoder_attention_mask,
493
+ layer_head_mask=cross_attn_layer_head_mask,
494
+ past_key_value=cross_attn_past_key_value,
495
+ training=training,
496
+ )
497
+ hidden_states = self.dropout(hidden_states, training=training)
498
+ hidden_states = residual + hidden_states
499
+
500
+ # add cross-attn to positions 3,4 of present_key_value tuple
501
+ present_key_value = present_key_value + cross_attn_present_key_value
502
+
503
+ # Fully Connected
504
+ residual = hidden_states
505
+ hidden_states = self.final_layer_norm(hidden_states)
506
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
507
+ hidden_states = self.activation_dropout(hidden_states, training=training)
508
+ hidden_states = self.fc2(hidden_states)
509
+ hidden_states = self.dropout(hidden_states, training=training)
510
+ hidden_states = residual + hidden_states
511
+
512
+ return (
513
+ hidden_states,
514
+ self_attn_weights,
515
+ cross_attn_weights,
516
+ present_key_value,
517
+ )
518
+
519
+ def build(self, input_shape=None):
520
+ if self.built:
521
+ return
522
+ self.built = True
523
+ if getattr(self, "self_attn", None) is not None:
524
+ with tf.name_scope(self.self_attn.name):
525
+ self.self_attn.build(None)
526
+ if getattr(self, "self_attn_layer_norm", None) is not None:
527
+ with tf.name_scope(self.self_attn_layer_norm.name):
528
+ self.self_attn_layer_norm.build([None, None, self.embed_dim])
529
+ if getattr(self, "encoder_attn", None) is not None:
530
+ with tf.name_scope(self.encoder_attn.name):
531
+ self.encoder_attn.build(None)
532
+ if getattr(self, "encoder_attn_layer_norm", None) is not None:
533
+ with tf.name_scope(self.encoder_attn_layer_norm.name):
534
+ self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
535
+ if getattr(self, "fc1", None) is not None:
536
+ with tf.name_scope(self.fc1.name):
537
+ self.fc1.build([None, None, self.embed_dim])
538
+ if getattr(self, "fc2", None) is not None:
539
+ with tf.name_scope(self.fc2.name):
540
+ self.fc2.build([None, None, self.config.decoder_ffn_dim])
541
+ if getattr(self, "final_layer_norm", None) is not None:
542
+ with tf.name_scope(self.final_layer_norm.name):
543
+ self.final_layer_norm.build([None, None, self.embed_dim])
544
+
545
+
546
+ class TFWhisperPreTrainedModel(TFPreTrainedModel):
547
+ config_class = WhisperConfig
548
+ base_model_prefix = "model"
549
+ main_input_name = "input_features"
550
+
551
+ def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor) -> int:
552
+ """
553
+ Computes the output length of the convolutional layers
554
+ """
555
+ input_lengths = (input_lengths - 1) // 2 + 1
556
+
557
+ return input_lengths
558
+
559
+ @property
560
+ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
561
+ """
562
+ Dummy inputs to build the network.
563
+
564
+ Returns:
565
+ `Dict[str, tf.Tensor]`: The dummy inputs.
566
+ """
567
+ return {
568
+ self.main_input_name: tf.random.uniform(
569
+ [1, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32
570
+ ),
571
+ "decoder_input_ids": tf.constant([[1, 3]], dtype=tf.int32),
572
+ }
573
+
574
+ @property
575
+ def input_signature(self):
576
+ return {
577
+ "input_features": tf.TensorSpec((None, self.config.num_mel_bins, None), tf.float32, name="input_features"),
578
+ "decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
579
+ "decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
580
+ }
581
+
582
+
583
+ WHISPER_START_DOCSTRING = r"""
584
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
585
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
586
+ etc.)
587
+
588
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
589
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
590
+ behavior.
591
+
592
+ Parameters:
593
+ config ([`WhisperConfig`]):
594
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
595
+ load the weights associated with the model, only the configuration. Check out the
596
+ [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
597
+ """
598
+
599
+ WHISPER_INPUTS_DOCSTRING = r"""
600
+ Args:
601
+ input_features (`tf.Tensor` of shape `(batch_size, feature_size, sequence_length)`):
602
+ Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
603
+ by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.*
604
+ via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
605
+ [`AutoFeatureExtractor`] should be used for extracting the fbank features, padding and conversion into a
606
+ tensor of type `tf.Tensor`. See [`~WhisperFeatureExtractor.__call__`]
607
+ decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
608
+ Indices of decoder input sequence tokens in the vocabulary.
609
+
610
+ Indices can be obtained using [`SpeechToTextTokenizer`]. See [`PreTrainedTokenizer.encode`] and
611
+ [`PreTrainedTokenizer.__call__`] for details.
612
+
613
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
614
+
615
+ SpeechToText uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
616
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
617
+ `past_key_values`).
618
+ decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
619
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
620
+ be used by default.
621
+
622
+ If you want to change padding behavior, you should read
623
+ [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the
624
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
625
+ head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
626
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
627
+
628
+ - 1 indicates the head is **not masked**,
629
+ - 0 indicates the head is **masked**.
630
+
631
+ decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
632
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
633
+
634
+ - 1 indicates the head is **not masked**,
635
+ - 0 indicates the head is **masked**.
636
+
637
+ cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
638
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
639
+
640
+ - 1 indicates the head is **not masked**,
641
+ - 0 indicates the head is **masked**.
642
+
643
+ encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*):
644
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
645
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
646
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
647
+ past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
648
+ Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
649
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
650
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
651
+
652
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
653
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
654
+
655
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
656
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
657
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
658
+ decoder_inputs_embeds (`tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
659
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
660
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
661
+ input (see `past_key_values`). This is useful if you want more control over how to convert
662
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
663
+ use_cache (`bool`, *optional*):
664
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
665
+ `past_key_values`).
666
+ output_attentions (`bool`, *optional*):
667
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
668
+ tensors for more detail.
669
+ output_hidden_states (`bool`, *optional*):
670
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
671
+ more detail.
672
+ return_dict (`bool`, *optional*):
673
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
674
+ """
675
+
676
+
677
+ @keras_serializable
678
+ class TFWhisperEncoder(keras.layers.Layer):
679
+ config_class = WhisperConfig
680
+ """
681
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
682
+ [`TFWhisperEncoderLayer`].
683
+
684
+ Args:
685
+ config: WhisperConfig
686
+ embed_tokens (TFWhisperEmbedding): output embedding
687
+ """
688
+
689
+ def __init__(self, config: WhisperConfig, **kwargs):
690
+ super().__init__(**kwargs)
691
+ self.config = config
692
+ self.layerdrop = config.encoder_layerdrop
693
+
694
+ self.embed_dim = config.d_model
695
+ self.num_mel_bins = config.num_mel_bins
696
+ self.padding_idx = config.pad_token_id
697
+ self.max_source_positions = config.max_source_positions
698
+ self.embed_scale = math.sqrt(self.embed_dim) if config.scale_embedding else 1.0
699
+
700
+ # Padding is added in call() to match the PyTorch implementation
701
+ self.conv1 = keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=1, padding="valid", name="conv1")
702
+ self.conv2 = keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2")
703
+
704
+ self.embed_positions = TFWhisperPositionalEmbedding(
705
+ num_positions=self.max_source_positions,
706
+ embedding_dim=self.embed_dim,
707
+ embedding_initializer=sinusoidal_embedding_init,
708
+ name="embed_positions",
709
+ )
710
+ self.embed_positions.trainable = False
711
+
712
+ self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
713
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
714
+
715
+ self.dropout = keras.layers.Dropout(config.dropout)
716
+
717
+ @unpack_inputs
718
+ def call(
719
+ self,
720
+ input_features=None,
721
+ head_mask=None,
722
+ output_attentions=None,
723
+ output_hidden_states=None,
724
+ return_dict=None,
725
+ training=False,
726
+ ):
727
+ r"""
728
+ Args:
729
+ input_features (`tf.Tensor` of shape `(batch_size, feature_size, sequence_length)`):
730
+ Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be
731
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
732
+ `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
733
+ `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the fbank features,
734
+ padding and conversion into a tensor of type `tf.Tensor`. See [`~WhisperFeatureExtractor.__call__`]
735
+ head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
736
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
737
+
738
+ - 1 indicates the head is **not masked**,
739
+ - 0 indicates the head is **masked**.
740
+
741
+ output_attentions (`bool`, *optional*):
742
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
743
+ returned tensors for more detail.
744
+ output_hidden_states (`bool`, *optional*):
745
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
746
+ for more detail.
747
+ return_dict (`bool`, *optional*):
748
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
749
+ """
750
+
751
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
752
+ output_hidden_states = (
753
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
754
+ )
755
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
756
+
757
+ # TF 2.0 layers can't use channels first format when running on CPU.
758
+ input_features = tf.transpose(input_features, perm=(0, 2, 1))
759
+ input_features = tf.pad(input_features, [[0, 0], [1, 1], [0, 0]])
760
+ inputs_embeds = keras.activations.gelu(self.conv1(input_features))
761
+ inputs_embeds = tf.pad(inputs_embeds, [[0, 0], [1, 1], [0, 0]])
762
+ inputs_embeds = keras.activations.gelu(self.conv2(inputs_embeds))
763
+ inputs_embeds = tf.transpose(inputs_embeds, perm=(0, 1, 2))
764
+
765
+ embed_pos = self.embed_positions(input_ids=tf.zeros((1, self.max_source_positions), dtype=tf.int32))
766
+
767
+ hidden_states = inputs_embeds + embed_pos
768
+ hidden_states = self.dropout(hidden_states, training=training)
769
+
770
+ encoder_states = () if output_hidden_states else None
771
+ all_attentions = () if output_attentions else None
772
+
773
+ # check if head_mask has a correct number of layers specified if desired
774
+ if head_mask is not None:
775
+ tf.debugging.assert_equal(
776
+ shape_list(head_mask)[0],
777
+ len(self.encoder_layers),
778
+ message=(
779
+ f"The head_mask should be specified for {len(self.encoder_layers)} layers, but it is for"
780
+ f" {shape_list(head_mask)[0]}."
781
+ ),
782
+ )
783
+
784
+ for idx, encoder_layer in enumerate(self.encoder_layers):
785
+ if output_hidden_states:
786
+ encoder_states = encoder_states + (hidden_states,)
787
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
788
+ dropout_probability = random.uniform(0, 1)
789
+ if training and (dropout_probability < self.layerdrop): # skip the layer
790
+ continue
791
+
792
+ hidden_states, attn = encoder_layer(
793
+ hidden_states,
794
+ None,
795
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
796
+ training=training,
797
+ )
798
+
799
+ if output_attentions:
800
+ all_attentions += (attn,)
801
+
802
+ hidden_states = self.layer_norm(hidden_states)
803
+ if output_hidden_states:
804
+ encoder_states = encoder_states + (hidden_states,)
805
+
806
+ if not return_dict:
807
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
808
+ return TFBaseModelOutput(
809
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
810
+ )
811
+
812
+ def build(self, input_shape=None):
813
+ if self.built:
814
+ return
815
+ self.built = True
816
+ if getattr(self, "conv1", None) is not None:
817
+ with tf.name_scope(self.conv1.name):
818
+ self.conv1.build([None, None, self.num_mel_bins])
819
+ if getattr(self, "conv2", None) is not None:
820
+ with tf.name_scope(self.conv2.name):
821
+ self.conv2.build([None, None, self.embed_dim])
822
+ if getattr(self, "embed_positions", None) is not None:
823
+ with tf.name_scope(self.embed_positions.name):
824
+ self.embed_positions.build(None)
825
+ if getattr(self, "layer_norm", None) is not None:
826
+ with tf.name_scope(self.layer_norm.name):
827
+ self.layer_norm.build([None, None, self.config.d_model])
828
+ if getattr(self, "encoder_layers", None) is not None:
829
+ for layer in self.encoder_layers:
830
+ with tf.name_scope(layer.name):
831
+ layer.build(None)
832
+
833
+
834
+ @keras_serializable
835
+ class TFWhisperDecoder(keras.layers.Layer):
836
+ config_class = WhisperConfig
837
+ """
838
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFWhisperDecoderLayer`]
839
+
840
+ Args:
841
+ config: WhisperConfig
842
+ """
843
+
844
+ def __init__(self, config: WhisperConfig, **kwargs):
845
+ super().__init__(**kwargs)
846
+ self.config = config
847
+ self.dropout = keras.layers.Dropout(config.dropout)
848
+ self.layerdrop = config.decoder_layerdrop
849
+ self.padding_idx = config.pad_token_id
850
+ self.max_target_positions = config.max_target_positions
851
+ self.max_source_positions = config.max_source_positions
852
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
853
+
854
+ self.embed_tokens = keras.layers.Embedding(
855
+ input_dim=config.vocab_size,
856
+ output_dim=config.d_model,
857
+ embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),
858
+ name="embed_tokens",
859
+ )
860
+ self.embed_positions = TFWhisperPositionalEmbedding(
861
+ self.max_target_positions, config.d_model, name="embed_positions"
862
+ )
863
+
864
+ self.decoder_layers = [TFWhisperDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
865
+
866
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
867
+
868
+ def get_input_embeddings(self):
869
+ return self.embed_tokens
870
+
871
+ def set_input_embeddings(self, value):
872
+ self.embed_tokens = value
873
+
874
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
875
+ # create causal mask
876
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
877
+ batch_size, seq_len = input_shape[0], input_shape[1]
878
+
879
+ combined_attention_mask = tf.cond(
880
+ tf.math.greater(seq_len, 1),
881
+ lambda: _make_causal_mask(input_shape, past_key_values_length=past_key_values_length),
882
+ lambda: _expand_mask(tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len),
883
+ )
884
+
885
+ if attention_mask is not None:
886
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
887
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
888
+ combined_attention_mask = (
889
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
890
+ )
891
+ return combined_attention_mask
892
+
893
+ @unpack_inputs
894
+ def call(
895
+ self,
896
+ input_ids=None,
897
+ attention_mask=None,
898
+ position_ids=None,
899
+ encoder_hidden_states=None,
900
+ head_mask=None,
901
+ cross_attn_head_mask=None,
902
+ past_key_values=None,
903
+ inputs_embeds=None,
904
+ use_cache=None,
905
+ output_attentions=None,
906
+ output_hidden_states=None,
907
+ return_dict=None,
908
+ training=False,
909
+ ):
910
+ r"""
911
+ Args:
912
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
913
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
914
+ provide it.
915
+
916
+ Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
917
+ [`PreTrainedTokenizer.__call__`] for details.
918
+
919
+ [What are input IDs?](../glossary#input-ids)
920
+ attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
921
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
922
+
923
+ - 1 for tokens that are **not masked**,
924
+ - 0 for tokens that are **masked**.
925
+
926
+ [What are attention masks?](../glossary#attention-mask)
927
+ position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
928
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
929
+ range `[0, config.max_position_embeddings - 1]`.
930
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
931
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
932
+ of the decoder.
933
+ head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
934
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
935
+
936
+ - 1 indicates the head is **not masked**,
937
+ - 0 indicates the head is **masked**.
938
+
939
+ cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
940
+ Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
941
+ on hidden heads. Mask values selected in `[0, 1]`:
942
+
943
+ - 1 indicates the head is **not masked**,
944
+ - 0 indicates the head is **masked**.
945
+
946
+ past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
947
+ Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
948
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
949
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
950
+
951
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
952
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
953
+
954
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
955
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
956
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
957
+ inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
958
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
959
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
960
+ than the model's internal embedding lookup matrix.
961
+ output_attentions (`bool`, *optional*):
962
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
963
+ returned tensors for more detail.
964
+ output_hidden_states (`bool`, *optional*):
965
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
966
+ for more detail.
967
+ return_dict (`bool`, *optional*):
968
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
969
+ """
970
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
971
+ output_hidden_states = (
972
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
973
+ )
974
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
975
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
976
+
977
+ # retrieve input_ids and inputs_embeds
978
+ if input_ids is not None and inputs_embeds is not None:
979
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
980
+ elif input_ids is not None:
981
+ input_shape = tf.shape(input_ids)
982
+ input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
983
+ elif inputs_embeds is not None:
984
+ input_shape = tf.shape(inputs_embeds)[:-1]
985
+ else:
986
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
987
+
988
+ # past_key_values_length
989
+ past_key_values_length = tf.shape(past_key_values[0][0])[2] if past_key_values is not None else 0
990
+
991
+ if inputs_embeds is None:
992
+ check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
993
+ inputs_embeds = self.embed_tokens(input_ids)
994
+
995
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
996
+
997
+ # embed positions
998
+ filled_past_positions = past_key_values_length if position_ids is None else position_ids[0, -1]
999
+ positions = self.embed_positions(input_ids, past_key_values_length=filled_past_positions)
1000
+
1001
+ hidden_states = inputs_embeds + positions
1002
+ hidden_states = self.dropout(hidden_states, training=training)
1003
+
1004
+ # decoder layers
1005
+ all_hidden_states = () if output_hidden_states else None
1006
+ all_self_attns = () if output_attentions else None
1007
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1008
+ next_decoder_cache = () if use_cache else None
1009
+
1010
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1011
+ for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
1012
+ if attn_mask is not None:
1013
+ tf.debugging.assert_equal(
1014
+ shape_list(attn_mask)[0],
1015
+ len(self.decoder_layers),
1016
+ message=(
1017
+ f"The {attn_mask_name} should be specified for {len(self.decoder_layers)} layers, but it is"
1018
+ f" for {shape_list(attn_mask)[0]}."
1019
+ ),
1020
+ )
1021
+
1022
+ for idx, decoder_layer in enumerate(self.decoder_layers):
1023
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1024
+ if output_hidden_states:
1025
+ all_hidden_states += (hidden_states,)
1026
+ dropout_probability = random.uniform(0, 1)
1027
+ if training and (dropout_probability < self.layerdrop):
1028
+ continue
1029
+
1030
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1031
+
1032
+ layer_outputs = decoder_layer(
1033
+ hidden_states,
1034
+ attention_mask=attention_mask,
1035
+ encoder_hidden_states=encoder_hidden_states,
1036
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1037
+ cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
1038
+ past_key_value=past_key_value,
1039
+ training=training,
1040
+ )
1041
+ hidden_states = layer_outputs[0]
1042
+
1043
+ if use_cache:
1044
+ next_decoder_cache += (layer_outputs[3],)
1045
+
1046
+ if output_attentions:
1047
+ all_self_attns += (layer_outputs[1],)
1048
+
1049
+ if encoder_hidden_states is not None:
1050
+ all_cross_attentions += (layer_outputs[2],)
1051
+
1052
+ hidden_states = self.layer_norm(hidden_states)
1053
+ # add hidden states from the last decoder layer
1054
+ if output_hidden_states:
1055
+ all_hidden_states += (hidden_states,)
1056
+
1057
+ next_cache = next_decoder_cache if use_cache else None
1058
+ if not return_dict:
1059
+ return tuple(
1060
+ v
1061
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
1062
+ if v is not None
1063
+ )
1064
+ return TFBaseModelOutputWithPastAndCrossAttentions(
1065
+ last_hidden_state=hidden_states,
1066
+ past_key_values=next_cache,
1067
+ hidden_states=all_hidden_states,
1068
+ attentions=all_self_attns,
1069
+ cross_attentions=all_cross_attentions,
1070
+ )
1071
+
1072
+ def build(self, input_shape=None):
1073
+ if self.built:
1074
+ return
1075
+ self.built = True
1076
+ if getattr(self, "embed_tokens", None) is not None:
1077
+ with tf.name_scope(self.embed_tokens.name):
1078
+ self.embed_tokens.build(None)
1079
+ if getattr(self, "embed_positions", None) is not None:
1080
+ with tf.name_scope(self.embed_positions.name):
1081
+ self.embed_positions.build(None)
1082
+ if getattr(self, "layer_norm", None) is not None:
1083
+ with tf.name_scope(self.layer_norm.name):
1084
+ self.layer_norm.build([None, None, self.config.d_model])
1085
+ if getattr(self, "decoder_layers", None) is not None:
1086
+ for layer in self.decoder_layers:
1087
+ with tf.name_scope(layer.name):
1088
+ layer.build(None)
1089
+
1090
+
1091
+ @add_start_docstrings(
1092
+ "The bare Whisper Model outputting raw hidden-states without any specific head on top.",
1093
+ WHISPER_START_DOCSTRING,
1094
+ )
1095
+ @keras_serializable
1096
+ class TFWhisperMainLayer(keras.layers.Layer):
1097
+ config_class = WhisperConfig
1098
+
1099
+ def __init__(self, config: WhisperConfig, **kwargs):
1100
+ super().__init__(**kwargs)
1101
+ self.config = config
1102
+ self.encoder = TFWhisperEncoder(config, name="encoder")
1103
+ self.decoder = TFWhisperDecoder(config, name="decoder")
1104
+
1105
+ def get_input_embeddings(self):
1106
+ return self.decoder.embed_tokens
1107
+
1108
+ def set_input_embeddings(self, value):
1109
+ self.decoder.embed_tokens = value
1110
+
1111
+ def get_encoder(self):
1112
+ return self.encoder
1113
+
1114
+ def get_decoder(self):
1115
+ return self.decoder
1116
+
1117
+ @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
1118
+ @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1119
+ @unpack_inputs
1120
+ def call(
1121
+ self,
1122
+ input_features=None,
1123
+ decoder_input_ids=None,
1124
+ decoder_attention_mask=None,
1125
+ decoder_position_ids=None,
1126
+ head_mask=None,
1127
+ decoder_head_mask=None,
1128
+ cross_attn_head_mask=None,
1129
+ encoder_outputs=None,
1130
+ past_key_values=None,
1131
+ decoder_inputs_embeds=None,
1132
+ use_cache=None,
1133
+ output_attentions=None,
1134
+ output_hidden_states=None,
1135
+ return_dict=None,
1136
+ training=False,
1137
+ ):
1138
+ r"""
1139
+ Returns:
1140
+
1141
+ Example:
1142
+
1143
+ ```python
1144
+ >>> import tensorflow as tf
1145
+ >>> from transformers import TFWhisperModel, AutoFeatureExtractor
1146
+ >>> from datasets import load_dataset
1147
+
1148
+ >>> model = TFWhisperModel.from_pretrained("openai/whisper-base")
1149
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
1150
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1151
+ >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="tf")
1152
+ >>> input_features = inputs.input_features
1153
+ >>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id
1154
+ >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
1155
+ >>> list(last_hidden_state.shape)
1156
+ [1, 2, 512]
1157
+ ```"""
1158
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1159
+ output_hidden_states = (
1160
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1161
+ )
1162
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1163
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1164
+
1165
+ if encoder_outputs is None:
1166
+ encoder_outputs = self.encoder(
1167
+ input_features,
1168
+ head_mask=head_mask,
1169
+ output_attentions=output_attentions,
1170
+ output_hidden_states=output_hidden_states,
1171
+ return_dict=return_dict,
1172
+ training=training,
1173
+ )
1174
+ # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
1175
+ elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
1176
+ encoder_outputs = TFBaseModelOutput(
1177
+ last_hidden_state=encoder_outputs[0],
1178
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1179
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1180
+ )
1181
+
1182
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1183
+ decoder_outputs = self.decoder(
1184
+ input_ids=decoder_input_ids,
1185
+ attention_mask=decoder_attention_mask,
1186
+ position_ids=decoder_position_ids,
1187
+ encoder_hidden_states=encoder_outputs[0],
1188
+ head_mask=decoder_head_mask,
1189
+ cross_attn_head_mask=cross_attn_head_mask,
1190
+ past_key_values=past_key_values,
1191
+ inputs_embeds=decoder_inputs_embeds,
1192
+ use_cache=use_cache,
1193
+ output_attentions=output_attentions,
1194
+ output_hidden_states=output_hidden_states,
1195
+ return_dict=return_dict,
1196
+ training=training,
1197
+ )
1198
+
1199
+ if not return_dict:
1200
+ return decoder_outputs + encoder_outputs
1201
+
1202
+ return TFSeq2SeqModelOutput(
1203
+ last_hidden_state=decoder_outputs.last_hidden_state,
1204
+ past_key_values=decoder_outputs.past_key_values,
1205
+ decoder_hidden_states=decoder_outputs.hidden_states,
1206
+ decoder_attentions=decoder_outputs.attentions,
1207
+ cross_attentions=decoder_outputs.cross_attentions,
1208
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1209
+ encoder_hidden_states=encoder_outputs.hidden_states,
1210
+ encoder_attentions=encoder_outputs.attentions,
1211
+ )
1212
+
1213
+ def build(self, input_shape=None):
1214
+ if self.built:
1215
+ return
1216
+ self.built = True
1217
+ if getattr(self, "encoder", None) is not None:
1218
+ with tf.name_scope(self.encoder.name):
1219
+ self.encoder.build(None)
1220
+ if getattr(self, "decoder", None) is not None:
1221
+ with tf.name_scope(self.decoder.name):
1222
+ self.decoder.build(None)
1223
+
1224
+
1225
+ @add_start_docstrings(
1226
+ "The bare Whisper Model outputting raw hidden-states without any specific head on top.",
1227
+ WHISPER_START_DOCSTRING,
1228
+ )
1229
+ class TFWhisperModel(TFWhisperPreTrainedModel):
1230
+ def __init__(self, config: WhisperConfig, **kwargs):
1231
+ super().__init__(config, **kwargs)
1232
+
1233
+ self.model = TFWhisperMainLayer(config, name="model")
1234
+
1235
+ def get_input_embeddings(self):
1236
+ return self.model.decoder.embed_tokens
1237
+
1238
+ def set_input_embeddings(self, value):
1239
+ self.model.decoder.embed_tokens = value
1240
+
1241
+ def get_encoder(self):
1242
+ return self.model.encoder
1243
+
1244
+ def get_decoder(self):
1245
+ return self.model.decoder
1246
+
1247
+ def decoder(self):
1248
+ return self.model.decoder
1249
+
1250
+ def encoder(self):
1251
+ return self.model.encoder
1252
+
1253
+ @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
1254
+ @replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1255
+ @unpack_inputs
1256
+ def call(
1257
+ self,
1258
+ input_features: TFModelInputType | None = None,
1259
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
1260
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
1261
+ decoder_position_ids: np.ndarray | tf.Tensor | None = None,
1262
+ head_mask: np.ndarray | tf.Tensor | None = None,
1263
+ decoder_head_mask: np.ndarray | tf.Tensor | None = None,
1264
+ cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
1265
+ encoder_outputs: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
1266
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
1267
+ decoder_inputs_embeds: Optional[Tuple[Union[np.ndarray, tf.Tensor]]] = None,
1268
+ use_cache: Optional[bool] = None,
1269
+ output_attentions: Optional[bool] = None,
1270
+ output_hidden_states: Optional[bool] = None,
1271
+ return_dict: Optional[bool] = None,
1272
+ training: bool = False,
1273
+ ) -> Union[Tuple[tf.Tensor], TFSeq2SeqModelOutput]:
1274
+ r"""
1275
+ Returns:
1276
+
1277
+ Example:
1278
+
1279
+ ```python
1280
+ >>> import tensorflow as tf
1281
+ >>> from transformers import TFWhisperModel, AutoFeatureExtractor
1282
+ >>> from datasets import load_dataset
1283
+
1284
+ >>> model = TFWhisperModel.from_pretrained("openai/whisper-base")
1285
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
1286
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1287
+ >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="tf")
1288
+ >>> input_features = inputs.input_features
1289
+ >>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id
1290
+ >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
1291
+ >>> list(last_hidden_state.shape)
1292
+ [1, 2, 512]
1293
+ ```"""
1294
+ outputs = self.model(
1295
+ input_features=input_features,
1296
+ decoder_input_ids=decoder_input_ids,
1297
+ decoder_attention_mask=decoder_attention_mask,
1298
+ decoder_position_ids=decoder_position_ids,
1299
+ head_mask=head_mask,
1300
+ decoder_head_mask=decoder_head_mask,
1301
+ cross_attn_head_mask=cross_attn_head_mask,
1302
+ encoder_outputs=encoder_outputs,
1303
+ past_key_values=past_key_values,
1304
+ decoder_inputs_embeds=decoder_inputs_embeds,
1305
+ use_cache=use_cache,
1306
+ output_attentions=output_attentions,
1307
+ output_hidden_states=output_hidden_states,
1308
+ return_dict=return_dict,
1309
+ training=training,
1310
+ )
1311
+ return outputs
1312
+
1313
+ def serving_output(self, output):
1314
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
1315
+ dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
1316
+ dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
1317
+ cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
1318
+ enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
1319
+ enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
1320
+
1321
+ return TFSeq2SeqModelOutput(
1322
+ last_hidden_state=output.last_hidden_state,
1323
+ past_key_values=pkv,
1324
+ decoder_hidden_states=dec_hs,
1325
+ decoder_attentions=dec_attns,
1326
+ cross_attentions=cross_attns,
1327
+ encoder_last_hidden_state=output.encoder_last_hidden_state,
1328
+ encoder_hidden_states=enc_hs,
1329
+ encoder_attentions=enc_attns,
1330
+ )
1331
+
1332
+ def build(self, input_shape=None):
1333
+ if self.built:
1334
+ return
1335
+ self.built = True
1336
+ if getattr(self, "model", None) is not None:
1337
+ with tf.name_scope(self.model.name):
1338
+ self.model.build(None)
1339
+
1340
+
1341
+ @add_start_docstrings(
1342
+ "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
1343
+ WHISPER_START_DOCSTRING,
1344
+ )
1345
+ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLanguageModelingLoss):
1346
+ base_model_prefix = "model"
1347
+ _keys_to_ignore_on_load_missing = [
1348
+ r"encoder.version",
1349
+ r"decoder.version",
1350
+ r"proj_out.weight",
1351
+ ]
1352
+ _keys_to_ignore_on_save = [
1353
+ r"proj_out.weight",
1354
+ ]
1355
+
1356
+ def __init__(self, config: WhisperConfig, **kwargs):
1357
+ super().__init__(config, **kwargs)
1358
+ self.model = TFWhisperMainLayer(config, name="model")
1359
+
1360
+ def get_encoder(self):
1361
+ return self.model.get_encoder()
1362
+
1363
+ def get_decoder(self):
1364
+ return self.model.get_decoder()
1365
+
1366
+ def get_output_embeddings(self):
1367
+ return self.get_input_embeddings()
1368
+
1369
+ def set_output_embeddings(self, value):
1370
+ self.set_input_embeddings(value)
1371
+
1372
+ def resize_token_embeddings(self, new_num_tokens: int) -> keras.layers.Embedding:
1373
+ new_embeddings = super().resize_token_embeddings(new_num_tokens)
1374
+ return new_embeddings
1375
+
1376
+ @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
1377
+ @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1378
+ @unpack_inputs
1379
+ def call(
1380
+ self,
1381
+ input_features: TFModelInputType | None = None,
1382
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
1383
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
1384
+ decoder_position_ids: np.ndarray | tf.Tensor | None = None,
1385
+ head_mask: np.ndarray | tf.Tensor | None = None,
1386
+ decoder_head_mask: np.ndarray | tf.Tensor | None = None,
1387
+ cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
1388
+ encoder_outputs: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
1389
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
1390
+ decoder_inputs_embeds: Optional[Tuple[Union[np.ndarray, tf.Tensor]]] = None,
1391
+ labels: np.ndarray | tf.Tensor | None = None,
1392
+ use_cache: Optional[bool] = None,
1393
+ output_attentions: Optional[bool] = None,
1394
+ output_hidden_states: Optional[bool] = None,
1395
+ return_dict: Optional[bool] = None,
1396
+ training: bool = False,
1397
+ ) -> Union[Tuple[tf.Tensor], TFSeq2SeqLMOutput]:
1398
+ r"""
1399
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1400
+ Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
1401
+ or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
1402
+ only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1403
+
1404
+ Returns:
1405
+
1406
+ Example:
1407
+
1408
+ ```python
1409
+ >>> import tensorflow as tf
1410
+ >>> from transformers import AutoProcessor, TFWhisperForConditionalGeneration
1411
+ >>> from datasets import load_dataset
1412
+
1413
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
1414
+ >>> model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
1415
+
1416
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1417
+
1418
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf")
1419
+ >>> input_features = inputs.input_features
1420
+
1421
+ >>> generated_ids = model.generate(input_features=input_features)
1422
+
1423
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1424
+ >>> transcription
1425
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
1426
+ ```"""
1427
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1428
+
1429
+ if labels is not None:
1430
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1431
+ decoder_input_ids = shift_tokens_right(
1432
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1433
+ )
1434
+
1435
+ outputs = self.model(
1436
+ input_features,
1437
+ decoder_input_ids=decoder_input_ids,
1438
+ encoder_outputs=encoder_outputs,
1439
+ decoder_attention_mask=decoder_attention_mask,
1440
+ decoder_position_ids=decoder_position_ids,
1441
+ head_mask=head_mask,
1442
+ decoder_head_mask=decoder_head_mask,
1443
+ cross_attn_head_mask=cross_attn_head_mask,
1444
+ past_key_values=past_key_values,
1445
+ decoder_inputs_embeds=decoder_inputs_embeds,
1446
+ use_cache=use_cache,
1447
+ output_attentions=output_attentions,
1448
+ output_hidden_states=output_hidden_states,
1449
+ return_dict=return_dict,
1450
+ training=training,
1451
+ )
1452
+ decoder_last_hidden_state = outputs[0]
1453
+ # Decoder and encoder embeddings are tied
1454
+ lm_logits = tf.matmul(decoder_last_hidden_state, self.get_output_embeddings().weights, transpose_b=True)
1455
+
1456
+ loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
1457
+
1458
+ if not return_dict:
1459
+ output = (lm_logits,) + outputs[1:]
1460
+ return ((loss,) + output) if loss is not None else output
1461
+
1462
+ return TFSeq2SeqLMOutput(
1463
+ loss=loss,
1464
+ logits=lm_logits,
1465
+ past_key_values=outputs.past_key_values,
1466
+ decoder_hidden_states=outputs.decoder_hidden_states,
1467
+ decoder_attentions=outputs.decoder_attentions,
1468
+ cross_attentions=outputs.cross_attentions,
1469
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1470
+ encoder_hidden_states=outputs.encoder_hidden_states,
1471
+ encoder_attentions=outputs.encoder_attentions,
1472
+ )
1473
+
1474
+ def generate(
1475
+ self,
1476
+ inputs: Optional[tf.Tensor] = None,
1477
+ generation_config: Optional[GenerationConfig] = None,
1478
+ logits_processor: Optional[TFLogitsProcessorList] = None,
1479
+ seed: Optional[List[int]] = None,
1480
+ return_timestamps: Optional[bool] = None,
1481
+ task: Optional[str] = None,
1482
+ language: Optional[str] = None,
1483
+ is_multilingual: Optional[bool] = None,
1484
+ prompt_ids: Optional[tf.Tensor] = None,
1485
+ return_token_timestamps=None,
1486
+ **kwargs,
1487
+ ):
1488
+ r"""
1489
+ Generates sequences of token ids for models with a language modeling head.
1490
+
1491
+ <Tip warning={true}>
1492
+
1493
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
1494
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
1495
+ parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
1496
+
1497
+ For an overview of generation strategies and code examples, check out the [following
1498
+ guide](../generation_strategies).
1499
+
1500
+ </Tip>
1501
+
1502
+ Parameters:
1503
+ inputs (`tf.Tensor` of varying shape depending on the modality, *optional*):
1504
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If unset the method
1505
+ initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in
1506
+ the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`,
1507
+ `input_values`, `input_features`, or `pixel_values`.
1508
+ generation_config (`~generation.GenerationConfig`, *optional*):
1509
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
1510
+ passed to generate matching the attributes of `generation_config` will override them. If
1511
+ `generation_config` is not provided, the default will be used, which had the following loading
1512
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
1513
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
1514
+ default values, whose documentation should be checked to parameterize generation.
1515
+ logits_processor (`LogitsProcessorList`, *optional*):
1516
+ Custom logits processors that complement the default logits processors built from arguments and
1517
+ generation config. If a logit processor is passed that is already created with the arguments or a
1518
+ generation config an error is thrown. This feature is intended for advanced users.
1519
+ seed (`List[int]`, *optional*):
1520
+ Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
1521
+ `seed` argument from stateless functions in `tf.random`.
1522
+ return_timestamps (`bool`, *optional*):
1523
+ Whether to return the timestamps with the text. This enables the `TFWhisperTimestampsLogitsProcessor`.
1524
+ task (`str`, *optional*):
1525
+ Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
1526
+ will be updated accordingly.
1527
+ language (`str`, *optional*):
1528
+ Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can
1529
+ find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
1530
+ is_multilingual (`bool`, *optional*):
1531
+ Whether or not the model is multilingual.
1532
+ prompt_ids (`tf.Tensor`, *optional*):
1533
+ Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
1534
+ provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
1535
+ transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
1536
+ correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
1537
+ return_token_timestamps (`bool`, *optional*):
1538
+ Whether to return token-level timestamps with the text. This can be used with or without the
1539
+ `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
1540
+ words.
1541
+ kwargs (`Dict[str, Any]`, *optional*):
1542
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
1543
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
1544
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
1545
+
1546
+ Return:
1547
+ [`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when
1548
+ `config.return_dict_in_generate=True`) or a `tf.Tensor`.
1549
+
1550
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
1551
+ [`~utils.ModelOutput`] types are:
1552
+
1553
+ - [`~generation.TFGreedySearchDecoderOnlyOutput`],
1554
+ - [`~generation.TFSampleDecoderOnlyOutput`],
1555
+ - [`~generation.TFBeamSearchDecoderOnlyOutput`],
1556
+ - [`~generation.TFBeamSampleDecoderOnlyOutput`]
1557
+
1558
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
1559
+ [`~utils.ModelOutput`] types are:
1560
+
1561
+ - [`~generation.TFGreedySearchEncoderDecoderOutput`],
1562
+ - [`~generation.TFSampleEncoderDecoderOutput`],
1563
+ - [`~generation.TFBeamSearchEncoderDecoderOutput`],
1564
+ - [`~generation.TFBeamSampleEncoderDecoderOutput`]
1565
+
1566
+ """
1567
+ if generation_config is None:
1568
+ generation_config = self.generation_config
1569
+
1570
+ if return_timestamps is not None:
1571
+ if not hasattr(generation_config, "no_timestamps_token_id"):
1572
+ raise ValueError(
1573
+ "You are trying to return timestamps, but the generation config is not properly set. "
1574
+ "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
1575
+ "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
1576
+ )
1577
+
1578
+ generation_config.return_timestamps = return_timestamps
1579
+ else:
1580
+ generation_config.return_timestamps = False
1581
+
1582
+ if language is not None:
1583
+ language = language.lower()
1584
+ generation_config.language = language
1585
+ if task is not None:
1586
+ generation_config.task = task
1587
+
1588
+ forced_decoder_ids = None
1589
+
1590
+ # Legacy code for backward compatibility
1591
+ if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
1592
+ forced_decoder_ids = self.config.forced_decoder_ids
1593
+ elif (
1594
+ hasattr(self.generation_config, "forced_decoder_ids")
1595
+ and self.generation_config.forced_decoder_ids is not None
1596
+ ):
1597
+ forced_decoder_ids = self.generation_config.forced_decoder_ids
1598
+ else:
1599
+ forced_decoder_ids = kwargs.get("forced_decoder_ids", None)
1600
+
1601
+ if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
1602
+ forced_decoder_ids = []
1603
+ if hasattr(generation_config, "language"):
1604
+ if generation_config.language in generation_config.lang_to_id.keys():
1605
+ language_token = generation_config.language
1606
+ elif generation_config.language in TO_LANGUAGE_CODE.keys():
1607
+ language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
1608
+ elif generation_config.language in TO_LANGUAGE_CODE.values():
1609
+ language_token = f"<|{generation_config.language}|>"
1610
+ else:
1611
+ is_language_code = len(generation_config.language) == 2
1612
+ raise ValueError(
1613
+ f"Unsupported language: {generation_config.language}. Language should be one of:"
1614
+ f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
1615
+ )
1616
+ if language_token not in generation_config.lang_to_id:
1617
+ raise ValueError(
1618
+ f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
1619
+ "(You should just add it to the generation config)"
1620
+ )
1621
+ forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
1622
+ else:
1623
+ forced_decoder_ids.append((1, None)) # automatically detect the language
1624
+
1625
+ if hasattr(generation_config, "task"):
1626
+ if generation_config.task in TASK_IDS:
1627
+ forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
1628
+ else:
1629
+ raise ValueError(
1630
+ f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
1631
+ )
1632
+ elif hasattr(generation_config, "task_to_id"):
1633
+ forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
1634
+ if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
1635
+ idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
1636
+ forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
1637
+
1638
+ if forced_decoder_ids is not None:
1639
+ generation_config.forced_decoder_ids = forced_decoder_ids
1640
+
1641
+ if prompt_ids is not None:
1642
+ if kwargs.get("decoder_start_token_id") is not None:
1643
+ raise ValueError(
1644
+ "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
1645
+ )
1646
+ prompt_ids = prompt_ids.tolist()
1647
+ decoder_start_token_id, *text_prompt_ids = prompt_ids
1648
+ # Slicing the text prompt ids in a manner consistent with the OpenAI implementation
1649
+ # to accommodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
1650
+ text_prompt_ids = text_prompt_ids[-self.config.max_length // 2 - 1 :]
1651
+ # Set the decoder_start_token_id to <|startofprev|>
1652
+ kwargs.update({"decoder_start_token_id": decoder_start_token_id})
1653
+
1654
+ # Update the max generation length to include the prompt
1655
+ specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None)
1656
+ default_max_length = generation_config.max_new_tokens or generation_config.max_length
1657
+ non_prompt_max_length = specified_max_length or default_max_length
1658
+ kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids)
1659
+
1660
+ # Reformat the forced_decoder_ids to incorporate the prompt
1661
+ non_prompt_forced_decoder_ids = (
1662
+ kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
1663
+ )
1664
+ forced_decoder_ids = [
1665
+ *text_prompt_ids,
1666
+ generation_config.decoder_start_token_id,
1667
+ *[token for _rank, token in non_prompt_forced_decoder_ids],
1668
+ ]
1669
+ forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
1670
+ generation_config.forced_decoder_ids = forced_decoder_ids
1671
+
1672
+ # TODO: Implement `WhisperTimeStampLogitsProcessor`.
1673
+ if generation_config.return_timestamps:
1674
+ # logits_processor = [TFWhisperTimeStampLogitsProcessor(generation_config)]
1675
+ raise ValueError("`TFWhisperForConditionalGeneration` doesn't support returning the timestamps yet.")
1676
+
1677
+ if return_token_timestamps:
1678
+ kwargs["output_attentions"] = True
1679
+ kwargs["return_dict_in_generate"] = True
1680
+
1681
+ if getattr(generation_config, "task", None) == "translate":
1682
+ logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
1683
+ if not hasattr(generation_config, "alignment_heads"):
1684
+ raise ValueError(
1685
+ "Model generation config has no `alignment_heads`, token-level timestamps not available. "
1686
+ "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
1687
+ )
1688
+
1689
+ outputs = super().generate(
1690
+ inputs,
1691
+ generation_config,
1692
+ logits_processor,
1693
+ **kwargs,
1694
+ )
1695
+
1696
+ if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
1697
+ outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads)
1698
+
1699
+ return outputs
1700
+
1701
+ def serving_output(self, output):
1702
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
1703
+ dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
1704
+ dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
1705
+ cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
1706
+ enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
1707
+ enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
1708
+
1709
+ return TFSeq2SeqLMOutput(
1710
+ logits=output.logits,
1711
+ past_key_values=pkv,
1712
+ decoder_hidden_states=dec_hs,
1713
+ decoder_attentions=dec_attns,
1714
+ cross_attentions=cross_attns,
1715
+ encoder_last_hidden_state=output.encoder_last_hidden_state,
1716
+ encoder_hidden_states=enc_hs,
1717
+ encoder_attentions=enc_attns,
1718
+ )
1719
+
1720
+ def prepare_inputs_for_generation(
1721
+ self,
1722
+ decoder_input_ids,
1723
+ past_key_values=None,
1724
+ use_cache=None,
1725
+ encoder_outputs=None,
1726
+ attention_mask=None,
1727
+ decoder_attention_mask=None,
1728
+ **kwargs,
1729
+ ):
1730
+ # cut decoder_input_ids if past is used
1731
+ if past_key_values is not None:
1732
+ decoder_input_ids = decoder_input_ids[:, -1:]
1733
+
1734
+ if decoder_attention_mask is not None: # xla
1735
+ decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
1736
+ elif past_key_values is not None: # no xla + past
1737
+ decoder_position_ids = past_key_values[0][0].shape[2]
1738
+ else: # no xla + no past
1739
+ decoder_position_ids = tf.range(decoder_input_ids.shape[1])
1740
+ decoder_position_ids = tf.broadcast_to(decoder_position_ids, decoder_input_ids.shape)
1741
+
1742
+ return {
1743
+ "input_features": None, # Needs to be passed to make Keras.layer.__call__ happy
1744
+ "encoder_outputs": encoder_outputs,
1745
+ "past_key_values": past_key_values,
1746
+ "decoder_input_ids": decoder_input_ids,
1747
+ "use_cache": use_cache,
1748
+ "decoder_attention_mask": decoder_attention_mask,
1749
+ "decoder_position_ids": decoder_position_ids,
1750
+ }
1751
+
1752
+ def build(self, input_shape=None):
1753
+ if self.built:
1754
+ return
1755
+ self.built = True
1756
+ if getattr(self, "model", None) is not None:
1757
+ with tf.name_scope(self.model.name):
1758
+ self.model.build(None)
modeling_whisper (1).py ADDED
The diff for this file is too large to render. See raw diff
 
modeling_whisper.cpython-312 (1).pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4526e23b9c4686aa6e18e6c4e49c76bb1c06cc8c70ac6c84f5368cf281a5615
3
+ size 105050
modeling_whisper.cpython-312.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4526e23b9c4686aa6e18e6c4e49c76bb1c06cc8c70ac6c84f5368cf281a5615
3
+ size 105050
modeling_whisper.py ADDED
The diff for this file is too large to render. See raw diff
 
processing_whisper (1).py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Speech processor class for Whisper
17
+ """
18
+
19
+ from ...processing_utils import ProcessorMixin
20
+
21
+
22
+ class WhisperProcessor(ProcessorMixin):
23
+ r"""
24
+ Constructs a Whisper processor which wraps a Whisper feature extractor and a Whisper tokenizer into a single
25
+ processor.
26
+
27
+ [`WhisperProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and [`WhisperTokenizer`]. See
28
+ the [`~WhisperProcessor.__call__`] and [`~WhisperProcessor.decode`] for more information.
29
+
30
+ Args:
31
+ feature_extractor (`WhisperFeatureExtractor`):
32
+ An instance of [`WhisperFeatureExtractor`]. The feature extractor is a required input.
33
+ tokenizer (`WhisperTokenizer`):
34
+ An instance of [`WhisperTokenizer`]. The tokenizer is a required input.
35
+ """
36
+
37
+ feature_extractor_class = "WhisperFeatureExtractor"
38
+ tokenizer_class = "WhisperTokenizer"
39
+
40
+ def __init__(self, feature_extractor, tokenizer):
41
+ super().__init__(feature_extractor, tokenizer)
42
+ self.current_processor = self.feature_extractor
43
+ self._in_target_context_manager = False
44
+
45
+ def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
46
+ return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
47
+
48
+ def __call__(self, *args, **kwargs):
49
+ """
50
+ Forwards the `audio` argument to WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] and the `text`
51
+ argument to [`~WhisperTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
52
+ information.
53
+ """
54
+ # For backward compatibility
55
+ if self._in_target_context_manager:
56
+ return self.current_processor(*args, **kwargs)
57
+
58
+ audio = kwargs.pop("audio", None)
59
+ sampling_rate = kwargs.pop("sampling_rate", None)
60
+ text = kwargs.pop("text", None)
61
+ if len(args) > 0:
62
+ audio = args[0]
63
+ args = args[1:]
64
+
65
+ if audio is None and text is None:
66
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
67
+
68
+ if audio is not None:
69
+ inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
70
+ if text is not None:
71
+ encodings = self.tokenizer(text, **kwargs)
72
+
73
+ if text is None:
74
+ return inputs
75
+
76
+ elif audio is None:
77
+ return encodings
78
+ else:
79
+ inputs["labels"] = encodings["input_ids"]
80
+ return inputs
81
+
82
+ def batch_decode(self, *args, **kwargs):
83
+ """
84
+ This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
85
+ refer to the docstring of this method for more information.
86
+ """
87
+ return self.tokenizer.batch_decode(*args, **kwargs)
88
+
89
+ def decode(self, *args, **kwargs):
90
+ """
91
+ This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
92
+ the docstring of this method for more information.
93
+ """
94
+ return self.tokenizer.decode(*args, **kwargs)
95
+
96
+ def get_prompt_ids(self, text: str, return_tensors="np"):
97
+ return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors)
processing_whisper.cpython-312 (1).pyc ADDED
Binary file (4.28 kB). View file
 
processing_whisper.cpython-312.pyc ADDED
Binary file (4.28 kB). View file
 
processing_whisper.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Speech processor class for Whisper
17
+ """
18
+
19
+ from ...processing_utils import ProcessorMixin
20
+
21
+
22
+ class WhisperProcessor(ProcessorMixin):
23
+ r"""
24
+ Constructs a Whisper processor which wraps a Whisper feature extractor and a Whisper tokenizer into a single
25
+ processor.
26
+
27
+ [`WhisperProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and [`WhisperTokenizer`]. See
28
+ the [`~WhisperProcessor.__call__`] and [`~WhisperProcessor.decode`] for more information.
29
+
30
+ Args:
31
+ feature_extractor (`WhisperFeatureExtractor`):
32
+ An instance of [`WhisperFeatureExtractor`]. The feature extractor is a required input.
33
+ tokenizer (`WhisperTokenizer`):
34
+ An instance of [`WhisperTokenizer`]. The tokenizer is a required input.
35
+ """
36
+
37
+ feature_extractor_class = "WhisperFeatureExtractor"
38
+ tokenizer_class = "WhisperTokenizer"
39
+
40
+ def __init__(self, feature_extractor, tokenizer):
41
+ super().__init__(feature_extractor, tokenizer)
42
+ self.current_processor = self.feature_extractor
43
+ self._in_target_context_manager = False
44
+
45
+ def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
46
+ return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
47
+
48
+ def __call__(self, *args, **kwargs):
49
+ """
50
+ Forwards the `audio` argument to WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] and the `text`
51
+ argument to [`~WhisperTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
52
+ information.
53
+ """
54
+ # For backward compatibility
55
+ if self._in_target_context_manager:
56
+ return self.current_processor(*args, **kwargs)
57
+
58
+ audio = kwargs.pop("audio", None)
59
+ sampling_rate = kwargs.pop("sampling_rate", None)
60
+ text = kwargs.pop("text", None)
61
+ if len(args) > 0:
62
+ audio = args[0]
63
+ args = args[1:]
64
+
65
+ if audio is None and text is None:
66
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
67
+
68
+ if audio is not None:
69
+ inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
70
+ if text is not None:
71
+ encodings = self.tokenizer(text, **kwargs)
72
+
73
+ if text is None:
74
+ return inputs
75
+
76
+ elif audio is None:
77
+ return encodings
78
+ else:
79
+ inputs["labels"] = encodings["input_ids"]
80
+ return inputs
81
+
82
+ def batch_decode(self, *args, **kwargs):
83
+ """
84
+ This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
85
+ refer to the docstring of this method for more information.
86
+ """
87
+ return self.tokenizer.batch_decode(*args, **kwargs)
88
+
89
+ def decode(self, *args, **kwargs):
90
+ """
91
+ This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
92
+ the docstring of this method for more information.
93
+ """
94
+ return self.tokenizer.decode(*args, **kwargs)
95
+
96
+ def get_prompt_ids(self, text: str, return_tensors="np"):
97
+ return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors)
realtime-whisper-webgpu/.eslintrc.cjs ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ module.exports = {
2
+ root: true,
3
+ env: { browser: true, es2020: true },
4
+ extends: [
5
+ "eslint:recommended",
6
+ "plugin:react/recommended",
7
+ "plugin:react/jsx-runtime",
8
+ "plugin:react-hooks/recommended",
9
+ ],
10
+ ignorePatterns: ["dist", ".eslintrc.cjs"],
11
+ parserOptions: { ecmaVersion: "latest", sourceType: "module" },
12
+ settings: { react: { version: "18.2" } },
13
+ plugins: ["react-refresh"],
14
+ rules: {
15
+ "react-refresh/only-export-components": [
16
+ "warn",
17
+ { allowConstantExport: true },
18
+ ],
19
+ "react/prop-types": "off",
20
+ },
21
+ };
realtime-whisper-webgpu/.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ pnpm-debug.log*
8
+ lerna-debug.log*
9
+
10
+ node_modules
11
+ dist
12
+ dist-ssr
13
+ *.local
14
+
15
+ # Editor directories and files
16
+ .vscode/*
17
+ !.vscode/extensions.json
18
+ .idea
19
+ .DS_Store
20
+ *.suo
21
+ *.ntvs*
22
+ *.njsproj
23
+ *.sln
24
+ *.sw?
realtime-whisper-webgpu/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # React + Vite
2
+
3
+ This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
4
+
5
+ Currently, two official plugins are available:
6
+
7
+ - [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh
8
+ - [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
realtime-whisper-webgpu/index.html ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/png" href="/logo.png" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>Whisper WebGPU</title>
8
+ </head>
9
+ <body>
10
+ <div id="root"></div>
11
+ <script type="module" src="/src/main.jsx"></script>
12
+ </body>
13
+ </html>
realtime-whisper-webgpu/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
realtime-whisper-webgpu/package.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "realtime-whisper-webgpu",
3
+ "private": true,
4
+ "version": "0.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite",
8
+ "build": "vite build",
9
+ "lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
10
+ "preview": "vite preview"
11
+ },
12
+ "dependencies": {
13
+ "@huggingface/transformers": "3.0.0",
14
+ "react": "^18.2.0",
15
+ "react-dom": "^18.2.0"
16
+ },
17
+ "devDependencies": {
18
+ "@types/react": "^18.2.43",
19
+ "@types/react-dom": "^18.2.17",
20
+ "@vitejs/plugin-react": "^4.2.1",
21
+ "autoprefixer": "^10.4.19",
22
+ "eslint": "^8.55.0",
23
+ "eslint-plugin-react": "^7.33.2",
24
+ "eslint-plugin-react-hooks": "^4.6.0",
25
+ "eslint-plugin-react-refresh": "^0.4.5",
26
+ "postcss": "^8.4.38",
27
+ "tailwindcss": "^3.4.3",
28
+ "vite": "^5.2.11"
29
+ }
30
+ }
realtime-whisper-webgpu/postcss.config.js ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ export default {
2
+ plugins: {
3
+ tailwindcss: {},
4
+ autoprefixer: {},
5
+ },
6
+ };
realtime-whisper-webgpu/public/banner.png ADDED

Git LFS Details

  • SHA256: 9cf25d62289d9e499ab0e023bc8586694f5a0a9434bbccc1f2021a32acc1f28f
  • Pointer size: 131 Bytes
  • Size of remote file: 274 kB
realtime-whisper-webgpu/public/logo.png ADDED

Git LFS Details

  • SHA256: 36bf18d7461cc51bdc78e7c322dad09bb781c0525af2c0b56ef97d78a0f2f207
  • Pointer size: 131 Bytes
  • Size of remote file: 321 kB
realtime-whisper-webgpu/public/realtime-whisper-webgpu.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:523ddcfcab7baaca8ee13187a6074e7cbe4422359b1208fa8d8228095c0ee9aa
3
+ size 15302693
realtime-whisper-webgpu/src/App.jsx ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useState, useRef } from "react";
2
+
3
+ import { AudioVisualizer } from "./components/AudioVisualizer";
4
+ import Progress from "./components/Progress";
5
+ import { LanguageSelector } from "./components/LanguageSelector";
6
+
7
+ const IS_WEBGPU_AVAILABLE = !!navigator.gpu;
8
+
9
+ const WHISPER_SAMPLING_RATE = 16_000;
10
+ const MAX_AUDIO_LENGTH = 30; // seconds
11
+ const MAX_SAMPLES = WHISPER_SAMPLING_RATE * MAX_AUDIO_LENGTH;
12
+
13
+ function App() {
14
+ // Create a reference to the worker object.
15
+ const worker = useRef(null);
16
+
17
+ const recorderRef = useRef(null);
18
+
19
+ // Model loading and progress
20
+ const [status, setStatus] = useState(null);
21
+ const [loadingMessage, setLoadingMessage] = useState("");
22
+ const [progressItems, setProgressItems] = useState([]);
23
+
24
+ // Inputs and outputs
25
+ const [text, setText] = useState("");
26
+ const [tps, setTps] = useState(null);
27
+ const [language, setLanguage] = useState("en");
28
+
29
+ // Processing
30
+ const [recording, setRecording] = useState(false);
31
+ const [isProcessing, setIsProcessing] = useState(false);
32
+ const [chunks, setChunks] = useState([]);
33
+ const [stream, setStream] = useState(null);
34
+ const audioContextRef = useRef(null);
35
+
36
+ // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
37
+ useEffect(() => {
38
+ if (!worker.current) {
39
+ // Create the worker if it does not yet exist.
40
+ worker.current = new Worker(new URL("./worker.js", import.meta.url), {
41
+ type: "module",
42
+ });
43
+ }
44
+
45
+ // Create a callback function for messages from the worker thread.
46
+ const onMessageReceived = (e) => {
47
+ switch (e.data.status) {
48
+ case "loading":
49
+ // Model file start load: add a new progress item to the list.
50
+ setStatus("loading");
51
+ setLoadingMessage(e.data.data);
52
+ break;
53
+
54
+ case "initiate":
55
+ setProgressItems((prev) => [...prev, e.data]);
56
+ break;
57
+
58
+ case "progress":
59
+ // Model file progress: update one of the progress items.
60
+ setProgressItems((prev) =>
61
+ prev.map((item) => {
62
+ if (item.file === e.data.file) {
63
+ return { ...item, ...e.data };
64
+ }
65
+ return item;
66
+ }),
67
+ );
68
+ break;
69
+
70
+ case "done":
71
+ // Model file loaded: remove the progress item from the list.
72
+ setProgressItems((prev) =>
73
+ prev.filter((item) => item.file !== e.data.file),
74
+ );
75
+ break;
76
+
77
+ case "ready":
78
+ // Pipeline ready: the worker is ready to accept messages.
79
+ setStatus("ready");
80
+ recorderRef.current?.start();
81
+ break;
82
+
83
+ case "start":
84
+ {
85
+ // Start generation
86
+ setIsProcessing(true);
87
+
88
+ // Request new data from the recorder
89
+ recorderRef.current?.requestData();
90
+ }
91
+ break;
92
+
93
+ case "update":
94
+ {
95
+ // Generation update: update the output text.
96
+ const { tps } = e.data;
97
+ setTps(tps);
98
+ }
99
+ break;
100
+
101
+ case "complete":
102
+ // Generation complete: re-enable the "Generate" button
103
+ setIsProcessing(false);
104
+ setText(e.data.output);
105
+ break;
106
+ }
107
+ };
108
+
109
+ // Attach the callback function as an event listener.
110
+ worker.current.addEventListener("message", onMessageReceived);
111
+
112
+ // Define a cleanup function for when the component is unmounted.
113
+ return () => {
114
+ worker.current.removeEventListener("message", onMessageReceived);
115
+ };
116
+ }, []);
117
+
118
+ useEffect(() => {
119
+ if (recorderRef.current) return; // Already set
120
+
121
+ if (navigator.mediaDevices.getUserMedia) {
122
+ navigator.mediaDevices
123
+ .getUserMedia({ audio: true })
124
+ .then((stream) => {
125
+ setStream(stream);
126
+
127
+ recorderRef.current = new MediaRecorder(stream);
128
+ audioContextRef.current = new AudioContext({
129
+ sampleRate: WHISPER_SAMPLING_RATE,
130
+ });
131
+
132
+ recorderRef.current.onstart = () => {
133
+ setRecording(true);
134
+ setChunks([]);
135
+ };
136
+ recorderRef.current.ondataavailable = (e) => {
137
+ if (e.data.size > 0) {
138
+ setChunks((prev) => [...prev, e.data]);
139
+ } else {
140
+ // Empty chunk received, so we request new data after a short timeout
141
+ setTimeout(() => {
142
+ recorderRef.current.requestData();
143
+ }, 25);
144
+ }
145
+ };
146
+
147
+ recorderRef.current.onstop = () => {
148
+ setRecording(false);
149
+ };
150
+ })
151
+ .catch((err) => console.error("The following error occurred: ", err));
152
+ } else {
153
+ console.error("getUserMedia not supported on your browser!");
154
+ }
155
+
156
+ return () => {
157
+ recorderRef.current?.stop();
158
+ recorderRef.current = null;
159
+ };
160
+ }, []);
161
+
162
+ useEffect(() => {
163
+ if (!recorderRef.current) return;
164
+ if (!recording) return;
165
+ if (isProcessing) return;
166
+ if (status !== "ready") return;
167
+
168
+ if (chunks.length > 0) {
169
+ // Generate from data
170
+ const blob = new Blob(chunks, { type: recorderRef.current.mimeType });
171
+
172
+ const fileReader = new FileReader();
173
+
174
+ fileReader.onloadend = async () => {
175
+ const arrayBuffer = fileReader.result;
176
+ const decoded =
177
+ await audioContextRef.current.decodeAudioData(arrayBuffer);
178
+ let audio = decoded.getChannelData(0);
179
+ if (audio.length > MAX_SAMPLES) {
180
+ // Get last MAX_SAMPLES
181
+ audio = audio.slice(-MAX_SAMPLES);
182
+ }
183
+
184
+ worker.current.postMessage({
185
+ type: "generate",
186
+ data: { audio, language },
187
+ });
188
+ };
189
+ fileReader.readAsArrayBuffer(blob);
190
+ } else {
191
+ recorderRef.current?.requestData();
192
+ }
193
+ }, [status, recording, isProcessing, chunks, language]);
194
+
195
+ return IS_WEBGPU_AVAILABLE ? (
196
+ <div className="flex flex-col h-screen mx-auto justify-end text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900">
197
+ {
198
+ <div className="h-full overflow-auto scrollbar-thin flex justify-center items-center flex-col relative">
199
+ <div className="flex flex-col items-center mb-1 max-w-[400px] text-center">
200
+ <img
201
+ src="logo.png"
202
+ width="50%"
203
+ height="auto"
204
+ className="block"
205
+ ></img>
206
+ <h1 className="text-4xl font-bold mb-1">Whisper WebGPU</h1>
207
+ <h2 className="text-xl font-semibold">
208
+ Real-time in-browser speech recognition
209
+ </h2>
210
+ </div>
211
+
212
+ <div className="flex flex-col items-center px-4">
213
+ {status === null && (
214
+ <>
215
+ <p className="max-w-[480px] mb-4">
216
+ <br />
217
+ You are about to load{" "}
218
+ <a
219
+ href="https://huggingface.co/onnx-community/whisper-base"
220
+ target="_blank"
221
+ rel="noreferrer"
222
+ className="font-medium underline"
223
+ >
224
+ whisper-base
225
+ </a>
226
+ , a 73 million parameter speech recognition model that is
227
+ optimized for inference on the web. Once downloaded, the model
228
+ (~200&nbsp;MB) will be cached and reused when you revisit the
229
+ page.
230
+ <br />
231
+ <br />
232
+ Everything runs directly in your browser using{" "}
233
+ <a
234
+ href="https://huggingface.co/docs/transformers.js"
235
+ target="_blank"
236
+ rel="noreferrer"
237
+ className="underline"
238
+ >
239
+ 🤗&nbsp;Transformers.js
240
+ </a>{" "}
241
+ and ONNX Runtime Web, meaning no data is sent to a server. You
242
+ can even disconnect from the internet after the model has
243
+ loaded!
244
+ </p>
245
+
246
+ <button
247
+ className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
248
+ onClick={() => {
249
+ worker.current.postMessage({ type: "load" });
250
+ setStatus("loading");
251
+ }}
252
+ disabled={status !== null}
253
+ >
254
+ Load model
255
+ </button>
256
+ </>
257
+ )}
258
+
259
+ <div className="w-[500px] p-2">
260
+ <AudioVisualizer className="w-full rounded-lg" stream={stream} />
261
+ {status === "ready" && (
262
+ <div className="relative">
263
+ <p className="w-full h-[80px] overflow-y-auto overflow-wrap-anywhere border rounded-lg p-2">
264
+ {text}
265
+ </p>
266
+ {tps && (
267
+ <span className="absolute bottom-0 right-0 px-1">
268
+ {tps.toFixed(2)} tok/s
269
+ </span>
270
+ )}
271
+ </div>
272
+ )}
273
+ </div>
274
+ {status === "ready" && (
275
+ <div className="relative w-full flex justify-center">
276
+ <LanguageSelector
277
+ language={language}
278
+ setLanguage={(e) => {
279
+ recorderRef.current?.stop();
280
+ setLanguage(e);
281
+ recorderRef.current?.start();
282
+ }}
283
+ />
284
+ <button
285
+ className="border rounded-lg px-2 absolute right-2"
286
+ onClick={() => {
287
+ recorderRef.current?.stop();
288
+ recorderRef.current?.start();
289
+ }}
290
+ >
291
+ Reset
292
+ </button>
293
+ </div>
294
+ )}
295
+ {status === "loading" && (
296
+ <div className="w-full max-w-[500px] text-left mx-auto p-4">
297
+ <p className="text-center">{loadingMessage}</p>
298
+ {progressItems.map(({ file, progress, total }, i) => (
299
+ <Progress
300
+ key={i}
301
+ text={file}
302
+ percentage={progress}
303
+ total={total}
304
+ />
305
+ ))}
306
+ </div>
307
+ )}
308
+ </div>
309
+ </div>
310
+ }
311
+ </div>
312
+ ) : (
313
+ <div className="fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] text-white text-2xl font-semibold flex justify-center items-center text-center">
314
+ WebGPU is not supported
315
+ <br />
316
+ by this browser :&#40;
317
+ </div>
318
+ );
319
+ }
320
+
321
+ export default App;
realtime-whisper-webgpu/src/components/AudioVisualizer.jsx ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useRef, useCallback, useEffect } from "react";
2
+
3
+ export function AudioVisualizer({ stream, ...props }) {
4
+ const canvasRef = useRef(null);
5
+
6
+ const visualize = useCallback((stream) => {
7
+ const audioContext = new (window.AudioContext ||
8
+ window.webkitAudioContext)();
9
+ const source = audioContext.createMediaStreamSource(stream);
10
+ const analyser = audioContext.createAnalyser();
11
+ analyser.fftSize = 2048;
12
+ source.connect(analyser);
13
+
14
+ const canvas = canvasRef.current;
15
+ const canvasCtx = canvas.getContext("2d");
16
+ const bufferLength = analyser.frequencyBinCount;
17
+ const dataArray = new Uint8Array(bufferLength);
18
+
19
+ const drawVisual = () => {
20
+ requestAnimationFrame(drawVisual);
21
+ analyser.getByteTimeDomainData(dataArray);
22
+
23
+ canvasCtx.fillStyle = "rgb(255, 255, 255)";
24
+ canvasCtx.fillRect(0, 0, canvas.width, canvas.height);
25
+
26
+ canvasCtx.lineWidth = 2;
27
+ canvasCtx.strokeStyle = "rgb(0, 0, 0)";
28
+ canvasCtx.beginPath();
29
+
30
+ const sliceWidth = (canvas.width * 1.0) / bufferLength;
31
+
32
+ let x = 0;
33
+ for (let i = 0; i < bufferLength; ++i) {
34
+ const v = dataArray[i] / 128.0;
35
+ const y = (v * canvas.height) / 2;
36
+
37
+ if (i === 0) {
38
+ canvasCtx.moveTo(x, y);
39
+ } else {
40
+ canvasCtx.lineTo(x, y);
41
+ }
42
+
43
+ x += sliceWidth;
44
+ }
45
+
46
+ canvasCtx.lineTo(canvas.width, canvas.height / 2);
47
+ canvasCtx.stroke();
48
+ };
49
+
50
+ drawVisual();
51
+ }, []);
52
+
53
+ useEffect(() => {
54
+ stream && visualize(stream);
55
+ }, [visualize, stream]);
56
+ return <canvas {...props} width={720} height={240} ref={canvasRef}></canvas>;
57
+ }
realtime-whisper-webgpu/src/components/LanguageSelector.jsx ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function titleCase(str) {
2
+ str = str.toLowerCase();
3
+ return (str.match(/\w+.?/g) || [])
4
+ .map((word) => {
5
+ return word.charAt(0).toUpperCase() + word.slice(1);
6
+ })
7
+ .join("");
8
+ }
9
+
10
+ // List of supported languages:
11
+ // https://help.openai.com/en/articles/7031512-whisper-api-faq
12
+ // https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L79
13
+ const LANGUAGES = {
14
+ en: "english",
15
+ zh: "chinese",
16
+ de: "german",
17
+ es: "spanish/castilian",
18
+ ru: "russian",
19
+ ko: "korean",
20
+ fr: "french",
21
+ ja: "japanese",
22
+ pt: "portuguese",
23
+ tr: "turkish",
24
+ pl: "polish",
25
+ ca: "catalan/valencian",
26
+ nl: "dutch/flemish",
27
+ ar: "arabic",
28
+ sv: "swedish",
29
+ it: "italian",
30
+ id: "indonesian",
31
+ hi: "hindi",
32
+ fi: "finnish",
33
+ vi: "vietnamese",
34
+ he: "hebrew",
35
+ uk: "ukrainian",
36
+ el: "greek",
37
+ ms: "malay",
38
+ cs: "czech",
39
+ ro: "romanian/moldavian/moldovan",
40
+ da: "danish",
41
+ hu: "hungarian",
42
+ ta: "tamil",
43
+ no: "norwegian",
44
+ th: "thai",
45
+ ur: "urdu",
46
+ hr: "croatian",
47
+ bg: "bulgarian",
48
+ lt: "lithuanian",
49
+ la: "latin",
50
+ mi: "maori",
51
+ ml: "malayalam",
52
+ cy: "welsh",
53
+ sk: "slovak",
54
+ te: "telugu",
55
+ fa: "persian",
56
+ lv: "latvian",
57
+ bn: "bengali",
58
+ sr: "serbian",
59
+ az: "azerbaijani",
60
+ sl: "slovenian",
61
+ kn: "kannada",
62
+ et: "estonian",
63
+ mk: "macedonian",
64
+ br: "breton",
65
+ eu: "basque",
66
+ is: "icelandic",
67
+ hy: "armenian",
68
+ ne: "nepali",
69
+ mn: "mongolian",
70
+ bs: "bosnian",
71
+ kk: "kazakh",
72
+ sq: "albanian",
73
+ sw: "swahili",
74
+ gl: "galician",
75
+ mr: "marathi",
76
+ pa: "punjabi/panjabi",
77
+ si: "sinhala/sinhalese",
78
+ km: "khmer",
79
+ sn: "shona",
80
+ yo: "yoruba",
81
+ so: "somali",
82
+ af: "afrikaans",
83
+ oc: "occitan",
84
+ ka: "georgian",
85
+ be: "belarusian",
86
+ tg: "tajik",
87
+ sd: "sindhi",
88
+ gu: "gujarati",
89
+ am: "amharic",
90
+ yi: "yiddish",
91
+ lo: "lao",
92
+ uz: "uzbek",
93
+ fo: "faroese",
94
+ ht: "haitian creole/haitian",
95
+ ps: "pashto/pushto",
96
+ tk: "turkmen",
97
+ nn: "nynorsk",
98
+ mt: "maltese",
99
+ sa: "sanskrit",
100
+ lb: "luxembourgish/letzeburgesch",
101
+ my: "myanmar/burmese",
102
+ bo: "tibetan",
103
+ tl: "tagalog",
104
+ mg: "malagasy",
105
+ as: "assamese",
106
+ tt: "tatar",
107
+ haw: "hawaiian",
108
+ ln: "lingala",
109
+ ha: "hausa",
110
+ ba: "bashkir",
111
+ jw: "javanese",
112
+ su: "sundanese",
113
+ };
114
+ export function LanguageSelector({ language, setLanguage }) {
115
+ const handleLanguageChange = (event) => {
116
+ setLanguage(event.target.value);
117
+ };
118
+
119
+ const names = Object.values(LANGUAGES).map(titleCase);
120
+
121
+ return (
122
+ <select
123
+ className="border rounded-lg p-2 max-w-[100px]"
124
+ value={language}
125
+ onChange={handleLanguageChange}
126
+ >
127
+ {Object.keys(LANGUAGES).map((key, i) => (
128
+ <option key={key} value={key}>
129
+ {names[i]}
130
+ </option>
131
+ ))}
132
+ </select>
133
+ );
134
+ }
realtime-whisper-webgpu/src/components/Progress.jsx ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function formatBytes(size) {
2
+ const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
3
+ return (
4
+ +(size / Math.pow(1024, i)).toFixed(2) * 1 +
5
+ ["B", "kB", "MB", "GB", "TB"][i]
6
+ );
7
+ }
8
+
9
+ export default function Progress({ text, percentage, total }) {
10
+ percentage ??= 0;
11
+ return (
12
+ <div className="w-full bg-gray-100 dark:bg-gray-700 text-left rounded-lg overflow-hidden mb-0.5">
13
+ <div
14
+ className="bg-blue-400 whitespace-nowrap px-1 text-sm"
15
+ style={{ width: `${percentage}%` }}
16
+ >
17
+ {text} ({percentage.toFixed(2)}%
18
+ {isNaN(total) ? "" : ` of ${formatBytes(total)}`})
19
+ </div>
20
+ </div>
21
+ );
22
+ }
realtime-whisper-webgpu/src/index.css ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @tailwind base;
2
+ @tailwind components;
3
+ @tailwind utilities;
4
+
5
+ @layer utilities {
6
+ .scrollbar-thin::-webkit-scrollbar {
7
+ @apply w-2;
8
+ }
9
+
10
+ .scrollbar-thin::-webkit-scrollbar-track {
11
+ @apply rounded-full bg-gray-100 dark:bg-gray-700;
12
+ }
13
+
14
+ .scrollbar-thin::-webkit-scrollbar-thumb {
15
+ @apply rounded-full bg-gray-300 dark:bg-gray-600;
16
+ }
17
+
18
+ .scrollbar-thin::-webkit-scrollbar-thumb:hover {
19
+ @apply bg-gray-500;
20
+ }
21
+
22
+ .animation-delay-200 {
23
+ animation-delay: 200ms;
24
+ }
25
+ .animation-delay-400 {
26
+ animation-delay: 400ms;
27
+ }
28
+
29
+ .overflow-wrap-anywhere {
30
+ overflow-wrap: anywhere;
31
+ }
32
+ }
realtime-whisper-webgpu/src/main.jsx ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from "react";
2
+ import ReactDOM from "react-dom/client";
3
+ import App from "./App.jsx";
4
+ import "./index.css";
5
+
6
+ ReactDOM.createRoot(document.getElementById("root")).render(
7
+ <React.StrictMode>
8
+ <App />
9
+ </React.StrictMode>,
10
+ );
realtime-whisper-webgpu/src/worker.js ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import {
2
+ AutoTokenizer,
3
+ AutoProcessor,
4
+ WhisperForConditionalGeneration,
5
+ TextStreamer,
6
+ full,
7
+ } from "@huggingface/transformers";
8
+
9
+ const MAX_NEW_TOKENS = 64;
10
+
11
+ /**
12
+ * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
13
+ */
14
+ class AutomaticSpeechRecognitionPipeline {
15
+ static model_id = "onnx-community/whisper-base";
16
+ static tokenizer = null;
17
+ static processor = null;
18
+ static model = null;
19
+
20
+ static async getInstance(progress_callback = null) {
21
+ this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
22
+ progress_callback,
23
+ });
24
+ this.processor ??= AutoProcessor.from_pretrained(this.model_id, {
25
+ progress_callback,
26
+ });
27
+
28
+ this.model ??= WhisperForConditionalGeneration.from_pretrained(
29
+ this.model_id,
30
+ {
31
+ dtype: {
32
+ encoder_model: "fp32", // 'fp16' works too
33
+ decoder_model_merged: "q4", // or 'fp32' ('fp16' is broken)
34
+ },
35
+ device: "webgpu",
36
+ progress_callback,
37
+ },
38
+ );
39
+
40
+ return Promise.all([this.tokenizer, this.processor, this.model]);
41
+ }
42
+ }
43
+
44
+ let processing = false;
45
+ async function generate({ audio, language }) {
46
+ if (processing) return;
47
+ processing = true;
48
+
49
+ // Tell the main thread we are starting
50
+ self.postMessage({ status: "start" });
51
+
52
+ // Retrieve the text-generation pipeline.
53
+ const [tokenizer, processor, model] =
54
+ await AutomaticSpeechRecognitionPipeline.getInstance();
55
+
56
+ let startTime;
57
+ let numTokens = 0;
58
+ let tps;
59
+ const token_callback_function = () => {
60
+ startTime ??= performance.now();
61
+
62
+ if (numTokens++ > 0) {
63
+ tps = (numTokens / (performance.now() - startTime)) * 1000;
64
+ }
65
+ };
66
+ const callback_function = (output) => {
67
+ self.postMessage({
68
+ status: "update",
69
+ output,
70
+ tps,
71
+ numTokens,
72
+ });
73
+ };
74
+
75
+ const streamer = new TextStreamer(tokenizer, {
76
+ skip_prompt: true,
77
+ skip_special_tokens: true,
78
+ callback_function,
79
+ token_callback_function,
80
+ });
81
+
82
+ const inputs = await processor(audio);
83
+
84
+ const outputs = await model.generate({
85
+ ...inputs,
86
+ max_new_tokens: MAX_NEW_TOKENS,
87
+ language,
88
+ streamer,
89
+ });
90
+
91
+ const decoded = tokenizer.batch_decode(outputs, {
92
+ skip_special_tokens: true,
93
+ });
94
+
95
+ // Send the output back to the main thread
96
+ self.postMessage({
97
+ status: "complete",
98
+ output: decoded,
99
+ });
100
+ processing = false;
101
+ }
102
+
103
+ async function load() {
104
+ self.postMessage({
105
+ status: "loading",
106
+ data: "Loading model...",
107
+ });
108
+
109
+ // Load the pipeline and save it for future use.
110
+ const [tokenizer, processor, model] =
111
+ await AutomaticSpeechRecognitionPipeline.getInstance((x) => {
112
+ // We also add a progress callback to the pipeline so that we can
113
+ // track model loading.
114
+ self.postMessage(x);
115
+ });
116
+
117
+ self.postMessage({
118
+ status: "loading",
119
+ data: "Compiling shaders and warming up model...",
120
+ });
121
+
122
+ // Run model with dummy input to compile shaders
123
+ await model.generate({
124
+ input_features: full([1, 80, 3000], 0.0),
125
+ max_new_tokens: 1,
126
+ });
127
+ self.postMessage({ status: "ready" });
128
+ }
129
+
130
+ // Listen for messages from the main thread
131
+ self.addEventListener("message", async (e) => {
132
+ const { type, data } = e.data;
133
+
134
+ switch (type) {
135
+ case "load":
136
+ load();
137
+ break;
138
+
139
+ case "generate":
140
+ generate(data);
141
+ break;
142
+ }
143
+ });
realtime-whisper-webgpu/tailwind.config.js ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ /** @type {import('tailwindcss').Config} */
2
+ export default {
3
+ content: ["./index.html", "./src/**/*.{js,ts,jsx,tsx}"],
4
+ theme: {
5
+ extend: {},
6
+ },
7
+ plugins: [],
8
+ };