|
import os |
|
import sys |
|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
import uroman |
|
import numpy as np |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from outetts.wav_tokenizer.decoder import WavTokenizer |
|
|
|
|
|
if not os.path.exists("yarngpt"): |
|
print("Cloning YarnGPT repository...") |
|
os.system("git clone https://github.com/saheedniyi02/yarngpt.git") |
|
|
|
sys.path.append("yarngpt") |
|
|
|
|
|
from yarngpt.audiotokenizer import AudioTokenizerV2 |
|
|
|
|
|
MODEL_PATH = "saheedniyi/YarnGPT2b" |
|
WAV_TOKENIZER_CONFIG_PATH = "wavtokenizer_config.yaml" |
|
WAV_TOKENIZER_MODEL_PATH = "wavtokenizer_model.ckpt" |
|
|
|
|
|
if not os.path.exists(WAV_TOKENIZER_CONFIG_PATH): |
|
print("Downloading WavTokenizer config...") |
|
os.system(f"wget -O {WAV_TOKENIZER_CONFIG_PATH} https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml") |
|
|
|
if not os.path.exists(WAV_TOKENIZER_MODEL_PATH): |
|
print("Downloading WavTokenizer model...") |
|
os.system(f"wget -O {WAV_TOKENIZER_MODEL_PATH} https://huggingface.co/novateur/WavTokenizer-large-speech-75token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt") |
|
|
|
|
|
def initialize_model(): |
|
print("Initializing AudioTokenizer and model...") |
|
audio_tokenizer = AudioTokenizerV2( |
|
MODEL_PATH, |
|
WAV_TOKENIZER_MODEL_PATH, |
|
WAV_TOKENIZER_CONFIG_PATH |
|
) |
|
|
|
print("Loading YarnGPT model...") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_PATH, |
|
torch_dtype="auto" |
|
).to(audio_tokenizer.device) |
|
|
|
return model, audio_tokenizer |
|
|
|
|
|
print("Starting model initialization...") |
|
model, audio_tokenizer = initialize_model() |
|
print("Model initialization complete!") |
|
|
|
|
|
VOICES = ["idera", "jude", "kemi", "tunde", "funmi"] |
|
LANGUAGES = ["english", "yoruba", "igbo", "hausa", "pidgin"] |
|
|
|
|
|
def generate_speech(text, language, voice, temperature=0.1, rep_penalty=1.1): |
|
if not text: |
|
return None, "Please enter some text to convert to speech." |
|
|
|
try: |
|
|
|
prompt = audio_tokenizer.create_prompt(text, lang=language, speaker_name=voice) |
|
|
|
|
|
input_ids = audio_tokenizer.tokenize_prompt(prompt) |
|
|
|
|
|
output = model.generate( |
|
input_ids=input_ids, |
|
temperature=temperature, |
|
repetition_penalty=rep_penalty, |
|
max_length=4000, |
|
) |
|
|
|
|
|
codes = audio_tokenizer.get_codes(output) |
|
audio = audio_tokenizer.get_audio(codes) |
|
|
|
|
|
temp_audio_path = "output.wav" |
|
torchaudio.save(temp_audio_path, audio, sample_rate=24000) |
|
|
|
return temp_audio_path, f"Successfully generated speech for: {text[:50]}..." |
|
|
|
except Exception as e: |
|
return None, f"Error generating speech: {str(e)}" |
|
|
|
|
|
examples = [ |
|
["Hello, my name is Claude. I am an AI assistant created by Anthropic.", "english", "idera"], |
|
["Báwo ni o ṣe wà? Mo ń gbádùn ọjọ́ mi.", "yoruba", "kemi"], |
|
["I don dey come house now, make you prepare food.", "pidgin", "jude"] |
|
] |
|
|
|
|
|
with gr.Blocks(title="YarnGPT - Nigerian Accented Text-to-Speech") as demo: |
|
gr.Markdown("# YarnGPT - Nigerian Accented Text-to-Speech") |
|
gr.Markdown("Generate speech with Nigerian accents using YarnGPT model.") |
|
|
|
with gr.Tab("Basic TTS"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox( |
|
label="Text to convert to speech", |
|
placeholder="Enter text here...", |
|
lines=5 |
|
) |
|
language = gr.Dropdown( |
|
label="Language", |
|
choices=LANGUAGES, |
|
value="english" |
|
) |
|
voice = gr.Dropdown( |
|
label="Voice", |
|
choices=VOICES, |
|
value="idera" |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.1, |
|
step=0.1 |
|
) |
|
rep_penalty = gr.Slider( |
|
label="Repetition Penalty", |
|
minimum=1.0, |
|
maximum=2.0, |
|
value=1.1, |
|
step=0.1 |
|
) |
|
generate_btn = gr.Button("Generate Speech") |
|
|
|
with gr.Column(): |
|
audio_output = gr.Audio(label="Generated Speech") |
|
status_output = gr.Textbox(label="Status") |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[text_input, language, voice], |
|
outputs=[audio_output, status_output], |
|
fn=generate_speech, |
|
cache_examples=False |
|
) |
|
|
|
generate_btn.click( |
|
generate_speech, |
|
inputs=[text_input, language, voice, temperature, rep_penalty], |
|
outputs=[audio_output, status_output] |
|
) |
|
|
|
gr.Markdown(""" |
|
## About YarnGPT |
|
YarnGPT is a text-to-speech model with Nigerian accents. It supports multiple languages and voices. |
|
|
|
### Credits |
|
- Model by [saheedniyi](https://huggingface.co/saheedniyi/YarnGPT2b) |
|
- [Original Repository](https://github.com/saheedniyi02/yarngpt) |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |