Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -10,13 +10,55 @@ import io
|
|
10 |
import matplotlib.pyplot as plt
|
11 |
import librosa.display
|
12 |
from PIL import Image # For image conversion
|
|
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# Mapping from emotion labels to emojis
|
22 |
emotion_to_emoji = {
|
@@ -35,7 +77,7 @@ def add_emoji_to_label(label):
|
|
35 |
emoji = emotion_to_emoji.get(label.lower(), "")
|
36 |
return f"{label.capitalize()} {emoji}"
|
37 |
|
38 |
-
# Load the pre-trained SpeechBrain classifier
|
39 |
classifier = foreign_class(
|
40 |
source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
|
41 |
pymodule_file="custom_interface.py",
|
@@ -52,6 +94,11 @@ def preprocess_audio(audio_file, apply_noise_reduction=False):
|
|
52 |
Saves the processed audio to a temporary file and returns its path.
|
53 |
"""
|
54 |
y, sr = librosa.load(audio_file, sr=16000, mono=True)
|
|
|
|
|
|
|
|
|
|
|
55 |
if apply_noise_reduction and NOISEREDUCE_AVAILABLE:
|
56 |
y = nr.reduce_noise(y=y, sr=sr)
|
57 |
if np.max(np.abs(y)) > 0:
|
@@ -110,7 +157,7 @@ def predict_emotion(audio_file, use_ensemble=False, apply_noise_reduction=False,
|
|
110 |
result = classifier.classify_file(temp_file)
|
111 |
os.remove(temp_file)
|
112 |
if isinstance(result, tuple) and len(result) > 3:
|
113 |
-
label = result[3][0]
|
114 |
else:
|
115 |
label = str(result)
|
116 |
return add_emoji_to_label(label.lower())
|
@@ -134,39 +181,52 @@ def plot_waveform(audio_file):
|
|
134 |
def predict_and_plot(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap):
|
135 |
"""
|
136 |
Run emotion prediction and generate a waveform plot.
|
|
|
137 |
Returns a tuple: (emotion label with emoji, waveform image as a PIL Image).
|
138 |
"""
|
|
|
|
|
|
|
|
|
|
|
139 |
emotion = predict_emotion(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap)
|
140 |
waveform = plot_waveform(audio_file)
|
|
|
|
|
|
|
|
|
|
|
141 |
return emotion, waveform
|
142 |
|
|
|
|
|
|
|
143 |
with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: Arial;}") as demo:
|
144 |
gr.Markdown("<h1 style='text-align: center;'>Enhanced Emotion Recognition</h1>")
|
145 |
gr.Markdown(
|
146 |
"Upload an audio file, and the model will predict the emotion using a wav2vec2 model fine-tuned on IEMOCAP data. "
|
147 |
-
"The prediction is accompanied by an emoji
|
148 |
-
"
|
149 |
)
|
150 |
|
151 |
with gr.Tabs():
|
152 |
with gr.TabItem("Emotion Recognition"):
|
153 |
with gr.Row():
|
154 |
audio_input = gr.Audio(type="filepath", label="Upload Audio")
|
155 |
-
|
156 |
-
|
157 |
with gr.Row():
|
158 |
-
|
159 |
-
|
160 |
predict_button = gr.Button("Predict Emotion")
|
161 |
result_text = gr.Textbox(label="Predicted Emotion")
|
162 |
waveform_image = gr.Image(label="Audio Waveform", type="pil")
|
163 |
|
164 |
predict_button.click(
|
165 |
predict_and_plot,
|
166 |
-
inputs=[audio_input,
|
167 |
outputs=[result_text, waveform_image]
|
168 |
)
|
169 |
-
|
170 |
with gr.TabItem("About"):
|
171 |
gr.Markdown("""
|
172 |
**Enhanced Emotion Recognition App**
|
@@ -176,7 +236,8 @@ with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: A
|
|
176 |
- Ensemble Prediction for long audio files.
|
177 |
- Optional Noise Reduction.
|
178 |
- Visualization of the audio waveform.
|
179 |
-
- Emoji representation of the predicted emotion
|
|
|
180 |
|
181 |
**Credits:**
|
182 |
- [SpeechBrain](https://speechbrain.github.io)
|
|
|
10 |
import matplotlib.pyplot as plt
|
11 |
import librosa.display
|
12 |
from PIL import Image # For image conversion
|
13 |
+
from datetime import datetime
|
14 |
|
15 |
+
# ---------------------------
|
16 |
+
# Firebase Setup
|
17 |
+
# ---------------------------
|
18 |
+
import firebase_admin
|
19 |
+
from firebase_admin import credentials, db
|
20 |
+
from google.cloud import storage
|
21 |
+
|
22 |
+
# Update the path below to your Firebase service account key JSON file
|
23 |
+
SERVICE_ACCOUNT_KEY = "path/to/serviceAccountKey.json" # <-- update this!
|
24 |
+
|
25 |
+
# Initialize Firebase Admin for Realtime Database
|
26 |
+
cred = credentials.Certificate(SERVICE_ACCOUNT_KEY)
|
27 |
+
firebase_admin.initialize_app(cred, {
|
28 |
+
'databaseURL': 'https://your-database-name.firebaseio.com/' # <-- update with your DB URL
|
29 |
+
})
|
30 |
+
|
31 |
+
def upload_file_to_firebase(file_path, destination_blob_name):
|
32 |
+
"""
|
33 |
+
Uploads a file to Firebase Storage and returns its public URL.
|
34 |
+
"""
|
35 |
+
# Update bucket name (usually: your-project-id.appspot.com)
|
36 |
+
bucket_name = "your-project-id.appspot.com" # <-- update this!
|
37 |
+
storage_client = storage.Client.from_service_account_json(SERVICE_ACCOUNT_KEY)
|
38 |
+
bucket = storage_client.bucket(bucket_name)
|
39 |
+
blob = bucket.blob(destination_blob_name)
|
40 |
+
blob.upload_from_filename(file_path)
|
41 |
+
blob.make_public()
|
42 |
+
print(f"File uploaded to {blob.public_url}")
|
43 |
+
return blob.public_url
|
44 |
+
|
45 |
+
def store_prediction_metadata(file_url, predicted_emotion):
|
46 |
+
"""
|
47 |
+
Stores the file URL, predicted emotion, and timestamp in Firebase Realtime Database.
|
48 |
+
"""
|
49 |
+
ref = db.reference('predictions')
|
50 |
+
data = {
|
51 |
+
'file_url': file_url,
|
52 |
+
'predicted_emotion': predicted_emotion,
|
53 |
+
'timestamp': datetime.now().isoformat()
|
54 |
+
}
|
55 |
+
new_record_ref = ref.push(data)
|
56 |
+
print(f"Stored metadata with key: {new_record_ref.key}")
|
57 |
+
return new_record_ref.key
|
58 |
+
|
59 |
+
# ---------------------------
|
60 |
+
# Emotion Recognition Code
|
61 |
+
# ---------------------------
|
62 |
|
63 |
# Mapping from emotion labels to emojis
|
64 |
emotion_to_emoji = {
|
|
|
77 |
emoji = emotion_to_emoji.get(label.lower(), "")
|
78 |
return f"{label.capitalize()} {emoji}"
|
79 |
|
80 |
+
# Load the pre-trained SpeechBrain classifier (Emotion Recognition with wav2vec2 on IEMOCAP)
|
81 |
classifier = foreign_class(
|
82 |
source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
|
83 |
pymodule_file="custom_interface.py",
|
|
|
94 |
Saves the processed audio to a temporary file and returns its path.
|
95 |
"""
|
96 |
y, sr = librosa.load(audio_file, sr=16000, mono=True)
|
97 |
+
try:
|
98 |
+
import noisereduce as nr
|
99 |
+
NOISEREDUCE_AVAILABLE = True
|
100 |
+
except ImportError:
|
101 |
+
NOISEREDUCE_AVAILABLE = False
|
102 |
if apply_noise_reduction and NOISEREDUCE_AVAILABLE:
|
103 |
y = nr.reduce_noise(y=y, sr=sr)
|
104 |
if np.max(np.abs(y)) > 0:
|
|
|
157 |
result = classifier.classify_file(temp_file)
|
158 |
os.remove(temp_file)
|
159 |
if isinstance(result, tuple) and len(result) > 3:
|
160 |
+
label = result[3][0]
|
161 |
else:
|
162 |
label = str(result)
|
163 |
return add_emoji_to_label(label.lower())
|
|
|
181 |
def predict_and_plot(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap):
|
182 |
"""
|
183 |
Run emotion prediction and generate a waveform plot.
|
184 |
+
Additionally, upload the audio file to Firebase Storage and store the metadata in Firebase Realtime Database.
|
185 |
Returns a tuple: (emotion label with emoji, waveform image as a PIL Image).
|
186 |
"""
|
187 |
+
# Upload the original audio file to Firebase Storage
|
188 |
+
destination_blob_name = os.path.basename(audio_file)
|
189 |
+
file_url = upload_file_to_firebase(audio_file, destination_blob_name)
|
190 |
+
|
191 |
+
# Predict emotion and generate waveform
|
192 |
emotion = predict_emotion(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap)
|
193 |
waveform = plot_waveform(audio_file)
|
194 |
+
|
195 |
+
# Store metadata (file URL and predicted emotion) in Firebase Realtime Database
|
196 |
+
record_key = store_prediction_metadata(file_url, emotion)
|
197 |
+
print(f"Record stored with key: {record_key}")
|
198 |
+
|
199 |
return emotion, waveform
|
200 |
|
201 |
+
# ---------------------------
|
202 |
+
# Gradio App UI
|
203 |
+
# ---------------------------
|
204 |
with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: Arial;}") as demo:
|
205 |
gr.Markdown("<h1 style='text-align: center;'>Enhanced Emotion Recognition</h1>")
|
206 |
gr.Markdown(
|
207 |
"Upload an audio file, and the model will predict the emotion using a wav2vec2 model fine-tuned on IEMOCAP data. "
|
208 |
+
"The prediction is accompanied by an emoji, and you can view the audio's waveform. "
|
209 |
+
"The audio file and prediction metadata are stored in Firebase Realtime Database."
|
210 |
)
|
211 |
|
212 |
with gr.Tabs():
|
213 |
with gr.TabItem("Emotion Recognition"):
|
214 |
with gr.Row():
|
215 |
audio_input = gr.Audio(type="filepath", label="Upload Audio")
|
216 |
+
use_ensemble_checkbox = gr.Checkbox(label="Use Ensemble Prediction (for long audio)", value=False)
|
217 |
+
apply_noise_reduction_checkbox = gr.Checkbox(label="Apply Noise Reduction", value=False)
|
218 |
with gr.Row():
|
219 |
+
segment_duration_slider = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=3.0, label="Segment Duration (s)")
|
220 |
+
overlap_slider = gr.Slider(minimum=0.0, maximum=5.0, step=0.5, value=1.0, label="Segment Overlap (s)")
|
221 |
predict_button = gr.Button("Predict Emotion")
|
222 |
result_text = gr.Textbox(label="Predicted Emotion")
|
223 |
waveform_image = gr.Image(label="Audio Waveform", type="pil")
|
224 |
|
225 |
predict_button.click(
|
226 |
predict_and_plot,
|
227 |
+
inputs=[audio_input, use_ensemble_checkbox, apply_noise_reduction_checkbox, segment_duration_slider, overlap_slider],
|
228 |
outputs=[result_text, waveform_image]
|
229 |
)
|
|
|
230 |
with gr.TabItem("About"):
|
231 |
gr.Markdown("""
|
232 |
**Enhanced Emotion Recognition App**
|
|
|
236 |
- Ensemble Prediction for long audio files.
|
237 |
- Optional Noise Reduction.
|
238 |
- Visualization of the audio waveform.
|
239 |
+
- Emoji representation of the predicted emotion.
|
240 |
+
- Audio file and prediction metadata stored in Firebase Realtime Database.
|
241 |
|
242 |
**Credits:**
|
243 |
- [SpeechBrain](https://speechbrain.github.io)
|