|
import gradio as gr |
|
import torch |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSequenceClassification, AutoTokenizer |
|
import librosa |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
import warnings |
|
import os |
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
processor = None |
|
whisper_model = None |
|
emotion_tokenizer = None |
|
emotion_model = None |
|
|
|
def load_models(): |
|
"""Initialize and load all required models""" |
|
global processor, whisper_model, emotion_tokenizer, emotion_model |
|
|
|
try: |
|
print("Loading Whisper model...") |
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") |
|
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") |
|
|
|
print("Loading emotion model...") |
|
emotion_tokenizer = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base") |
|
emotion_model = AutoModelForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base") |
|
|
|
|
|
whisper_model.to("cpu") |
|
emotion_model.to("cpu") |
|
|
|
print("Models loaded successfully!") |
|
return True |
|
except Exception as e: |
|
print(f"Error loading models: {str(e)}") |
|
return False |
|
|
|
def process_audio(audio_input): |
|
"""Process audio file and extract waveform""" |
|
try: |
|
print(f"Audio input received: {type(audio_input)}") |
|
|
|
|
|
if isinstance(audio_input, tuple): |
|
print(f"Audio input is tuple: {audio_input[0]}, {audio_input[1]}") |
|
audio_path = audio_input[0] |
|
else: |
|
audio_path = audio_input |
|
|
|
print(f"Processing audio from path: {audio_path}") |
|
|
|
|
|
if not os.path.exists(audio_path): |
|
raise FileNotFoundError(f"Audio file not found at {audio_path}") |
|
|
|
|
|
print("Loading audio file with librosa...") |
|
waveform, sr = librosa.load(audio_path, sr=16000) |
|
print(f"Audio loaded successfully. Shape: {waveform.shape}, SR: {sr}") |
|
|
|
return waveform |
|
except Exception as e: |
|
print(f"Error processing audio: {str(e)}") |
|
raise |
|
|
|
def create_emotion_plot(emotions): |
|
"""Create plotly visualization for emotion scores""" |
|
try: |
|
fig = go.Figure(data=[ |
|
go.Bar( |
|
x=list(emotions.keys()), |
|
y=list(emotions.values()), |
|
marker_color='rgb(55, 83, 109)' |
|
) |
|
]) |
|
|
|
fig.update_layout( |
|
title='Emotion Analysis', |
|
xaxis_title='Emotion', |
|
yaxis_title='Score', |
|
yaxis_range=[0, 1], |
|
template='plotly_white', |
|
height=400 |
|
) |
|
|
|
return fig.to_html(include_plotlyjs=True) |
|
except Exception as e: |
|
print(f"Error creating plot: {str(e)}") |
|
return "Error creating visualization" |
|
|
|
def analyze_audio(audio_input): |
|
"""Main function to analyze audio input""" |
|
try: |
|
if audio_input is None: |
|
print("No audio input provided") |
|
return "No audio file provided", "Please provide an audio file" |
|
|
|
print(f"Received audio input: {audio_input}") |
|
|
|
|
|
waveform = process_audio(audio_input) |
|
|
|
if waveform is None or len(waveform) == 0: |
|
return "Error: Invalid audio file", "Please provide a valid audio file" |
|
|
|
|
|
print("Transcribing audio...") |
|
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt").input_features |
|
|
|
with torch.no_grad(): |
|
predicted_ids = whisper_model.generate(inputs) |
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
|
|
|
print(f"Transcription completed: {transcription}") |
|
|
|
if not transcription or transcription.isspace(): |
|
return "No speech detected in audio", "Unable to analyze emotions without speech" |
|
|
|
|
|
print("Analyzing emotions...") |
|
inputs = emotion_tokenizer( |
|
transcription, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=512 |
|
) |
|
|
|
with torch.no_grad(): |
|
outputs = emotion_model(**inputs) |
|
emotions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
|
emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise'] |
|
emotion_scores = { |
|
label: float(score) |
|
for label, score in zip(emotion_labels, emotions[0].cpu().numpy()) |
|
} |
|
|
|
print(f"Emotion analysis completed: {emotion_scores}") |
|
|
|
|
|
emotion_viz = create_emotion_plot(emotion_scores) |
|
|
|
return transcription, emotion_viz |
|
|
|
except FileNotFoundError as e: |
|
error_msg = f"Audio file not found: {str(e)}" |
|
print(error_msg) |
|
return error_msg, "Please provide a valid audio file" |
|
except Exception as e: |
|
error_msg = f"Error analyzing audio: {str(e)}" |
|
print(error_msg) |
|
return error_msg, "Error in analysis" |
|
|
|
|
|
print("Initializing application...") |
|
if not load_models(): |
|
raise RuntimeError("Failed to load required models") |
|
|
|
|
|
demo = gr.Interface( |
|
fn=analyze_audio, |
|
inputs=gr.Audio( |
|
source="microphone", |
|
type="filepath", |
|
label="Audio Input" |
|
), |
|
outputs=[ |
|
gr.Textbox(label="Transcription"), |
|
gr.HTML(label="Emotion Analysis") |
|
], |
|
title="Vocal Emotion Analysis", |
|
description=""" |
|
This app analyzes voice recordings to: |
|
1. Transcribe speech to text |
|
2. Detect emotions in the speech |
|
|
|
Upload an audio file or record directly through your microphone. |
|
""", |
|
article=""" |
|
Models used: |
|
- Speech recognition: Whisper (tiny) |
|
- Emotion detection: DistilRoBERTa |
|
|
|
Note: Processing may take a few moments depending on the length of the audio. |
|
""", |
|
examples=None, |
|
cache_examples=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |