bam-fr / app.py
YMEA's picture
Update app.py
a35dd1c verified
# Import necessary libraries
import gradio as gr
import numpy as np
import torch
import torchaudio # Import torchaudio
import soundfile as sf # pour sauvegarder l'audio
import os
from pydub import AudioSegment
from scipy.io import wavfile
# Load the processor and model
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("YMEA/bambara-fr-asr-whisper_25_v2_15k")
# Function to record and save audio in WAV format, resampled to 16kHz
def record_audio(audio):
if audio is not None:
sr, data = audio
# Ensure the audio is resampled to 16kHz
if sr != 16000:
from scipy.signal import resample
data = resample(data, int(len(data) * 16000 / sr))
sr = 16000
# Save the audio to a temporary file in WAV format
temp_audio_path = "temp_recorded_audio.wav"
sf.write(temp_audio_path, data, sr)
# Use PyDub to process audio
sound = AudioSegment.from_wav(temp_audio_path)
# Normalize volume and increase slightly to reduce background noise impact
normalized_sound = sound.apply_gain(-sound.max_dBFS).apply_gain(5) # Adjust gain as needed
# Export the processed audio
processed_audio_path = "processed_audio.wav"
normalized_sound.export(processed_audio_path, format="wav")
# Remove the temporary file
os.remove(temp_audio_path)
return processed_audio_path
else:
return None
# Function to transcribe audio
def transcribe_audio(audio_path):
if audio_path is None:
return "No audio was recorded."
# Load the audio data using torchaudio
waveform, sample_rate = torchaudio.load(audio_path)
# Ensure the audio is at 16kHz
if sample_rate != 16000:
waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
sample_rate = 16000
# Preprocess the audio data, ensuring correct input shape
if waveform.shape[0] == 2: # Check if it's stereo
waveform = waveform.mean(dim=0, keepdim=True) # Convert to mono
# Pass waveform to the processor
audio_input = processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt")
# Generate the transcription
with torch.no_grad():
input_features = audio_input.input_features
generated_ids = model.generate(inputs=input_features)
# Decode the generated IDs to text
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
return transcription[0]
# Create the Gradio interface
with gr.Blocks() as demo:
# Create an audio input component with microphone source and no format conversion
audio_input = gr.Audio(sources=["microphone"], type="numpy")
# Create a button to trigger the recording
record_button = gr.Button("Record Audio")
# Create an output component to display the recorded audio
audio_output = gr.Audio(type="filepath")
# Create a button for transcription
transcribe_button = gr.Button("Transcribe")
# Create a text box to display the transcription
transcription_output = gr.Textbox(label="Transcription", lines=3)
# Set up the event listener for the recording button click
record_button.click(fn=record_audio, inputs=audio_input, outputs=audio_output)
# Set up the event listener for the transcription button click
transcribe_button.click(fn=transcribe_audio, inputs=audio_output, outputs=transcription_output)
# Launch the Gradio app
demo.launch(show_error=True)