Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
import torch | |
from TTS.api import TTS | |
import os | |
import tempfile | |
import torchaudio | |
from huggingface_hub import hf_hub_download | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
# Aceptar los t茅rminos de COQUI | |
os.environ["COQUI_TOS_AGREED"] = "1" | |
# Definir el dispositivo como CPU | |
device = "cpu" | |
# Descargar archivos desde HuggingFace | |
model_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="model.pth") | |
config_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="config.json") | |
vocab_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="vocab.json") | |
# Funci贸n para limpiar la cach茅 de GPU (por si en el futuro se usa GPU) | |
def clear_gpu_cache(): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Cargar el modelo XTTS | |
XTTS_MODEL = None | |
def load_model(xtts_checkpoint, xtts_config, xtts_vocab): | |
global XTTS_MODEL | |
clear_gpu_cache() | |
if not xtts_checkpoint or not xtts_config or not xtts_vocab: | |
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!" | |
# Configuraci贸n del modelo | |
config = XttsConfig() | |
config.load_json(xtts_config) | |
# Inicializar el modelo | |
XTTS_MODEL = Xtts.init_from_config(config) | |
print("Loading XTTS model!") | |
# Cargar el checkpoint del modelo | |
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False, weights_only=True) | |
print("Model Loaded!") | |
# Funci贸n para ejecutar TTS | |
def run_tts(lang, tts_text, speaker_audio_file): | |
if XTTS_MODEL is None or not speaker_audio_file: | |
return "You need to run the previous step to load the model !!", None, None | |
# Usar inference_mode para mejorar el rendimiento | |
with torch.inference_mode(): | |
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents( | |
audio_path=speaker_audio_file, | |
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, | |
max_ref_length=XTTS_MODEL.config.max_ref_len, | |
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs | |
) | |
out = XTTS_MODEL.inference( | |
text=tts_text, | |
language=lang, | |
gpt_cond_latent=gpt_cond_latent, | |
speaker_embedding=speaker_embedding, | |
temperature=XTTS_MODEL.config.temperature, | |
length_penalty=XTTS_MODEL.config.length_penalty, | |
repetition_penalty=XTTS_MODEL.config.repetition_penalty, | |
top_k=XTTS_MODEL.config.top_k, | |
top_p=XTTS_MODEL.config.top_p, | |
) | |
# Guardar el audio generado en un archivo temporal | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: | |
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) | |
out_path = fp.name | |
torchaudio.save(out_path, out["wav"], 24000) | |
print("Speech generated!") | |
return out_path, speaker_audio_file | |
# Definir la funci贸n para Gradio | |
def generate(text, audio): | |
load_model(model_path, config_path, vocab_path) | |
out_path, speaker_audio_file = run_tts(lang='es', tts_text=text, speaker_audio_file=audio) | |
return out_path | |
# Configurar la interfaz de Gradio | |
demo = gr.Interface( | |
fn=generate, | |
inputs=[gr.Textbox(label='Frase a generar'), gr.Audio(type='filepath', label='Voz de referencia')], | |
outputs=gr.Audio(type='filepath') | |
) | |
# Lanzar la interfaz | |
demo.launch() |