import os import torch import random import numpy as np import gradio as gr import librosa import spaces from accelerate import Accelerator from transformers import T5Tokenizer, T5EncoderModel from diffusers import DDIMScheduler from src.models.conditioners import MaskDiT from src.modules.autoencoder_wrapper import Autoencoder from src.inference import inference from src.utils import load_yaml_with_includes # Load model and configs def load_models(config_name, ckpt_path, vae_path, device): params = load_yaml_with_includes(config_name) # Load codec model autoencoder = Autoencoder(ckpt_path=vae_path, model_type=params['autoencoder']['name'], quantization_first=params['autoencoder']['q_first']).to(device) autoencoder.eval() # Load text encoder tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model']) text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model']).to(device) text_encoder.eval() # Load main U-Net model unet = MaskDiT(**params['model']).to(device) unet.load_state_dict(torch.load(ckpt_path, map_location='cpu')['model']) unet.eval() accelerator = Accelerator(mixed_precision="fp16") unet = accelerator.prepare(unet) # Load noise scheduler noise_scheduler = DDIMScheduler(**params['diff']) latents = torch.randn((1, 128, 128), device=device) noise = torch.randn_like(latents) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device) _ = noise_scheduler.add_noise(latents, noise, timesteps) return autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params MAX_SEED = np.iinfo(np.int32).max # Model and config paths config_name = 'ckpts/ezaudio-xl.yml' ckpt_path = 'ckpts/s3/ezaudio_s3_xl.pt' vae_path = 'ckpts/vae/1m.pt' # save_path = 'output/' # os.makedirs(save_path, exist_ok=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params = load_models(config_name, ckpt_path, vae_path, device) @spaces.GPU def generate_audio(text, length, guidance_scale, guidance_rescale, ddim_steps, eta, random_seed, randomize_seed): neg_text = None length = length * params['autoencoder']['latent_sr'] gt, gt_mask = None, None if text == '': guidance_scale = None print('empyt input') if randomize_seed: random_seed = random.randint(0, MAX_SEED) pred = inference(autoencoder, unet, gt, gt_mask, tokenizer, text_encoder, params, noise_scheduler, text, neg_text, length, guidance_scale, guidance_rescale, ddim_steps, eta, random_seed, device) pred = pred.cpu().numpy().squeeze(0).squeeze(0) # output_file = f"{save_path}/{text}.wav" # sf.write(output_file, pred, samplerate=params['autoencoder']['sr']) return params['autoencoder']['sr'], pred @spaces.GPU def editing_audio(text, boundary, gt_file, mask_start, mask_length, guidance_scale, guidance_rescale, ddim_steps, eta, random_seed, randomize_seed): neg_text = None max_length = 10 if text == '': guidance_scale = None print('empyt input') mask_end = mask_start + mask_length # Load and preprocess ground truth audio gt, sr = librosa.load(gt_file, sr=params['autoencoder']['sr']) gt = gt / (np.max(np.abs(gt)) + 1e-9) audio_length = len(gt) / sr mask_start = min(mask_start, audio_length) if mask_end > audio_length: # outpadding mode padding = round((mask_end - audio_length)*params['autoencoder']['sr']) gt = np.pad(gt, (0, padding), 'constant') audio_length = len(gt) / sr output_audio = gt.copy() gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device) boundary = min((mask_end - mask_start)/2, boundary) # print(boundary) # Calculate start and end indices start_idx = max(mask_start - boundary, 0) end_idx = min(mask_end + boundary, audio_length) # print(start_idx) # print(end_idx) mask_start -= start_idx mask_end -= start_idx gt = gt[:, :, round(start_idx*params['autoencoder']['sr']):round(end_idx*params['autoencoder']['sr'])] # Encode the audio to latent space gt_latent = autoencoder(audio=gt) B, D, L = gt_latent.shape length = L gt_mask = torch.zeros(B, D, L).to(device) latent_sr = params['autoencoder']['latent_sr'] gt_mask[:, :, round(mask_start * latent_sr): round(mask_end * latent_sr)] = 1 gt_mask = gt_mask.bool() if randomize_seed: random_seed = random.randint(0, MAX_SEED) # Perform inference to get the edited latent representation pred = inference(autoencoder, unet, gt_latent, gt_mask, tokenizer, text_encoder, params, noise_scheduler, text, neg_text, length, guidance_scale, guidance_rescale, ddim_steps, eta, random_seed, device) pred = pred.cpu().numpy().squeeze(0).squeeze(0) chunk_length = end_idx - start_idx pred = pred[:round(chunk_length*params['autoencoder']['sr'])] output_audio[round(start_idx*params['autoencoder']['sr']):round(end_idx*params['autoencoder']['sr'])] = pred pred = output_audio return params['autoencoder']['sr'], pred # Examples (if needed for the demo) examples = [ "a dog barking in the distance", "the sound of rain falling softly", "light guitar music is playing", ] # Examples (if needed for the demo) examples_edit = [ ["a dog barking in the background", 2, 3], ["kids playing and laughing nearby", 5, 4], ["rock music playing on the street", 8, 6] ] # CSS styling (optional) css = """ #col-container { margin: 0 auto; max-width: 1280px; } """ # Gradio Blocks layout with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # EzAudio: High-quality Text-to-Audio Generator Generate and edit audio from text using a diffusion transformer. Adjust advanced settings for more control. """) # Tabs for Generate and Edit with gr.Tab("Audio Generation"): # Basic Input: Text prompt with gr.Row(): text_input = gr.Textbox( label="Text Prompt", show_label=True, max_lines=2, placeholder="Enter your prompt", container=True, value="a dog barking in the distance", scale=4 ) # Run button run_button = gr.Button("Generate", scale=1) # Output Component result = gr.Audio(label="Generate", type="numpy") # Advanced settings in an Accordion with gr.Accordion("Advanced Settings", open=False): # Audio Length audio_length = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Audio Length (in seconds)") guidance_scale = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=5.0, label="Guidance Scale") guidance_rescale = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.75, label="Guidance Rescale") ddim_steps = gr.Slider(minimum=25, maximum=200, step=5, value=50, label="DDIM Steps") eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="Eta") seed = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Seed") randomize_seed = gr.Checkbox(label="Randomize Seed (Disable Seed)", value=True) # Examples block gr.Examples( examples=examples, inputs=[text_input] ) # Define the trigger and input-output linking for generation run_button.click( fn=generate_audio, inputs=[text_input, audio_length, guidance_scale, guidance_rescale, ddim_steps, eta, seed, randomize_seed], outputs=[result] ) text_input.submit(fn=generate_audio, inputs=[text_input, audio_length, guidance_scale, guidance_rescale, ddim_steps, eta, seed, randomize_seed], outputs=[result] ) with gr.Tab("Audio Editing and Inpainting"): # Input: Upload audio file with gr.Row(): gt_file_input = gr.Audio(label="Upload Audio to Edit", type="filepath", value="edit_example.wav") # Text prompt for editing text_edit_input = gr.Textbox( label="Edit Prompt", show_label=True, max_lines=2, placeholder="Describe the edit you wat", container=True, value="a dog barking in the background", scale=4 ) # Mask settings mask_start = gr.Number(label="Edit Start (seconds)", value=2.0) mask_length = gr.Slider(minimum=0.5, maximum=10, step=0.5, value=3, label="Edit Length (seconds)") edit_explanation = gr.Markdown(value="**Edit Start**: The time when the edit begins. \n\n**Edit Length**: The duration of the segment to be edited. \n\n**Outpainting**: If the edit extends beyond the audio's length, Outpainting Mode will automatically activate.") # Run button for editing edit_button = gr.Button("Generate", scale=1) # Output Component for edited audio edited_result = gr.Audio(label="Edited Audio", type="numpy") # Advanced settings in an Accordion with gr.Accordion("Advanced Settings", open=False): # Audio Length (optional for editing, can be auto or user-defined) edit_boundary = gr.Slider(minimum=0.5, maximum=4, step=0.5, value=2, label="Edit Boundary (in seconds)") edit_guidance_scale = gr.Slider(minimum=1.0, maximum=10, step=0.5, value=3.0, label="Guidance Scale") edit_guidance_rescale = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.0, label="Guidance Rescale") edit_ddim_steps = gr.Slider(minimum=25, maximum=200, step=5, value=50, label="DDIM Steps") edit_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="Eta") edit_seed = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Seed") edit_randomize_seed = gr.Checkbox(label="Randomize Seed (Disable Seed)", value=True) # Examples block gr.Examples( examples=examples_edit, inputs=[text_edit_input, mask_start, mask_length] ) # Define the trigger and input-output linking for editing edit_button.click( fn=editing_audio, inputs=[ text_edit_input, edit_boundary, gt_file_input, mask_start, mask_length, edit_guidance_scale, edit_guidance_rescale, edit_ddim_steps, edit_eta, edit_seed, edit_randomize_seed ], outputs=[edited_result] ) text_edit_input.submit( fn=editing_audio, inputs=[ text_edit_input, edit_boundary, gt_file_input, mask_start, mask_length, edit_guidance_scale, edit_guidance_rescale, edit_ddim_steps, edit_eta, edit_seed, edit_randomize_seed ], outputs=[edited_result] ) # Launch the Gradio demo demo.launch()