import gradio as gr |
import numpy as np |
import torch |
import torchaudio |
import soundfile as sf |
import os |
from pydub import AudioSegment |
from scipy.io import wavfile |
processor = WhisperProcessor.from_pretrained("openai/whisper-small") |
model = WhisperForConditionalGeneration.from_pretrained("YMEA/bambara-fr-asr-whisper_25_v2_15k") |
def record_audio(audio): |
if audio is not None: |
sr, data = audio |
if sr != 16000: |
from scipy.signal import resample |
data = resample(data, int(len(data) * 16000 / sr)) |
sr = 16000 |
temp_audio_path = "temp_recorded_audio.wav" |
sf.write(temp_audio_path, data, sr) |
sound = AudioSegment.from_wav(temp_audio_path) |
normalized_sound = sound.apply_gain(-sound.max_dBFS).apply_gain(5) |
processed_audio_path = "processed_audio.wav" |
normalized_sound.export(processed_audio_path, format="wav") |
os.remove(temp_audio_path) |
return processed_audio_path |
else: |
return None |
def transcribe_audio(audio_path): |
if audio_path is None: |
return "No audio was recorded." |
waveform, sample_rate = torchaudio.load(audio_path) |
if sample_rate != 16000: |
waveform = torchaudio.functional.resample(waveform, sample_rate, 16000) |
sample_rate = 16000 |
if waveform.shape[0] == 2: |
waveform = waveform.mean(dim=0, keepdim=True) |
audio_input = processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt") |
with torch.no_grad(): |
input_features = audio_input.input_features |
generated_ids = model.generate(inputs=input_features) |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True) |
return transcription[0] |
with gr.Blocks() as demo: |
audio_input = gr.Audio(sources=["microphone"], type="numpy") |
record_button = gr.Button("Record Audio") |
audio_output = gr.Audio(type="filepath") |
transcribe_button = gr.Button("Transcribe") |
transcription_output = gr.Textbox(label="Transcription", lines=3) |
record_button.click(fn=record_audio, inputs=audio_input, outputs=audio_output) |
transcribe_button.click(fn=transcribe_audio, inputs=audio_output, outputs=transcription_output) |
demo.launch(show_error=True) |