Boltz79's picture
Update app.py
c7241fb verified
raw
history blame
9.63 kB
# app.py
import gradio as gr
import librosa
import numpy as np
import os
import tempfile
from collections import Counter
from speechbrain.inference.interfaces import foreign_class
import io
import matplotlib.pyplot as plt
import librosa.display
from PIL import Image # For image conversion
from datetime import datetime
# ---------------------------
# Firebase Setup
# ---------------------------
import firebase_admin
from firebase_admin import credentials, db
from google.cloud import storage
# Update the path below to your Firebase service account key JSON file
SERVICE_ACCOUNT_KEY = "path/to/serviceAccountKey.json" # <-- update this!
# Initialize Firebase Admin for Realtime Database
cred = credentials.Certificate(SERVICE_ACCOUNT_KEY)
firebase_admin.initialize_app(cred, {
'databaseURL': 'https://your-database-name.firebaseio.com/' # <-- update with your DB URL
})
def upload_file_to_firebase(file_path, destination_blob_name):
"""
Uploads a file to Firebase Storage and returns its public URL.
"""
# Update bucket name (usually: your-project-id.appspot.com)
bucket_name = "your-project-id.appspot.com" # <-- update this!
storage_client = storage.Client.from_service_account_json(SERVICE_ACCOUNT_KEY)
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(destination_blob_name)
blob.upload_from_filename(file_path)
blob.make_public()
print(f"File uploaded to {blob.public_url}")
return blob.public_url
def store_prediction_metadata(file_url, predicted_emotion):
"""
Stores the file URL, predicted emotion, and timestamp in Firebase Realtime Database.
"""
ref = db.reference('predictions')
data = {
'file_url': file_url,
'predicted_emotion': predicted_emotion,
'timestamp': datetime.now().isoformat()
}
new_record_ref = ref.push(data)
print(f"Stored metadata with key: {new_record_ref.key}")
return new_record_ref.key
# ---------------------------
# Emotion Recognition Code
# ---------------------------
# Mapping from emotion labels to emojis
emotion_to_emoji = {
"angry": "😠",
"happy": "😊",
"sad": "😒",
"neutral": "😐",
"excited": "πŸ˜„",
"fear": "😨",
"disgust": "🀒",
"surprise": "😲"
}
def add_emoji_to_label(label):
"""Append an emoji corresponding to the emotion label."""
emoji = emotion_to_emoji.get(label.lower(), "")
return f"{label.capitalize()} {emoji}"
# Load the pre-trained SpeechBrain classifier (Emotion Recognition with wav2vec2 on IEMOCAP)
classifier = foreign_class(
source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
pymodule_file="custom_interface.py",
classname="CustomEncoderWav2vec2Classifier",
run_opts={"device": "cpu"} # Change to {"device": "cuda"} if GPU is available
)
def preprocess_audio(audio_file, apply_noise_reduction=False):
"""
Load and preprocess the audio file:
- Convert to 16kHz mono.
- Optionally apply noise reduction.
- Normalize the audio.
Saves the processed audio to a temporary file and returns its path.
"""
y, sr = librosa.load(audio_file, sr=16000, mono=True)
try:
import noisereduce as nr
NOISEREDUCE_AVAILABLE = True
except ImportError:
NOISEREDUCE_AVAILABLE = False
if apply_noise_reduction and NOISEREDUCE_AVAILABLE:
y = nr.reduce_noise(y=y, sr=sr)
if np.max(np.abs(y)) > 0:
y = y / np.max(np.abs(y))
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
import soundfile as sf
sf.write(temp_file.name, y, sr)
return temp_file.name
def ensemble_prediction(audio_file, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
"""
For longer audio files, split into overlapping segments, predict each segment,
and return the majority-voted emotion label.
"""
y, sr = librosa.load(audio_file, sr=16000, mono=True)
total_duration = librosa.get_duration(y=y, sr=sr)
if total_duration <= segment_duration:
temp_file = preprocess_audio(audio_file, apply_noise_reduction)
_, _, _, label = classifier.classify_file(temp_file)
os.remove(temp_file)
return label[0]
step = segment_duration - overlap
segments = []
for start in np.arange(0, total_duration - segment_duration + 0.001, step):
start_sample = int(start * sr)
end_sample = int((start + segment_duration) * sr)
segment_audio = y[start_sample:end_sample]
temp_seg = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
import soundfile as sf
sf.write(temp_seg.name, segment_audio, sr)
segments.append(temp_seg.name)
predictions = []
for seg in segments:
temp_file = preprocess_audio(seg, apply_noise_reduction)
_, _, _, label = classifier.classify_file(temp_file)
predictions.append(label[0])
os.remove(temp_file)
os.remove(seg)
vote = Counter(predictions)
most_common = vote.most_common(1)[0][0]
return most_common
def predict_emotion(audio_file, use_ensemble=False, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
"""
Predict emotion from an audio file and return the emotion with an emoji.
"""
try:
if use_ensemble:
label = ensemble_prediction(audio_file, apply_noise_reduction, segment_duration, overlap)
else:
temp_file = preprocess_audio(audio_file, apply_noise_reduction)
result = classifier.classify_file(temp_file)
os.remove(temp_file)
if isinstance(result, tuple) and len(result) > 3:
label = result[3][0]
else:
label = str(result)
return add_emoji_to_label(label.lower())
except Exception as e:
return f"Error processing file: {str(e)}"
def plot_waveform(audio_file):
"""
Generate and return a waveform plot image (as a PIL Image) for the given audio file.
"""
y, sr = librosa.load(audio_file, sr=16000, mono=True)
plt.figure(figsize=(10, 3))
librosa.display.waveshow(y, sr=sr)
plt.title("Waveform")
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close()
buf.seek(0)
return Image.open(buf)
def predict_and_plot(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap):
"""
Run emotion prediction and generate a waveform plot.
Additionally, upload the audio file to Firebase Storage and store the metadata in Firebase Realtime Database.
Returns a tuple: (emotion label with emoji, waveform image as a PIL Image).
"""
# Upload the original audio file to Firebase Storage
destination_blob_name = os.path.basename(audio_file)
file_url = upload_file_to_firebase(audio_file, destination_blob_name)
# Predict emotion and generate waveform
emotion = predict_emotion(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap)
waveform = plot_waveform(audio_file)
# Store metadata (file URL and predicted emotion) in Firebase Realtime Database
record_key = store_prediction_metadata(file_url, emotion)
print(f"Record stored with key: {record_key}")
return emotion, waveform
# ---------------------------
# Gradio App UI
# ---------------------------
with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: Arial;}") as demo:
gr.Markdown("<h1 style='text-align: center;'>Enhanced Emotion Recognition</h1>")
gr.Markdown(
"Upload an audio file, and the model will predict the emotion using a wav2vec2 model fine-tuned on IEMOCAP data. "
"The prediction is accompanied by an emoji, and you can view the audio's waveform. "
"The audio file and prediction metadata are stored in Firebase Realtime Database."
)
with gr.Tabs():
with gr.TabItem("Emotion Recognition"):
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Upload Audio")
use_ensemble_checkbox = gr.Checkbox(label="Use Ensemble Prediction (for long audio)", value=False)
apply_noise_reduction_checkbox = gr.Checkbox(label="Apply Noise Reduction", value=False)
with gr.Row():
segment_duration_slider = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=3.0, label="Segment Duration (s)")
overlap_slider = gr.Slider(minimum=0.0, maximum=5.0, step=0.5, value=1.0, label="Segment Overlap (s)")
predict_button = gr.Button("Predict Emotion")
result_text = gr.Textbox(label="Predicted Emotion")
waveform_image = gr.Image(label="Audio Waveform", type="pil")
predict_button.click(
predict_and_plot,
inputs=[audio_input, use_ensemble_checkbox, apply_noise_reduction_checkbox, segment_duration_slider, overlap_slider],
outputs=[result_text, waveform_image]
)
with gr.TabItem("About"):
gr.Markdown("""
**Enhanced Emotion Recognition App**
- **Model:** SpeechBrain's wav2vec2 model fine-tuned on IEMOCAP for emotion recognition.
- **Features:**
- Ensemble Prediction for long audio files.
- Optional Noise Reduction.
- Visualization of the audio waveform.
- Emoji representation of the predicted emotion.
- Audio file and prediction metadata stored in Firebase Realtime Database.
**Credits:**
- [SpeechBrain](https://speechbrain.github.io)
- [Gradio](https://gradio.app)
""")
if __name__ == "__main__":
demo.launch()