'extend_with_continuation' issue

#3
by Aviel08 - opened

Hi! congratulations on MelodyFlow, this is a big leap forward in comparison with the previous models.
I'm just trying to work around its main limitation, the 30 seconds limit, so I've been trying to extend it with the windowing approach, just like in MusicGen but I'm having issues. After the original 30 seconds the quality of the audio degrades and I can't figure out why. The continuation is coherent but it sounds muffled and distorted.
Is it related to the samples? A simple demo code would be appreciated.
Thanks!

Here's a simple code showing the possible issue:

import torch
import random
from audiocraft.models import MelodyFlow
from audiocraft.data.audio import audio_write
import soundfile as sf

# Forces CUDA if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Creates a dummy initial audio tensor (replace with your actual short audio clip if needed)
initial_audio = torch.randn(1, 2, 48000).to(device)  # 2 seconds of stereo audio at 24kHz

# Loads the model
model = MelodyFlow.get_pretrained("facebook/melodyflow-t24-30secs", device=device)

# Simplified extend_with_continuation function (for demonstration)
def extend_with_continuation(model, initial_audio_tensor, target_duration=32, overlap=1, duration_multiplier=1, max_iterations=2):
    output = initial_audio_tensor

    sample_rate = model.sample_rate
    current_duration = output.shape[-1] / sample_rate
    remaining_duration = target_duration - current_duration

    hop_length = model.lm.cfg.dataset.get("hop_length", 512)

    iterations = 0
    while remaining_duration > 0 and iterations < max_iterations:
        iterations += 1
        last_chunk = output[:, :, -int(overlap * sample_rate):].contiguous()

        # Encode last_chunk into latent tokens
        attributes, _ = model._prepare_tokens_and_attributes([""], last_chunk)
        prompt_tokens = model.encode_audio(last_chunk).contiguous()

        # Calculate desired next_segment length in samples
        gen_duration_samples = int(overlap * duration_multiplier * sample_rate)

        # Generate continuation tokens
        next_segment_tokens = model._generate_tokens(attributes, prompt_tokens, progress=True)

        # Ensure correct number of generated samples (simplified for demonstration)
        generated_samples = next_segment_tokens.shape[-1] * hop_length
        if generated_samples != gen_duration_samples:
            raise ValueError(f"Generated samples ({generated_samples}) do not match gen_duration_samples ({gen_duration_samples}). This might need to be fixed in the audiocraft library?")

        next_segment_tokens = next_segment_tokens.contiguous()

        # Reset LSTM hidden state (if accessible)
        if hasattr(model.compression_model.decoder.model, 'lstm'):
            model.compression_model.decoder.model.lstm.reset_hidden_state()
        elif hasattr(model.compression_model.decoder, 'reset_state'):
            model.compression_model.decoder.reset_state()


        # Decode on CPU if CUDA causes issues
        # next_segment = model.compression_model.to('cpu').decode(next_segment_tokens.cpu(), None).to(model.device)
        next_segment = model.generate_audio(next_segment_tokens)


        # Calculate actual next_segment duration
        next_segment_duration = next_segment.shape[-1] / sample_rate

        # Concatenate
        output = torch.cat([output[:, :, :-int(overlap * sample_rate)], next_segment], dim=-1)

        # Update remaining duration
        remaining_duration -= (next_segment_duration - overlap)
        current_duration = output.shape[-1] / sample_rate

        print(f"Current duration: {current_duration:.1f}s, Remaining: {remaining_duration:.1f}s")

    output = output.detach().cpu().float()[0]
    return output, sample_rate


# Call extend_with_continuation
extended_audio, extended_sample_rate = extend_with_continuation(
    model, initial_audio, target_duration=32, overlap=1, duration_multiplier=1, max_iterations=2
)

# Save the extended audio (optional)
sf.write("extended_audio.wav", extended_audio.T, extended_sample_rate)

print("Finished")

Sign up or log in to comment