Spaces:
Running
on
Zero
Running
on
Zero
'extend_with_continuation' issue
#3
by
Aviel08
- opened
Hi! congratulations on MelodyFlow, this is a big leap forward in comparison with the previous models.
I'm just trying to work around its main limitation, the 30 seconds limit, so I've been trying to extend it with the windowing approach, just like in MusicGen but I'm having issues. After the original 30 seconds the quality of the audio degrades and I can't figure out why. The continuation is coherent but it sounds muffled and distorted.
Is it related to the samples? A simple demo code would be appreciated.
Thanks!
Here's a simple code showing the possible issue:
import torch
import random
from audiocraft.models import MelodyFlow
from audiocraft.data.audio import audio_write
import soundfile as sf
# Forces CUDA if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Creates a dummy initial audio tensor (replace with your actual short audio clip if needed)
initial_audio = torch.randn(1, 2, 48000).to(device) # 2 seconds of stereo audio at 24kHz
# Loads the model
model = MelodyFlow.get_pretrained("facebook/melodyflow-t24-30secs", device=device)
# Simplified extend_with_continuation function (for demonstration)
def extend_with_continuation(model, initial_audio_tensor, target_duration=32, overlap=1, duration_multiplier=1, max_iterations=2):
output = initial_audio_tensor
sample_rate = model.sample_rate
current_duration = output.shape[-1] / sample_rate
remaining_duration = target_duration - current_duration
hop_length = model.lm.cfg.dataset.get("hop_length", 512)
iterations = 0
while remaining_duration > 0 and iterations < max_iterations:
iterations += 1
last_chunk = output[:, :, -int(overlap * sample_rate):].contiguous()
# Encode last_chunk into latent tokens
attributes, _ = model._prepare_tokens_and_attributes([""], last_chunk)
prompt_tokens = model.encode_audio(last_chunk).contiguous()
# Calculate desired next_segment length in samples
gen_duration_samples = int(overlap * duration_multiplier * sample_rate)
# Generate continuation tokens
next_segment_tokens = model._generate_tokens(attributes, prompt_tokens, progress=True)
# Ensure correct number of generated samples (simplified for demonstration)
generated_samples = next_segment_tokens.shape[-1] * hop_length
if generated_samples != gen_duration_samples:
raise ValueError(f"Generated samples ({generated_samples}) do not match gen_duration_samples ({gen_duration_samples}). This might need to be fixed in the audiocraft library?")
next_segment_tokens = next_segment_tokens.contiguous()
# Reset LSTM hidden state (if accessible)
if hasattr(model.compression_model.decoder.model, 'lstm'):
model.compression_model.decoder.model.lstm.reset_hidden_state()
elif hasattr(model.compression_model.decoder, 'reset_state'):
model.compression_model.decoder.reset_state()
# Decode on CPU if CUDA causes issues
# next_segment = model.compression_model.to('cpu').decode(next_segment_tokens.cpu(), None).to(model.device)
next_segment = model.generate_audio(next_segment_tokens)
# Calculate actual next_segment duration
next_segment_duration = next_segment.shape[-1] / sample_rate
# Concatenate
output = torch.cat([output[:, :, :-int(overlap * sample_rate)], next_segment], dim=-1)
# Update remaining duration
remaining_duration -= (next_segment_duration - overlap)
current_duration = output.shape[-1] / sample_rate
print(f"Current duration: {current_duration:.1f}s, Remaining: {remaining_duration:.1f}s")
output = output.detach().cpu().float()[0]
return output, sample_rate
# Call extend_with_continuation
extended_audio, extended_sample_rate = extend_with_continuation(
model, initial_audio, target_duration=32, overlap=1, duration_multiplier=1, max_iterations=2
)
# Save the extended audio (optional)
sf.write("extended_audio.wav", extended_audio.T, extended_sample_rate)
print("Finished")