Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSequenceClassification, AutoTokenizer | |
import librosa | |
import numpy as np | |
import plotly.graph_objects as go | |
import warnings | |
import os | |
warnings.filterwarnings('ignore') | |
# Global variables for models | |
processor = None | |
whisper_model = None | |
emotion_tokenizer = None | |
emotion_model = None | |
def load_models(): | |
"""Initialize and load all required models""" | |
global processor, whisper_model, emotion_tokenizer, emotion_model | |
try: | |
print("Loading Whisper model...") | |
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") | |
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") | |
print("Loading emotion model...") | |
emotion_tokenizer = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base") | |
emotion_model = AutoModelForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base") | |
# Move models to CPU explicitly | |
whisper_model.to("cpu") | |
emotion_model.to("cpu") | |
print("Models loaded successfully!") | |
return True | |
except Exception as e: | |
print(f"Error loading models: {str(e)}") | |
return False | |
def process_audio(audio_input): | |
"""Process audio file and extract waveform""" | |
try: | |
print(f"Audio input received: {type(audio_input)}") | |
# Handle tuple input from Gradio | |
if isinstance(audio_input, tuple): | |
print(f"Audio input is tuple: {audio_input[0]}, {audio_input[1]}") | |
audio_path = audio_input[0] # Get the file path | |
else: | |
audio_path = audio_input | |
print(f"Processing audio from path: {audio_path}") | |
# Verify file exists | |
if not os.path.exists(audio_path): | |
raise FileNotFoundError(f"Audio file not found at {audio_path}") | |
# Load and resample audio | |
print("Loading audio file with librosa...") | |
waveform, sr = librosa.load(audio_path, sr=16000) | |
print(f"Audio loaded successfully. Shape: {waveform.shape}, SR: {sr}") | |
return waveform | |
except Exception as e: | |
print(f"Error processing audio: {str(e)}") | |
raise | |
def create_emotion_plot(emotions): | |
"""Create plotly visualization for emotion scores""" | |
try: | |
fig = go.Figure(data=[ | |
go.Bar( | |
x=list(emotions.keys()), | |
y=list(emotions.values()), | |
marker_color='rgb(55, 83, 109)' | |
) | |
]) | |
fig.update_layout( | |
title='Emotion Analysis', | |
xaxis_title='Emotion', | |
yaxis_title='Score', | |
yaxis_range=[0, 1], | |
template='plotly_white', | |
height=400 | |
) | |
return fig.to_html(include_plotlyjs=True) | |
except Exception as e: | |
print(f"Error creating plot: {str(e)}") | |
return "Error creating visualization" | |
def analyze_audio(audio_input): | |
"""Main function to analyze audio input""" | |
try: | |
if audio_input is None: | |
print("No audio input provided") | |
return "No audio file provided", "Please provide an audio file" | |
print(f"Received audio input: {audio_input}") | |
# Process audio | |
waveform = process_audio(audio_input) | |
if waveform is None or len(waveform) == 0: | |
return "Error: Invalid audio file", "Please provide a valid audio file" | |
# Transcribe audio | |
print("Transcribing audio...") | |
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt").input_features | |
with torch.no_grad(): | |
predicted_ids = whisper_model.generate(inputs) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
print(f"Transcription completed: {transcription}") | |
if not transcription or transcription.isspace(): | |
return "No speech detected in audio", "Unable to analyze emotions without speech" | |
# Analyze emotions | |
print("Analyzing emotions...") | |
inputs = emotion_tokenizer( | |
transcription, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
) | |
with torch.no_grad(): | |
outputs = emotion_model(**inputs) | |
emotions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise'] | |
emotion_scores = { | |
label: float(score) | |
for label, score in zip(emotion_labels, emotions[0].cpu().numpy()) | |
} | |
print(f"Emotion analysis completed: {emotion_scores}") | |
# Create visualization | |
emotion_viz = create_emotion_plot(emotion_scores) | |
return transcription, emotion_viz | |
except FileNotFoundError as e: | |
error_msg = f"Audio file not found: {str(e)}" | |
print(error_msg) | |
return error_msg, "Please provide a valid audio file" | |
except Exception as e: | |
error_msg = f"Error analyzing audio: {str(e)}" | |
print(error_msg) | |
return error_msg, "Error in analysis" | |
# Load models at startup | |
print("Initializing application...") | |
if not load_models(): | |
raise RuntimeError("Failed to load required models") | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=analyze_audio, | |
inputs=gr.Audio( | |
source="microphone", | |
type="filepath", | |
label="Audio Input" | |
), | |
outputs=[ | |
gr.Textbox(label="Transcription"), | |
gr.HTML(label="Emotion Analysis") | |
], | |
title="Vocal Emotion Analysis", | |
description=""" | |
This app analyzes voice recordings to: | |
1. Transcribe speech to text | |
2. Detect emotions in the speech | |
Upload an audio file or record directly through your microphone. | |
""", | |
article=""" | |
Models used: | |
- Speech recognition: Whisper (tiny) | |
- Emotion detection: DistilRoBERTa | |
Note: Processing may take a few moments depending on the length of the audio. | |
""", | |
examples=None, | |
cache_examples=False | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |