Spaces:
Running
Running
Now the user doesn't have to wait for the que to know if their input was too long :)
3271f83
verified
import gradio as gr | |
from outetts.v0_1.interface import InterfaceHF | |
import torch | |
# Initialize the TTS model interface | |
interface = InterfaceHF("OuteAI/OuteTTS-0.1-350M") | |
# Check if running on a CPU | |
is_cpu = not torch.cuda.is_available() | |
# Define a function to generate and save TTS output from input text | |
def generate_tts(text, temperature=0.1, repetition_penalty=1.1, max_length=4096): | |
# Set a character limit for the text input | |
max_characters = 30 # adjust as needed | |
# Check if input text exceeds character limit when on CPU | |
if is_cpu and len(text) > max_characters: | |
raise gr.Error( | |
f"Text input is too long! Please limit to {max_characters} characters.\nThis limit is in place to prevent long processing times as this interface is running on a free CPU tier." | |
) | |
# Log user input and parameters in the terminal | |
print(f"User entered text: {text}") | |
print(f"Temperature set to: {temperature}") | |
print(f"Repetition Penalty set to: {repetition_penalty}") | |
print(f"Max Length set to: {max_length}") | |
# Generate TTS output | |
output = interface.generate( | |
text=text, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
max_lenght=max_length | |
) | |
# Save the output audio to a file | |
output.save("output.wav") | |
print("Audio generated and saved as output.wav") | |
return "output.wav" | |
# Create the Gradio Blocks interface | |
with gr.Blocks() as demo: | |
# Log each interaction | |
def on_text_input(text): | |
print(f"User typed text: {text}") | |
def on_temperature_change(val): | |
print(f"Temperature slider adjusted to: {val}") | |
def on_repetition_penalty_change(val): | |
print(f"Repetition Penalty slider adjusted to: {val}") | |
def on_max_length_change(val): | |
print(f"Max Length slider adjusted to: {val}") | |
# Dynamically set max_chars for text input based on whether it's CPU or GPU | |
if is_cpu: | |
text_input = gr.Textbox( | |
lines=2, | |
placeholder="Enter text to convert to speech (30 character limit on CPU)", | |
label="Text", | |
max_length=30 # Enforce character limit only on CPU | |
) | |
else: | |
text_input = gr.Textbox( | |
lines=2, | |
placeholder="Enter text to convert to speech", | |
label="Text" | |
) | |
# Track changes for debugging | |
text_input.change(on_text_input, inputs=text_input) | |
# Sliders with change events for tracking | |
temperature_slider = gr.Slider(0.1, 1.0, value=0.1, label="Temperature") | |
temperature_slider.change(on_temperature_change, inputs=temperature_slider) | |
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.1, label="Repetition Penalty") | |
repetition_penalty_slider.change(on_repetition_penalty_change, inputs=repetition_penalty_slider) | |
max_length_slider = gr.Slider(512, 4096, value=4096, step=256, label="Max Length") | |
max_length_slider.change(on_max_length_change, inputs=max_length_slider) | |
# Button to generate TTS and Audio output | |
generate_button = gr.Button("Generate Speech") | |
audio_output = gr.Audio(type="filepath", label="Generated Speech") | |
# Define interaction between input and output | |
generate_button.click( | |
generate_tts, | |
inputs=[text_input, temperature_slider, repetition_penalty_slider, max_length_slider], | |
outputs=audio_output | |
) | |
print("Launching Gradio interface...") | |
demo.launch() |