Spaces:
Sleeping
Sleeping
File size: 2,269 Bytes
17cb7d3 4db4bee 17cb7d3 f4ab270 17cb7d3 8afa9a4 17cb7d3 8afa9a4 4db4bee 8afa9a4 17cb7d3 4db4bee 8afa9a4 17cb7d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
"""Røst ASR demo."""
import os
import warnings
import gradio as gr
import numpy as np
import samplerate
import torch
from punctfix import PunctFixer
from transformers import pipeline
from dotenv import load_dotenv
load_dotenv()
warnings.filterwarnings("ignore", category=FutureWarning)
icon = """
<svg xmlns="http://www.w3.org/2000/svg" width="14px" viewBox="0 0 24 24" fill="none"
stroke="currentColor" stroke-width="2" stroke-linecap="round"
stroke-linejoin="round" style="display: inline;">
<path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"/>
<polyline points="17 8 12 3 7 8"/>
<line x1="12" y1="3" x2="12" y2="15"/>
</svg>
"""
TITLE = "Røst ASR Demo"
DESCRIPTION = f"""
This is a demo of the Danish speech recognition model
[Røst](https://huggingface.co/alexandrainst/roest-315m). Press "Record" to record your
own voice. When you're done you can press "Stop" to stop recording and "Submit" to
send the audio to the model for transcription. You can also upload an audio file by
pressing the {icon} button.
"""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
transcriber = pipeline(
task="automatic-speech-recognition",
model="alexandrainst/roest-315m",
device=device,
token=os.getenv("HUGGINGFACE_HUB_TOKEN", True),
)
transcription_fixer = PunctFixer(language="da", device=device)
def transcribe_audio(sampling_rate_and_audio: tuple[int, np.ndarray]) -> str:
"""Transcribe the audio.
Args:
sampling_rate_and_audio:
A tuple with the sampling rate and the audio.
Returns:
The transcription.
"""
sampling_rate, audio = sampling_rate_and_audio
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
audio = samplerate.resample(audio, 16_000 / sampling_rate, "sinc_best")
transcription = transcriber(inputs=audio)
if not isinstance(transcription, dict):
return ""
cleaned_transcription = transcription_fixer.punctuate(
text=transcription["text"]
)
return cleaned_transcription
demo = gr.Interface(
fn=transcribe_audio,
inputs=gr.Audio(sources=["microphone", "upload"]),
outputs="textbox",
title=TITLE,
description=DESCRIPTION,
allow_flagging="never",
)
demo.launch()
|