tts / app.py
okewunmi's picture
Update app.py
9147378 verified
raw
history blame
5.89 kB
import os
import sys
import gradio as gr
import torch
import torchaudio
import uroman
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from outetts.wav_tokenizer.decoder import WavTokenizer
# Clone and install YarnGPT at startup
if not os.path.exists("yarngpt"):
print("Cloning YarnGPT repository...")
os.system("git clone https://github.com/saheedniyi02/yarngpt.git")
# Add the repository to Python path
sys.path.append("yarngpt")
# Import the YarnGPT AudioTokenizer
from yarngpt.audiotokenizer import AudioTokenizerV2
# Constants and paths
MODEL_PATH = "saheedniyi/YarnGPT2b"
WAV_TOKENIZER_CONFIG_PATH = "wavtokenizer_config.yaml"
WAV_TOKENIZER_MODEL_PATH = "wavtokenizer_model.ckpt"
# Download the model files at startup
if not os.path.exists(WAV_TOKENIZER_CONFIG_PATH):
print("Downloading WavTokenizer config...")
os.system(f"wget -O {WAV_TOKENIZER_CONFIG_PATH} https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml")
if not os.path.exists(WAV_TOKENIZER_MODEL_PATH):
print("Downloading WavTokenizer model...")
os.system(f"wget -O {WAV_TOKENIZER_MODEL_PATH} https://huggingface.co/novateur/WavTokenizer-large-speech-75token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt")
# Initialize the model and tokenizer
def initialize_model():
print("Initializing AudioTokenizer and model...")
audio_tokenizer = AudioTokenizerV2(
MODEL_PATH,
WAV_TOKENIZER_MODEL_PATH,
WAV_TOKENIZER_CONFIG_PATH
)
print("Loading YarnGPT model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype="auto"
).to(audio_tokenizer.device)
return model, audio_tokenizer
# Initialize the model and tokenizer
print("Starting model initialization...")
model, audio_tokenizer = initialize_model()
print("Model initialization complete!")
# Available voices and languages
VOICES = ["idera", "jude", "kemi", "tunde", "funmi"]
LANGUAGES = ["english", "yoruba", "igbo", "hausa", "pidgin"]
# Function to generate speech
def generate_speech(text, language, voice, temperature=0.1, rep_penalty=1.1):
if not text:
return None, "Please enter some text to convert to speech."
try:
# Create prompt
prompt = audio_tokenizer.create_prompt(text, lang=language, speaker_name=voice)
# Tokenize prompt
input_ids = audio_tokenizer.tokenize_prompt(prompt)
# Generate output
output = model.generate(
input_ids=input_ids,
temperature=temperature,
repetition_penalty=rep_penalty,
max_length=4000,
)
# Convert to audio
codes = audio_tokenizer.get_codes(output)
audio = audio_tokenizer.get_audio(codes)
# Save audio to file
temp_audio_path = "output.wav"
torchaudio.save(temp_audio_path, audio, sample_rate=24000)
return temp_audio_path, f"Successfully generated speech for: {text[:50]}..."
except Exception as e:
return None, f"Error generating speech: {str(e)}"
# Example text for demonstration
examples = [
["Hello, my name is Claude. I am an AI assistant created by Anthropic.", "english", "idera"],
["Báwo ni o ṣe wà? Mo ń gbádùn ọjọ́ mi.", "yoruba", "kemi"],
["I don dey come house now, make you prepare food.", "pidgin", "jude"]
]
# Create the Gradio interface
with gr.Blocks(title="YarnGPT - Nigerian Accented Text-to-Speech") as demo:
gr.Markdown("# YarnGPT - Nigerian Accented Text-to-Speech")
gr.Markdown("Generate speech with Nigerian accents using YarnGPT model.")
with gr.Tab("Basic TTS"):
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Text to convert to speech",
placeholder="Enter text here...",
lines=5
)
language = gr.Dropdown(
label="Language",
choices=LANGUAGES,
value="english"
)
voice = gr.Dropdown(
label="Voice",
choices=VOICES,
value="idera"
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=1.0,
value=0.1,
step=0.1
)
rep_penalty = gr.Slider(
label="Repetition Penalty",
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.1
)
generate_btn = gr.Button("Generate Speech")
with gr.Column():
audio_output = gr.Audio(label="Generated Speech")
status_output = gr.Textbox(label="Status")
gr.Examples(
examples=examples,
inputs=[text_input, language, voice],
outputs=[audio_output, status_output],
fn=generate_speech,
cache_examples=False
)
generate_btn.click(
generate_speech,
inputs=[text_input, language, voice, temperature, rep_penalty],
outputs=[audio_output, status_output]
)
gr.Markdown("""
## About YarnGPT
YarnGPT is a text-to-speech model with Nigerian accents. It supports multiple languages and voices.
### Credits
- Model by [saheedniyi](https://huggingface.co/saheedniyi/YarnGPT2b)
- [Original Repository](https://github.com/saheedniyi02/yarngpt)
""")
# Launch the app
if __name__ == "__main__":
demo.launch()