File size: 2,039 Bytes
f60e2c1 d864aed f60e2c1 1231312 57a760c 922c47d f60e2c1 922c47d f60e2c1 922c47d f60e2c1 57a760c f60e2c1 922c47d f60e2c1 922c47d f60e2c1 922c47d f60e2c1 d864aed f60e2c1 922c47d f60e2c1 922c47d f60e2c1 922c47d f60e2c1 922c47d f60e2c1 d864aed 922c47d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import gradio as gr
import torchaudio
from transformers import pipeline
# Load your model
classifier = pipeline("audio-classification", model="Ahmed107/wav2vec2-base-eos-v5-mulaw-eos-v5-mulaw")
# Function to resample audio to 16kHz and convert to mono if needed
def resample_audio(audio_file, target_sampling_rate=16000):
waveform, original_sample_rate = torchaudio.load(audio_file)
# Resample if necessary
if original_sample_rate != target_sampling_rate:
resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sampling_rate)
waveform = resampler(waveform)
# Convert stereo to mono by averaging channels (if needed)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
return waveform.squeeze().numpy(), target_sampling_rate
# Define the prediction function
def classify_audio(audio_file):
# Resample the audio to 16kHz and handle channels
resampled_audio, sampling_rate = resample_audio(audio_file)
# Pass both the array and sampling_rate to the classifier
input_audio = {"array": resampled_audio, "sampling_rate": sampling_rate}
prediction = classifier(input_audio)
# Return predictions as a dictionary with labels and their scores
return {entry['label']: entry['score'] for entry in prediction}
# Define Gradio interface
def demo():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Eos Audio Classification")
# Input Audio component
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Input Audio")
# Output Labels component
with gr.Row():
label_output = gr.Label(label="Prediction")
# Predict Button
classify_btn = gr.Button("Classify")
# Set the button click action
classify_btn.click(fn=classify_audio, inputs=audio_input, outputs=label_output)
return demo
# Launch the Gradio demo
demo().launch()
|