Spaces:
Running
Running
#!/usr/bin/env python | |
# ruff: noqa: E402 | |
import json | |
import tempfile | |
import os | |
import click | |
import gradio as gr | |
import numpy as np | |
import soundfile as sf | |
import torchaudio | |
from importlib.resources import files | |
from groq import Groq | |
from cached_path import cached_path | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Try to import spaces; if available, wrap functions for GPU support. | |
try: | |
import spaces | |
USING_SPACES = True | |
except ImportError: | |
USING_SPACES = False | |
def gpu_decorator(func): | |
""" | |
Decorator that wraps a function with GPU acceleration if running in a Spaces environment. | |
""" | |
if USING_SPACES: | |
return spaces.GPU(func) | |
return func | |
# Local package imports | |
from f5_tts.model import DiT, UNetT | |
from f5_tts.infer.utils_infer import ( | |
load_vocoder, | |
load_model, | |
preprocess_ref_audio_text, | |
infer_process, | |
remove_silence_for_generated_wav, | |
save_spectrogram, | |
) | |
DEFAULT_TTS_MODEL = "F5-TTS" | |
DEFAULT_TTS_MODEL_CFG = [ | |
"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", | |
"hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt", | |
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)), | |
] | |
# Load vocoder and TTS model | |
vocoder = load_vocoder() | |
def load_f5tts( | |
ckpt_path: str = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) | |
): | |
""" | |
Load the F5-TTS model from the given checkpoint path. | |
""" | |
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) | |
return load_model(DiT, F5TTS_model_cfg, ckpt_path) | |
F5TTS_ema_model = load_f5tts() | |
def generate_response(messages, apikey): | |
""" | |
Generate a chat response using the Groq API. | |
If messages is a string, wrap it as a user message. | |
""" | |
if isinstance(messages, str): | |
messages_payload = [{"role": "user", "content": messages}] | |
else: | |
messages_payload = messages | |
client = Groq(api_key=apikey) | |
chat_completion = client.chat.completions.create( | |
messages=messages_payload, | |
model="deepseek-r1-distill-llama-70b", | |
stream=False, | |
) | |
if chat_completion.choices and hasattr(chat_completion.choices[0].message, "content"): | |
return chat_completion.choices[0].message.content | |
return "" | |
def process_audio_input(audio_path, text, apikey, history, conv_state): | |
""" | |
Process audio and/or text input from the user: | |
- If an audio file is provided, its transcript is obtained. | |
- The conversation state and history are updated. | |
Updated to construct the chat history as a list of dictionaries. | |
""" | |
if not audio_path and not text.strip(): | |
return history, conv_state, "" | |
if audio_path: | |
# preprocess_ref_audio_text returns a tuple (audio, transcript) | |
_, text = preprocess_ref_audio_text(audio_path, text) | |
if not text.strip(): | |
return history, conv_state, "" | |
# Wrap the user input in a dict. | |
user_msg = {"role": "user", "content": text} | |
conv_state.append(user_msg) | |
history.append(user_msg) | |
response = generate_response(conv_state, apikey) | |
assistant_msg = {"role": "assistant", "content": response} | |
conv_state.append(assistant_msg) | |
history.append(assistant_msg) | |
return history, conv_state, "" | |
def infer( | |
ref_audio_orig, | |
ref_text, | |
gen_text, | |
remove_silence, | |
cross_fade_duration: float = 0.15, | |
nfe_step: int = 32, | |
speed: float = 1, | |
show_info=print, | |
): | |
""" | |
Generate speech audio using the F5-TTS system based on a reference audio/text and generated text. | |
""" | |
if not ref_audio_orig: | |
gr.Warning("Please provide reference audio.") | |
return gr.update(), gr.update(), ref_text | |
if not gen_text.strip(): | |
gr.Warning("Please enter text to generate.") | |
return gr.update(), gr.update(), ref_text | |
# Preprocess the reference audio and text. | |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) | |
ema_model = F5TTS_ema_model # Use the default F5-TTS model. | |
final_wave, final_sample_rate, combined_spectrogram = infer_process( | |
ref_audio, | |
ref_text, | |
gen_text, | |
ema_model, | |
vocoder, | |
cross_fade_duration=cross_fade_duration, | |
nfe_step=nfe_step, | |
speed=speed, | |
show_info=show_info, | |
progress=gr.Progress(), | |
) | |
if remove_silence: | |
# Write the generated waveform to a temporary file. | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
temp_audio_path = f.name | |
sf.write(temp_audio_path, final_wave, final_sample_rate) | |
# Process the file to remove silence. | |
remove_silence_for_generated_wav(temp_audio_path) | |
final_wave_tensor, _ = torchaudio.load(temp_audio_path) | |
final_wave = final_wave_tensor.squeeze().cpu().numpy() | |
os.unlink(temp_audio_path) # Clean up the temporary file. | |
# Save the spectrogram as a temporary PNG file. | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram: | |
spectrogram_path = tmp_spectrogram.name | |
save_spectrogram(combined_spectrogram, spectrogram_path) | |
return (final_sample_rate, final_wave), spectrogram_path, ref_text | |
with gr.Blocks() as app: | |
gr.Markdown( | |
""" | |
# Voice Chat | |
Have a conversation with an AI using your reference voice! | |
1. Upload a reference audio clip and optionally its transcript. | |
2. Load the chat model. | |
3. Record your message through your microphone. | |
4. The AI will respond using the reference voice. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath") | |
with gr.Column(): | |
with gr.Accordion("Advanced Settings", open=False): | |
remove_silence_chat = gr.Checkbox(label="Remove Silences", value=True) | |
ref_text_chat = gr.Textbox( | |
label="Reference Text", | |
info="Optional: Leave blank to auto-transcribe", | |
lines=2, | |
) | |
system_prompt_chat = gr.Textbox( | |
label="System Prompt", | |
value=( | |
"You are not an AI assistant, you are whoever the user says you are. " | |
"You must stay in character. Keep your responses concise since they will be spoken out loud." | |
), | |
lines=2, | |
) | |
# Specify the chatbot type to avoid deprecation warnings. | |
chatbot_interface = gr.Chatbot(label="Conversation", type="messages") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input_chat = gr.Microphone(label="Speak your message", type="filepath") | |
audio_output_chat = gr.Audio(autoplay=True) | |
with gr.Column(): | |
groq_apikey = gr.Textbox(label="Your Groq API Key") | |
text_input_chat = gr.Textbox(label="Type your message", lines=1) | |
send_btn_chat = gr.Button("Send Message") | |
clear_btn_chat = gr.Button("Clear Conversation") | |
# Initialize the conversation state with the system prompt. | |
conversation_state = gr.State( | |
value=[ | |
{ | |
"role": "system", | |
"content": ( | |
"You are not an AI assistant, you are whoever the user says you are. " | |
"You must stay in character. Keep your responses concise since they will be spoken out loud." | |
), | |
} | |
] | |
) | |
# Create a dummy hidden output to capture the extra (unused) output. | |
dummy_output = gr.Textbox(visible=False) | |
def generate_audio_response(history, ref_audio, ref_text, remove_silence): | |
""" | |
Generate an audio response from the last AI message in the conversation. | |
Returns the generated audio, the (possibly updated) reference text, and the unchanged chat history. | |
""" | |
if not history or not ref_audio: | |
return None, ref_text, history | |
# Find the last assistant message in the history. | |
last_assistant = None | |
for message in reversed(history): | |
if message.get("role") == "assistant": | |
last_assistant = message | |
break | |
if last_assistant is None or not last_assistant.get("content", "").strip(): | |
return None, ref_text, history | |
audio_result, _, ref_text_out = infer( | |
ref_audio, | |
ref_text, | |
last_assistant["content"], | |
remove_silence, | |
cross_fade_duration=0.15, | |
speed=1.0, | |
show_info=print, | |
) | |
return audio_result, ref_text_out, history | |
def clear_conversation(): | |
""" | |
Clear the chat conversation and reset the conversation state. | |
""" | |
initial_state = [ | |
{ | |
"role": "system", | |
"content": ( | |
"You are not an AI assistant, you are whoever the user says you are. " | |
"You must stay in character. Keep your responses concise since they will be spoken out loud." | |
), | |
} | |
] | |
return [], initial_state | |
def update_system_prompt(new_prompt): | |
""" | |
Update the system prompt and reset the conversation. | |
""" | |
initial_state = [{"role": "system", "content": new_prompt}] | |
return [], initial_state | |
# Set up callbacks so that when recording stops or text is submitted, the processing chain is run. | |
audio_input_chat.stop_recording( | |
process_audio_input, | |
inputs=[audio_input_chat, text_input_chat, groq_apikey, chatbot_interface, conversation_state], | |
outputs=[chatbot_interface, conversation_state, dummy_output], | |
).then( | |
generate_audio_response, | |
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat], | |
outputs=[audio_output_chat, ref_text_chat, chatbot_interface], | |
).then(lambda: None, None, audio_input_chat) | |
text_input_chat.submit( | |
process_audio_input, | |
inputs=[audio_input_chat, text_input_chat, groq_apikey, chatbot_interface, conversation_state], | |
outputs=[chatbot_interface, conversation_state, dummy_output], | |
).then( | |
generate_audio_response, | |
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat], | |
outputs=[audio_output_chat, ref_text_chat, chatbot_interface], | |
).then(lambda: None, None, text_input_chat) | |
send_btn_chat.click( | |
process_audio_input, | |
inputs=[audio_input_chat, text_input_chat, groq_apikey, chatbot_interface, conversation_state], | |
outputs=[chatbot_interface, conversation_state, dummy_output], | |
).then( | |
generate_audio_response, | |
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat], | |
outputs=[audio_output_chat, ref_text_chat, chatbot_interface], | |
).then(lambda: None, None, text_input_chat) | |
clear_btn_chat.click(clear_conversation, outputs=[chatbot_interface, conversation_state]) | |
system_prompt_chat.change( | |
update_system_prompt, | |
inputs=system_prompt_chat, | |
outputs=[chatbot_interface, conversation_state], | |
) | |
def main(port, host, share, api, root_path): | |
""" | |
Launch the Gradio app. | |
""" | |
app.queue(api_open=api).launch( | |
server_name=host, | |
server_port=port, | |
share=share, | |
show_api=api, | |
root_path=root_path, | |
) | |
if __name__ == "__main__": | |
main() | |