Boltz79 commited on
Commit
6f98b5f
·
verified ·
1 Parent(s): d250b36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -25
app.py CHANGED
@@ -30,10 +30,11 @@ emotion_to_emoji = {
30
  }
31
 
32
  def add_emoji_to_label(label):
 
33
  emoji = emotion_to_emoji.get(label.lower(), "")
34
  return f"{label.capitalize()} {emoji}"
35
 
36
- # Load the pre-trained SpeechBrain classifier (Emotion Recognition with wav2vec2 on IEMOCAP)
37
  classifier = foreign_class(
38
  source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
39
  pymodule_file="custom_interface.py",
@@ -47,16 +48,13 @@ def preprocess_audio(audio_file, apply_noise_reduction=False):
47
  - Convert to 16kHz mono.
48
  - Optionally apply noise reduction.
49
  - Normalize the audio.
50
- The processed audio is saved to a temporary file and its path is returned.
51
  """
52
  y, sr = librosa.load(audio_file, sr=16000, mono=True)
53
-
54
  if apply_noise_reduction and NOISEREDUCE_AVAILABLE:
55
  y = nr.reduce_noise(y=y, sr=sr)
56
-
57
  if np.max(np.abs(y)) > 0:
58
  y = y / np.max(np.abs(y))
59
-
60
  temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
61
  import soundfile as sf
62
  sf.write(temp_file.name, y, sr)
@@ -64,18 +62,19 @@ def preprocess_audio(audio_file, apply_noise_reduction=False):
64
 
65
  def ensemble_prediction(audio_file, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
66
  """
67
- For long audio files, split the file into overlapping segments, predict the emotion for each segment,
68
- and return the majority-voted label.
69
  """
70
  y, sr = librosa.load(audio_file, sr=16000, mono=True)
71
  total_duration = librosa.get_duration(y=y, sr=sr)
72
 
 
73
  if total_duration <= segment_duration:
74
  temp_file = preprocess_audio(audio_file, apply_noise_reduction)
75
  _, _, _, label = classifier.classify_file(temp_file)
76
  os.remove(temp_file)
77
  return label
78
-
79
  step = segment_duration - overlap
80
  segments = []
81
  for start in np.arange(0, total_duration - segment_duration + 0.001, step):
@@ -101,10 +100,10 @@ def ensemble_prediction(audio_file, apply_noise_reduction=False, segment_duratio
101
 
102
  def predict_emotion(audio_file, use_ensemble=False, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
103
  """
104
- Main prediction function.
105
  - Uses ensemble prediction if enabled.
106
  - Otherwise, processes the entire audio at once.
107
- - Returns the predicted emotion with an emoji.
108
  """
109
  try:
110
  if use_ensemble:
@@ -119,7 +118,7 @@ def predict_emotion(audio_file, use_ensemble=False, apply_noise_reduction=False,
119
 
120
  def plot_waveform(audio_file):
121
  """
122
- Generate a waveform plot for the given audio file and return the image bytes.
123
  """
124
  y, sr = librosa.load(audio_file, sr=16000, mono=True)
125
  plt.figure(figsize=(10, 3))
@@ -133,8 +132,8 @@ def plot_waveform(audio_file):
133
 
134
  def predict_and_plot(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap):
135
  """
136
- Predict the emotion and also generate the waveform plot.
137
- Returns a tuple: (emotion label with emoji, waveform image)
138
  """
139
  emotion = predict_emotion(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap)
140
  waveform = plot_waveform(audio_file)
@@ -152,7 +151,7 @@ with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: A
152
  with gr.Tabs():
153
  with gr.TabItem("Emotion Recognition"):
154
  with gr.Row():
155
- # Removed the 'source' argument which caused the error.
156
  audio_input = gr.Audio(type="filepath", label="Upload Audio")
157
  use_ensemble = gr.Checkbox(label="Use Ensemble Prediction (for long audio)", value=False)
158
  apply_noise_reduction = gr.Checkbox(label="Apply Noise Reduction", value=False)
@@ -171,18 +170,18 @@ with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: A
171
 
172
  with gr.TabItem("About"):
173
  gr.Markdown("""
174
- **Enhanced Emotion Recognition App**
 
 
 
 
 
 
 
175
 
176
- - **Model:** SpeechBrain's wav2vec2 model fine-tuned on IEMOCAP for emotion recognition.
177
- - **Features:**
178
- - Ensemble Prediction for long audio files.
179
- - Optional Noise Reduction.
180
- - Visualization of the audio waveform.
181
- - Emoji representation of the predicted emotion.
182
-
183
- **Credits:**
184
- - [SpeechBrain](https://speechbrain.github.io)
185
- - [Gradio](https://gradio.app)
186
  """)
187
 
188
  if __name__ == "__main__":
 
30
  }
31
 
32
  def add_emoji_to_label(label):
33
+ """Append an emoji corresponding to the emotion label."""
34
  emoji = emotion_to_emoji.get(label.lower(), "")
35
  return f"{label.capitalize()} {emoji}"
36
 
37
+ # Load the pre-trained SpeechBrain classifier
38
  classifier = foreign_class(
39
  source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
40
  pymodule_file="custom_interface.py",
 
48
  - Convert to 16kHz mono.
49
  - Optionally apply noise reduction.
50
  - Normalize the audio.
51
+ Saves the processed audio to a temporary file and returns its path.
52
  """
53
  y, sr = librosa.load(audio_file, sr=16000, mono=True)
 
54
  if apply_noise_reduction and NOISEREDUCE_AVAILABLE:
55
  y = nr.reduce_noise(y=y, sr=sr)
 
56
  if np.max(np.abs(y)) > 0:
57
  y = y / np.max(np.abs(y))
 
58
  temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
59
  import soundfile as sf
60
  sf.write(temp_file.name, y, sr)
 
62
 
63
  def ensemble_prediction(audio_file, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
64
  """
65
+ For longer audio files, split into overlapping segments, predict each segment,
66
+ and return the majority-voted emotion label.
67
  """
68
  y, sr = librosa.load(audio_file, sr=16000, mono=True)
69
  total_duration = librosa.get_duration(y=y, sr=sr)
70
 
71
+ # If the audio is short, process it directly
72
  if total_duration <= segment_duration:
73
  temp_file = preprocess_audio(audio_file, apply_noise_reduction)
74
  _, _, _, label = classifier.classify_file(temp_file)
75
  os.remove(temp_file)
76
  return label
77
+
78
  step = segment_duration - overlap
79
  segments = []
80
  for start in np.arange(0, total_duration - segment_duration + 0.001, step):
 
100
 
101
  def predict_emotion(audio_file, use_ensemble=False, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
102
  """
103
+ Main prediction function:
104
  - Uses ensemble prediction if enabled.
105
  - Otherwise, processes the entire audio at once.
106
+ Returns the emotion label enhanced with an emoji.
107
  """
108
  try:
109
  if use_ensemble:
 
118
 
119
  def plot_waveform(audio_file):
120
  """
121
+ Generate and return a waveform plot image for the given audio file.
122
  """
123
  y, sr = librosa.load(audio_file, sr=16000, mono=True)
124
  plt.figure(figsize=(10, 3))
 
132
 
133
  def predict_and_plot(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap):
134
  """
135
+ Run emotion prediction and generate a waveform plot.
136
+ Returns a tuple: (emotion label with emoji, waveform image).
137
  """
138
  emotion = predict_emotion(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap)
139
  waveform = plot_waveform(audio_file)
 
151
  with gr.Tabs():
152
  with gr.TabItem("Emotion Recognition"):
153
  with gr.Row():
154
+ # 'source' argument removed to avoid errors
155
  audio_input = gr.Audio(type="filepath", label="Upload Audio")
156
  use_ensemble = gr.Checkbox(label="Use Ensemble Prediction (for long audio)", value=False)
157
  apply_noise_reduction = gr.Checkbox(label="Apply Noise Reduction", value=False)
 
170
 
171
  with gr.TabItem("About"):
172
  gr.Markdown("""
173
+ **Enhanced Emotion Recognition App**
174
+
175
+ - **Model:** SpeechBrain's wav2vec2 model fine-tuned on IEMOCAP for emotion recognition.
176
+ - **Features:**
177
+ - Ensemble Prediction for long audio files.
178
+ - Optional Noise Reduction.
179
+ - Visualization of the audio waveform.
180
+ - Emoji representation of the predicted emotion.
181
 
182
+ **Credits:**
183
+ - [SpeechBrain](https://speechbrain.github.io)
184
+ - [Gradio](https://gradio.app)
 
 
 
 
 
 
 
185
  """)
186
 
187
  if __name__ == "__main__":