Ahmed107 commited on
Commit
922c47d
·
verified ·
1 Parent(s): bfc238d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -1,50 +1,57 @@
1
  import gradio as gr
2
  import torchaudio
3
  from transformers import pipeline
4
- from datasets import load_dataset, Audio
5
 
6
  # Load your model
7
  classifier = pipeline("audio-classification", model="Ahmed107/Hamsa-Conversational-v1.0-mulaw-eos-v3-mulaw")
8
 
9
- # Function to resample audio to 16kHz
10
  def resample_audio(audio_file, target_sampling_rate=16000):
11
  waveform, original_sample_rate = torchaudio.load(audio_file)
 
 
12
  if original_sample_rate != target_sampling_rate:
13
  resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sampling_rate)
14
  waveform = resampler(waveform)
 
 
 
 
 
15
  return waveform.squeeze().numpy(), target_sampling_rate
16
 
17
  # Define the prediction function
18
  def classify_audio(audio_file):
19
- # Resample the audio to 16kHz
20
- resampled_audio, _ = resample_audio(audio_file)
21
 
22
- # Classify the audio
23
- prediction = classifier(resampled_audio)
 
24
 
25
- # Return predictions as a dictionary
26
  return {entry['label']: entry['score'] for entry in prediction}
27
 
28
  # Define Gradio interface
29
  def demo():
30
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
31
- gr.Markdown("## Eos")
32
 
33
- # Input Audio
34
  with gr.Row():
35
  audio_input = gr.Audio(type="filepath", label="Input Audio")
36
 
37
- # Output Labels
38
  with gr.Row():
39
  label_output = gr.Label(label="Prediction")
40
 
41
  # Predict Button
42
  classify_btn = gr.Button("Classify")
43
-
44
- # Define the interaction
45
  classify_btn.click(fn=classify_audio, inputs=audio_input, outputs=label_output)
46
 
47
  return demo
48
 
49
- # Launch the demo
50
- demo().launch()
 
1
  import gradio as gr
2
  import torchaudio
3
  from transformers import pipeline
 
4
 
5
  # Load your model
6
  classifier = pipeline("audio-classification", model="Ahmed107/Hamsa-Conversational-v1.0-mulaw-eos-v3-mulaw")
7
 
8
+ # Function to resample audio to 16kHz and convert to mono if needed
9
  def resample_audio(audio_file, target_sampling_rate=16000):
10
  waveform, original_sample_rate = torchaudio.load(audio_file)
11
+
12
+ # Resample if necessary
13
  if original_sample_rate != target_sampling_rate:
14
  resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sampling_rate)
15
  waveform = resampler(waveform)
16
+
17
+ # Convert stereo to mono by averaging channels (if needed)
18
+ if waveform.shape[0] > 1:
19
+ waveform = waveform.mean(dim=0, keepdim=True)
20
+
21
  return waveform.squeeze().numpy(), target_sampling_rate
22
 
23
  # Define the prediction function
24
  def classify_audio(audio_file):
25
+ # Resample the audio to 16kHz and handle channels
26
+ resampled_audio, sampling_rate = resample_audio(audio_file)
27
 
28
+ # Pass both the array and sampling_rate to the classifier
29
+ input_audio = {"array": resampled_audio, "sampling_rate": sampling_rate}
30
+ prediction = classifier(input_audio)
31
 
32
+ # Return predictions as a dictionary with labels and their scores
33
  return {entry['label']: entry['score'] for entry in prediction}
34
 
35
  # Define Gradio interface
36
  def demo():
37
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
38
+ gr.Markdown("## Eos Audio Classification")
39
 
40
+ # Input Audio component
41
  with gr.Row():
42
  audio_input = gr.Audio(type="filepath", label="Input Audio")
43
 
44
+ # Output Labels component
45
  with gr.Row():
46
  label_output = gr.Label(label="Prediction")
47
 
48
  # Predict Button
49
  classify_btn = gr.Button("Classify")
50
+
51
+ # Set the button click action
52
  classify_btn.click(fn=classify_audio, inputs=audio_input, outputs=label_output)
53
 
54
  return demo
55
 
56
+ # Launch the Gradio demo
57
+ demo().launch()