TArtx's picture
Update app.py
1346f0a verified
raw
history blame
3.52 kB
import gradio as gr
import torch
from transformers.models.speecht5.number_normalizer import EnglishNumberNormalizer
from string import punctuation
import re
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
# Set device to CPU only
device = "cpu"
# Load Mini model and associated components with low memory usage
repo_id = "TArtx/parler-tts-mini-v1-finetuned-12"
model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
feature_extractor = AutoFeatureExtractor.from_pretrained(parler-tts/parler-tts-mini-v1)
# Constants
SAMPLE_RATE = feature_extractor.sampling_rate
SEED = 42
# Default input text and description
default_text = "This is a demonstration of my ability to convert written words into spoken language, seamlessly and naturally. As a text-to-speech model, my goal is to sound as clear and engaging as a human, making sure every word I say leaves an impression."
default_description = "moderate speed, very clear, monotone, wonderful speech quality"
# Number normalizer
number_normalizer = EnglishNumberNormalizer()
# Preprocessing function
def preprocess(text):
text = number_normalizer(text).strip()
text = text.replace("-", " ")
if text[-1] not in punctuation:
text = f"{text}."
abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
def separate_abb(chunk):
chunk = chunk.replace(".", "")
return " ".join(chunk)
abbreviations = re.findall(abbreviations_pattern, text)
for abv in abbreviations:
if abv in text:
text = text.replace(abv, separate_abb(abv))
return text
# TTS generation function
def gen_tts(text, description):
try:
# Tokenize inputs and prompts with truncation to avoid memory issues
inputs = tokenizer(description.strip(), return_tensors="pt", truncation=True, max_length=128).to(device)
prompt = tokenizer(preprocess(text), return_tensors="pt", truncation=True, max_length=128).to(device)
set_seed(SEED)
generation = model.generate(
input_ids=inputs.input_ids,
prompt_input_ids=prompt.input_ids,
attention_mask=inputs.attention_mask,
prompt_attention_mask=prompt.prompt_attention_mask,
do_sample=True,
temperature=1.0,
)
audio_arr = generation.cpu().numpy().squeeze()
return SAMPLE_RATE, audio_arr
except Exception as e:
return SAMPLE_RATE, f"Error: {str(e)}"
# Gradio interface
with gr.Blocks() as block:
gr.Markdown(
"""
## Parler-TTS 🗣️
Parler-TTS is a training and inference library for high-fidelity text-to-speech (TTS) models. This demo uses the Mini v1 model.
"""
)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
description = gr.Textbox(label="Description", lines=2, value=default_description, elem_id="input_description")
run_button = gr.Button("Generate Audio", variant="primary")
with gr.Column():
audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out")
inputs = [input_text, description]
outputs = [audio_out]
run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
# Launch the interface
block.queue()
block.launch()