Spaces:
Build error
Build error
# tabs/speech_emotion_recognition.py | |
import gradio as gr | |
import numpy as np | |
import librosa | |
import librosa.display | |
import matplotlib.pyplot as plt | |
from transformers import pipeline | |
import torch | |
import tempfile | |
import warnings | |
import os | |
# Suppress specific warnings from transformers if needed | |
warnings.filterwarnings("ignore", category=UserWarning, module='transformers') | |
# Determine the device | |
def get_device(): | |
if torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
print("Using MPS device for inference.") | |
elif torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print("Using CUDA device for inference.") | |
else: | |
device = torch.device("cpu") | |
print("Using CPU for inference.") | |
return device | |
device = get_device() | |
# Initialize the pipelines with the specified device | |
try: | |
emotion_model = pipeline( | |
"audio-classification", | |
model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", | |
device=0 if device.type == "cuda" else ("mps" if device.type == "mps" else -1) | |
) | |
print("Emotion model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading emotion model: {e}") | |
emotion_model = None | |
try: | |
transcription_model = pipeline( | |
"automatic-speech-recognition", | |
model="facebook/wav2vec2-base-960h", | |
device=0 if device.type == "cuda" else ("mps" if device.type == "mps" else -1) | |
) | |
print("Transcription model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading transcription model: {e}") | |
transcription_model = None | |
# Emotion Mapping | |
emotion_mapping = { | |
"angry": (0.8, 0.8, -0.5), | |
"happy": (0.6, 0.6, 0.8), | |
"sad": (-0.6, -0.4, -0.6), | |
"neutral": (0, 0, 0), | |
"fear": (0.3, -0.3, -0.7), | |
"surprise": (0.4, 0.2, 0.2), | |
"disgust": (0.2, 0.5, -0.6), | |
"calm": (-0.2, 0.1, 0.3), | |
"excited": (0.7, 0.5, 0.7), | |
"frustrated": (0.6, 0.5, -0.4) | |
} | |
def process_audio_emotion(audio_file): | |
""" | |
Processes the input audio file to perform transcription and emotion recognition. | |
Generates waveform and mel spectrogram plots. | |
Returns: | |
A tuple containing: | |
- Transcription (str) | |
- Emotion (str) | |
- Confidence (%) (float) | |
- Arousal (float) | |
- Dominance (float) | |
- Valence (float) | |
- Waveform Plot (str: filepath) | |
- Mel Spectrogram Plot (str: filepath) | |
""" | |
if not audio_file: | |
return ( | |
"No audio file provided.", # Transcription (textbox) | |
None, # Emotion (textbox) | |
None, # Confidence (%) (number) | |
None, # Arousal (number) | |
None, # Dominance (number) | |
None, # Valence (number) | |
None, # Waveform Plot (image) | |
None # Mel Spectrogram Plot (image) | |
) | |
try: | |
y, sr = librosa.load(audio_file, sr=None) | |
# Transcription | |
if transcription_model: | |
transcription_result = transcription_model(audio_file) | |
transcription = transcription_result.get("text", "N/A") | |
else: | |
transcription = "Transcription model not loaded." | |
# Emotion Recognition | |
if emotion_model: | |
emotion_results = emotion_model(audio_file) | |
if emotion_results: | |
emotion_result = emotion_results[0] | |
emotion = emotion_result.get("label", "Unknown").lower() | |
confidence = emotion_result.get("score", 0.0) * 100 # Convert to percentage | |
arousal, dominance, valence = emotion_mapping.get(emotion, (0.0, 0.0, 0.0)) | |
else: | |
emotion = "No emotion detected." | |
confidence = 0.0 | |
arousal, dominance, valence = 0.0, 0.0, 0.0 | |
else: | |
emotion = "Emotion model not loaded." | |
confidence = 0.0 | |
arousal, dominance, valence = 0.0, 0.0, 0.0 | |
# Plotting Waveform | |
plt.figure(figsize=(10, 4)) | |
librosa.display.waveshow(y, sr=sr) | |
plt.title("Waveform") | |
plt.xlabel("Time (s)") | |
plt.ylabel("Amplitude") | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_waveform: | |
plt.savefig(tmp_waveform.name, bbox_inches='tight') | |
waveform_plot_path = tmp_waveform.name | |
plt.close() | |
# Plotting Mel Spectrogram | |
mel_spec = librosa.feature.melspectrogram(y=y, sr=sr) | |
plt.figure(figsize=(10, 4)) | |
librosa.display.specshow(librosa.power_to_db(mel_spec, ref=np.max), sr=sr, x_axis='time', y_axis='mel') | |
plt.colorbar(format='%+2.0f dB') | |
plt.title("Mel Spectrogram") | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_mel: | |
plt.savefig(tmp_mel.name, bbox_inches='tight') | |
mel_spec_plot_path = tmp_mel.name | |
plt.close() | |
return ( | |
transcription, # Transcription (textbox) | |
emotion.capitalize(), # Emotion (textbox) | |
confidence, # Confidence (%) (number) | |
arousal, # Arousal (number) | |
dominance, # Dominance (number) | |
valence, # Valence (number) | |
waveform_plot_path, # Waveform Plot (image) | |
mel_spec_plot_path # Mel Spectrogram Plot (image) | |
) | |
except Exception as e: | |
return ( | |
f"Error: {str(e)}", # Transcription (textbox) | |
None, # Emotion (textbox) | |
None, # Confidence (%) (number) | |
None, # Arousal (number) | |
None, # Dominance (number) | |
None, # Valence (number) | |
None, # Waveform Plot (image) | |
None # Mel Spectrogram Plot (image) | |
) | |
def create_emotion_recognition_tab(): | |
""" | |
Creates the Emotion Recognition tab in the Gradio interface. | |
""" | |
with gr.Row(): | |
with gr.Column(scale=2): | |
input_audio = gr.Audio(label="Input Audio", type="filepath") | |
gr.Examples( | |
examples=["./assets/audio/fitness.wav"], | |
inputs=[input_audio], | |
label="Examples" | |
) | |
with gr.Column(scale=1): | |
transcription_output = gr.Textbox(label="Transcription", interactive=False) | |
emotion_output = gr.Textbox(label="Emotion", interactive=False) | |
confidence_output = gr.Number(label="Confidence (%)", interactive=False) | |
arousal_output = gr.Number(label="Arousal (Level of Energy)", interactive=False) | |
dominance_output = gr.Number(label="Dominance (Degree of Control)", interactive=False) | |
valence_output = gr.Number(label="Valence (Positivity/Negativity)", interactive=False) | |
with gr.Column(scale=1): | |
waveform_plot = gr.Image(label="Waveform") | |
mel_spec_plot = gr.Image(label="Mel Spectrogram") | |
input_audio.change( | |
fn=process_audio_emotion, | |
inputs=[input_audio], | |
outputs=[ | |
transcription_output, | |
emotion_output, | |
confidence_output, | |
arousal_output, | |
dominance_output, | |
valence_output, | |
waveform_plot, | |
mel_spec_plot | |
] | |
) | |
# Call create_emotion_recognition_tab to create the Gradio interface | |