aiola commited on
Commit
d11bb28
1 Parent(s): ff7d8b1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +70 -3
README.md CHANGED
@@ -1,3 +1,70 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - openslr/librispeech_asr
5
+ tags:
6
+ - ASR
7
+ - Automatic Speech Recognition
8
+ - Whisper
9
+ - Medusa
10
+ - Speech
11
+ - Speculative Decoding
12
+ language:
13
+ - en
14
+ ---
15
+
16
+ # Whisper Medusa
17
+
18
+ Whisper is an advanced encoder-decoder model for speech transcription and
19
+ translation, processing audio through encoding and decoding stages. Given
20
+ its large size and slow inference speed, various optimization strategies like
21
+ Faster-Whisper and Speculative Decoding have been proposed to enhance performance.
22
+ Our Medusa model builds on Whisper by predicting multiple tokens per iteration,
23
+ which significantly improves speed with small degradation in WER. We train and
24
+ evaluate our model on the LibriSpeech dataset, demonstrating speed improvements.
25
+
26
+ ---------
27
+
28
+ ## Training Details
29
+ `aiola/whisper-medusa-linear-libri` was trained on the LibriSpeech dataset to perform audio translation.
30
+ The Medusa heads were optimized for English, so for optimal performance and speed improvements, please use English audio only.
31
+
32
+ ---------
33
+
34
+ ## Usage
35
+ To use `whisper-medusa-linear-libri` install [`whisper-medusa`](https://github.com/aiola/whisper-medusa-linear-libri) repo following the README instructions.
36
+
37
+ Inference can be done using the following code:
38
+ ```python
39
+ import torch
40
+ import torchaudio
41
+
42
+ from whisper_medusa import WhisperMedusaModel
43
+ from transformers import WhisperProcessor
44
+
45
+ model_name = "aiola/whisper-medusa-linear-libri"
46
+ model = WhisperMedusaModel.from_pretrained(model_name)
47
+ processor = WhisperProcessor.from_pretrained(model_name)
48
+
49
+ path_to_audio = "path/to/audio.wav"
50
+ SAMPLING_RATE = 16000
51
+ language = "en"
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+
54
+ input_speech, sr = torchaudio.load(path_to_audio)
55
+ if sr != SAMPLING_RATE:
56
+ input_speech = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(input_speech)
57
+
58
+ input_features = processor(input_speech.squeeze(), return_tensors="pt", sampling_rate=SAMPLING_RATE).input_features
59
+ input_features = input_features.to(device)
60
+
61
+ model = model.to(device)
62
+ model_output = model.generate(
63
+ input_features,
64
+ language=language,
65
+ )
66
+ predict_ids = model_output[0]
67
+ pred = processor.decode(predict_ids, skip_special_tokens=True)
68
+ print(pred)
69
+
70
+ ```