Instructions to use Localsong/LocalSong with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use Localsong/LocalSong with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("Localsong/LocalSong", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from torchvision.utils import make_grid, save_image | |
| from tqdm import tqdm | |
| from ddt_model import LocalSongModel | |
| from transformers import get_cosine_schedule_with_warmup | |
| from datasets import load_from_disk | |
| from accelerate import Accelerator | |
| import os | |
| import argparse | |
| from torch.utils.tensorboard import SummaryWriter | |
| from datetime import datetime | |
| from collections import deque | |
| import torchaudio | |
| import re | |
| import sys | |
| import math | |
| from tag_embedder import TagEmbedder | |
| # Import MusicDCAE | |
| from acestep.music_dcae.music_dcae_pipeline import MusicDCAE | |
| # Import Muon optimizer | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import timm.optim | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| def save(model, optimizer, scheduler, global_step, accelerator): | |
| if accelerator.is_main_process: | |
| checkpoint_dir = "checkpoints" | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{global_step}.pth") | |
| save_dict = { | |
| 'model_state_dict': unwrapped_model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'global_step': global_step | |
| } | |
| accelerator.save(save_dict, checkpoint_path) | |
| print(f"Checkpoint saved at step {global_step}: {checkpoint_path}") | |
| checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")], | |
| key=lambda x: int(x.split("_")[1].split(".")[0]), reverse=True) | |
| for old_checkpoint in checkpoints[5:]: | |
| os.remove(os.path.join(checkpoint_dir, old_checkpoint)) | |
| print(f"Removed old checkpoint: {old_checkpoint}") | |
| def load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator): | |
| checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| state_dict = {k.replace("_orig_mod.", ""): v for k, v in checkpoint['model_state_dict'].items()} | |
| missing, unexpected = unwrapped_model.load_state_dict(state_dict, strict=True) | |
| print("MISSING:", missing) | |
| print("UNEXPECTED:", unexpected) | |
| if 'optimizer_state_dict' in checkpoint: | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| print("Optimizer loaded") | |
| global_step = checkpoint['global_step'] | |
| print(f"Resumed from step {global_step}") | |
| return global_step | |
| def resume(model, optimizer, scheduler, accelerator): | |
| checkpoint_dir = "checkpoints" | |
| if os.path.exists(checkpoint_dir): | |
| checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")] | |
| if checkpoints: | |
| latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("_")[1].split(".")[0])) | |
| checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint) | |
| if accelerator.is_main_process: | |
| print(f"Resuming from checkpoint: {checkpoint_path}") | |
| return load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator) | |
| else: | |
| if accelerator.is_main_process: | |
| print("No checkpoints found. Starting from scratch.") | |
| else: | |
| if accelerator.is_main_process: | |
| print("Checkpoint directory not found. Starting from scratch.") | |
| return 0 | |
| class AudioVAE: | |
| def __init__(self, device): | |
| self.model = MusicDCAE().to(device) | |
| self.model.eval() | |
| self.device = device | |
| self.latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], device=device).view(1, -1, 1, 1) | |
| self.latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], device=device).view(1, -1, 1, 1) | |
| def encode(self, audio): | |
| """Encode audio to latents""" | |
| # audio should be (B, 2, T) at 48kHz | |
| with torch.no_grad(): | |
| audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device) | |
| latents, _ = self.model.encode(audio, audio_lengths, sr=48000) | |
| # Normalize latents: (latents - mean) / std | |
| latents = (latents - self.latent_mean) / self.latent_std | |
| return latents | |
| def decode(self, latents): | |
| """Decode latents to audio""" | |
| with torch.no_grad(): | |
| # Denormalize latents: latents * std + mean | |
| latents = latents * self.latent_std + self.latent_mean | |
| sr, audio_list = self.model.decode(latents, sr=48000) | |
| # Convert list of audio tensors to batch tensor | |
| audio_batch = torch.stack(audio_list).to(self.device) | |
| return audio_batch | |
| class RF: | |
| def __init__(self, model, time_sampling="sigmoid"): | |
| self.model = model | |
| self.time_sampling = time_sampling | |
| def sample_timesteps(self, batch, device): | |
| """Sample timesteps based on the configured strategy.""" | |
| if self.time_sampling == "sigmoid": | |
| return torch.sigmoid(torch.randn((batch,), device=device)) | |
| elif self.time_sampling == "warped": | |
| pm = 128 * 16 * 16 | |
| alpha = max(1.0, math.sqrt(pm / 4096.0)) | |
| u = torch.rand(batch, device=device) | |
| return alpha * u / (1.0 + (alpha - 1.0) * u) | |
| elif self.time_sampling == "uniform": | |
| return torch.rand(batch, device=device) | |
| else: | |
| raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}") | |
| def forward(self, x, cond): | |
| b = x.size(0) | |
| t = self.sample_timesteps(b, x.device) | |
| texp = t.view([b, *([1] * len(x.shape[1:]))]) | |
| z1 = torch.randn_like(x) | |
| zt = (1 - texp) * x + texp * z1 | |
| x_pred = self.model(zt, t, cond) | |
| target = (zt - x) / (texp + 0.05) | |
| v_pred = (zt - x_pred) / (texp + 0.05) | |
| loss = F.mse_loss(target, v_pred) | |
| return loss | |
| def get_sampling_timesteps(self, steps, device): | |
| """Generate timesteps for sampling.""" | |
| if self.time_sampling == "uniform" or self.time_sampling == "sigmoid": | |
| return torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1] | |
| elif self.time_sampling == "warped": | |
| pm = 128 * 16 * 16 | |
| alpha = max(1.0, math.sqrt(pm / 4096.0)) | |
| u = torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1] | |
| return alpha * u / (1.0 + (alpha - 1.0) * u) | |
| else: | |
| raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}") | |
| def sample(self, z, cond, null_cond=None, sample_steps=100, cfg=3.0): | |
| b = z.size(0) | |
| device = z.device | |
| latent_shape = [b, *([1] * len(z.shape[1:]))] | |
| timesteps = self.get_sampling_timesteps(sample_steps, device) | |
| images = [z] | |
| for idx in range(sample_steps): | |
| t_curr = timesteps[idx] | |
| t_next = timesteps[idx + 1] if idx + 1 < sample_steps else torch.tensor(0.0, device=device) | |
| dt = t_curr - t_next | |
| t = t_curr.expand(b) | |
| vc = self.model(z, t, cond) | |
| vc = (z - vc) / t_curr | |
| if null_cond is not None: | |
| vu = self.model(z, t, null_cond) | |
| vu = (z - vu) / t_curr | |
| vc = vu + cfg * (vc - vu) | |
| z = z - dt * vc | |
| images.append(z) | |
| return images | |
| def save_audio_samples(audio_batch, sample_rate, filename): | |
| """Save audio samples to file""" | |
| os.makedirs("audio_samples", exist_ok=True) | |
| # Take first sample from batch and convert to CPU | |
| audio = audio_batch[0].cpu() # Shape: (2, T) for stereo | |
| # Save as WAV file | |
| filepath = os.path.join("audio_samples", filename) | |
| torchaudio.save(filepath, audio, sample_rate) | |
| print(f"Saved audio sample: {filepath}") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Audio training script with TensorBoard logging') | |
| parser.add_argument('--channels', type=int, default=8, help='Number of input channels in the audio latents') | |
| parser.add_argument('--audio_height', type=int, default=16, help='Height of audio latents') | |
| parser.add_argument('--max_audio_width', type=int, default=4096, help='Max width of audio latents') | |
| parser.add_argument('--subsection_length', type=int, default=256, help='Length of random subsection to sample from each audio latent') | |
| parser.add_argument('--n_layers', type=int, default=36, help='Number of layers in the model') | |
| parser.add_argument('--n_encoder_layers', type=int, default=36, help='Number of encoder layers in the model') | |
| parser.add_argument('--n_heads', type=int, default=16, help='Number of heads in the model') | |
| parser.add_argument('--dim', type=int, default=768, help='Dimension of the encoder') | |
| parser.add_argument('--decoder_dim', type=int, default=1536, help='Dimension of the decoder (if None, uses --dim)') | |
| parser.add_argument('--dataset_name', type=str, default="cache", help='Audio dataset name') | |
| parser.add_argument('--num_workers', type=int, default=16, help='Number of workers for dataloader') | |
| parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training') | |
| parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs to train') | |
| parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') | |
| parser.add_argument('--warmup_steps', type=int, default=0, help='Number of warmup steps') | |
| parser.add_argument('--sample_every', type=int, default=500, help='Audio sampling interval (batches)') | |
| parser.add_argument('--save_every', type=int, default=1000, help='Model saving interval (batches)') | |
| parser.add_argument('--num_samples', type=int, default=16, help='Number of samples to generate') | |
| parser.add_argument('--resume', type=bool, default=True, help='Resume training from checkpoint') | |
| parser.add_argument('--pad_to_length', action='store_true', help='Pad short samples to subsection_length instead of filtering them out') | |
| parser.add_argument('--time_sampling', type=str, default='warped', choices=['sigmoid', 'warped', 'uniform'], help='Timestep sampling strategy') | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| accelerator = Accelerator(mixed_precision="bf16" if torch.cuda.is_available() else "no") | |
| is_main_process = accelerator.is_main_process | |
| writer = None | |
| if is_main_process: | |
| run_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| writer = SummaryWriter(log_dir=f"runs/{run_datetime}") | |
| dataset = load_from_disk(args.dataset_name).with_format(type="torch") | |
| # Filter out audio samples shorter than subsection_length (unless padding is enabled) | |
| if not args.pad_to_length: | |
| def filter_by_length(example): | |
| latent_width = example['latents'].shape[-1] | |
| return latent_width >= args.subsection_length * 2 | |
| dataset = dataset.filter(filter_by_length) | |
| if is_main_process: | |
| print(f"Dataset filtered to {len(dataset)} samples with width >= {args.subsection_length * 2}") | |
| else: | |
| if is_main_process: | |
| print(f"Padding enabled: short samples will be zero-padded to {args.subsection_length}") | |
| # Latent normalization parameters (per-channel) | |
| latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526]).view(1, -1, 1, 1) | |
| latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707]).view(1, -1, 1, 1) | |
| # Initialize tag embedder for converting metadata to tag indices | |
| num_classes = 2304 | |
| tag_embedder = TagEmbedder(num_classes=num_classes) | |
| # Custom collate function to randomly sample subsections from variable-width audio latents | |
| def collate_fn(batch): | |
| subsection_length = args.subsection_length | |
| pad_to_length = False | |
| sampled_latents = [] | |
| album_names = [] | |
| song_names = [] | |
| ids = [] | |
| tags = [] # List of tag lists for each sample | |
| for item in batch: | |
| latent = item['latents'] | |
| if len(latent.shape) == 3: # Add batch dimension if missing | |
| latent = latent.unsqueeze(0) | |
| # Get the width of the current latent | |
| _, _, _, width = latent.shape | |
| if width < subsection_length: | |
| if pad_to_length: | |
| # Pad the latent to subsection_length with zeros on the right | |
| pad_amount = subsection_length - width | |
| sampled_latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0) | |
| else: | |
| # Randomly sample a starting position | |
| max_start = width - subsection_length | |
| start_idx = torch.randint(0, max_start + 1, (1,)).item() | |
| # Extract the subsection | |
| sampled_latent = latent[:, :, :, start_idx:start_idx + subsection_length] | |
| sampled_latents.append(sampled_latent.squeeze(0)) # Remove batch dim for stacking | |
| album_name = item['album_name'] | |
| song_name = item['song_name'] | |
| album_names.append(album_name) | |
| song_names.append(song_name) | |
| sample_tags = tag_embedder.get_tags(album_name, song_name) | |
| tags.append(sample_tags) | |
| # Stack latents and normalize | |
| stacked_latents = torch.stack(sampled_latents) | |
| normalized_latents = (stacked_latents - latent_mean) / latent_std | |
| return { | |
| 'latents': normalized_latents, | |
| 'tags': tags | |
| } | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| drop_last=True, | |
| persistent_workers=True, | |
| num_workers=args.num_workers if torch.cuda.is_available() else 0, | |
| pin_memory=True, | |
| collate_fn=collate_fn | |
| ) | |
| channels = args.channels | |
| model = LocalSongModel( | |
| in_channels=channels, | |
| num_groups=args.n_heads, | |
| hidden_size=args.dim, | |
| decoder_hidden_size=args.decoder_dim, | |
| num_blocks=args.n_layers, | |
| patch_size=(16, 1), # Audio patch size (16 in height, 1 in width) | |
| num_classes=num_classes, # Number of tag classes | |
| max_tags=8, # Maximum number of tags per sample | |
| ) | |
| vae = AudioVAE(accelerator.device) | |
| rf = RF(model, time_sampling=args.time_sampling) | |
| optimizer = timm.optim.Muon(model.parameters(),lr=args.lr) | |
| scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.epochs * len(dataloader)) | |
| global_step = 0 | |
| if args.resume: | |
| global_step = resume(model, optimizer, scheduler, accelerator) | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| model.forward_emb = torch.compile(model.forward_emb) | |
| model, optimizer, scheduler, dataloader = accelerator.prepare( | |
| model, optimizer, scheduler, dataloader | |
| ) | |
| rf.model = model | |
| if is_main_process: | |
| model_size = sum(p.numel() for p in accelerator.unwrap_model(model).parameters() if p.requires_grad) | |
| print(f"Number of parameters: {model_size}, {model_size / 1e6}M") | |
| os.makedirs("audio_samples", exist_ok=True) | |
| num_samples = args.num_samples | |
| fixed_batch = None | |
| fixed_latents = None | |
| fixed_labels = None | |
| fixed_noise = None | |
| if is_main_process: | |
| data_iter = iter(dataloader) | |
| fixed_batch = next(data_iter) | |
| fixed_latents = fixed_batch["latents"][:num_samples] | |
| print("Fixed ids:", fixed_batch["album_names"]) | |
| # Get fixed tags for sampling | |
| fixed_tags = [] | |
| # Create reverse mapping from tag indices to strings | |
| idx_to_tag = {v: k for k, v in tag_embedder.tag_mapping.items()} | |
| # Print string labels for fixed tags | |
| print("Fixed tag labels:") | |
| for i, tag_list in enumerate(fixed_tags): | |
| labels = [idx_to_tag.get(idx, f"<unknown:{idx}>") for idx in tag_list] | |
| print(f" Sample {i}: {labels}") | |
| # Create noise with same shape as fixed latents | |
| B, C, H, W = fixed_latents.shape | |
| fixed_noise = torch.randn(num_samples, C, H, args.subsection_length, device=accelerator.device) | |
| fixed_latents = fixed_latents.to(accelerator.device) | |
| if is_main_process: | |
| print("Begin training") | |
| mse_loss_window = deque(maxlen=100) | |
| start_epoch = 0 | |
| for epoch in range(start_epoch, args.epochs): | |
| pbar = tqdm(dataloader) if is_main_process else dataloader | |
| for batch in pbar: | |
| x = batch["latents"] | |
| # Get tags from batch | |
| tags = batch["tags"] | |
| # Apply classifier-free guidance dropout (10% chance to drop all tags) | |
| dropout_tags = [] | |
| for tag_list in tags: | |
| if torch.rand(1).item() < 0.1: | |
| # Replace with empty list (will be padded to [0] in embed_condition) | |
| dropout_tags.append([]) | |
| else: | |
| dropout_tags.append(tag_list) | |
| # Tags will be embedded inside the model's forward pass | |
| c = dropout_tags | |
| with accelerator.accumulate(model): | |
| optimizer.zero_grad() | |
| mse_loss = rf.forward(x, c) | |
| loss = mse_loss | |
| accelerator.backward(loss) | |
| accelerator.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| if is_main_process: | |
| mse_loss_window.append(mse_loss.item()) | |
| avg_mse_loss = sum(mse_loss_window) / len(mse_loss_window) | |
| if isinstance(pbar, tqdm): | |
| pbar.set_postfix({"mse_loss": avg_mse_loss, "lr": optimizer.param_groups[0]['lr']}) | |
| if writer is not None: | |
| writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], global_step) | |
| writer.add_scalar('MSE_Loss', avg_mse_loss, global_step) | |
| global_step += 1 | |
| if is_main_process and global_step % args.save_every == 0: | |
| save(model, optimizer, scheduler, global_step, accelerator) | |
| if is_main_process and global_step % args.sample_every == 0: | |
| model.eval() | |
| with torch.no_grad(): | |
| # Use fixed tags for conditional sampling | |
| cond = fixed_tags | |
| # Unconditional is empty tags for all samples | |
| null_cond = [[] for _ in range(len(cond))] | |
| sampled_latents = rf.sample(fixed_noise, cond, null_cond)[-1] | |
| # Decode latents to audio | |
| try: | |
| sampled_audio = vae.decode(sampled_latents) | |
| # Save audio samples | |
| for i in range(min(8, sampled_audio.shape[0])): # Save first 2 samples | |
| save_audio_samples( | |
| sampled_audio[i:i+1], | |
| 48000, | |
| f"sample_{global_step}_generated_{i}.wav" | |
| ) | |
| # Also save original for comparison | |
| if global_step == args.sample_every: | |
| original_audio = vae.decode(fixed_latents) | |
| for i in range(min(8, original_audio.shape[0])): | |
| save_audio_samples( | |
| original_audio[i:i+1], | |
| 48000, | |
| f"sample_{global_step}_original_{i}.wav" | |
| ) | |
| except Exception as e: | |
| print(f"Error during audio generation: {e}") | |
| model.train() | |
| print("Saving final model") | |
| save(model, optimizer, scheduler, global_step, accelerator) | |
| if writer is not None: | |
| writer.close() | |
| if __name__ == '__main__': | |
| main() | |