wolof-asr / app.py
dofbi's picture
♻️ refactor (model): add new
6728136
import spaces
import torch
import gradio as gr
import librosa
import numpy as np
import json
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from scipy.signal import butter, lfilter
# Charger la liste des modèles depuis un fichier JSON
def load_model_list(file_path="model_list.json"):
try:
with open(file_path, "r") as f:
return json.load(f)
except Exception as e:
raise ValueError(f"Erreur lors du chargement de la liste des modèles : {str(e)}")
# Charger les modèles depuis le fichier JSON
MODEL_LIST = load_model_list()
# Fonction pour charger le modèle et le processeur
def load_model_and_processor(model_name):
model_info = MODEL_LIST.get(model_name)
if not model_info:
raise ValueError("Modèle non trouvé dans la liste.")
model_path = model_info["model_path"]
processor = WhisperProcessor.from_pretrained(model_path)
model = WhisperForConditionalGeneration.from_pretrained(model_path)
model.eval()
return processor, model
# Nettoyage et normalisation de l'audio
def preprocess_audio(audio, sr=16000):
# Charger l'audio
audio_data, _ = librosa.load(audio, sr=sr)
# Filtrage passe-bas pour réduire les bruits aigus
b, a = butter(6, 0.1, btype="low", analog=False)
audio_data = lfilter(b, a, audio_data)
# Normaliser l'audio
audio_data = librosa.util.normalize(audio_data)
return audio_data
# Fonction pour transcrire l'audio
@spaces.GPU(duration=120)
def transcribe_audio(audio, model_name):
try:
# Charger le modèle et le processeur en fonction du choix
processor, model = load_model_and_processor(model_name)
# Nettoyer et normaliser l'audio
audio_input = preprocess_audio(audio)
# Prétraiter l'audio avec le processeur
inputs = processor(audio_input, sampling_rate=16000, return_tensors="pt")
inputs["attention_mask"] = torch.ones_like(inputs["input_features"]).to(inputs["input_features"].dtype)
# Faire la prédiction
with torch.no_grad():
predicted_ids = model.generate(
inputs['input_features'],
forced_decoder_ids=None, # Suppression du conflit
language="fr", # Ajustez selon votre langue cible
task="transcribe"
)
# Convertir les IDs de prédiction en texte
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription[0]
except Exception as e:
return f"Erreur de transcription : {str(e)}"
# Charger une seule fois le tableau (statique)
MODEL_TABLE = [
[name, details.get("dataset", "Non spécifié"), details.get("performance", {}).get("WER", "Non spécifié"), details.get("performance", {}).get("CER", "Non spécifié")]
for name, details in MODEL_LIST.items()
]
# Interface Gradio
with gr.Blocks() as app:
# Section principale
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## Téléchargez ou enregistrez un fichier audio")
audio_input = gr.Audio(type="filepath", label="Audio (télécharger ou enregistrer)")
model_dropdown = gr.Dropdown(choices=list(MODEL_LIST.keys()), label="Sélectionnez un modèle", value="Wolof ASR - dofbi")
submit_button = gr.Button("Transcrire")
with gr.Column(scale=3):
transcription_output = gr.Textbox(label="Transcription", lines=6)
# Tableau statique en bas
gr.Markdown("## Informations sur les modèles disponibles")
gr.Dataframe(
headers=["Nom du modèle", "Dataset utilisé", "WER", "CER"],
value=MODEL_TABLE,
interactive=False,
label="Informations sur les modèles"
)
# Action du bouton
submit_button.click(
fn=transcribe_audio,
inputs=[audio_input, model_dropdown],
outputs=transcription_output
)
# Lancer l'application
if __name__ == "__main__":
app.launch()