|
import gradio as gr |
|
import torch |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSequenceClassification, AutoTokenizer |
|
import librosa |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
|
|
class ModelManager: |
|
def __init__(self): |
|
self.device = torch.device("cpu") |
|
self.models = {} |
|
self.tokenizers = {} |
|
self.processors = {} |
|
self.load_models() |
|
|
|
def load_models(self): |
|
try: |
|
print("Loading Whisper model...") |
|
self.processors['whisper'] = WhisperProcessor.from_pretrained( |
|
"openai/whisper-base" |
|
) |
|
self.models['whisper'] = WhisperForConditionalGeneration.from_pretrained( |
|
"openai/whisper-base" |
|
).to(self.device) |
|
|
|
print("Loading emotion model...") |
|
self.tokenizers['emotion'] = AutoTokenizer.from_pretrained( |
|
"j-hartmann/emotion-english-distilroberta-base" |
|
) |
|
self.models['emotion'] = AutoModelForSequenceClassification.from_pretrained( |
|
"j-hartmann/emotion-english-distilroberta-base" |
|
).to(self.device) |
|
|
|
print("Models loaded successfully") |
|
except Exception as e: |
|
print(f"Error loading models: {str(e)}") |
|
raise |
|
|
|
class AudioProcessor: |
|
def __init__(self): |
|
self.sample_rate = 16000 |
|
self.n_mfcc = 13 |
|
|
|
def process_audio(self, audio_path): |
|
try: |
|
waveform, sr = librosa.load(audio_path, sr=self.sample_rate) |
|
return waveform, self._extract_features(waveform) |
|
except Exception as e: |
|
print(f"Error processing audio: {str(e)}") |
|
raise |
|
|
|
def _extract_features(self, waveform): |
|
try: |
|
return { |
|
'mfcc': librosa.feature.mfcc(y=waveform, sr=self.sample_rate, n_mfcc=self.n_mfcc), |
|
'energy': librosa.feature.rms(y=waveform)[0] |
|
} |
|
except Exception as e: |
|
print(f"Error extracting features: {str(e)}") |
|
raise |
|
|
|
class Analyzer: |
|
def __init__(self): |
|
print("Initializing Analyzer...") |
|
try: |
|
self.model_manager = ModelManager() |
|
self.audio_processor = AudioProcessor() |
|
print("Analyzer initialization complete") |
|
except Exception as e: |
|
print(f"Error initializing Analyzer: {str(e)}") |
|
raise |
|
|
|
def analyze(self, audio_path): |
|
try: |
|
print(f"Processing audio file: {audio_path}") |
|
waveform, features = self.audio_processor.process_audio(audio_path) |
|
|
|
print("Transcribing audio...") |
|
inputs = self.model_manager.processors['whisper']( |
|
waveform, |
|
return_tensors="pt" |
|
).input_features.to(self.model_manager.device) |
|
|
|
with torch.no_grad(): |
|
predicted_ids = self.model_manager.models['whisper'].generate(inputs) |
|
transcription = self.model_manager.processors['whisper'].batch_decode( |
|
predicted_ids, |
|
skip_special_tokens=True |
|
)[0] |
|
|
|
print("Analyzing emotions...") |
|
inputs = self.model_manager.tokenizers['emotion']( |
|
transcription, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=512 |
|
) |
|
inputs = {k: v.to(self.model_manager.device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = self.model_manager.models['emotion'](**inputs) |
|
emotions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
|
emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise'] |
|
emotion_scores = { |
|
label: float(score) |
|
for label, score in zip(emotion_labels, emotions[0].cpu()) |
|
} |
|
|
|
return { |
|
'transcription': transcription, |
|
'emotions': emotion_scores |
|
} |
|
except Exception as e: |
|
print(f"Error in analysis: {str(e)}") |
|
raise |
|
|
|
def create_emotion_plot(emotions): |
|
try: |
|
fig = go.Figure(data=[ |
|
go.Bar( |
|
x=list(emotions.keys()), |
|
y=list(emotions.values()), |
|
marker_color='rgb(55, 83, 109)' |
|
) |
|
]) |
|
|
|
fig.update_layout( |
|
title='Emotion Analysis', |
|
xaxis_title='Emotion', |
|
yaxis_title='Score', |
|
yaxis_range=[0, 1], |
|
template='plotly_white', |
|
height=400 |
|
) |
|
|
|
return fig.to_html(include_plotlyjs=True) |
|
except Exception as e: |
|
print(f"Error creating plot: {str(e)}") |
|
return "Error creating visualization" |
|
|
|
def process_audio(audio_file): |
|
try: |
|
if audio_file is None: |
|
return "No audio file provided", "Please provide an audio file" |
|
|
|
print(f"Processing audio file: {audio_file}") |
|
results = analyzer.analyze(audio_file) |
|
|
|
return ( |
|
results['transcription'], |
|
create_emotion_plot(results['emotions']) |
|
) |
|
except Exception as e: |
|
error_msg = f"Error processing audio: {str(e)}" |
|
print(error_msg) |
|
return error_msg, "Error in analysis" |
|
|
|
if __name__ == "__main__": |
|
print("Initializing application...") |
|
try: |
|
analyzer = Analyzer() |
|
|
|
print("Creating Gradio interface...") |
|
interface = gr.Interface( |
|
fn=process_audio, |
|
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"), |
|
outputs=[ |
|
gr.Textbox(label="Transcription"), |
|
gr.HTML(label="Emotion Analysis") |
|
], |
|
title="Vocal Biomarker Analysis", |
|
description="Analyze voice for emotional indicators", |
|
examples=[], |
|
cache_examples=False |
|
) |
|
|
|
print("Launching application...") |
|
interface.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |
|
except Exception as e: |
|
print(f"Fatal error during application startup: {str(e)}") |
|
raise |