File size: 2,269 Bytes
17cb7d3
 
4db4bee
17cb7d3
 
 
 
 
 
 
 
f4ab270
 
 
17cb7d3
 
 
8afa9a4
 
 
 
 
 
 
 
 
17cb7d3
8afa9a4
4db4bee
8afa9a4
 
 
 
17cb7d3
 
 
 
 
 
4db4bee
8afa9a4
17cb7d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Røst ASR demo."""

import os
import warnings

import gradio as gr
import numpy as np
import samplerate
import torch
from punctfix import PunctFixer
from transformers import pipeline
from dotenv import load_dotenv

load_dotenv()

warnings.filterwarnings("ignore", category=FutureWarning)

icon = """
<svg xmlns="http://www.w3.org/2000/svg" width="14px" viewBox="0 0 24 24" fill="none"
     stroke="currentColor" stroke-width="2" stroke-linecap="round"
     stroke-linejoin="round" style="display: inline;">
  <path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"/>
  <polyline points="17 8 12 3 7 8"/>
  <line x1="12" y1="3" x2="12" y2="15"/>
</svg>
"""
TITLE = "Røst ASR Demo"
DESCRIPTION = f"""
This is a demo of the Danish speech recognition model
[Røst](https://huggingface.co/alexandrainst/roest-315m). Press "Record" to record your
own voice. When you're done you can press "Stop" to stop recording and "Submit" to
send the audio to the model for transcription. You can also upload an audio file by
pressing the {icon} button.
"""

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
transcriber = pipeline(
    task="automatic-speech-recognition",
    model="alexandrainst/roest-315m",
    device=device,
    token=os.getenv("HUGGINGFACE_HUB_TOKEN", True),
)
transcription_fixer = PunctFixer(language="da", device=device)

def transcribe_audio(sampling_rate_and_audio: tuple[int, np.ndarray]) -> str:
    """Transcribe the audio.

    Args:
        sampling_rate_and_audio:
            A tuple with the sampling rate and the audio.

    Returns:
        The transcription.
    """
    sampling_rate, audio = sampling_rate_and_audio
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)
    audio = samplerate.resample(audio, 16_000 / sampling_rate, "sinc_best")
    transcription = transcriber(inputs=audio)
    if not isinstance(transcription, dict):
        return ""
    cleaned_transcription = transcription_fixer.punctuate(
        text=transcription["text"]
    )
    return cleaned_transcription

demo = gr.Interface(
    fn=transcribe_audio,
    inputs=gr.Audio(sources=["microphone", "upload"]),
    outputs="textbox",
    title=TITLE,
    description=DESCRIPTION,
    allow_flagging="never",
)

demo.launch()