Yehor Smoliakov
Init
bcaad24
raw
history blame
No virus
7.87 kB
import sys
import time
from importlib.metadata import version
from os import remove
from os.path import exists
import numpy as np
import torch
import torchaudio
import torchaudio.transforms as T
import streamlit as st
from streamlit.runtime.uploaded_file_manager import UploadedFile
from transformers import HubertForCTC, Wav2Vec2Processor
# Config
model_name = "Yehor/hubert-uk"
torchaudio_backend = "soundfile"
min_duration = 0.5
max_duration = 60
concurrency_limit = 5
use_torch_compile = False
# Torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load the model
asr_model = HubertForCTC.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device
)
processor = Wav2Vec2Processor.from_pretrained(model_name)
if use_torch_compile:
asr_model = torch.compile(asr_model)
# Elements
examples = [
"example_1.wav",
"example_2.wav",
"example_3.wav",
"example_4.wav",
"example_5.wav",
"example_6.wav",
]
examples_table = """
| File | Text |
| ------------- | ------------- |
| `example_1.wav` | тема про яку не люблять говорити офіційні джерела у генштабі і міноборони це хімічна зброя окупанти вже тривалий час використовують хімічну зброю заборонену |
| `example_2.wav` | всіма конвенціями якщо спочатку це були гранати з дронів то тепер фіксують випадки застосування |
| `example_3.wav` | хімічних снарядів причому склад отруйної речовони різний а отже й наслідки для наших військових теж різні |
| `example_4.wav` | використовує на фронті все що має і хімічна зброя не вийняток тож з чим маємо справу розбиралася марія моганисян |
| `example_5.wav` | двох тисяч випадків застосування росіянами боєприпасів споряджених небезпечними хімічними речовинами |
| `example_6.wav` | на всі писані норми марія моганисян олександр моторний спецкор марафон єдині новини |
""".strip()
authors_table = """
## Authors
Follow them in social networks and **contact** if you need any help or have any questions:
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> <br> **Yehor Smoliakov** |
|------------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram |
| https://x.com/yehor_smoliakov at X |
| https://github.com/egorsmkv at GitHub |
| https://huggingface.co/Yehor at Hugging Face |
| or use [email protected] |
""".strip()
description_head = f"""
## Overview
This space uses https://huggingface.co/Yehor/hubert-uk model to recognize audio files.
> Due to resource limitations, audio duration **must not** exceed **{max_duration}** seconds.
""".strip()
description_foot = f"""
## Community
- **Discord**: https://discord.gg/yVAjkBgmt4
- Speech Recognition: https://t.me/speech_recognition_uk
- Speech Synthesis: https://t.me/speech_synthesis_uk
## More
Check out other ASR models: https://github.com/egorsmkv/speech-recognition-uk
{authors_table}
""".strip()
transcription_value = """
Recognized text will appear here.
Choose **an example file** below the Recognize button, upload **your audio file**, or use **the microphone** to record own voice.
""".strip()
tech_env = f"""
#### Environment
- Python: {sys.version}
- Torch device: {device}
- Torch dtype: {torch_dtype}
- Use torch.compile: {use_torch_compile}
""".strip()
tech_libraries = f"""
#### Libraries
- torch: {version('torch')}
- torchaudio: {version('torchaudio')}
- transformers: {version('transformers')}
- accelerate: {version('accelerate')}
- streamlit: {version('streamlit')}
""".strip()
# UploadedFile
def inference(uploaded_file: UploadedFile):
audio_path = uploaded_file.file_id + '.wav'
with open(audio_path, 'wb') as f:
f.write(uploaded_file.getvalue())
if not audio_path:
st.error("Please upload an audio file.")
return
st.info("Starting recognition")
meta = torchaudio.info(audio_path, backend=torchaudio_backend)
duration = meta.num_frames / meta.sample_rate
if duration < min_duration:
st.error(
f"The duration of the file is less than {min_duration} seconds, it is {round(duration, 2)} seconds."
)
return
if duration > max_duration:
st.error(f"The duration of the file exceeds {max_duration} seconds.")
return
paths = [
audio_path,
]
results = []
for path in paths:
t0 = time.time()
meta = torchaudio.info(audio_path, backend=torchaudio_backend)
audio_duration = meta.num_frames / meta.sample_rate
audio_input, sr = torchaudio.load(path, backend=torchaudio_backend)
if meta.num_channels > 1:
audio_input = torch.mean(audio_input, dim=0, keepdim=True)
if meta.sample_rate != 16_000:
resampler = T.Resample(sr, 16_000, dtype=audio_input.dtype)
audio_input = resampler(audio_input)
audio_input = audio_input.squeeze(0).numpy()
inputs = processor(
[audio_input], sampling_rate=16_000, padding=True
).input_values
features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)
with torch.inference_mode():
logits = asr_model(features).logits
predicted_ids = torch.argmax(logits, dim=-1)
predictions = processor.batch_decode(predicted_ids)
if not predictions:
predictions = "-"
elapsed_time = round(time.time() - t0, 2)
rtf = round(elapsed_time / audio_duration, 4)
audio_duration = round(audio_duration, 2)
results.append(
{
"path": path.split("/")[-1],
"transcription": "\n".join(predictions),
"audio_duration": audio_duration,
"rtf": rtf,
}
)
st.info("Finished!")
result_texts = []
for result in results:
result_texts.append(f'**{result["path"]}**')
result_texts.append("\n\n")
result_texts.append(f'> {result["transcription"]}')
result_texts.append("\n\n")
result_texts.append(f'**Audio duration**: {result["audio_duration"]}')
result_texts.append("\n")
result_texts.append(f'**Real-Time Factor**: {result["rtf"]}')
if exists(audio_path):
remove(audio_path)
return "\n".join(result_texts)
st.title("Speech-to-Text for Ukrainian using HuBERT")
st.markdown(description_head)
st.markdown("## Usage")
audio_file = st.file_uploader("Upload an audio file", type=["wav"])
if st.button("Recognize"):
if audio_file is not None:
transcription = inference(audio_file)
st.markdown(transcription)
else:
st.error("Please upload an audio file.")
st.markdown("### Examples")
st.markdown(examples_table)
st.markdown(description_foot, unsafe_allow_html=True)
st.markdown("### Environment")
st.markdown(tech_env)
st.markdown(tech_libraries)