mmesa-gitex / tabs /speech_emotion_recognition.py
vitorcalvi's picture
12 Oct Gitex 2024
b20a621
# tabs/speech_emotion_recognition.py
import gradio as gr
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
from transformers import pipeline
import torch
import tempfile
import warnings
import os
# Suppress specific warnings from transformers if needed
warnings.filterwarnings("ignore", category=UserWarning, module='transformers')
# Determine the device
def get_device():
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Using MPS device for inference.")
elif torch.cuda.is_available():
device = torch.device("cuda")
print("Using CUDA device for inference.")
else:
device = torch.device("cpu")
print("Using CPU for inference.")
return device
device = get_device()
# Initialize the pipelines with the specified device
try:
emotion_model = pipeline(
"audio-classification",
model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition",
device=0 if device.type == "cuda" else ("mps" if device.type == "mps" else -1)
)
print("Emotion model loaded successfully.")
except Exception as e:
print(f"Error loading emotion model: {e}")
emotion_model = None
try:
transcription_model = pipeline(
"automatic-speech-recognition",
model="facebook/wav2vec2-base-960h",
device=0 if device.type == "cuda" else ("mps" if device.type == "mps" else -1)
)
print("Transcription model loaded successfully.")
except Exception as e:
print(f"Error loading transcription model: {e}")
transcription_model = None
# Emotion Mapping
emotion_mapping = {
"angry": (0.8, 0.8, -0.5),
"happy": (0.6, 0.6, 0.8),
"sad": (-0.6, -0.4, -0.6),
"neutral": (0, 0, 0),
"fear": (0.3, -0.3, -0.7),
"surprise": (0.4, 0.2, 0.2),
"disgust": (0.2, 0.5, -0.6),
"calm": (-0.2, 0.1, 0.3),
"excited": (0.7, 0.5, 0.7),
"frustrated": (0.6, 0.5, -0.4)
}
def process_audio_emotion(audio_file):
"""
Processes the input audio file to perform transcription and emotion recognition.
Generates waveform and mel spectrogram plots.
Returns:
A tuple containing:
- Transcription (str)
- Emotion (str)
- Confidence (%) (float)
- Arousal (float)
- Dominance (float)
- Valence (float)
- Waveform Plot (str: filepath)
- Mel Spectrogram Plot (str: filepath)
"""
if not audio_file:
return (
"No audio file provided.", # Transcription (textbox)
None, # Emotion (textbox)
None, # Confidence (%) (number)
None, # Arousal (number)
None, # Dominance (number)
None, # Valence (number)
None, # Waveform Plot (image)
None # Mel Spectrogram Plot (image)
)
try:
y, sr = librosa.load(audio_file, sr=None)
# Transcription
if transcription_model:
transcription_result = transcription_model(audio_file)
transcription = transcription_result.get("text", "N/A")
else:
transcription = "Transcription model not loaded."
# Emotion Recognition
if emotion_model:
emotion_results = emotion_model(audio_file)
if emotion_results:
emotion_result = emotion_results[0]
emotion = emotion_result.get("label", "Unknown").lower()
confidence = emotion_result.get("score", 0.0) * 100 # Convert to percentage
arousal, dominance, valence = emotion_mapping.get(emotion, (0.0, 0.0, 0.0))
else:
emotion = "No emotion detected."
confidence = 0.0
arousal, dominance, valence = 0.0, 0.0, 0.0
else:
emotion = "Emotion model not loaded."
confidence = 0.0
arousal, dominance, valence = 0.0, 0.0, 0.0
# Plotting Waveform
plt.figure(figsize=(10, 4))
librosa.display.waveshow(y, sr=sr)
plt.title("Waveform")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_waveform:
plt.savefig(tmp_waveform.name, bbox_inches='tight')
waveform_plot_path = tmp_waveform.name
plt.close()
# Plotting Mel Spectrogram
mel_spec = librosa.feature.melspectrogram(y=y, sr=sr)
plt.figure(figsize=(10, 4))
librosa.display.specshow(librosa.power_to_db(mel_spec, ref=np.max), sr=sr, x_axis='time', y_axis='mel')
plt.colorbar(format='%+2.0f dB')
plt.title("Mel Spectrogram")
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_mel:
plt.savefig(tmp_mel.name, bbox_inches='tight')
mel_spec_plot_path = tmp_mel.name
plt.close()
return (
transcription, # Transcription (textbox)
emotion.capitalize(), # Emotion (textbox)
confidence, # Confidence (%) (number)
arousal, # Arousal (number)
dominance, # Dominance (number)
valence, # Valence (number)
waveform_plot_path, # Waveform Plot (image)
mel_spec_plot_path # Mel Spectrogram Plot (image)
)
except Exception as e:
return (
f"Error: {str(e)}", # Transcription (textbox)
None, # Emotion (textbox)
None, # Confidence (%) (number)
None, # Arousal (number)
None, # Dominance (number)
None, # Valence (number)
None, # Waveform Plot (image)
None # Mel Spectrogram Plot (image)
)
def create_emotion_recognition_tab():
"""
Creates the Emotion Recognition tab in the Gradio interface.
"""
with gr.Row():
with gr.Column(scale=2):
input_audio = gr.Audio(label="Input Audio", type="filepath")
gr.Examples(
examples=["./assets/audio/fitness.wav"],
inputs=[input_audio],
label="Examples"
)
with gr.Column(scale=1):
transcription_output = gr.Textbox(label="Transcription", interactive=False)
emotion_output = gr.Textbox(label="Emotion", interactive=False)
confidence_output = gr.Number(label="Confidence (%)", interactive=False)
arousal_output = gr.Number(label="Arousal (Level of Energy)", interactive=False)
dominance_output = gr.Number(label="Dominance (Degree of Control)", interactive=False)
valence_output = gr.Number(label="Valence (Positivity/Negativity)", interactive=False)
with gr.Column(scale=1):
waveform_plot = gr.Image(label="Waveform")
mel_spec_plot = gr.Image(label="Mel Spectrogram")
input_audio.change(
fn=process_audio_emotion,
inputs=[input_audio],
outputs=[
transcription_output,
emotion_output,
confidence_output,
arousal_output,
dominance_output,
valence_output,
waveform_plot,
mel_spec_plot
]
)
# Call create_emotion_recognition_tab to create the Gradio interface