vtubers-speak / app.py
fakeavatar's picture
update with new model and contextual seeding
8027264
import spaces
import gradio as gr
import torch
import soundfile as sf
from transformers import AutoTokenizer, AutoModelForCausalLM
from xcodec2.modeling_xcodec2 import XCodec2Model
import tempfile
import torchaudio
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
####################
# Global model loading
####################
model_name = "fakeavatar/vtubers-4"
print("Loading tokenizer & model ...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
if os.name != "nt" and torch.cuda.is_available(): # 'nt' means Windows, so this runs on Linux/macOS
model = torch.compile(model)
torch.backends.cudnn.benchmark = True # For variable input sizes
torch.backends.cuda.matmul.allow_tf32 = True # Allow TF32 on Ampere GPUs
model.eval().to(device)
print("Loading XCodec2Model ...")
codec_model_path = "HKUSTAudio/xcodec2"
Codec_model = XCodec2Model.from_pretrained(codec_model_path)
Codec_model.eval().to(device)
print("Models loaded.")
####################
# Inference function
####################
def extract_speech_ids(speech_tokens_str):
"""
Restore an integer 23456 from tokens like <|s_23456|>
"""
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith("<|s_") and token_str.endswith("|>"):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
@spaces.GPU
def text2speech(input_text, num_samples):
"""
Convert text to speech waveform and return the audio file path
"""
results = []
with torch.no_grad():
audio, sr = torchaudio.load("./sample.wav")
vq_code = Codec_model.encode_code(audio.to("cuda"))
vq_strings = [f"<|s_{i}|>" for i in vq_code.to("cpu")[0][0].tolist()]
vq_str = "".join(vq_strings)
for i in range(0, num_samples):
# Add start and end tokens around the input text
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
chat = [
{"role": "user", "content": "Convert the text to speech:" + formatted_text},
{"role": "assistant", "content": f"<|SPEECH_GENERATION_START|>"}
]
chat = [
{"role": "system", "content": "the speaker is yui. She has a mild chinese accent and is speaking english. The voice is flowing and nasal, high pitched with a measured speed. The sound is recorded in a fairly clean and carries a medium happy emotion."},
{"role": "user", "content": "Convert the text to speech:" + f"<|TEXT_UNDERSTANDING_START|>Hey, wake up! {input_text}<|TEXT_UNDERSTANDING_END|>"},
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + vq_str},
# {"role": "user", "content": formatted_text},
# {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}
]
# tokenizer.apply_chat_template is used in the Llasa-style dialogue model
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
).to(device)
# End token
speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
# Text generation
outputs = model.generate(
input_ids,
max_length=2048, # We trained our model with a max length of 2048
eos_token_id=speech_end_id,
do_sample=True,
top_p=0.95, # Adjusts the diversity of generated content
temperature=0.9, # Controls randomness in output
repetition_penalty=1.2,
)
# Extract newly generated tokens (excluding the input part)
generated_ids = outputs[0][input_ids.shape[1]:-1]
if (generated_ids.shape[0] < 2):
continue
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Extract <|s_23456|> as [23456 ...]
speech_tokens_int = extract_speech_ids(speech_tokens_str)
speech_tokens_int = torch.tensor(speech_tokens_int).to(device).unsqueeze(0).unsqueeze(0)
# Decode waveform using XCodec2Model
gen_wav = Codec_model.decode_code(speech_tokens_int) # [batch, channels, samples]
# Get audio data and sample rate
audio = gen_wav[0, 0, :].cpu().numpy()
sample_rate = 16000
# Save the audio to a temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
sf.write(tmpfile.name, audio, sample_rate)
audio_path = tmpfile.name
results.append(audio_path)
while len(results) < 10:
results.append(results[-1])
return results
####################
# Gradio Interface
####################
# Slider to control the number of audio samples to generate
num_samples_slider = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Number of Audio Samples")
demo = gr.Interface(
fn=text2speech,
inputs=[gr.Textbox(label="Enter text", lines=5), num_samples_slider],
outputs=[gr.Audio(label=f"Generated Audio {i+1}", type="numpy") for i in range(10)],
title="VTuber TTS",
description="Input a piece of text in English, and click to generate speech."
)
if __name__ == "__main__":
demo.launch(
share=True )