Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import uuid | |
import logging | |
import tempfile | |
from datetime import datetime | |
import spaces | |
import gradio as gr | |
import librosa | |
import soundfile as sf | |
import torch | |
from datasets import Dataset, DatasetDict, concatenate_datasets, Audio, load_dataset, DownloadConfig | |
from transformers import pipeline | |
from huggingface_hub import HfApi, login | |
from resemble_enhance.enhancer.inference import denoise, enhance | |
import torchaudio | |
# Configure logging | |
logging.basicConfig( | |
format="%(asctime)s — %(levelname)s — %(message)s", | |
level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
# Constants | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if not HF_TOKEN: | |
logger.error("Hugging Face token not found. Please set HF_TOKEN environment variable.") | |
raise SystemExit | |
CURRENT_DATASET = "sawadogosalif/Sachi_demo_dataset" | |
SAMPLE_RATE = 16_000 | |
ASR_MODEL = "sawadogosalif/SaChi-ASR" | |
# Authenticate with Hugging Face | |
login(token=HF_TOKEN) | |
api = HfApi(token=HF_TOKEN) | |
def get_or_create_dataset(dataset_name: str) -> Dataset: | |
""" | |
Load the dataset if it exists, otherwise create a new empty one. | |
""" | |
try: | |
ds = load_dataset( | |
dataset_name, | |
split="train", | |
download_config=DownloadConfig(token=HF_TOKEN) | |
) | |
logger.info(f"Loaded dataset '{dataset_name}' with {len(ds)} examples.") | |
except Exception: | |
logger.warning(f"Dataset '{dataset_name}' not found or failed to load. Creating a new one.") | |
ds = Dataset.from_dict({ | |
"audio": [], | |
"text": [], | |
"language": [], | |
"datetime": [], | |
}) | |
DatasetDict({"train": ds}).push_to_hub(dataset_name, token=HF_TOKEN) | |
logger.info(f"Created empty dataset '{dataset_name}'.") | |
return ds | |
def save_dataset(dataset: Dataset, dataset_name: str) -> None: | |
""" | |
Push the updated dataset back to Hugging Face hub. | |
""" | |
ds_dict = DatasetDict({"train": dataset}) | |
ds_dict.push_to_hub(dataset_name, token=HF_TOKEN) | |
logger.info(f"Pushed updated dataset to '{dataset_name}' ({len(dataset)} records).") | |
class Transcriber: | |
def __init__(self, asr_model: str): | |
self.pipeline = pipeline(model=asr_model) | |
def transcribe(self, audio_path: str) -> str: | |
result = self.pipeline(audio_path) | |
return result.get("text", "") | |
# Initialize components | |
current_dataset = get_or_create_dataset(CURRENT_DATASET) | |
asr_client = Transcriber(ASR_MODEL) | |
def transcribe_and_update(audio_filepath: str, history: str, apply_enhance: bool) -> tuple: | |
""" | |
Denoise every input, optionally enhance, then transcribe and push to HF dataset. | |
""" | |
if not audio_filepath: | |
return "No audio detected. Please record or upload audio.", history | |
try: | |
# Load and preprocess | |
audio_data, sr = dwav, sr = torchaudio.load(audio_filepath) | |
# Always denoise | |
try: | |
device = "cuda" | |
audio_data = audio_data.mean(dim=0) | |
denoised_data, sr = denoise(audio_data, sr, device) | |
logger.info("Audio denoised successfully.") | |
except Exception as e: | |
logger.warning(f"Denoise failed, using raw audio: {e}") | |
denoised_data = audio_data | |
# Optionally enhance | |
if apply_enhance: | |
try: | |
enhanced_data, sr = enhance(denoised_data, sr, device) | |
final_audio = enhanced_data | |
logger.info("Audio enhanced successfully.") | |
except Exception as e: | |
logger.warning(f"Enhancement failed, using denoised audio: {e}") | |
final_audio = denoised_data | |
else: | |
final_audio = denoised_data | |
# Save processed audio to temp file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmpf: | |
sf.write(tmpf.name, final_audio, sr) | |
local_path = tmpf.name | |
# Transcription | |
transcription = asr_client.transcribe(local_path) | |
logger.info(f"Transcription: {transcription}") | |
# Prepare new record | |
new_record = { | |
"audio": [local_path], | |
"text": [transcription], | |
"language": ["moore"], | |
"datetime": [datetime.utcnow().isoformat()] | |
} | |
new_ds = Dataset.from_dict(new_record).cast_column("audio", Audio()) | |
# Update in-memory dataset | |
global current_dataset | |
if len(current_dataset) == 0: | |
current_dataset = new_ds | |
else: | |
current_dataset = concatenate_datasets([current_dataset, new_ds]) | |
# Push to hub | |
save_dataset(current_dataset, CURRENT_DATASET) | |
# Update conversation history | |
history = history + f"\nUser: [audio]\nAssistant: {transcription}" | |
return transcription, history | |
except Exception as exc: | |
logger.error(f"Error during transcription pipeline: {exc}") | |
return f"Error: {exc}", history | |
def build_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🗣️ ASR Moore Live 🧠") | |
gr.Markdown("Speech Recognition interface for Moore language. Records or uploads audio, always denoises, and optionally enhances before ASR.") | |
with gr.Row(): | |
audio_input = gr.Audio(type="filepath", label="Record or upload audio", sources=["microphone", "upload"]) | |
state_box = gr.State(value="") | |
enhance_checkbox = gr.Checkbox(label="Apply Enhancement", value=False) | |
output_text = gr.Textbox(label="Transcription") | |
submit_btn = gr.Button("Transcribe and Save") | |
submit_btn.click(fn=transcribe_and_update, | |
inputs=[audio_input, state_box, enhance_checkbox], | |
outputs=[output_text, state_box]) | |
demo.launch(debug=True) | |
if __name__ == "__main__": | |
build_interface() |