File size: 5,101 Bytes
306c7f2 07e760c 494ea20 07e760c 494ea20 07e760c 494ea20 306c7f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
---
license: apache-2.0
language: de
library_name: transformers
thumbnail: null
tags:
- automatic-speech-recognition
- whisper-event
datasets:
- mozilla-foundation/common_voice_11_0
metrics:
- wer
model-index:
- name: Fine-tuned whisper-large-v2 model for ASR in German
results:
- task:
name: Automatic Speech Recognition
type: automatic-speech-recognition
dataset:
name: Common Voice 11.0
type: mozilla-foundation/common_voice_11_0
config: de
split: test
args: de
metrics:
- name: WER (Greedy)
type: wer
value: 5.76
---
<style>
img {
display: inline;
}
</style>
![Model architecture](https://img.shields.io/badge/Model_Architecture-seq2seq-lightgrey)
![Model size](https://img.shields.io/badge/Params-1550M-lightgrey)
![Language](https://img.shields.io/badge/Language-German-lightgrey)
# Fine-tuned whisper-large-v2 model for ASR in German
This model is a fine-tuned version of [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2), trained on the mozilla-foundation/common_voice_11_0 de dataset. When using the model make sure that your speech input is also sampled at 16Khz. **This model also predicts casing and punctuation.**
*Below are the WERs on the [Common Voice 9.0](https://huggingface.co/datasets/mozilla-foundation/common_voice_9_0) of the pre-trained models . These results are reported in the original [paper](https://cdn.openai.com/papers/whisper.pdf).*
| Model | WER |
| --- | --- |
| [openai/whisper-small](https://huggingface.co/openai/whisper-small) | 13.0 |
| [openai/whisper-medium](https://huggingface.co/openai/whisper-medium) | 8.5 |
| [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2) | 6.4 |
*Below are the WERs on the [Common Voice 11.0](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) of the fine-tuned models.*
| Model | WER |
| --- | --- |
| [bofenghuang/whisper-small-cv11-german-punct](https://huggingface.co/bofenghuang/whisper-small-cv11-german-punct) | 11.35 |
| [bofenghuang/whisper-medium-cv11-german-punct](https://huggingface.co/bofenghuang/whisper-medium-cv11-german-punct) | 7.05 |
| [bofenghuang/whisper-large-v2-cv11-german-punct](https://huggingface.co/bofenghuang/whisper-large-v2-cv11-german-punct) | **5.76** |
## Usage
Inference with 🤗 Pipeline
```python
import torch
from datasets import load_dataset
from transformers import pipeline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load pipeline
pipe = pipeline("automatic-speech-recognition", model="bofenghuang/whisper-large-v2-cv11-german-punct", device=device)
# NB: set forced_decoder_ids for generation utils
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language="de", task="transcribe")
# Load data
ds_mcv_test = load_dataset("mozilla-foundation/common_voice_11_0", "de", split="test", streaming=True)
test_segment = next(iter(ds_mcv_test))
waveform = test_segment["audio"]
# NB: decoding option
# limit the maximum number of generated tokens to 225
pipe.model.config.max_length = 225 + 1
# sampling
# pipe.model.config.do_sample = True
# beam search
# pipe.model.config.num_beams = 5
# return
# pipe.model.config.return_dict_in_generate = True
# pipe.model.config.output_scores = True
# pipe.model.config.num_return_sequences = 5
# Run
generated_sentences = pipe(waveform)["text"]
```
Inference with 🤗 low-level APIs
```python
import torch
import torchaudio
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model
model = AutoModelForSpeechSeq2Seq.from_pretrained("bofenghuang/whisper-large-v2-cv11-german-punct").to(device)
processor = AutoProcessor.from_pretrained("bofenghuang/whisper-large-v2-cv11-german-punct", language="german", task="transcribe")
# NB: set forced_decoder_ids for generation utils
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="de", task="transcribe")
# 16_000
model_sample_rate = processor.feature_extractor.sampling_rate
# Load data
ds_mcv_test = load_dataset("mozilla-foundation/common_voice_11_0", "de", split="test", streaming=True)
test_segment = next(iter(ds_mcv_test))
waveform = torch.from_numpy(test_segment["audio"]["array"])
sample_rate = test_segment["audio"]["sampling_rate"]
# Resample
if sample_rate != model_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, model_sample_rate)
waveform = resampler(waveform)
# Get feat
inputs = processor(waveform, sampling_rate=model_sample_rate, return_tensors="pt")
input_features = inputs.input_features
input_features = input_features.to(device)
# Generate
generated_ids = model.generate(inputs=input_features, max_new_tokens=225) # greedy
# generated_ids = model.generate(inputs=input_features, max_new_tokens=225, num_beams=5) # beam search
# Detokenize
generated_sentences = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Normalise predicted sentences if necessary
``` |