import sys
import time
from importlib.metadata import version
from os import remove
from os.path import exists
import numpy as np
import torch
import torchaudio
import torchaudio.transforms as T
import streamlit as st
from streamlit.runtime.uploaded_file_manager import UploadedFile
from transformers import HubertForCTC, Wav2Vec2Processor
# Config
model_name = "Yehor/hubert-uk"
torchaudio_backend = "soundfile"
min_duration = 0.5
max_duration = 60
concurrency_limit = 5
use_torch_compile = False
# Torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load the model
asr_model = HubertForCTC.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device
)
processor = Wav2Vec2Processor.from_pretrained(model_name)
if use_torch_compile:
asr_model = torch.compile(asr_model)
# Elements
examples = [
"example_1.wav",
"example_2.wav",
"example_3.wav",
"example_4.wav",
"example_5.wav",
"example_6.wav",
]
examples_table = """
| File | Text |
| ------------- | ------------- |
| `example_1.wav` | тема про яку не люблять говорити офіційні джерела у генштабі і міноборони це хімічна зброя окупанти вже тривалий час використовують хімічну зброю заборонену |
| `example_2.wav` | всіма конвенціями якщо спочатку це були гранати з дронів то тепер фіксують випадки застосування |
| `example_3.wav` | хімічних снарядів причому склад отруйної речовони різний а отже й наслідки для наших військових теж різні |
| `example_4.wav` | використовує на фронті все що має і хімічна зброя не вийняток тож з чим маємо справу розбиралася марія моганисян |
| `example_5.wav` | двох тисяч випадків застосування росіянами боєприпасів споряджених небезпечними хімічними речовинами |
| `example_6.wav` | на всі писані норми марія моганисян олександр моторний спецкор марафон єдині новини |
""".strip()
authors_table = """
## Authors
Follow them in social networks and **contact** if you need any help or have any questions:
|
**Yehor Smoliakov** |
|------------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram |
| https://x.com/yehor_smoliakov at X |
| https://github.com/egorsmkv at GitHub |
| https://huggingface.co/Yehor at Hugging Face |
| or use egorsmkv@gmail.com |
""".strip()
description_head = f"""
## Overview
This space uses https://huggingface.co/Yehor/hubert-uk model to recognize audio files.
> Due to resource limitations, audio duration **must not** exceed **{max_duration}** seconds.
""".strip()
description_foot = f"""
## Community
- **Discord**: https://discord.gg/yVAjkBgmt4
- Speech Recognition: https://t.me/speech_recognition_uk
- Speech Synthesis: https://t.me/speech_synthesis_uk
## More
Check out other ASR models: https://github.com/egorsmkv/speech-recognition-uk
{authors_table}
""".strip()
transcription_value = """
Recognized text will appear here.
Choose **an example file** below the Recognize button, upload **your audio file**, or use **the microphone** to record own voice.
""".strip()
tech_env = f"""
#### Environment
- Python: {sys.version}
- Torch device: {device}
- Torch dtype: {torch_dtype}
- Use torch.compile: {use_torch_compile}
""".strip()
tech_libraries = f"""
#### Libraries
- torch: {version('torch')}
- torchaudio: {version('torchaudio')}
- transformers: {version('transformers')}
- accelerate: {version('accelerate')}
- streamlit: {version('streamlit')}
""".strip()
# UploadedFile
def inference(uploaded_file: UploadedFile):
audio_path = uploaded_file.file_id + '.wav'
with open(audio_path, 'wb') as f:
f.write(uploaded_file.getvalue())
if not audio_path:
st.error("Please upload an audio file.")
return
st.info("Starting recognition")
meta = torchaudio.info(audio_path, backend=torchaudio_backend)
duration = meta.num_frames / meta.sample_rate
if duration < min_duration:
st.error(
f"The duration of the file is less than {min_duration} seconds, it is {round(duration, 2)} seconds."
)
return
if duration > max_duration:
st.error(f"The duration of the file exceeds {max_duration} seconds.")
return
paths = [
audio_path,
]
results = []
for path in paths:
t0 = time.time()
meta = torchaudio.info(audio_path, backend=torchaudio_backend)
audio_duration = meta.num_frames / meta.sample_rate
audio_input, sr = torchaudio.load(path, backend=torchaudio_backend)
if meta.num_channels > 1:
audio_input = torch.mean(audio_input, dim=0, keepdim=True)
if meta.sample_rate != 16_000:
resampler = T.Resample(sr, 16_000, dtype=audio_input.dtype)
audio_input = resampler(audio_input)
audio_input = audio_input.squeeze(0).numpy()
inputs = processor(
[audio_input], sampling_rate=16_000, padding=True
).input_values
features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)
with torch.inference_mode():
logits = asr_model(features).logits
predicted_ids = torch.argmax(logits, dim=-1)
predictions = processor.batch_decode(predicted_ids)
if not predictions:
predictions = "-"
elapsed_time = round(time.time() - t0, 2)
rtf = round(elapsed_time / audio_duration, 4)
audio_duration = round(audio_duration, 2)
results.append(
{
"path": path.split("/")[-1],
"transcription": "\n".join(predictions),
"audio_duration": audio_duration,
"rtf": rtf,
}
)
st.info("Finished!")
result_texts = []
for result in results:
result_texts.append(f'**{result["path"]}**')
result_texts.append("\n\n")
result_texts.append(f'> {result["transcription"]}')
result_texts.append("\n\n")
result_texts.append(f'**Audio duration**: {result["audio_duration"]}')
result_texts.append("\n")
result_texts.append(f'**Real-Time Factor**: {result["rtf"]}')
if exists(audio_path):
remove(audio_path)
return "\n".join(result_texts)
st.title("Speech-to-Text for Ukrainian using HuBERT")
st.markdown(description_head)
st.markdown("## Usage")
audio_file = st.file_uploader("Upload an audio file", type=["wav"])
if st.button("Recognize"):
if audio_file is not None:
transcription = inference(audio_file)
st.markdown(transcription)
else:
st.error("Please upload an audio file.")
st.markdown("### Examples")
st.markdown(examples_table)
st.markdown(description_foot, unsafe_allow_html=True)
st.markdown("### Environment")
st.markdown(tech_env)
st.markdown(tech_libraries)