from transformers import T5EncoderModel,T5TokenizerFast import torch from diffusers import FluxTransformer2DModel from torch import nn from typing import List from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.training_utils import compute_density_for_timestep_sampling import copy import torch.nn.functional as F import numpy as np from tqdm import tqdm from typing import Optional,Union,List from datasets import load_dataset, Audio from math import pi import inspect import yaml class StableAudioPositionalEmbedding(nn.Module): """Used for continuous time Adapted from stable audio open. """ def __init__(self, dim: int): super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) def forward(self, times: torch.Tensor) -> torch.Tensor: times = times[..., None] freqs = times * self.weights[None] * 2 * pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) fouriered = torch.cat((times, fouriered), dim=-1) return fouriered class DurationEmbedder(nn.Module): """ A simple linear projection model to map numbers to a latent space. Code is adapted from https://github.com/Stability-AI/stable-audio-tools Args: number_embedding_dim (`int`): Dimensionality of the number embeddings. min_value (`int`): The minimum value of the seconds number conditioning modules. max_value (`int`): The maximum value of the seconds number conditioning modules internal_dim (`int`): Dimensionality of the intermediate number hidden states. """ def __init__( self, number_embedding_dim, min_value, max_value, internal_dim: Optional[int] = 256, ): super().__init__() self.time_positional_embedding = nn.Sequential( StableAudioPositionalEmbedding(internal_dim), nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), ) self.number_embedding_dim = number_embedding_dim self.min_value = min_value self.max_value = max_value self.dtype = torch.float32 def forward( self, floats: torch.Tensor, ): floats = floats.clamp(self.min_value, self.max_value) normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) # Cast floats to same type as embedder embedder_dtype = next(self.time_positional_embedding.parameters()).dtype normalized_floats = normalized_floats.to(embedder_dtype) embedding = self.time_positional_embedding(normalized_floats) float_embeds = embedding.view(-1, 1, self.number_embedding_dim) return float_embeds def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class TangoFlux(nn.Module): def __init__(self,config,initialize_reference_model=False): super().__init__() self.num_layers = config.get('num_layers', 6) self.num_single_layers = config.get('num_single_layers', 18) self.in_channels = config.get('in_channels', 64) self.attention_head_dim = config.get('attention_head_dim', 128) self.joint_attention_dim = config.get('joint_attention_dim', 1024) self.num_attention_heads = config.get('num_attention_heads', 8) self.audio_seq_len = config.get('audio_seq_len', 645) self.max_duration = config.get('max_duration', 30) self.uncondition = config.get('uncondition', False) self.text_encoder_name = config.get('text_encoder_name', "google/flan-t5-large") self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler) self.max_text_seq_len = 64 self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name) self.tokenizer = T5TokenizerFast.from_pretrained(self.text_encoder_name) self.text_embedding_dim = self.text_encoder.config.d_model self.fc = nn.Sequential(nn.Linear(self.text_embedding_dim,self.joint_attention_dim),nn.ReLU()) self.duration_emebdder = DurationEmbedder(self.text_embedding_dim,min_value=0,max_value=self.max_duration) self.transformer = FluxTransformer2DModel( in_channels=self.in_channels, num_layers=self.num_layers, num_single_layers=self.num_single_layers, attention_head_dim=self.attention_head_dim, num_attention_heads=self.num_attention_heads, joint_attention_dim=self.joint_attention_dim, pooled_projection_dim=self.text_embedding_dim, guidance_embeds=False) self.beta_dpo = 2000 ## this is used for dpo training def get_sigmas(self,timesteps, n_dim=3, dtype=torch.float32): device = self.text_encoder.device sigmas = self.noise_scheduler_copy.sigmas.to(device=device, dtype=dtype) schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device) timesteps = timesteps.to(device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: sigma = sigma.unsqueeze(-1) return sigma def encode_text_classifier_free(self, prompt: List[str], num_samples_per_prompt=1): device = self.text_encoder.device batch = self.tokenizer( prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" ) input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) with torch.no_grad(): prompt_embeds = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask )[0] prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0) # get unconditional embeddings for classifier free guidance uncond_tokens = [""] max_length = prompt_embeds.shape[1] uncond_batch = self.tokenizer( uncond_tokens, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_batch.input_ids.to(device) uncond_attention_mask = uncond_batch.attention_mask.to(device) with torch.no_grad(): negative_prompt_embeds = self.text_encoder( input_ids=uncond_input_ids, attention_mask=uncond_attention_mask )[0] negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0) # For classifier free guidance, we need to do two forward passes. # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_mask = torch.cat([uncond_attention_mask, attention_mask]) boolean_prompt_mask = (prompt_mask == 1).to(device) return prompt_embeds, boolean_prompt_mask @torch.no_grad() def encode_text(self, prompt): device = self.text_encoder.device batch = self.tokenizer( prompt, max_length=self.max_text_seq_len, padding=True, truncation=True, return_tensors="pt") input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) encoder_hidden_states = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask)[0] boolean_encoder_mask = (attention_mask == 1).to(device) return encoder_hidden_states, boolean_encoder_mask def encode_duration(self,duration): return self.duration_emebdder(duration) @torch.no_grad() def inference_flow(self, prompt, num_inference_steps=50, timesteps=None, guidance_scale=3, duration=10, disable_progress=False, num_samples_per_prompt=1): '''Only tested for single inference. Haven't test for batch inference''' bsz = num_samples_per_prompt device = self.transformer.device scheduler = self.noise_scheduler if not isinstance(prompt,list): prompt = [prompt] if not isinstance(duration,torch.Tensor): duration = torch.tensor([duration],device=device) classifier_free_guidance = guidance_scale > 1.0 duration_hidden_states = self.encode_duration(duration) if classifier_free_guidance: bsz = 2 * num_samples_per_prompt encoder_hidden_states, boolean_encoder_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt=num_samples_per_prompt) duration_hidden_states = duration_hidden_states.repeat(bsz,1,1) else: encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt,num_samples_per_prompt=num_samples_per_prompt) mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(encoder_hidden_states) masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan'))) pooled = torch.nanmean(masked_data, dim=1) pooled_projection = self.fc(pooled) encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) timesteps, num_inference_steps = retrieve_timesteps( scheduler, num_inference_steps, device, timesteps, sigmas ) latents = torch.randn(num_samples_per_prompt,self.audio_seq_len,64) weight_dtype = latents.dtype progress_bar = tqdm(range(num_inference_steps), disable=disable_progress) txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device) audio_ids = torch.arange(self.audio_seq_len).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device) timesteps = timesteps.to(device) latents = latents.to(device) encoder_hidden_states = encoder_hidden_states.to(device) for i, t in enumerate(timesteps): latents_input = torch.cat([latents] * 2) if classifier_free_guidance else latents noise_pred = self.transformer( hidden_states=latents_input, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=torch.tensor([t/1000],device=device), guidance = None, pooled_projections=pooled_projection, encoder_hidden_states=encoder_hidden_states, txt_ids=txt_ids, img_ids=audio_ids, return_dict=False, )[0] if classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = scheduler.step(noise_pred, t, latents).prev_sample return latents def forward(self, latents, prompt, duration=torch.tensor([10]), sft=True ): device = latents.device audio_seq_length = self.audio_seq_len bsz = latents.shape[0] encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt) duration_hidden_states = self.encode_duration(duration) mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(encoder_hidden_states) masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan'))) pooled = torch.nanmean(masked_data, dim=1) pooled_projection = self.fc(pooled) ## Add duration hidden states to encoder hidden states encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim) txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device) audio_ids = torch.arange(audio_seq_length).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device) if sft: if self.uncondition: mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1] if len(mask_indices) > 0: encoder_hidden_states[mask_indices] = 0 noise = torch.randn_like(latents) u = compute_density_for_timestep_sampling( weighting_scheme='logit_normal', batch_size=bsz, logit_mean=0, logit_std=1, mode_scale=None, ) indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device) sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise model_pred = self.transformer( hidden_states=noisy_model_input, encoder_hidden_states=encoder_hidden_states, pooled_projections=pooled_projection, img_ids=audio_ids, txt_ids=txt_ids, guidance=None, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timesteps/1000, return_dict=False)[0] target = noise - latents loss = torch.mean( ( (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1, ) loss = loss.mean() raw_model_loss, raw_ref_loss,implicit_acc,epsilon_diff = 0,0,0,0 ## default this to 0 if doing sft else: encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1) pooled_projection = pooled_projection.repeat(2,1) noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1) ## Have to sample same noise for preferred and rejected u = compute_density_for_timestep_sampling( weighting_scheme='logit_normal', batch_size=bsz//2, logit_mean=0, logit_std=1, mode_scale=None, ) indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device) timesteps = timesteps.repeat(2) sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise model_pred = self.transformer( hidden_states=noisy_model_input, encoder_hidden_states=encoder_hidden_states, pooled_projections=pooled_projection, img_ids=audio_ids, txt_ids=txt_ids, guidance=None, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timesteps/1000, return_dict=False)[0] target = noise - latents model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none") model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape)))) model_losses_w, model_losses_l = model_losses.chunk(2) model_diff = model_losses_w - model_losses_l raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean()) with torch.no_grad(): ref_preds = self.ref_transformer( hidden_states=noisy_model_input, encoder_hidden_states=encoder_hidden_states, pooled_projections=pooled_projection, img_ids=audio_ids, txt_ids=txt_ids, guidance=None, timestep=timesteps/1000, return_dict=False)[0] ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none") ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape)))) ref_losses_w, ref_losses_l = ref_loss.chunk(2) ref_diff = ref_losses_w - ref_losses_l raw_ref_loss = ref_loss.mean() epsilon_diff = torch.max(torch.zeros_like(model_losses_w), ref_losses_w-model_losses_w).mean() scale_term = -0.5 * self.beta_dpo inside_term = scale_term * (model_diff - ref_diff) implicit_acc = (scale_term * (model_diff - ref_diff) > 0).sum().float() / inside_term.size(0) loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean() return loss, raw_model_loss, raw_ref_loss, implicit_acc,epsilon_diff