WhiSPA: Whisper Semantically and Psychologically Aligned
This model is the smallest variant from the WhiSPA paper.
Description
WhiSPA (Whisper with Semantic-Psychological Alignment) is a novel speech encoder that leverages the Whisper model as a backbone and aligns its audio embeddings with text representations from SBERT and psychological embeddings. This alignment is achieved through a contrastive student-teacher learning objective, using hundreds of thousands of audio segments from mental health interviews. WhiSPA aims to capture both semantic and psychological information in audio-only encoder models, surpassing state-of-the-art speech models in various tasks.
Training Procedure
WhiSPA is trained using a student-teacher contrastive alignment approach. The Whisper model (student) is aligned with SBERT and psychological embeddings (teacher) to increase the cosine similarity between their embeddings. This alignment helps WhiSPA capture both semantic and psychological information in the audio embeddings.
Example Usage
import torch, torchaudio
from transformers import WhisperProcessor, WhisperForConditionalGeneration
# WhiSPA PyTorch module can be sourced from https://github.com/Jarhatz/WhiSPA
from pretrain.whispa_model import WhiSPAModel
def preprocess_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
# Convert stereo (or multi-channel) to mono if needed
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample if necessary (Whisper requires 16kHz input)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
return waveform
processor = WhisperProcessor.from_pretrained('openai/whisper-tiny')
whisper = WhisperForConditionalGeneration.from_pretrained('openai/whisper-tiny')
whispa = WhiSPAModel.from_pretrained('Jarhatz/whispa_394_v1')
# Audio processing
waveform = preprocess_audio(audio_path)
input_features = processor(
waveform.squeeze(),
sampling_rate=16000,
return_tensors="pt"
).input_features
# Whisper-based tokenization
tokens = whisper.generate(input_features)
# WhiSPA embedding
emb = whispa(
audio_inputs=input_features,
text_input_ids=tokens,
text_attention_mask=torch.ones(tokens.size(), device=device),
)
print(f'WhiSPA Embedding: {emb.shape}')
- Downloads last month
- 7
Model tree for Jarhatz/whispa_394_v1
Base model
openai/whisper-tiny