File size: 18,060 Bytes
c76aba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d710fb8
 
c76aba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c2a145
c76aba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba9afe1
c76aba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba9afe1
c76aba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba9afe1
c76aba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
---
language:
- en
tags:
- audio
- automatic-speech-recognition
widget:
- example_title: LibriSpeech sample 1
  src: https://cdn-media.huggingface.co/speech_samples/sample1.flac
- example_title: LibriSpeech sample 2
  src: https://cdn-media.huggingface.co/speech_samples/sample2.flac
pipeline_tag: automatic-speech-recognition
license: mit
library_name: transformers
---

# Distil-Whisper: distil-medium.en

Distil-Whisper was proposed in the paper [Robust Knowledge Distillation via Large-Scale Pseudo Labelling](https://arxiv.org/abs/2311.00430).

It is a distilled version of the Whisper model that is **6 times faster**, 49% smaller, and performs 
**within 1% WER** on out-of-distribution evaluation sets. This is the repository for distil-medium.en, 
a distilled variant of [Whisper medium.en](https://huggingface.co/openai/whisper-medium.en).

| Model                                                                      | Params / M | Rel. Latency | Short-Form WER | Long-Form WER |
|----------------------------------------------------------------------------|------------|--------------|----------------|---------------|
| [large-v2](https://huggingface.co/openai/whisper-large-v2)                 | 1550       | 1.0          | **9.1**        | 11.7          |
|                                                                            |            |              |                |               |
| [distil-large-v2](https://huggingface.co/distil-whisper/distil-large-v2)   | 756        | 5.8          | 10.1           | **11.6**      |
| [distil-medium.en](https://huggingface.co/distil-whisper/distil-medium.en) | **394**    | **6.8**      | 11.1           | 12.4          |

**Note:** Distil-Whisper is currently only available for English speech recognition. Multilingual support will be provided in a follow-up.

## Usage

Distil-Whisper is supported in Hugging Face 🤗 Transformers from version 4.35 onwards. To run the model, first 
install the latest version of the Transformers library. For this example, we'll also install 🤗 Datasets to load toy 
audio dataset from the Hugging Face Hub:

```bash
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
```

### Short-Form Transcription

The model can be used with the [`pipeline`](https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline)
class to transcribe short-form audio files as follows:

```python
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "distil-whisper/distil-medium.en"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    torch_dtype=torch_dtype,
    device=device,
)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]

result = pipe(sample)
print(result["text"])
```

To transcribe a local audio file, simply pass the path to your audio file when you call the pipeline:
```diff
- result = pipe(sample)
+ result = pipe("audio.mp3")
```

### Long-Form Transcription

Distil-Whisper uses a chunked algorithm to transcribe long-form audio files. In practice, this chunked long-form algorithm 
is 9x faster than the sequential algorithm proposed by OpenAI in the Whisper paper (see Table 7 of the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)).

To enable chunking, pass the `chunk_length_s` parameter to the `pipeline`. For Distil-Whisper, a chunk length of 15-seconds
is optimal. To activate batching, pass the argument `batch_size`:

```python
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "distil-whisper/distil-medium.en"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    batch_size=16,
    torch_dtype=torch_dtype,
    device=device,
)

dataset = load_dataset("distil-whisper/librispeech_long", "default", split="validation")
sample = dataset[0]["audio"]

result = pipe(sample)
print(result["text"])
```

<!---
**Tip:** The pipeline can also be used to transcribe an audio file from a remote URL, for example:

```python
result = pipe("https://huggingface.co/datasets/sanchit-gandhi/librispeech_long/resolve/main/audio.wav")
```
--->

### Speculative Decoding

Distil-Whisper can be used as an assistant model to Whisper for speculative decoding. Speculative decoding mathematically
ensures the exact same outputs as Whisper are obtained while being 2 times faster. This makes it the perfect drop-in 
replacement for existing Whisper pipelines, since the same outputs are guaranteed.

In the following code-snippet, we load the assistant Distil-Whisper model standalone to the main Whisper pipeline. We then
specify it as the "assistant model" for generation:

```python
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

assistant_model_id = "distil-whisper/distil-medium.en"

assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)

model_id = "openai/whisper-medium.en"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    generate_kwargs={"assistant_model": assistant_model},
    torch_dtype=torch_dtype,
    device=device,
)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]

result = pipe(sample)
print(result["text"])
```

## Additional Speed & Memory Improvements

You can apply additional speed and memory improvements to Distil-Whisper which we cover in the following.

### Flash Attention

We recommend using [Flash-Attention 2](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#flashattention-2) if your GPU allows for it.
To do so, you first need to install [Flash Attention](https://github.com/Dao-AILab/flash-attention):

```
pip install flash-attn --no-build-isolation
```

and then all you have to do is to pass `use_flash_attention_2=True` to `from_pretrained`:

```diff
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=True)
```

### Torch Scale-Product-Attention (SDPA)

If your GPU does not support Flash Attention, we recommend making use of [BetterTransformers](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#bettertransformer).
To do so, you first need to install optimum:

```
pip install --upgrade optimum
```

And then convert your model to a "BetterTransformer" model before using it:

```diff
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = model.to_bettertransformer()
```

### 8bit & 4bit Quantization

Coming soon ...

### Candle

Coming soon ...

### Whisper.cpp

Coming soon ...

### Running Whisper in `openai/whisper`

Coming soon ...

## Model Details

Distil-Whisper inherits the encoder-decoder architecture from Whisper. The encoder maps a sequence of speech vector 
inputs to a sequence of hidden-state vectors. The decoder auto-regressively predicts text tokens, conditional on all 
previous tokens and the encoder hidden-states. Consequently, the encoder is only run forward once, whereas the decoder 
is run as many times as the number of tokens generated. In practice, this means the decoder accounts for over 90% of 
total inference time. Thus, to optimise for latency, the focus should be on minimising the inference time of the decoder.

To distill the Whisper model, we reduce the number of decoder layers while keeping the encoder fixed. 
The encoder (shown in green) is entirely copied from the teacher to the student and frozen during training. 
The student's decoder consists of only two decoder layers, which are initialised from the first and last decoder layer of 
the teacher (shown in red). All other decoder layers of the teacher are discarded. The model is then trained on a weighted sum 
of the KL divergence and pseudo-label loss terms.

<p align="center">
  <img src="https://huggingface.co/datasets/distil-whisper/figures/resolve/main/architecture.png?raw=true" width="600"/>
</p>

## Evaluation

The following code-snippets demonstrates how to evaluate the Distil-Whisper model on the LibriSpeech validation.clean 
dataset with [streaming mode](https://huggingface.co/blog/audio-datasets#streaming-mode-the-silver-bullet), meaning no 
audio data has to be downloaded to your local device.

First, we need to install the required packages, including 🤗 Datasets to stream and load the audio data, and 🤗 Evaluate to 
perform the WER calculation:

```bash
pip install --upgrade pip
pip install --upgrade transformers datasets[audio] evaluate jiwer
```

Evaluation can then be run end-to-end with the following example: 

```python
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
from datasets import load_dataset
from evaluate import load
import torch
from tqdm import tqdm

# define our torch configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "distil-whisper/distil-medium.en"

# load the model + processor
model =  AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)

# load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)

# define the evaluation metric
wer_metric = load("wer")
normalizer = EnglishTextNormalizer(processor.tokenizer.english_spelling_normalizer)

def inference(batch):
    # 1. Pre-process the audio data to log-mel spectrogram inputs
    audio = [sample["array"] for sample in batch["audio"]]
    input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
    input_features = input_features.to(device, dtype=torch_dtype)
    
    # 2. Auto-regressively generate the predicted token ids
    pred_ids = model.generate(input_features, max_new_tokens=128)
    
    # 3. Decode the token ids to the final transcription
    batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
    batch["reference"] = batch["text"]
    return batch

dataset = dataset.map(function=inference, batched=True, batch_size=16)

all_transcriptions = []
all_references = []

# iterate over the dataset and run inference
for i, result in tqdm(enumerate(dataset), desc="Evaluating..."):
    all_transcriptions.append(result["transcription"])
    all_references.append(result["reference"])

# normalize predictions and references
all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
all_references = [normalizer(reference) for reference in all_references]

# compute the WER metric
wer = 100 * wer_metric.compute(predictions=all_transcriptions, references=all_references)
print(wer)

```
**Print Output:**
```
3.593196832001168
```

## Intended Use

Distil-Whisper is intended to be a drop-in replacement for Whisper on English speech recognition. In particular, it 
achieves comparable WER results over out-of-distribution test data, while being 6x faster over both short and long-form 
audio.

## Data

Distil-Whisper is trained on 22,000 hours of audio data from 9 open-source, permissively licensed speech datasets on the 
Hugging Face Hub:

| Dataset                                                                                 | Size / h | Speakers | Domain                      | Licence         |
|-----------------------------------------------------------------------------------------|----------|----------|-----------------------------|-----------------|
| [People's Speech](https://huggingface.co/datasets/MLCommons/peoples_speech)             | 12,000   | unknown  | Internet Archive            | CC-BY-SA-4.0    |
| [Common Voice 13](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0) | 3,000    | unknown  | Narrated Wikipedia          | CC0-1.0         |
| [GigaSpeech](https://huggingface.co/datasets/speechcolab/gigaspeech)                    | 2,500    | unknown  | Audiobook, podcast, YouTube | apache-2.0      |
| Fisher                                                                                  | 1,960    | 11,900   | Telephone conversations     | LDC             |
| [LibriSpeech](https://huggingface.co/datasets/librispeech_asr)                          | 960      | 2,480    | Audiobooks                  | CC-BY-4.0       |
| [VoxPopuli](https://huggingface.co/datasets/facebook/voxpopuli)                         | 540      | 1,310    | European Parliament         | CC0             |
| [TED-LIUM](https://huggingface.co/datasets/LIUM/tedlium)                                | 450      | 2,030    | TED talks                   | CC-BY-NC-ND 3.0 |
| SwitchBoard                                                                             | 260      | 540      | Telephone conversations     | LDC             |
| [AMI](https://huggingface.co/datasets/edinburghcstr/ami)                                | 100      | unknown  | Meetings                    | CC-BY-4.0       |
||||||
| **Total**                                                                               | 21,770   | 18,260+  |                             |                 |

The combined dataset spans 10 distinct domains and over 50k speakers. The diversity of this dataset is crucial to ensuring 
the distilled model is robust to audio distributions and noise. 

The audio data is then pseudo-labelled using the Whisper large-v2 model: we use Whisper to generate predictions for all 
the audio in our training set and use these as the target labels during training. Using pseudo-labels ensures that the 
transcriptions are consistently formatted across datasets and provides sequence-level distillation signal during training.

## WER Filter

The Whisper pseudo-label predictions are subject to mis-transcriptions and hallucinations. To ensure we only train on 
accurate pseudo-labels, we employ a simple WER heuristic during training. First, we normalise the Whisper pseudo-labels
and the ground truth labels provided by each dataset. We then compute the WER between these labels. If the WER exceeds
a specified threshold, we discard the training example. Otherwise, we keep it for training.

Section 9.2 of the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430) demonstrates the effectiveness of this filter for improving downstream performance
of the distilled model. We also partially attribute Distil-Whisper's robustness to hallucinations to this filter.

## Training

The model was trained for 80,000 optimisation steps (or eight epochs). The Tensorboard training logs can be found under: https://huggingface.co/distil-whisper/distil-medium.en/tensorboard?params=scalars#frame

## Results

The distilled model performs to within 1% WER of Whisper on out-of-distribution (OOD) short-form audio, and outperforms Whisper
by 0.1% on OOD long-form audio. This performance gain is attributed to lower hallucinations.

For a detailed per-dataset breakdown of the evaluation results, refer to Tables 16 and 17 of the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)

Distil-Whisper is also evaluated on the [ESB benchmark](https://arxiv.org/abs/2210.13352) datasets as part of the [OpenASR leaderboard](https://huggingface.co/spaces/hf-audio/open_asr_leaderboard),
where it performs to within 0.2% WER of Whisper.

## Reproducing Distil-Whisper

Training and evaluation code to reproduce Distil-Whisper will be made available on the Distil-Whisper repository: https://github.com/huggingface/distil-whisper

## Citation

If you use this model, please consider citing the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430):
```
@misc{gandhi2023distilwhisper,
      title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling}, 
      author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
      year={2023},
      eprint={2311.00430},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
```