"""Røst ASR demo.""" import os import warnings import gradio as gr import numpy as np import samplerate import torch from punctfix import PunctFixer from transformers import pipeline from dotenv import load_dotenv load_dotenv() warnings.filterwarnings("ignore", category=FutureWarning) icon = """ """ TITLE = "Røst ASR Demo" DESCRIPTION = f""" This is a demo of the Danish speech recognition model [Røst](https://huggingface.co/alexandrainst/roest-315m). Press "Record" to record your own voice. When you're done you can press "Stop" to stop recording and "Submit" to send the audio to the model for transcription. You can also upload an audio file by pressing the {icon} button. """ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") transcriber = pipeline( task="automatic-speech-recognition", model="alexandrainst/roest-315m", device=device, token=os.getenv("HUGGINGFACE_HUB_TOKEN", True), ) transcription_fixer = PunctFixer(language="da", device=device) def transcribe_audio(sampling_rate_and_audio: tuple[int, np.ndarray]) -> str: """Transcribe the audio. Args: sampling_rate_and_audio: A tuple with the sampling rate and the audio. Returns: The transcription. """ sampling_rate, audio = sampling_rate_and_audio if audio.ndim > 1: audio = np.mean(audio, axis=1) audio = samplerate.resample(audio, 16_000 / sampling_rate, "sinc_best") transcription = transcriber(inputs=audio) if not isinstance(transcription, dict): return "" cleaned_transcription = transcription_fixer.punctuate( text=transcription["text"] ) return cleaned_transcription demo = gr.Interface( fn=transcribe_audio, inputs=gr.Audio(sources=["microphone", "upload"]), outputs="textbox", title=TITLE, description=DESCRIPTION, allow_flagging="never", ) demo.launch()