|
import gradio as gr |
|
import tempfile |
|
|
|
from huggingface_hub import hf_hub_download |
|
from torch import no_grad, package |
|
import ctypes |
|
import gc |
|
|
|
|
|
|
|
config = { |
|
"mykyta": "theodotus/tts-vits-mykyta-uk", |
|
"olena": "theodotus/tts-vits-olena-uk", |
|
"lada": "theodotus/tts-vits-lada-uk", |
|
} |
|
|
|
voices = list(config.keys()) |
|
|
|
tts_kwargs = { |
|
"speaker_name": "uk", |
|
"language_name": "uk", |
|
} |
|
|
|
|
|
def trim_memory(): |
|
libc = ctypes.CDLL("libc.so.6") |
|
libc.malloc_trim(0) |
|
gc.collect() |
|
|
|
|
|
def init_models(): |
|
models = {} |
|
for name, model_name in config.items(): |
|
model_path = hf_hub_download(model_name, "model.pt") |
|
importer = package.PackageImporter(model_path) |
|
synt = importer.load_pickle("tts_models", "model") |
|
models[name] = synt |
|
return models |
|
|
|
|
|
def tts(text: str, voice: str): |
|
synt = models[voice] |
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: |
|
with no_grad(): |
|
wav_data = synt.tts(text, **tts_kwargs) |
|
synt.save_wav(wav_data, fp) |
|
trim_memory() |
|
return fp.name |
|
|
|
|
|
|
|
models = init_models() |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=tts, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Input", |
|
value="К+ам'ян+ець-Под+ільський - м+істо в Хмельн+ицькій +області Укра+їни, ц+ентр Кам'ян+ець-Под+ільської міськ+ої об'+єднаної територі+альної гром+ади +і Кам'ян+ець-Под+ільського рай+ону.", |
|
), |
|
gr.Radio( |
|
label="Voice", |
|
choices=voices, |
|
value=voices[0], |
|
), |
|
], |
|
outputs=gr.Audio(label="Output"), |
|
title="🇺🇦 - Ukrainian Voices", |
|
) |
|
|
|
iface.launch() |