--- language: - en tags: - audio - automatic-speech-recognition widget: - example_title: LibriSpeech sample 1 src: https://cdn-media.huggingface.co/speech_samples/sample1.flac - example_title: LibriSpeech sample 2 src: https://cdn-media.huggingface.co/speech_samples/sample2.flac pipeline_tag: automatic-speech-recognition license: mit library_name: transformers --- # Distil-Whisper: distil-medium.en Distil-Whisper was proposed in the paper [Robust Knowledge Distillation via Large-Scale Pseudo Labelling](https://arxiv.org/abs/2311.00430). It is a distilled version of the Whisper model that is **6 times faster**, 49% smaller, and performs **within 1% WER** on out-of-distribution evaluation sets. This is the repository for distil-medium.en, a distilled variant of [Whisper medium.en](https://huggingface.co/openai/whisper-medium.en). | Model | Params / M | Rel. Latency | Short-Form WER | Long-Form WER | |----------------------------------------------------------------------------|------------|--------------|----------------|---------------| | [large-v2](https://huggingface.co/openai/whisper-large-v2) | 1550 | 1.0 | **9.1** | 11.7 | | | | | | | | [distil-large-v2](https://huggingface.co/distil-whisper/distil-large-v2) | 756 | 5.8 | 10.1 | **11.6** | | [distil-medium.en](https://huggingface.co/distil-whisper/distil-medium.en) | **394** | **6.8** | 11.1 | 12.4 | ## Usage Distil-Whisper is supported in Hugging Face 🤗 Transformers from version 4.35 onwards. To run the model, first install the latest version of the Transformers library. For this example, we'll also install 🤗 Datasets to load toy audio dataset from the Hugging Face Hub: ```bash pip install --upgrade pip pip install --upgrade transformers accelerate datasets[audio] ``` ### Short-Form Transcription The model can be used with the [`pipeline`](https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline) class to transcribe short-form audio files as follows: ```python import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from datasets import load_dataset device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = "distil-whisper/distil-medium.en" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, max_new_tokens=128, torch_dtype=torch_dtype, device=device, ) dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") sample = dataset[0]["audio"] result = pipe(sample) print(result["text"]) ``` To transcribe a local audio file, simply pass the path to your audio file when you call the pipeline: ```diff - result = pipe(sample) + result = pipe("audio.mp3") ``` ### Long-Form Transcription Distil-Whisper uses a chunked algorithm to transcribe long-form audio files. In practice, this chunked long-form algorithm is 9x faster than the sequential algorithm proposed by OpenAI in the Whisper paper (see Table 7 of the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)). To enable chunking, pass the `chunk_length_s` parameter to the `pipeline`. For Distil-Whisper, a chunk length of 15-seconds is optimal. To activate batching, pass the argument `batch_size`: ```python import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from datasets import load_dataset device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = "distil-whisper/distil-medium.en" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, max_new_tokens=128, chunk_length_s=15, batch_size=16, torch_dtype=torch_dtype, device=device, ) dataset = load_dataset("distil-whisper/librispeech_long", "default", split="validation") sample = dataset[0]["audio"] result = pipe(sample) print(result["text"]) ``` ### Speculative Decoding Distil-Whisper can be used as an assistant model to Whisper for speculative decoding. Speculative decoding mathematically ensures the exact same outputs as Whisper are obtained while being 2 times faster. This makes it the perfect drop-in replacement for existing Whisper pipelines, since the same outputs are guaranteed. In the following code-snippet, we load the assistant Distil-Whisper model standalone to the main Whisper pipeline. We then specify it as the "assistant model" for generation: ```python from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor import torch from datasets import load_dataset device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 assistant_model_id = "distil-whisper/distil-medium.en" assistant_model = AutoModelForCausalLM.from_pretrained( assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) assistant_model.to(device) model_id = "openai/whisper-medium.en" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, max_new_tokens=128, generate_kwargs={"assistant_model": assistant_model}, torch_dtype=torch_dtype, device=device, ) dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") sample = dataset[0]["audio"] result = pipe(sample) print(result["text"]) ``` ## Additional Speed & Memory Improvements You can apply additional speed and memory improvements to Distil-Whisper which we cover in the following. ### Flash Attention We recommend using [Flash-Attention 2](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#flashattention-2) if your GPU allows for it. To do so, you first need to install [Flash Attention](https://github.com/Dao-AILab/flash-attention): ``` pip install flash-attn --no-build-isolation ``` and then all you have to do is to pass `use_flash_attention_2=True` to `from_pretrained`: ```diff - model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True) + model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=True) ``` ### Torch Scale-Product-Attention (SDPA) If your GPU does not support Flash Attention, we recommend making use of [BetterTransformers](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#bettertransformer). To do so, you first need to install optimum: ``` pip install --upgrade optimum ``` And then convert your model to a "BetterTransformer" model before using it: ```diff model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True) + model = model.to_bettertransformer() ``` ### 8bit & 4bit Quantization Coming soon ... ### Candle Coming soon ... ### Whisper.cpp Coming soon ... ### Running Whisper in `openai/whisper` Coming soon ... ## Model Details Distil-Whisper inherits the encoder-decoder architecture from Whisper. The encoder maps a sequence of speech vector inputs to a sequence of hidden-state vectors. The decoder auto-regressively predicts text tokens, conditional on all previous tokens and the encoder hidden-states. Consequently, the encoder is only run forward once, whereas the decoder is run as many times as the number of tokens generated. In practice, this means the decoder accounts for over 90% of total inference time. Thus, to optimise for latency, the focus should be on minimising the inference time of the decoder. To distill the Whisper model, we reduce the number of decoder layers while keeping the encoder fixed. The encoder (shown in green) is entirely copied from the teacher to the student and frozen during training. The student's decoder consists of only two decoder layers, which are initialised from the first and last decoder layer of the teacher (shown in red). All other decoder layers of the teacher are discarded. The model is then trained on a weighted sum of the KL divergence and pseudo-label loss terms.