roest-demo / app.py
saattrupdan's picture
feat: Add .env
f4ab270
raw
history blame
2.27 kB
"""Røst ASR demo."""
import os
import warnings
import gradio as gr
import numpy as np
import samplerate
import torch
from punctfix import PunctFixer
from transformers import pipeline
from dotenv import load_dotenv
load_dotenv()
warnings.filterwarnings("ignore", category=FutureWarning)
icon = """
<svg xmlns="http://www.w3.org/2000/svg" width="14px" viewBox="0 0 24 24" fill="none"
stroke="currentColor" stroke-width="2" stroke-linecap="round"
stroke-linejoin="round" style="display: inline;">
<path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"/>
<polyline points="17 8 12 3 7 8"/>
<line x1="12" y1="3" x2="12" y2="15"/>
</svg>
"""
TITLE = "Røst ASR Demo"
DESCRIPTION = f"""
This is a demo of the Danish speech recognition model
[Røst](https://huggingface.co/alexandrainst/roest-315m). Press "Record" to record your
own voice. When you're done you can press "Stop" to stop recording and "Submit" to
send the audio to the model for transcription. You can also upload an audio file by
pressing the {icon} button.
"""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
transcriber = pipeline(
task="automatic-speech-recognition",
model="alexandrainst/roest-315m",
device=device,
token=os.getenv("HUGGINGFACE_HUB_TOKEN", True),
)
transcription_fixer = PunctFixer(language="da", device=device)
def transcribe_audio(sampling_rate_and_audio: tuple[int, np.ndarray]) -> str:
"""Transcribe the audio.
Args:
sampling_rate_and_audio:
A tuple with the sampling rate and the audio.
Returns:
The transcription.
"""
sampling_rate, audio = sampling_rate_and_audio
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
audio = samplerate.resample(audio, 16_000 / sampling_rate, "sinc_best")
transcription = transcriber(inputs=audio)
if not isinstance(transcription, dict):
return ""
cleaned_transcription = transcription_fixer.punctuate(
text=transcription["text"]
)
return cleaned_transcription
demo = gr.Interface(
fn=transcribe_audio,
inputs=gr.Audio(sources=["microphone", "upload"]),
outputs="textbox",
title=TITLE,
description=DESCRIPTION,
allow_flagging="never",
)
demo.launch()