WhiSPA: Whisper Semantically and Psychologically Aligned

This model is the smallest variant from the WhiSPA paper.

Description

WhiSPA (Whisper with Semantic-Psychological Alignment) is a novel speech encoder that leverages the Whisper model as a backbone and aligns its audio embeddings with text representations from SBERT and psychological embeddings. This alignment is achieved through a contrastive student-teacher learning objective, using hundreds of thousands of audio segments from mental health interviews. WhiSPA aims to capture both semantic and psychological information in audio-only encoder models, surpassing state-of-the-art speech models in various tasks.

Training Procedure

WhiSPA is trained using a student-teacher contrastive alignment approach. The Whisper model (student) is aligned with SBERT and psychological embeddings (teacher) to increase the cosine similarity between their embeddings. This alignment helps WhiSPA capture both semantic and psychological information in the audio embeddings.

Example Usage

import torch, torchaudio
from transformers import WhisperProcessor, WhisperForConditionalGeneration
# WhiSPA PyTorch module can be sourced from https://github.com/Jarhatz/WhiSPA
from pretrain.whispa_model import WhiSPAModel

def preprocess_audio(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)
    # Convert stereo (or multi-channel) to mono if needed   
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    # Resample if necessary (Whisper requires 16kHz input)
    if sample_rate != 16000:
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    return waveform

processor = WhisperProcessor.from_pretrained('openai/whisper-tiny')
whisper = WhisperForConditionalGeneration.from_pretrained('openai/whisper-tiny')
whispa = WhiSPAModel.from_pretrained('Jarhatz/whispa_394_v1')

# Audio processing
waveform = preprocess_audio(audio_path)
input_features = processor(
    waveform.squeeze(),
    sampling_rate=16000,
    return_tensors="pt"
).input_features

# Whisper-based tokenization
tokens = whisper.generate(input_features)

# WhiSPA embedding
emb = whispa(
    audio_inputs=input_features,
    text_input_ids=tokens,
    text_attention_mask=torch.ones(tokens.size(), device=device),
)

print(f'WhiSPA Embedding: {emb.shape}')
Downloads last month
7
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.

Model tree for Jarhatz/whispa_394_v1

Finetuned
(1384)
this model