{
"cells": [
{
"cell_type": "markdown",
"id": "75b58048-7d14-4fc6-8085-1fc08c81b4a6",
"metadata": {
"id": "75b58048-7d14-4fc6-8085-1fc08c81b4a6"
},
"source": [
"# Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers"
]
},
{
"cell_type": "markdown",
"id": "fbfa8ad5-4cdc-4512-9058-836cbbf65e1a",
"metadata": {
"id": "fbfa8ad5-4cdc-4512-9058-836cbbf65e1a"
},
"source": [
"In this Colab, we present a step-by-step guide on how to fine-tune Whisper \n",
"for any multilingual ASR dataset using Hugging Face 🤗 Transformers. This is a \n",
"more \"hands-on\" version of the accompanying [blog post](https://huggingface.co/blog/fine-tune-whisper). \n",
"For a more in-depth explanation of Whisper, the Common Voice dataset and the theory behind fine-tuning, the reader is advised to refer to the blog post."
]
},
{
"cell_type": "markdown",
"id": "afe0d503-ae4e-4aa7-9af4-dbcba52db41e",
"metadata": {
"id": "afe0d503-ae4e-4aa7-9af4-dbcba52db41e"
},
"source": [
"## Introduction"
]
},
{
"cell_type": "markdown",
"id": "9ae91ed4-9c3e-4ade-938e-f4c2dcfbfdc0",
"metadata": {
"id": "9ae91ed4-9c3e-4ade-938e-f4c2dcfbfdc0"
},
"source": [
"Whisper is a pre-trained model for automatic speech recognition (ASR) \n",
"published in [September 2022](https://openai.com/blog/whisper/) by the authors \n",
"Alec Radford et al. from OpenAI. Unlike many of its predecessors, such as \n",
"[Wav2Vec 2.0](https://arxiv.org/abs/2006.11477), which are pre-trained \n",
"on un-labelled audio data, Whisper is pre-trained on a vast quantity of \n",
"**labelled** audio-transcription data, 680,000 hours to be precise. \n",
"This is an order of magnitude more data than the un-labelled audio data used \n",
"to train Wav2Vec 2.0 (60,000 hours). What is more, 117,000 hours of this \n",
"pre-training data is multilingual ASR data. This results in checkpoints \n",
"that can be applied to over 96 languages, many of which are considered \n",
"_low-resource_.\n",
"\n",
"When scaled to 680,000 hours of labelled pre-training data, Whisper models \n",
"demonstrate a strong ability to generalise to many datasets and domains.\n",
"The pre-trained checkpoints achieve competitive results to state-of-the-art \n",
"ASR systems, with near 3% word error rate (WER) on the test-clean subset of \n",
"LibriSpeech ASR and a new state-of-the-art on TED-LIUM with 4.7% WER (_c.f._ \n",
"Table 8 of the [Whisper paper](https://cdn.openai.com/papers/whisper.pdf)).\n",
"The extensive multilingual ASR knowledge acquired by Whisper during pre-training \n",
"can be leveraged for other low-resource languages; through fine-tuning, the \n",
"pre-trained checkpoints can be adapted for specific datasets and languages \n",
"to further improve upon these results. We'll show just how Whisper can be fine-tuned \n",
"for low-resource languages in this Colab."
]
},
{
"cell_type": "markdown",
"id": "e59b91d6-be24-4b5e-bb38-4977ea143a72",
"metadata": {
"id": "e59b91d6-be24-4b5e-bb38-4977ea143a72"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"id": "21b6316e-8a55-4549-a154-66d3da2ab74a",
"metadata": {
"id": "21b6316e-8a55-4549-a154-66d3da2ab74a"
},
"source": [
"The Whisper checkpoints come in five configurations of varying model sizes.\n",
"The smallest four are trained on either English-only or multilingual data.\n",
"The largest checkpoint is multilingual only. All nine of the pre-trained checkpoints \n",
"are available on the [Hugging Face Hub](https://huggingface.co/models?search=openai/whisper). The \n",
"checkpoints are summarised in the following table with links to the models on the Hub:\n",
"\n",
"| Size | Layers | Width | Heads | Parameters | English-only | Multilingual |\n",
"|--------|--------|-------|-------|------------|------------------------------------------------------|---------------------------------------------------|\n",
"| tiny | 4 | 384 | 6 | 39 M | [âś“](https://huggingface.co/openai/whisper-tiny.en) | [âś“](https://huggingface.co/openai/whisper-tiny.) |\n",
"| base | 6 | 512 | 8 | 74 M | [âś“](https://huggingface.co/openai/whisper-base.en) | [âś“](https://huggingface.co/openai/whisper-base) |\n",
"| small | 12 | 768 | 12 | 244 M | [âś“](https://huggingface.co/openai/whisper-small.en) | [âś“](https://huggingface.co/openai/whisper-small) |\n",
"| medium | 24 | 1024 | 16 | 769 M | [âś“](https://huggingface.co/openai/whisper-medium.en) | [âś“](https://huggingface.co/openai/whisper-medium) |\n",
"| large | 32 | 1280 | 20 | 1550 M | x | [âś“](https://huggingface.co/openai/whisper-large) |\n",
"\n",
"For demonstration purposes, we'll fine-tune the multilingual version of the \n",
"[`\"small\"`](https://huggingface.co/openai/whisper-small) checkpoint with 244M params (~= 1GB). \n",
"As for our data, we'll train and evaluate our system on a low-resource language \n",
"taken from the [Common Voice](https://huggingface.co/datasets/mozilla-foundation/fleurs_11_0)\n",
"dataset. We'll show that with as little as 8 hours of fine-tuning data, we can achieve \n",
"strong performance in this language."
]
},
{
"cell_type": "markdown",
"id": "3a680dfc-cbba-4f6c-8a1f-e1a5ff3f123a",
"metadata": {
"id": "3a680dfc-cbba-4f6c-8a1f-e1a5ff3f123a"
},
"source": [
"------------------------------------------------------------------------\n",
"\n",
"\\\\({}^1\\\\) The name Whisper follows from the acronym “WSPSR”, which stands for “Web-scale Supervised Pre-training for Speech Recognition”."
]
},
{
"cell_type": "markdown",
"id": "b219c9dd-39b6-4a95-b2a1-3f547a1e7bc0",
"metadata": {
"id": "b219c9dd-39b6-4a95-b2a1-3f547a1e7bc0"
},
"source": [
"## Load Dataset\n",
"Loading MS-MY Dataset from FLEURS.\n",
"Combine train and validation set."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
"metadata": {
"id": "a2787582-554f-44ce-9f38-4180a5ed6b44"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset fleurs (/home/ubuntu/.cache/huggingface/datasets/google___fleurs/id_id/2.0.0/aabb39fb29739c495517ac904e2886819b6e344702f0a5b5283cb178b087c94a)\n",
"Found cached dataset fleurs (/home/ubuntu/.cache/huggingface/datasets/google___fleurs/id_id/2.0.0/aabb39fb29739c495517ac904e2886819b6e344702f0a5b5283cb178b087c94a)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['audio', 'transcription'],\n",
" num_rows: 2929\n",
" })\n",
" test: Dataset({\n",
" features: ['audio', 'transcription'],\n",
" num_rows: 687\n",
" })\n",
"})\n"
]
}
],
"source": [
"from datasets import load_dataset, DatasetDict\n",
"\n",
"fleurs = DatasetDict()\n",
"fleurs[\"train\"] = load_dataset(\"google/fleurs\", \"id_id\", split=\"train+validation\", use_auth_token=True)\n",
"fleurs[\"test\"] = load_dataset(\"google/fleurs\", \"id_id\", split=\"test\", use_auth_token=True)\n",
"\n",
"fleurs = fleurs.remove_columns([\"id\", \"num_samples\", \"path\", \"raw_transcription\", \"gender\", \"lang_id\", \"language\", \"lang_group_id\"])\n",
"\n",
"print(fleurs)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d087b451",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset common_voice_11_0 (/home/ubuntu/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/id/11.0.0/f8e47235d9b4e68fa24ed71d63266a02018ccf7194b2a8c9c598a5f3ab304d9f)\n",
"Found cached dataset common_voice_11_0 (/home/ubuntu/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/id/11.0.0/f8e47235d9b4e68fa24ed71d63266a02018ccf7194b2a8c9c598a5f3ab304d9f)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['audio', 'transcription'],\n",
" num_rows: 8274\n",
" })\n",
" test: Dataset({\n",
" features: ['audio', 'transcription'],\n",
" num_rows: 3618\n",
" })\n",
"})\n"
]
}
],
"source": [
"cv = DatasetDict()\n",
"cv[\"train\"] = load_dataset(\"mozilla-foundation/common_voice_11_0\", \"id\", split=\"train+validation\", use_auth_token=True)\n",
"cv[\"test\"] = load_dataset(\"mozilla-foundation/common_voice_11_0\", \"id\", split=\"test\", use_auth_token=True)\n",
"\n",
"cv = cv.remove_columns([\"client_id\", \"path\", 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'])\n",
"cv = cv.rename_column('sentence', 'transcription')\n",
"print(cv)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "60790ba0",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "99906e8e299f458591312a4b744a3efd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading builder script: 0%| | 0.00/6.25k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b9b7a67d960a4f6597389eeb70f33a97",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading readme: 0%| | 0.00/5.16k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "08e49e43530546ada0fce5930391414d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading extra modules: 0%| | 0.00/282 [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6557bddcdbaa4206bfa3646b53608e7d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading extra modules: 0%| | 0.00/1.35k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading and preparing dataset librivox-indonesia/ind to /home/ubuntu/.cache/huggingface/datasets/indonesian-nlp___librivox-indonesia/ind/1.0.0/80f85a3839e000a9443d20456b1bb183e6e9b0c11e92aa44ec79f2439941eb62...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5b682fd935664374b26e8a6466a1763e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/290M [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9cad9f2937184f36a4e8efe6d527a5d4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/190k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4d1cecde815148f7b5a8fe506905e458",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/31.3M [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a495c9ce15cb4809a49b92ba3937771c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/24.4k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating test split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset librivox-indonesia (/home/ubuntu/.cache/huggingface/datasets/indonesian-nlp___librivox-indonesia/ind/1.0.0/80f85a3839e000a9443d20456b1bb183e6e9b0c11e92aa44ec79f2439941eb62)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset librivox-indonesia downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/indonesian-nlp___librivox-indonesia/ind/1.0.0/80f85a3839e000a9443d20456b1bb183e6e9b0c11e92aa44ec79f2439941eb62. Subsequent calls will reuse this data.\n",
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['transcription', 'audio'],\n",
" num_rows: 5635\n",
" })\n",
" test: Dataset({\n",
" features: ['transcription', 'audio'],\n",
" num_rows: 603\n",
" })\n",
"})\n"
]
}
],
"source": [
"lbv = DatasetDict()\n",
"lbv[\"train\"] = load_dataset(\"indonesian-nlp/librivox-indonesia\", \"ind\", split=\"train\", use_auth_token=True)\n",
"lbv[\"test\"] = load_dataset(\"indonesian-nlp/librivox-indonesia\", \"ind\", split=\"test\", use_auth_token=True)\n",
"\n",
"lbv = lbv.remove_columns([\"path\", \"language\", \"reader\"])\n",
"lbv = lbv.rename_column('sentence', 'transcription')\n",
"print(lbv)"
]
},
{
"cell_type": "markdown",
"id": "2d63b2d2-f68a-4d74-b7f1-5127f6d16605",
"metadata": {
"id": "2d63b2d2-f68a-4d74-b7f1-5127f6d16605"
},
"source": [
"## Prepare Feature Extractor, Tokenizer and Data"
]
},
{
"cell_type": "markdown",
"id": "601c3099-1026-439e-93e2-5635b3ba5a73",
"metadata": {
"id": "601c3099-1026-439e-93e2-5635b3ba5a73"
},
"source": [
"The ASR pipeline can be de-composed into three stages: \n",
"1) A feature extractor which pre-processes the raw audio-inputs\n",
"2) The model which performs the sequence-to-sequence mapping \n",
"3) A tokenizer which post-processes the model outputs to text format\n",
"\n",
"In 🤗 Transformers, the Whisper model has an associated feature extractor and tokenizer, \n",
"called [WhisperFeatureExtractor](https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperFeatureExtractor)\n",
"and [WhisperTokenizer](https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperTokenizer) \n",
"respectively.\n",
"\n",
"We'll go through details for setting-up the feature extractor and tokenizer one-by-one!"
]
},
{
"cell_type": "markdown",
"id": "560332eb-3558-41a1-b500-e83a9f695f84",
"metadata": {
"id": "560332eb-3558-41a1-b500-e83a9f695f84"
},
"source": [
"### Load WhisperFeatureExtractor"
]
},
{
"cell_type": "markdown",
"id": "32ec8068-0bd7-412d-b662-0edb9d1e7365",
"metadata": {
"id": "32ec8068-0bd7-412d-b662-0edb9d1e7365"
},
"source": [
"The Whisper feature extractor performs two operations:\n",
"1. Pads / truncates the audio inputs to 30s: any audio inputs shorter than 30s are padded to 30s with silence (zeros), and those longer that 30s are truncated to 30s\n",
"2. Converts the audio inputs to _log-Mel spectrogram_ input features, a visual representation of the audio and the form of the input expected by the Whisper model"
]
},
{
"cell_type": "markdown",
"id": "589d9ec1-d12b-4b64-93f7-04c63997da19",
"metadata": {
"id": "589d9ec1-d12b-4b64-93f7-04c63997da19"
},
"source": [
"