Owos's picture
Upload WhisperForConditionalGeneration
91ce83d verified
metadata
language:
  - en
license: apache-2.0
pipeline_tag: automatic-speech-recognition
tags:
  - audio
  - automatic-speech-recognition
  - hf-asr-leaderboard
widget:
  - example_title: Librispeech sample 1
    src: https://cdn-media.huggingface.co/speech_samples/sample1.flac
  - example_title: Librispeech sample 2
    src: https://cdn-media.huggingface.co/speech_samples/sample2.flac
model-index:
  - name: whisper-medium
    results:
      - task:
          type: automatic-speech-recognition
          name: Automatic Speech Recognition
        dataset:
          name: Afrispeech-200
          type: intronhealth/afrispeech-200
          config: clean
          split: test
          args:
            language: en
        metrics:
          - type: wer
            value: 0
            name: Test WER

Afrispeech-Whisper-Medium-All

This model builds upon the capabilities of Whisper Medium (a pre-trained model for speech recognition and translation trained on a massive 680k hour dataset). While Whisper demonstrates impressive generalization abilities, this model takes it a step further to be very specific for African accents.

Fine-tuned on the AfriSpeech-200 dataset, specifically designed for African accents, this model offers enhanced performance for speech recognition tasks on African languages.

Transcription

In this example, the context tokens are 'unforced', meaning the model automatically predicts the output language (English) and task (transcribe).

>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset

>>> # load model and processor
>>> processor = WhisperProcessor.from_pretrained("intronhealth/afrispeech-whisper-medium-all")
>>> model = WhisperForConditionalGeneration.from_pretrained("intronhealth/afrispeech-whisper-medium-all")
>>> model.config.forced_decoder_ids = None

>>> # load dummy dataset and read audio files
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> sample = ds[0]["audio"]
>>> input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features 

>>> # generate token ids
>>> predicted_ids = model.generate(input_features)
>>> # decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|endoftext|>']

>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
[' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.']

The context tokens can be removed from the start of the transcription by setting skip_special_tokens=True.

Long-Form Transcription

The Whisper model is intrinsically designed to work on audio samples of up to 30s in duration. However, by using a chunking algorithm, it can be used to transcribe audio samples of up to arbitrary length. This is possible through Transformers pipeline method. Chunking is enabled by setting chunk_length_s=30 when instantiating the pipeline. With chunking enabled, the pipeline can be run with batched inference. It can also be extended to predict sequence level timestamps by passing return_timestamps=True:

>>> import torch
>>> from transformers import pipeline
>>> from datasets import load_dataset

>>> device = "cuda:0" if torch.cuda.is_available() else "cpu"

>>> pipe = pipeline(
>>>   "automatic-speech-recognition",
>>>   model="intronhealth/afrispeech-whisper-medium-all",
>>>   chunk_length_s=30,
>>>   device=device,
>>> )

>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> sample = ds[0]["audio"]

>>> prediction = pipe(sample.copy(), batch_size=8)["text"]
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."

>>> # we can also return timestamps for the predictions
>>> prediction = pipe(sample.copy(), batch_size=8, return_timestamps=True)["chunks"]
[{'text': ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.',
  'timestamp': (0.0, 5.44)}]

Refer to the blog post ASR Chunking for more details on the chunking algorithm.