File size: 9,633 Bytes
cd578af
00ae0ce
def04d4
7539cee
 
 
 
 
53d1efd
 
 
cd578af
c7241fb
53d1efd
c7241fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53d1efd
 
 
 
 
 
 
 
 
 
 
 
 
 
6f98b5f
53d1efd
 
7539cee
c7241fb
7539cee
 
 
 
 
 
00ae0ce
7539cee
cd578af
 
 
 
 
 
 
7539cee
c7241fb
 
 
 
 
7539cee
 
 
 
 
 
 
 
cb9a254
7539cee
cd578af
 
 
 
7539cee
 
 
 
 
 
 
cc50c45
6f98b5f
7539cee
 
 
 
 
 
 
 
 
 
def04d4
7539cee
 
 
 
cd578af
7539cee
 
def04d4
7539cee
 
 
 
53d1efd
cd578af
 
 
7539cee
 
53d1efd
7539cee
 
4667629
7539cee
4667629
c7241fb
eebb902
cd578af
 
7539cee
 
 
53d1efd
cd578af
 
 
53d1efd
 
 
 
 
 
 
 
cc50c45
53d1efd
 
cd578af
 
c7241fb
cd578af
 
c7241fb
 
 
 
 
53d1efd
 
c7241fb
 
 
 
 
cd578af
53d1efd
c7241fb
 
 
53d1efd
cd578af
53d1efd
cc50c45
c7241fb
 
7539cee
53d1efd
 
 
 
d250b36
c7241fb
 
53d1efd
c7241fb
 
53d1efd
 
9729a4f
53d1efd
 
cd578af
c7241fb
53d1efd
 
 
 
6f98b5f
 
 
 
 
 
 
c7241fb
 
53d1efd
6f98b5f
 
 
53d1efd
786ea23
def04d4
53d1efd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# app.py
import gradio as gr
import librosa
import numpy as np
import os
import tempfile
from collections import Counter
from speechbrain.inference.interfaces import foreign_class
import io
import matplotlib.pyplot as plt
import librosa.display
from PIL import Image  # For image conversion
from datetime import datetime

# ---------------------------
# Firebase Setup
# ---------------------------
import firebase_admin
from firebase_admin import credentials, db
from google.cloud import storage

# Update the path below to your Firebase service account key JSON file
SERVICE_ACCOUNT_KEY = "path/to/serviceAccountKey.json"  # <-- update this!

# Initialize Firebase Admin for Realtime Database
cred = credentials.Certificate(SERVICE_ACCOUNT_KEY)
firebase_admin.initialize_app(cred, {
    'databaseURL': 'https://your-database-name.firebaseio.com/'  # <-- update with your DB URL
})

def upload_file_to_firebase(file_path, destination_blob_name):
    """
    Uploads a file to Firebase Storage and returns its public URL.
    """
    # Update bucket name (usually: your-project-id.appspot.com)
    bucket_name = "your-project-id.appspot.com"  # <-- update this!
    storage_client = storage.Client.from_service_account_json(SERVICE_ACCOUNT_KEY)
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(destination_blob_name)
    blob.upload_from_filename(file_path)
    blob.make_public()
    print(f"File uploaded to {blob.public_url}")
    return blob.public_url

def store_prediction_metadata(file_url, predicted_emotion):
    """
    Stores the file URL, predicted emotion, and timestamp in Firebase Realtime Database.
    """
    ref = db.reference('predictions')
    data = {
        'file_url': file_url,
        'predicted_emotion': predicted_emotion,
        'timestamp': datetime.now().isoformat()
    }
    new_record_ref = ref.push(data)
    print(f"Stored metadata with key: {new_record_ref.key}")
    return new_record_ref.key

# ---------------------------
# Emotion Recognition Code
# ---------------------------

# Mapping from emotion labels to emojis
emotion_to_emoji = {
    "angry": "😠",
    "happy": "😊",
    "sad": "😒",
    "neutral": "😐",
    "excited": "πŸ˜„",
    "fear": "😨",
    "disgust": "🀒",
    "surprise": "😲"
}

def add_emoji_to_label(label):
    """Append an emoji corresponding to the emotion label."""
    emoji = emotion_to_emoji.get(label.lower(), "")
    return f"{label.capitalize()} {emoji}"

# Load the pre-trained SpeechBrain classifier (Emotion Recognition with wav2vec2 on IEMOCAP)
classifier = foreign_class(
    source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
    pymodule_file="custom_interface.py",
    classname="CustomEncoderWav2vec2Classifier",
    run_opts={"device": "cpu"}  # Change to {"device": "cuda"} if GPU is available
)

def preprocess_audio(audio_file, apply_noise_reduction=False):
    """
    Load and preprocess the audio file:
      - Convert to 16kHz mono.
      - Optionally apply noise reduction.
      - Normalize the audio.
    Saves the processed audio to a temporary file and returns its path.
    """
    y, sr = librosa.load(audio_file, sr=16000, mono=True)
    try:
        import noisereduce as nr
        NOISEREDUCE_AVAILABLE = True
    except ImportError:
        NOISEREDUCE_AVAILABLE = False
    if apply_noise_reduction and NOISEREDUCE_AVAILABLE:
        y = nr.reduce_noise(y=y, sr=sr)
    if np.max(np.abs(y)) > 0:
        y = y / np.max(np.abs(y))
    temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
    import soundfile as sf
    sf.write(temp_file.name, y, sr)
    return temp_file.name

def ensemble_prediction(audio_file, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
    """
    For longer audio files, split into overlapping segments, predict each segment,
    and return the majority-voted emotion label.
    """
    y, sr = librosa.load(audio_file, sr=16000, mono=True)
    total_duration = librosa.get_duration(y=y, sr=sr)
    
    if total_duration <= segment_duration:
        temp_file = preprocess_audio(audio_file, apply_noise_reduction)
        _, _, _, label = classifier.classify_file(temp_file)
        os.remove(temp_file)
        return label[0]

    step = segment_duration - overlap
    segments = []
    for start in np.arange(0, total_duration - segment_duration + 0.001, step):
        start_sample = int(start * sr)
        end_sample = int((start + segment_duration) * sr)
        segment_audio = y[start_sample:end_sample]
        temp_seg = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
        import soundfile as sf
        sf.write(temp_seg.name, segment_audio, sr)
        segments.append(temp_seg.name)
    
    predictions = []
    for seg in segments:
        temp_file = preprocess_audio(seg, apply_noise_reduction)
        _, _, _, label = classifier.classify_file(temp_file)
        predictions.append(label[0])
        os.remove(temp_file)
        os.remove(seg)
    
    vote = Counter(predictions)
    most_common = vote.most_common(1)[0][0]
    return most_common

def predict_emotion(audio_file, use_ensemble=False, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
    """
    Predict emotion from an audio file and return the emotion with an emoji.
    """
    try:
        if use_ensemble:
            label = ensemble_prediction(audio_file, apply_noise_reduction, segment_duration, overlap)
        else:
            temp_file = preprocess_audio(audio_file, apply_noise_reduction)
            result = classifier.classify_file(temp_file)
            os.remove(temp_file)
            if isinstance(result, tuple) and len(result) > 3:
                label = result[3][0]
            else:
                label = str(result)
        return add_emoji_to_label(label.lower())
    except Exception as e:
        return f"Error processing file: {str(e)}"

def plot_waveform(audio_file):
    """
    Generate and return a waveform plot image (as a PIL Image) for the given audio file.
    """
    y, sr = librosa.load(audio_file, sr=16000, mono=True)
    plt.figure(figsize=(10, 3))
    librosa.display.waveshow(y, sr=sr)
    plt.title("Waveform")
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close()
    buf.seek(0)
    return Image.open(buf)

def predict_and_plot(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap):
    """
    Run emotion prediction and generate a waveform plot.
    Additionally, upload the audio file to Firebase Storage and store the metadata in Firebase Realtime Database.
    Returns a tuple: (emotion label with emoji, waveform image as a PIL Image).
    """
    # Upload the original audio file to Firebase Storage
    destination_blob_name = os.path.basename(audio_file)
    file_url = upload_file_to_firebase(audio_file, destination_blob_name)
    
    # Predict emotion and generate waveform
    emotion = predict_emotion(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap)
    waveform = plot_waveform(audio_file)
    
    # Store metadata (file URL and predicted emotion) in Firebase Realtime Database
    record_key = store_prediction_metadata(file_url, emotion)
    print(f"Record stored with key: {record_key}")
    
    return emotion, waveform

# ---------------------------
# Gradio App UI
# ---------------------------
with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: Arial;}") as demo:
    gr.Markdown("<h1 style='text-align: center;'>Enhanced Emotion Recognition</h1>")
    gr.Markdown(
        "Upload an audio file, and the model will predict the emotion using a wav2vec2 model fine-tuned on IEMOCAP data. "
        "The prediction is accompanied by an emoji, and you can view the audio's waveform. "
        "The audio file and prediction metadata are stored in Firebase Realtime Database."
    )
    
    with gr.Tabs():
        with gr.TabItem("Emotion Recognition"):
            with gr.Row():
                audio_input = gr.Audio(type="filepath", label="Upload Audio")
            use_ensemble_checkbox = gr.Checkbox(label="Use Ensemble Prediction (for long audio)", value=False)
            apply_noise_reduction_checkbox = gr.Checkbox(label="Apply Noise Reduction", value=False)
            with gr.Row():
                segment_duration_slider = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=3.0, label="Segment Duration (s)")
                overlap_slider = gr.Slider(minimum=0.0, maximum=5.0, step=0.5, value=1.0, label="Segment Overlap (s)")
            predict_button = gr.Button("Predict Emotion")
            result_text = gr.Textbox(label="Predicted Emotion")
            waveform_image = gr.Image(label="Audio Waveform", type="pil")
            
            predict_button.click(
                predict_and_plot,
                inputs=[audio_input, use_ensemble_checkbox, apply_noise_reduction_checkbox, segment_duration_slider, overlap_slider],
                outputs=[result_text, waveform_image]
            )
        with gr.TabItem("About"):
            gr.Markdown("""
**Enhanced Emotion Recognition App**

- **Model:** SpeechBrain's wav2vec2 model fine-tuned on IEMOCAP for emotion recognition.
- **Features:**
  - Ensemble Prediction for long audio files.
  - Optional Noise Reduction.
  - Visualization of the audio waveform.
  - Emoji representation of the predicted emotion.
  - Audio file and prediction metadata stored in Firebase Realtime Database.

**Credits:**
- [SpeechBrain](https://speechbrain.github.io)
- [Gradio](https://gradio.app)
            """)

if __name__ == "__main__":
    demo.launch()