Sachi-ASR-demo / app.py
sawadogosalif's picture
use torchaudio
3bc425a verified
import os
import uuid
import logging
import tempfile
from datetime import datetime
import spaces
import gradio as gr
import librosa
import soundfile as sf
import torch
from datasets import Dataset, DatasetDict, concatenate_datasets, Audio, load_dataset, DownloadConfig
from transformers import pipeline
from huggingface_hub import HfApi, login
from resemble_enhance.enhancer.inference import denoise, enhance
import torchaudio
# Configure logging
logging.basicConfig(
format="%(asctime)s — %(levelname)s — %(message)s",
level=logging.INFO
)
logger = logging.getLogger(__name__)
# Constants
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
logger.error("Hugging Face token not found. Please set HF_TOKEN environment variable.")
raise SystemExit
CURRENT_DATASET = "sawadogosalif/Sachi_demo_dataset"
SAMPLE_RATE = 16_000
ASR_MODEL = "sawadogosalif/SaChi-ASR"
# Authenticate with Hugging Face
login(token=HF_TOKEN)
api = HfApi(token=HF_TOKEN)
def get_or_create_dataset(dataset_name: str) -> Dataset:
"""
Load the dataset if it exists, otherwise create a new empty one.
"""
try:
ds = load_dataset(
dataset_name,
split="train",
download_config=DownloadConfig(token=HF_TOKEN)
)
logger.info(f"Loaded dataset '{dataset_name}' with {len(ds)} examples.")
except Exception:
logger.warning(f"Dataset '{dataset_name}' not found or failed to load. Creating a new one.")
ds = Dataset.from_dict({
"audio": [],
"text": [],
"language": [],
"datetime": [],
})
DatasetDict({"train": ds}).push_to_hub(dataset_name, token=HF_TOKEN)
logger.info(f"Created empty dataset '{dataset_name}'.")
return ds
def save_dataset(dataset: Dataset, dataset_name: str) -> None:
"""
Push the updated dataset back to Hugging Face hub.
"""
ds_dict = DatasetDict({"train": dataset})
ds_dict.push_to_hub(dataset_name, token=HF_TOKEN)
logger.info(f"Pushed updated dataset to '{dataset_name}' ({len(dataset)} records).")
class Transcriber:
def __init__(self, asr_model: str):
self.pipeline = pipeline(model=asr_model)
def transcribe(self, audio_path: str) -> str:
result = self.pipeline(audio_path)
return result.get("text", "")
# Initialize components
current_dataset = get_or_create_dataset(CURRENT_DATASET)
asr_client = Transcriber(ASR_MODEL)
@spaces.GPU(duration=15)
def transcribe_and_update(audio_filepath: str, history: str, apply_enhance: bool) -> tuple:
"""
Denoise every input, optionally enhance, then transcribe and push to HF dataset.
"""
if not audio_filepath:
return "No audio detected. Please record or upload audio.", history
try:
# Load and preprocess
audio_data, sr = dwav, sr = torchaudio.load(audio_filepath)
# Always denoise
try:
device = "cuda"
audio_data = audio_data.mean(dim=0)
denoised_data, sr = denoise(audio_data, sr, device)
logger.info("Audio denoised successfully.")
except Exception as e:
logger.warning(f"Denoise failed, using raw audio: {e}")
denoised_data = audio_data
# Optionally enhance
if apply_enhance:
try:
enhanced_data, sr = enhance(denoised_data, sr, device)
final_audio = enhanced_data
logger.info("Audio enhanced successfully.")
except Exception as e:
logger.warning(f"Enhancement failed, using denoised audio: {e}")
final_audio = denoised_data
else:
final_audio = denoised_data
# Save processed audio to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmpf:
sf.write(tmpf.name, final_audio, sr)
local_path = tmpf.name
# Transcription
transcription = asr_client.transcribe(local_path)
logger.info(f"Transcription: {transcription}")
# Prepare new record
new_record = {
"audio": [local_path],
"text": [transcription],
"language": ["moore"],
"datetime": [datetime.utcnow().isoformat()]
}
new_ds = Dataset.from_dict(new_record).cast_column("audio", Audio())
# Update in-memory dataset
global current_dataset
if len(current_dataset) == 0:
current_dataset = new_ds
else:
current_dataset = concatenate_datasets([current_dataset, new_ds])
# Push to hub
save_dataset(current_dataset, CURRENT_DATASET)
# Update conversation history
history = history + f"\nUser: [audio]\nAssistant: {transcription}"
return transcription, history
except Exception as exc:
logger.error(f"Error during transcription pipeline: {exc}")
return f"Error: {exc}", history
def build_interface():
with gr.Blocks() as demo:
gr.Markdown("# 🗣️ ASR Moore Live 🧠")
gr.Markdown("Speech Recognition interface for Moore language. Records or uploads audio, always denoises, and optionally enhances before ASR.")
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Record or upload audio", sources=["microphone", "upload"])
state_box = gr.State(value="")
enhance_checkbox = gr.Checkbox(label="Apply Enhancement", value=False)
output_text = gr.Textbox(label="Transcription")
submit_btn = gr.Button("Transcribe and Save")
submit_btn.click(fn=transcribe_and_update,
inputs=[audio_input, state_box, enhance_checkbox],
outputs=[output_text, state_box])
demo.launch(debug=True)
if __name__ == "__main__":
build_interface()