RedSparkie's picture
Update app.py
1a3ef6b verified
raw
history blame
3.54 kB
import spaces
import gradio as gr
import torch
from TTS.api import TTS
import os
import tempfile
import torchaudio
from huggingface_hub import hf_hub_download
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
# Aceptar los t茅rminos de COQUI
os.environ["COQUI_TOS_AGREED"] = "1"
# Definir el dispositivo como CPU
device = "cpu"
# Descargar archivos desde HuggingFace
model_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="model.pth")
config_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="config.json")
vocab_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="vocab.json")
# Funci贸n para limpiar la cach茅 de GPU (por si en el futuro se usa GPU)
def clear_gpu_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Cargar el modelo XTTS
XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
global XTTS_MODEL
clear_gpu_cache()
if not xtts_checkpoint or not xtts_config or not xtts_vocab:
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
# Configuraci贸n del modelo
config = XttsConfig()
config.load_json(xtts_config)
# Inicializar el modelo
XTTS_MODEL = Xtts.init_from_config(config)
print("Loading XTTS model!")
# Cargar el checkpoint del modelo
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False, weights_only=True)
print("Model Loaded!")
# Funci贸n para ejecutar TTS
def run_tts(lang, tts_text, speaker_audio_file):
if XTTS_MODEL is None or not speaker_audio_file:
return "You need to run the previous step to load the model !!", None, None
# Usar inference_mode para mejorar el rendimiento
with torch.inference_mode():
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
audio_path=speaker_audio_file,
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
max_ref_length=XTTS_MODEL.config.max_ref_len,
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs
)
out = XTTS_MODEL.inference(
text=tts_text,
language=lang,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=XTTS_MODEL.config.temperature,
length_penalty=XTTS_MODEL.config.length_penalty,
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
top_k=XTTS_MODEL.config.top_k,
top_p=XTTS_MODEL.config.top_p,
)
# Guardar el audio generado en un archivo temporal
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
out_path = fp.name
torchaudio.save(out_path, out["wav"], 24000)
print("Speech generated!")
return out_path, speaker_audio_file
# Definir la funci贸n para Gradio
@spaces.GPU
def generate(text, audio):
load_model(model_path, config_path, vocab_path)
out_path, speaker_audio_file = run_tts(lang='es', tts_text=text, speaker_audio_file=audio)
return out_path
# Configurar la interfaz de Gradio
demo = gr.Interface(
fn=generate,
inputs=[gr.Textbox(label='Frase a generar'), gr.Audio(type='filepath', label='Voz de referencia')],
outputs=gr.Audio(type='filepath')
)
# Lanzar la interfaz
demo.launch()