Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,095 Bytes
71de706 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import torch
class WhisperMixin:
is_initialized = False
def setup_whisper(
self,
pretrained_model_name_or_path: str = "openai/whisper-base.en",
device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):
from transformers import WhisperForConditionalGeneration
from transformers import WhisperProcessor
self.whisper_device = device
self.whisper_processor = WhisperProcessor.from_pretrained(
pretrained_model_name_or_path
)
self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path
).to(self.whisper_device)
self.is_initialized = True
def get_whisper_features(self) -> torch.Tensor:
"""Preprocess audio signal as per the whisper model's training config.
Returns
-------
torch.Tensor
The prepinput features of the audio signal. Shape: (1, channels, seq_len)
"""
import torch
if not self.is_initialized:
self.setup_whisper()
signal = self.to(self.device)
raw_speech = list(
(
signal.clone()
.resample(self.whisper_processor.feature_extractor.sampling_rate)
.audio_data[:, 0, :]
.numpy()
)
)
with torch.inference_mode():
input_features = self.whisper_processor(
raw_speech,
sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
return_tensors="pt",
).input_features
return input_features
def get_whisper_transcript(self) -> str:
"""Get the transcript of the audio signal using the whisper model.
Returns
-------
str
The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>.
"""
if not self.is_initialized:
self.setup_whisper()
input_features = self.get_whisper_features()
with torch.inference_mode():
input_features = input_features.to(self.whisper_device)
generated_ids = self.whisper_model.generate(inputs=input_features)
transcription = self.whisper_processor.batch_decode(generated_ids)
return transcription[0]
def get_whisper_embeddings(self) -> torch.Tensor:
"""Get the last hidden state embeddings of the audio signal using the whisper model.
Returns
-------
torch.Tensor
The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size)
"""
import torch
if not self.is_initialized:
self.setup_whisper()
input_features = self.get_whisper_features()
encoder = self.whisper_model.get_encoder()
with torch.inference_mode():
input_features = input_features.to(self.whisper_device)
embeddings = encoder(input_features)
return embeddings.last_hidden_state
|