File size: 2,039 Bytes
f60e2c1
 
 
d864aed
f60e2c1
1231312
57a760c
922c47d
f60e2c1
 
922c47d
 
f60e2c1
 
 
922c47d
 
 
 
 
f60e2c1
57a760c
f60e2c1
 
922c47d
 
f60e2c1
922c47d
 
 
f60e2c1
922c47d
f60e2c1
d864aed
f60e2c1
 
 
922c47d
f60e2c1
922c47d
f60e2c1
 
 
922c47d
f60e2c1
 
 
 
 
922c47d
 
f60e2c1
 
 
d864aed
922c47d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import torchaudio
from transformers import pipeline

# Load your model
classifier = pipeline("audio-classification", model="Ahmed107/wav2vec2-base-eos-v5-mulaw-eos-v5-mulaw")

# Function to resample audio to 16kHz and convert to mono if needed
def resample_audio(audio_file, target_sampling_rate=16000):
    waveform, original_sample_rate = torchaudio.load(audio_file)
    
    # Resample if necessary
    if original_sample_rate != target_sampling_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sampling_rate)
        waveform = resampler(waveform)
    
    # Convert stereo to mono by averaging channels (if needed)
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    return waveform.squeeze().numpy(), target_sampling_rate

# Define the prediction function
def classify_audio(audio_file):
    # Resample the audio to 16kHz and handle channels
    resampled_audio, sampling_rate = resample_audio(audio_file)
    
    # Pass both the array and sampling_rate to the classifier
    input_audio = {"array": resampled_audio, "sampling_rate": sampling_rate}
    prediction = classifier(input_audio)
    
    # Return predictions as a dictionary with labels and their scores
    return {entry['label']: entry['score'] for entry in prediction}

# Define Gradio interface
def demo():
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("## Eos Audio Classification")
        
        # Input Audio component
        with gr.Row():
            audio_input = gr.Audio(type="filepath", label="Input Audio")
        
        # Output Labels component
        with gr.Row():
            label_output = gr.Label(label="Prediction")
        
        # Predict Button
        classify_btn = gr.Button("Classify")
        
        # Set the button click action
        classify_btn.click(fn=classify_audio, inputs=audio_input, outputs=label_output)
        
    return demo

# Launch the Gradio demo
demo().launch()