"""Module for modeling discrete diffusion (absorbing state or uniform) and AR (a special case of absorbing state). """ import itertools import math import typing from dataclasses import dataclass import hydra.utils import lightning as L import numpy as np import omegaconf import torch import torch.nn.functional as F import torchmetrics import transformers from mamba_ssm.utils.generation import InferenceParams from torch import Tensor from tqdm.auto import tqdm import pdb import gc import classifier import dataloader import models import noise_schedule LOG2 = math.log(2) def _sample_categorical(categorical_probs): gumbel_norm = ( 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()).to(categorical_probs.dtype) return (categorical_probs / gumbel_norm).argmax(dim=-1) def _unsqueeze(x, reference): return x.view( * x.shape, * ((1,) * (len(reference.shape) - len(x.shape)))) @dataclass class Loss: loss: torch.FloatTensor nlls: torch.FloatTensor token_mask: torch.FloatTensor recon_loss: typing.Optional[torch.FloatTensor] = None diffusion_loss: typing.Optional[torch.FloatTensor] = None class NLL(torchmetrics.aggregation.MeanMetric): pass class BPD(NLL): def compute(self) -> Tensor: """Computes the bits per dimension. Returns: bpd """ return self.mean_value / self.weight / LOG2 class Perplexity(NLL): def compute(self) -> Tensor: """Computes the Perplexity. Returns: Perplexity """ return torch.exp(self.mean_value / self.weight) class Diffusion(L.LightningModule): def __init__( self, config, tokenizer: transformers.PreTrainedTokenizer): super().__init__() self.save_hyperparameters() self.config = config self.tokenizer = tokenizer self.vocab_size = tokenizer.vocab_size self.antithetic_sampling = config.training.antithetic_sampling self.importance_sampling = config.training.importance_sampling self.change_of_variables = config.training.change_of_variables self.noise = noise_schedule.get_noise(config, dtype=self.dtype) if self.config.is_vision: self.mask_index = getattr(tokenizer, 'mask_token_id', -1) else: if (not hasattr(self.tokenizer, 'mask_token') or tokenizer.mask_token is None): self.mask_index = self.vocab_size self.vocab_size += 1 else: self.mask_index = tokenizer.mask_token_id # Note: creating limiting distribution with # broadcast-able batch and sequence dimensions. self.parameterization = config.parameterization self.diffusion = config.diffusion if config.parameterization == 'ar': self.limiting_distribution = None else: if self.diffusion == 'absorbing_state': # Not needed, posterior calculated explicitly. limiting_distribution = None elif self.diffusion == 'uniform': limiting_distribution = torch.ones( (1, 1, self.vocab_size), dtype=self.dtype) / self.vocab_size else: raise NotImplementedError( f"Diffusion type {self.diffusion} not implemented.") self.register_buffer('limiting_distribution', limiting_distribution) self.T = config.T self.subs_masking = config.subs_masking self.time_conditioning = config.time_conditioning if self.config.backbone == 'dit': self.backbone = models.dit.DIT( self.config, vocab_size=self.vocab_size) elif self.config.backbone == 'dimamba': self.backbone = models.dimamba.DiMamba( self.config, vocab_size=self.vocab_size, pad_token_id=self.tokenizer.pad_token_id) elif self.config.backbone == 'unet': self.backbone = models.unet.UNet( self.config, vocab_size=self.vocab_size) elif self.config.backbone == 'hf_dit': self.backbone = transformers.AutoModelForMaskedLM.from_pretrained( config.model.pretrained_model_name_or_path, trust_remote_code=True) else: raise NotImplementedError( f"Backbone {self.config.backbone} not implemented.") self.lr = self.config.optim.lr self.sampling_eps = config.training.sampling_eps self.softplus = torch.nn.Softplus() self.neg_infinity = -1_000_000.0 if config.training.ema > 0: self.ema = models.ema.ExponentialMovingAverage( itertools.chain(self.backbone.parameters(), self.noise.parameters()), decay=config.training.ema) else: self.ema = None # metrics are automatically reset at end of epoch metrics = torchmetrics.MetricCollection({ 'nll': NLL(), 'bpd': BPD(), 'ppl': Perplexity(), }) metrics.set_dtype(torch.float64) self.train_metrics = metrics.clone(prefix='train/') self.valid_metrics = metrics.clone(prefix='val/') self.test_metrics = metrics.clone(prefix='test/') self.fast_forward_epochs = None self.fast_forward_batches = None self._validate_configuration() def _validate_configuration(self): assert not (self.change_of_variables and self.importance_sampling) if self.diffusion != 'absorbing_state': assert self.parameterization not in {'ar', 'subs'} if self.T > 0: assert self.parameterization in {'d3pm', 'subs'} if self.subs_masking: assert self.parameterization == 'd3pm' def on_load_checkpoint(self, checkpoint): if self.limiting_distribution is not None: checkpoint['state_dict']['limiting_distribution'] = self.limiting_distribution.to( list(checkpoint['state_dict'].values())[0].device) if self.ema: self.ema.load_state_dict(checkpoint['ema']) # Copied from: # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41 self.fast_forward_epochs = checkpoint['loops'][ 'fit_loop']['epoch_progress']['current']['completed'] self.fast_forward_batches = checkpoint['loops'][ 'fit_loop']['epoch_loop.batch_progress'][ 'current']['completed'] def on_save_checkpoint(self, checkpoint): # Do not save this buffer checkpoint['state_dict'].pop('limiting_distribution', None) if self.ema: checkpoint['ema'] = self.ema.state_dict() # Copied from: # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py # ['epoch_loop.batch_progress']['total']['completed'] is # 1 iteration behind, so we're using the optimizer's # progress. checkpoint['loops']['fit_loop'][ 'epoch_loop.batch_progress']['total'][ 'completed'] = checkpoint['loops']['fit_loop'][ 'epoch_loop.automatic_optimization.optim_progress'][ 'optimizer']['step']['total'][ 'completed'] * self.trainer.accumulate_grad_batches checkpoint['loops']['fit_loop'][ 'epoch_loop.batch_progress']['current'][ 'completed'] = checkpoint['loops']['fit_loop'][ 'epoch_loop.automatic_optimization.optim_progress'][ 'optimizer']['step']['current'][ 'completed'] * self.trainer.accumulate_grad_batches # _batches_that_stepped tracks the number of global # steps, not the number of local steps, so we don't # multiply with self.trainer.accumulate_grad_batches # here. checkpoint['loops']['fit_loop'][ 'epoch_loop.state_dict'][ '_batches_that_stepped'] = checkpoint['loops']['fit_loop'][ 'epoch_loop.automatic_optimization.optim_progress'][ 'optimizer']['step']['total']['completed'] if 'sampler' not in checkpoint.keys(): checkpoint['sampler'] = {} if hasattr(self.trainer.train_dataloader.sampler, 'state_dict'): sampler_state_dict = self.trainer.\ train_dataloader.sampler.state_dict() checkpoint['sampler'][ 'random_state'] = sampler_state_dict.get( 'random_state', None) else: checkpoint['sampler']['random_state'] = None def on_train_start(self): if self.ema: self.ema.move_shadow_params_to_device(self.device) # Adapted from: # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py distributed = ( self.trainer._accelerator_connector.use_distributed_sampler and self.trainer._accelerator_connector.is_distributed) if distributed: sampler_cls = dataloader.FaultTolerantDistributedSampler else: sampler_cls = dataloader.RandomFaultTolerantSampler updated_dls = [] for dl in self.trainer.fit_loop._combined_loader.flattened: if hasattr(dl.sampler, 'shuffle'): dl_sampler = sampler_cls( dl.dataset, shuffle=dl.sampler.shuffle) else: dl_sampler = sampler_cls(dl.dataset) if (distributed and self.fast_forward_epochs is not None and self.fast_forward_batches is not None): dl_sampler.load_state_dict({ 'epoch': self.fast_forward_epochs, 'counter': (self.fast_forward_batches * self.config.loader.batch_size)}) from functools import partial from dataloader import collate_fn collate_partial = partial(collate_fn) torch.cuda.empty_cache() updated_dls.append( torch.utils.data.DataLoader( dl.dataset, # batch_size=self.config.loader.batch_size, num_workers=self.config.loader.num_workers, pin_memory=self.config.loader.pin_memory, # sampler=dl_sampler, shuffle=False, persistent_workers=self.config.loader.persistent_workers, collate_fn=collate_partial )) self.trainer.fit_loop._combined_loader.flattened = updated_dls def configure_optimizers(self): # TODO(yair): Lightning currently giving this warning when using `fp16`: # "Detected call of `lr_scheduler.step()` before `optimizer.step()`. " # Not clear if this is a problem or not. # See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558 optimizer = torch.optim.AdamW( itertools.chain(self.backbone.parameters(), self.noise.parameters()), lr=self.config.optim.lr, betas=(self.config.optim.beta1, self.config.optim.beta2), eps=self.config.optim.eps, weight_decay=self.config.optim.weight_decay) scheduler = hydra.utils.instantiate( self.config.lr_scheduler, optimizer=optimizer) scheduler_dict = { 'scheduler': scheduler, 'interval': 'step', 'monitor': 'val/loss', 'name': 'trainer/lr', } return [optimizer], [scheduler_dict] def optimizer_step(self, *args, **kwargs): super().optimizer_step(*args, **kwargs) if self.ema: self.ema.update(itertools.chain( self.backbone.parameters(), self.noise.parameters())) def _subs_parameterization(self, logits, xt): # "Zero Masking Prob": # log prob at the mask index = - infinity logits[..., self.mask_index] += self.neg_infinity # "Copy over": # Apply updates directly in the logits matrix. # For the logits of the unmasked tokens, set all values # to -infinity except for the indices corresponding to # the unmasked tokens. unmasked_indices = (xt != self.mask_index) logits[unmasked_indices] = self.neg_infinity logits[unmasked_indices, xt[unmasked_indices]] = 0 # Normalize the logits such that x.exp() is # a probability distribution over vocab_size. return logits.log_softmax(dim=-1) def _process_sigma(self, sigma): if sigma is None: assert self.parameterization == 'ar' return sigma if sigma.ndim > 1: sigma = sigma.squeeze(-1) if not self.time_conditioning: sigma = torch.zeros_like(sigma) assert sigma.ndim == 1, sigma.shape return sigma def forward(self, x, sigma, cond=None, x_emb=None, **kwargs): """Returns log_probs / logits.""" sigma = self._process_sigma(sigma) with torch.cuda.amp.autocast(dtype=torch.float32): logits = self.backbone(x, sigma, cond, x_emb=x_emb, **kwargs) if self.parameterization == 'subs': # returns log_probs return self._subs_parameterization( logits=logits, xt=x) if self.parameterization in {'ar', 'd3pm'}: # returns log_probs if self.subs_masking: # Can use "zero masking prob" logits[:, :, self.mask_index] += self.neg_infinity return logits.log_softmax(dim=-1) return logits def _compute_posterior(self, x, xt, alpha_s, alpha_t): """Computes the posterior / approximate posterior. Args: x: Either clean input `x0` (one-hot), or model's predicted `x_theta` of shape (B, L, V). xt: The noisy latent (as indices) of shape (B, L). alpha_s: Noise level at s of shape (B, [L | 1], 1). alpha_t: Noise level at t of shape (B, [L | 1], 1). Returns: Posterior / approximate posterior of shape (B, L, V). """ alpha_ts = alpha_t / alpha_s d_alpha = alpha_s - alpha_t xt_one_hot = F.one_hot(xt, self.vocab_size) if self.diffusion == 'uniform': return ( (alpha_t * self.vocab_size * x * xt_one_hot + (alpha_ts - alpha_t) * xt_one_hot + d_alpha * x + (1 - alpha_ts) * (1 - alpha_s) * self.limiting_distribution) / (alpha_t * self.vocab_size * torch.gather(x, -1, xt[..., None]) + (1 - alpha_t)) ) raise NotImplementedError( f"Diffusion type {self.diffusion} not implemented.") def _d3pm_loss(self, model_output, xt, x0, t): assert self.config.noise.type == 'loglinear', ( 'D3PM loss only implemented for log-linear noise.') dt = 1 / self.T if torch.is_tensor(t): t = t[:, None] assert t.ndim == 2 t = t.clamp(0., 1. - 1e-4) alpha_t = 1 - t + torch.zeros_like(xt) alpha_s = 1 - (t - dt) + torch.zeros_like(xt) if self.diffusion == 'absorbing_state': log_x_theta_at_x0 = torch.gather( model_output, -1, x0[:, :, None]).squeeze(-1) log_x_theta_at_m = model_output[:, :, self.mask_index] x_theta_at_m = log_x_theta_at_m.exp() term_1_coef = dt / t term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1) term_1_log_dr = log_x_theta_at_x0 term_2_coef = 1 - dt / t term_2_log_nr = term_1_log_nr term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1) L_vb_masked = ( term_1_coef * (term_1_log_nr - term_1_log_dr) + term_2_coef * (term_2_log_nr - term_2_log_dr)) L_vb = L_vb_masked * (xt == self.mask_index) elif self.diffusion == 'uniform': posterior = self._compute_posterior( x=F.one_hot(x0, num_classes=self.vocab_size).to(self.dtype), xt=xt, alpha_s=alpha_s[..., None], alpha_t=alpha_t[..., None]) posterior_pred = self._compute_posterior( x=model_output.exp(), xt=xt, alpha_s=alpha_s[..., None], alpha_t=alpha_t[..., None]) L_vb = ( posterior * (torch.log(posterior + 1e-12) - torch.log(posterior_pred)) ).sum(dim=-1) else: raise NotImplementedError( f"Diffusion type {self.diffusion} not implemented for D3PM.") return self.T * L_vb def _reconstruction_loss(self, x0, cond=None): # For D3PM parameterization assert self.config.noise.type == 'loglinear', ( 'Reconstruction loss only implemented for log-linear ' 'noise.') t0 = torch.zeros(x0.shape[0], dtype=self.dtype, device=self.device) time_conditioning = self.noise(t0)[0][:, None] model_output_t0 = self.forward(x0, time_conditioning, cond=cond) return - torch.gather(input=model_output_t0, dim=-1, index=x0[:, :, None]).squeeze(-1) def _sample_t(self, n): _eps_t = torch.rand(n, device=self.device) if self.antithetic_sampling: offset = torch.arange(n, device=self.device) / n _eps_t = (_eps_t / n + offset) % 1 t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps if self.importance_sampling: return self.noise.importance_sampling_transformation( t) return t def _q_xt(self, x, move_chance): """Computes the noisy sample xt. Args: x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input. move_chance: float torch.Tensor with shape (batch_size, 1). """ move_indices = torch.rand( *x.shape, device=x.device) < move_chance if self.diffusion == 'absorbing_state': return torch.where(move_indices, self.mask_index, x) if self.diffusion == 'uniform': uniform_tensor = torch.randint( 0, self.vocab_size, x.shape, device=x.device) return torch.where(move_indices, uniform_tensor, x) elif self.diffusion == 'uniform_data_marginals': return torch.where( move_indices, self._sample_prior(*x.shape), x) raise NotImplementedError( f"Diffusion type {self.diffusion} not implemented.") def _forward_pass_diffusion(self, x0, cond=None): t = self._sample_t(x0.shape[0]) if self.T > 0: t = (t * self.T).to(torch.int) t = t / self.T # t \in {1/T, 2/T, ..., 1} t += (1 / self.T) if self.change_of_variables: time_conditioning = t[:, None] f_T = torch.log1p(- torch.exp(- self.noise.sigma_max)) f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min)) move_chance = torch.exp(f_0 + t * (f_T - f_0)) move_chance = move_chance[:, None] sigma, dsigma = None, None else: sigma, dsigma = self.noise(t) time_conditioning = sigma[:, None] move_chance = 1 - torch.exp(-sigma[:, None]) xt = self._q_xt(x0, move_chance) model_output = self.forward(xt, time_conditioning, cond=cond) # Discrete (finite T) time if self.T > 0: diffusion_loss = self._d3pm_loss( model_output=model_output, xt=xt, x0=x0, t=t) if self.parameterization == 'd3pm': reconstruction_loss = self._reconstruction_loss( x0, cond=cond) if self.training and self.config.training.use_simple_ce_loss: loss = -torch.gather( input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1) else: loss = reconstruction_loss + diffusion_loss return { 'recon_loss': reconstruction_loss, 'diffusion_loss': diffusion_loss, 'loss': loss} elif self.parameterization == 'subs': if self.training and self.config.training.use_simple_ce_loss: loss = -torch.gather( input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1) else: loss = diffusion_loss return {'diffusion_loss': diffusion_loss, 'loss': loss} else: raise ValueError( f"Invalid parameterization: {self.parameterization} for T > 0.") # Continuous (T --> infty) time if self.diffusion == 'absorbing_state': # SUBS parameterization, continuous time. log_p_theta = torch.gather( input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1) if self.change_of_variables or self.importance_sampling: if self.training and self.config.training.use_simple_ce_loss: return { 'diffusion_loss': log_p_theta * torch.log1p(-torch.exp(- self.noise.sigma_min)), 'loss': -log_p_theta } return log_p_theta * torch.log1p(-torch.exp(- self.noise.sigma_min)) if self.training and self.config.training.use_simple_ce_loss: return { 'diffusion_loss': log_p_theta * (dsigma / torch.expm1(sigma))[:, None], 'loss': log_p_theta } return - log_p_theta * (dsigma / torch.expm1(sigma))[:, None] elif self.diffusion == 'uniform': assert self.config.noise.type == 'loglinear', ( 'Continuous time uniform diffusion only implemented' ' for log-linear noise.') # TODO: Currently α_t' and α_t are hardcoded to a # log-linear noise. # Make generic (as above, for absorbing state): # alpha_t_prime = -dsigma * (-sigma).exp() # alpha_t = (-sigma).exp() alpha_t_prime = -1. alpha_t = 1. - t[..., None, None] # B, 1, 1 # x_bar = N * α_t * x + 1 - α_t ; B, L, V x_bar = self.vocab_size * alpha_t * F.one_hot(x0, self.vocab_size).float() + 1 - alpha_t x_bar_theta = self.vocab_size * alpha_t * model_output.exp() + 1 - alpha_t # α_t' / (N*α_t) coeff = alpha_t_prime / (self.vocab_size * alpha_t) # B, 1, 1 # Term 1: indices where z_t = 1 x_bar_zt = torch.gather(x_bar, -1, xt[..., None]) # B, L, 1 x_bar_theta_zt = torch.gather(x_bar_theta, -1, xt[..., None]) # B, L, 1 term1 = ((self.vocab_size / x_bar_zt) - (self.vocab_size / x_bar_theta_zt)) # B, L, 1 # Term 2: indices where z_t = 0 term2 = ( # B, L, V before summing --> B, L, 1 after (x_bar / x_bar_zt) * ( x_bar_theta_zt.log() - x_bar_theta.log() + x_bar.log() - x_bar_zt.log() ) ) term2 = term2.sum(dim=-1, keepdim=True) # B, L, 1 diffusion_loss = (coeff * (term1 - term2)).squeeze() # B, L reconstruction_loss = self._reconstruction_loss( x0, cond=cond) if self.training and self.config.training.use_simple_ce_loss: return { 'recon_loss': reconstruction_loss, 'diffusion_loss': diffusion_loss, 'loss': -torch.gather( input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1) } return { 'recon_loss': reconstruction_loss, 'diffusion_loss': diffusion_loss, 'loss': diffusion_loss if getattr(self.config, 'zero_recon_loss', False) else diffusion_loss + reconstruction_loss } else: raise NotImplementedError( f"Diffusion type {self.diffusion} not " "implemented for continuous time case.") def _maybe_sub_sample(self, x0, attention_mask): seqlen = x0.shape[1] # if seqlen > self.config.model.length: # assert seqlen == 2 * self.config.model.length # # cropping is necessary for the text8-crop dataset; # # try the same starting point for now # start = np.random.choice(self.config.model.length) # end = start + self.config.model.length # input_tokens = x0[:, start: end] # output_tokens = x0[:, start + 1: end + 1] # new_attention_mask = attention_mask[:, start: end] # # Helps with validation PPL, since the val # # examples will all start and end with BOS/EOS # input_tokens[:, 0] = self.tokenizer.bos_token_id # output_tokens[:, -1] = self.tokenizer.eos_token_id # elif self.parameterization == 'ar': # input_tokens = x0[:, :-1] # output_tokens = x0[:, 1:] # new_attention_mask = attention_mask[:, 1:] # else: # input_tokens = x0 # output_tokens = None # new_attention_mask = attention_mask input_tokens = x0 output_tokens = None new_attention_mask = attention_mask return input_tokens, output_tokens, new_attention_mask def _loss(self, x0, attention_mask, cond=None): (input_tokens, output_tokens, attention_mask) = self._maybe_sub_sample( x0, attention_mask) recon_loss, diffusion_loss = None, None if (cond is not None and self.training and self.config.training.guidance is not None and self.config.training.guidance.cond_dropout > 0): # Randomly mask out conditioning for classifier-free # guidance training. p = torch.bernoulli( torch.ones_like(cond) * self.config.training.guidance.cond_dropout).to(torch.bool) # Use num_classes index as conditioning mask_token_id cond[p] = self.config.data.num_classes if self.parameterization == 'ar': logprobs = self.forward( input_tokens, sigma=None, cond=cond) loss = - logprobs.gather( -1, output_tokens[:, :, None])[:, :, 0] else: loss = self._forward_pass_diffusion(input_tokens, cond=cond) if isinstance(loss, dict): recon_loss = loss['recon_loss'] diffusion_loss = loss['diffusion_loss'] loss = loss['loss'] nlls = loss * attention_mask count = attention_mask.sum() if (self.config.training.compute_loss_on_pad_tokens and self.training): token_nll = loss.mean() else: batch_nll = nlls.sum() token_nll = batch_nll / count if recon_loss is not None and diffusion_loss is not None: with torch.no_grad(): recon_loss_batch = (recon_loss * attention_mask).sum() / count diffusion_loss_batch = (diffusion_loss * attention_mask).sum() / count return Loss(loss=token_nll, nlls=nlls, token_mask=attention_mask, recon_loss=recon_loss_batch, diffusion_loss=diffusion_loss_batch) return Loss(loss=token_nll, nlls=nlls, token_mask=attention_mask) def _compute_loss(self, batch, prefix): if 'attention_mask' in batch: attention_mask = batch['attention_mask'] else: attention_mask = None cond = None if (self.config.training.guidance is not None or # Training for / using CFG (hasattr(self.config, 'guidance') and self.config.guidance is not None and self.config.guidance.method == 'cfg')): if self.config.data.label_col in batch: cond = batch[self.config.data.label_col] elif f"{self.config.data.label_col}_threshold" in batch: cond = batch[f"{self.config.data.label_col}_threshold"] else: raise RuntimeError( f"Conditioning {self.config.data.label_col}" f" not found in batch.") losses = self._loss(batch['input_ids'], attention_mask, cond=cond) if prefix == 'train': self.train_metrics.update(losses.nlls, losses.token_mask) metrics = self.train_metrics elif prefix == 'val': self.valid_metrics.update(losses.nlls, losses.token_mask) metrics = self.valid_metrics elif prefix == 'test': self.test_metrics.update(losses.nlls, losses.token_mask) metrics = self.test_metrics else: raise ValueError(f"Invalid prefix: {prefix}") self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True) return losses def training_step(self, batch, batch_idx): losses = self._compute_loss(batch, prefix='train') self.log(name='trainer/loss', value=losses.loss.item(), on_step=True, on_epoch=True, sync_dist=True, prog_bar=True) if losses.recon_loss is not None: self.log(name='trainer/recon_loss', value=losses.recon_loss.item(), on_step=True, on_epoch=True, sync_dist=True, prog_bar=False) self.log(name='trainer/diffusion_loss', value=losses.diffusion_loss.item(), on_step=True, on_epoch=True, sync_dist=True, prog_bar=False) self.log(name='lr', value=self.trainer.optimizers[0].param_groups[0]['lr'], on_step=True, on_epoch=False, sync_dist=True, prog_bar=True, logger=False) return losses.loss def validation_step(self, batch, batch_idx): losses = self._compute_loss(batch, prefix='val') self.log(name='trainer/val_loss', value=losses.loss.item(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) return losses.loss def load_ema_params(self): if self.ema: self.ema.store(itertools.chain( self.backbone.parameters(), self.noise.parameters())) self.ema.copy_to(itertools.chain( self.backbone.parameters(), self.noise.parameters())) def _restore_non_ema_params(self): if self.ema: self.ema.restore(itertools.chain( self.backbone.parameters(), self.noise.parameters())) def on_validation_epoch_start(self): # pdb.set_trace() gc.collect() torch.cuda.empty_cache() self.load_ema_params() assert self.valid_metrics.nll.mean_value == 0 assert self.valid_metrics.nll.weight == 0 def on_validation_epoch_end(self): # pdb.set_trace() # self._restore_non_ema_params() # if (not self.trainer.sanity_checking # and self.config.eval.generate_samples # and self.trainer.global_rank == 0): # self.config.sampling.batch_size = 1 # if self.config.is_vision: # samples = [] # if self.config.training.guidance is not None: # # Generate one image per class (up to 10 images) # guidance = { # 'method': 'cfg', 'condition': 0, 'gamma': 1.0} # omegaconf.OmegaConf.update( # self.config, key='guidance', value=guidance, # force_add=True) # for i in range(max(self.config.data.num_classes, 10)): # self.config.guidance.condition = i # samples.append(self.sample()) # else: # # Generate ten images # for i in range(10): # samples.append(self.sample()) # image_samples = self.tokenizer.batch_decode( # torch.concat(samples, dim=0)) # if hasattr(self.trainer.logger, 'log_image'): # self.trainer.logger.log_image( # key=f"samples@global_step{self.global_step}", # caption=[str(i) for i in range(len(samples))], # images=[s for s in image_samples.float()]) # else: # if self.config.training.guidance is not None: # guidance = { # 'method': 'cfg', 'condition': 0, 'gamma': 1.0} # omegaconf.OmegaConf.update( # self.config, key='guidance', value=guidance, # force_add=True) # for i in range(self.config.data.num_classes): # self.config.guidance.condition = i # samples = self.sample() # decoded_samples = self.tokenizer.batch_decode( # samples) # if hasattr(self.trainer.logger, 'log_table'): # # Log some generated samples # self.trainer.logger.log_table( # key=f"samples@global_step{self.global_step}_class-{i}", # columns=['Generated Samples'], # data=[decoded_samples]) # else: # self.config.sampling.batch_size = 2 # samples = self.sample() # decoded_samples = self.tokenizer.batch_decode( # samples) # if hasattr(self.trainer.logger, 'log_table'): # # Log some generated samples # self.trainer.logger.log_table( # key=f"samples@global_step{self.global_step}", # columns=['Generated Samples'], # data=[[s] for s in decoded_samples]) gc.collect() torch.cuda.empty_cache() self._restore_non_ema_params() def _sample_prior(self, *batch_dims): if self.diffusion == 'absorbing_state': return self.mask_index * torch.ones( *batch_dims, dtype=torch.int64, device=self.device) if self.diffusion == 'uniform': return torch.randint( 0, self.vocab_size, batch_dims, dtype=torch.int64, device=self.device) elif self.diffusion == 'uniform_data_marginals': if self.limiting_distribution.squeeze().ndim == 2: batch_dims = (batch_dims[0],) return torch.distributions.Categorical( self.limiting_distribution.squeeze()).sample( sample_shape=torch.Size(batch_dims)) raise NotImplementedError( f'Diffusion type {self.diffusion} not ' 'implemented.') def sample( self, eps=1e-5, wt_embed: torch.tensor = None, mut_embed: torch.tensor = None, classifier_model = None): # Note: differs from self.config.training.sampling_eps """Generate samples from (ema) model. Supports both AR and diffusion sampling. Supports: - standard decoding, - classifier-free guidance, - classifier-based guidance - CBG / FUDGE, - NOS / PPLM. """ # WARNING: Lightning auto-casting is not working in this method. if not self.config.eval.disable_ema: self.load_ema_params() if getattr(self.config, 'guidance', None) is not None: if self.config.guidance.method == 'cfg': cond = (torch.ones(self.config.sampling.batch_size, device=self.device) * self.config.guidance.condition).to(torch.long) else: cond = None if ((self.parameterization == 'ar' and self.config.guidance.method in {'fudge', 'pplm'}) or self.config.guidance.method in {'cbg', 'nos'}): if classifier_model is None: classifier_model = classifier.Classifier.load_from_checkpoint( self.config.guidance.classifier_checkpoint_path, tokenizer=self.tokenizer, config=self.config, logger=False) classifier_model = classifier_model.to(self.device) classifier_model.eval() else: classifier_model = None else: classifier_model, cond = None, None if self.parameterization == 'ar': samples = self._ar_sample( classifier_model=classifier_model, cond=cond) else: # Diffusion sampling samples = self._diffusion_sample( classifier_model=classifier_model, cond=cond, eps=eps, wt_embed = wt_embed.to(self.device), mut_embed = mut_embed.to(self.device)) if not self.config.eval.disable_ema: self._restore_non_ema_params() return samples @torch.no_grad() def _ar_sample( self, classifier_model: typing.Optional[classifier.Classifier] = None, cond: typing.Optional[torch.tensor] = None, ): # precompute token buffer num_pred_tokens = self.config.model.length - 1 x = torch.zeros( (self.config.sampling.batch_size, num_pred_tokens + 1), dtype=torch.long, device=self.device) x[:, 0] = self.tokenizer.bos_token_id # precompute Gumbel sampling noise if (getattr(self.config, 'guidance', None) is not None and self.config.guidance.method == 'fudge'): noise = torch.distributions.Gumbel(0, 1).sample( (self.config.sampling.batch_size, # type: ignore num_pred_tokens, self.config.guidance.topk)).to(self.device) else: noise = torch.distributions.Gumbel(0, 1).sample( (self.config.sampling.batch_size, # type: ignore num_pred_tokens, self.vocab_size)).to(self.device) if self.config.sampling.use_float64: noise = noise.to(torch.float64) pbar = tqdm(range(num_pred_tokens), desc='AR Sampling', leave=False) inference_params = InferenceParams( max_seqlen=num_pred_tokens, max_batch_size=x.shape[0], seqlen_offset=1) # For cfg we do 2 forward passes, one for conditional # model and one unconditional, so we need 2 copies of # inference_params. uncond_inference_params = InferenceParams( max_seqlen=num_pred_tokens, max_batch_size=x.shape[0], seqlen_offset=1) for i in pbar: if getattr(self.config, 'guidance', None) is None: if self.config.backbone == 'dimamba': log_probs = self.forward( x[:, i:i + 1], None, cond=None, inference_params=inference_params) else: log_probs = self.forward(x[:, :i + 1], None, cond=None) if self.config.sampling.use_float64: log_probs = log_probs.to(torch.float64) next_log_probs = log_probs[:, -1] y = (next_log_probs + noise[:, i]).argmax(-1) else: if self.config.guidance.method == 'cfg': if self.config.backbone == 'dimamba': next_log_probs = self._ar_cfg_denoise( cond=cond, gamma=self.config.guidance.gamma, x=x[:, i:i + 1], i=i, inference_params=(inference_params, uncond_inference_params)) else: next_log_probs = self._ar_cfg_denoise( cond=cond, gamma=self.config.guidance.gamma, x=x, i=i) y = (next_log_probs + noise[:, i]).argmax(-1) elif self.config.guidance.method == 'fudge': if self.config.backbone == 'dimamba': next_log_probs, top_indices = self._ar_fudge_denoise( classifier_model=classifier_model, guidance_cond=self.config.guidance.condition, topk=self.config.guidance.topk, gamma=self.config.guidance.gamma, x=x[:, i:i + 1], i=i, inference_params=inference_params) else: next_log_probs, top_indices = self._ar_fudge_denoise( classifier_model=classifier_model, guidance_cond=self.config.guidance.condition, topk=self.config.guidance.topk, gamma=self.config.guidance.gamma, x=x, i=i) y = torch.gather( top_indices, 1, (next_log_probs + noise[:, i]).argmax(-1).unsqueeze(1) ).squeeze(1) elif self.config.guidance.method == 'pplm': raise NotImplementedError else: raise NotImplementedError( f"Guidance method {self.config.guidance.method} not implemented.") pbar.set_postfix( prob_check=(next_log_probs.exp().sum() / x.shape[0]).item(), nan_check=bool(next_log_probs.isnan().sum() > 0)) x[:, i + 1] = y return x def _ar_cfg_denoise( self, cond: torch.tensor, gamma: float, x: torch.tensor, i: int, **kwargs ) -> torch.tensor: if self.config.guidance.gamma == 0.0: # Sample unconditionally mask_cond = (torch.ones_like(cond) * self.config.data.num_classes) if self.config.backbone == 'dimamba': inference_params = kwargs.pop('inference_params') log_probs = self.forward( x[:, :i + 1],None, cond=mask_cond, inference_params=inference_params[1]) else: log_probs = self.forward( x[:, :i + 1],None, cond=mask_cond, **kwargs) elif gamma == 1.0: # Sample conditionally if self.config.backbone == 'dimamba': inference_params = kwargs.pop('inference_params') log_probs = self.forward( x[:, :i + 1], None, cond=cond, inference_params=inference_params[0]) else: log_probs = self.forward( x[:, :i + 1], None, cond=cond, **kwargs) else: # Sample from tempered distribution mask_cond = (torch.ones_like(cond) * self.config.data.num_classes) if self.config.backbone == 'dimamba': inference_params = kwargs.pop('inference_params') log_probs_cond = self.forward( x[:, :i + 1], None, cond=cond, inference_params=inference_params[0]) log_probs_uncond = self.forward( x[:, :i + 1],None, cond=mask_cond, inference_params=inference_params[1]) else: log_probs_cond = self.forward( x[:, :i + 1], None, cond=cond, **kwargs) log_probs_uncond = self.forward( x[:, :i + 1],None, cond=mask_cond, **kwargs) log_probs = gamma * log_probs_cond + (1 - gamma) * log_probs_uncond # Gamma > 1.0 causes instability for Mamba, re-normalizing log_probs = log_probs.log_softmax(dim=-1) return log_probs[:, -1] def _ar_fudge_denoise( self, classifier_model: classifier.Classifier, guidance_cond: int, topk: int, gamma: float, x: torch.tensor, i: int, **kwargs ) -> typing.Tuple[torch.tensor, torch.LongTensor]: log_probs = self.forward( x[:, :i + 1], None, cond=None, **kwargs) next_log_probs = log_probs[:, -1] top_logits, top_indices = next_log_probs.topk(topk, dim=-1) t_candidates = torch.cat( [x[:, :i + 1].unsqueeze(1).expand(-1, topk, -1), top_indices.unsqueeze(2)], dim=2).view(-1, i + 2) # (B * K), L t = torch.zeros(t_candidates.shape[0], device=self.device) sigma, dsigma = self.noise(t) time_conditioning = sigma[:, None] classifier_log_prob = classifier_model.get_log_probs( t_candidates, time_conditioning) classifier_log_prob = classifier_log_prob[:, i + 1, :].view( x.shape[0], topk, -1)[..., guidance_cond] # (batch, topk) next_log_probs = (top_logits + gamma * classifier_log_prob).log_softmax(dim=-1) return next_log_probs, top_indices def _ar_pplm_denoise( self, classifier_model: classifier.Classifier, guidance_cond: int, num_ppl_steps: int, pplm_step_size: float, pplm_stability_coef: float, x: torch.tensor, i: int, ): raise NotImplementedError @torch.no_grad() def _diffusion_sample( self, classifier_model: typing.Optional[classifier.Classifier] = None, cond: typing.Optional[torch.tensor] = None, eps: float = 1e-5, # Note: differs from self.config.training.sampling_eps wt_embed: torch.tensor = None, mut_embed: torch.tensor = None, ): xt = self._sample_prior( self.config.sampling.batch_size, self.config.model.length ).to(self.device) timesteps = torch.linspace( 1, eps, self.config.sampling.steps + 1, device=self.device) dt = (1 - eps) / self.config.sampling.steps pbar = tqdm(range(self.config.sampling.steps), desc='Sampling', leave=False) NFEs = 0 cache = None for i in pbar: t = timesteps[i] if self.T > 0: # t in {1/T,..., 1}, to match training t = (t * self.T).to(torch.int) t = t / self.T t += (1 / self.T) t = t * torch.ones(xt.shape[0], 1, device=self.device) if cache is None: NFEs += 1 sigma_t, _ = self.noise(t) sigma_s, _ = self.noise(t - dt) if sigma_t.ndim > 1: sigma_t = sigma_t.squeeze(-1) if sigma_s.ndim > 1: sigma_s = sigma_s.squeeze(-1) assert sigma_t.ndim == 1, sigma_t.shape assert sigma_s.ndim == 1, sigma_s.shape move_chance_t = 1 - torch.exp(-sigma_t) move_chance_s = 1 - torch.exp(-sigma_s) move_chance_t = move_chance_t[:, None, None] move_chance_s = move_chance_s[:, None, None] assert move_chance_t.ndim == 3, move_chance_t.shape if getattr(self.config, 'guidance', None) is None: xs, q_xs, cache = self._ddpm_denoise( xt=xt, time_conditioning=sigma_t, move_chance_t=move_chance_t, move_chance_s=move_chance_s, cache=cache) else: if self.config.guidance.method == 'cfg': xs, q_xs, cache = self._cfg_denoise( cond=cond, gamma=self.config.guidance.gamma, xt=xt, time_conditioning=sigma_t, move_chance_t=move_chance_t, move_chance_s=move_chance_s, cache=cache) elif self.config.guidance.method == 'cbg': xs, q_xs, cache = self._cbg_denoise( classifier_model=classifier_model, conditioning_class=self.config.guidance.condition, gamma=self.config.guidance.gamma, use_approx=self.config.guidance.use_approx, xt=xt, time_conditioning=sigma_t, move_chance_t=move_chance_t, move_chance_s=move_chance_s, wt_embed=wt_embed, mut_embed=mut_embed, cache=cache) elif self.config.guidance.method == 'nos': xs, q_xs, cache = self._nos_denoise( classifier_model=classifier_model, conditioning_class=self.config.guidance.condition, num_nos_steps=self.config.guidance.num_nos_steps, nos_step_size=self.config.guidance.nos_step_size, nos_stability_coef=self.config.guidance.nos_stability_coef, xt=xt, time_conditioning=sigma_t, move_chance_t=move_chance_t, move_chance_s=move_chance_s) else: raise NotImplementedError( f"Guidance method {self.config.guidance.method} not implemented.") pbar.set_postfix( NFEs=NFEs, prob_check=(q_xs.sum() / xt.numel()).item(), nan_check=bool(q_xs.isnan().sum() > 0)) if (not self.config.sampling.use_cache or not torch.allclose(xs, xt)): # Disable caching cache = None xt = xs return xt def _ddpm_denoise( self, xt: torch.tensor, time_conditioning: torch.tensor, move_chance_t: torch.tensor, move_chance_s: torch.tensor, cache: typing.Optional[typing.Dict[str, torch.Tensor]] = None, ) -> typing.Tuple[torch.tensor, torch.tensor, typing.Dict[str, torch.tensor]]: # Compute x_theta if cache is not None: log_x_theta = cache['log_x_theta'] else: log_x_theta = self.forward(xt, time_conditioning, cond=None) if self.config.sampling.use_float64: log_x_theta = log_x_theta.to(torch.float64) x_theta = log_x_theta.exp() # Compute posterior if self.diffusion == 'absorbing_state': q_xs = x_theta * (move_chance_t - move_chance_s) q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0] q_xs /= move_chance_t elif self.diffusion == 'uniform': q_xs = self._compute_posterior( x=x_theta, xt=xt, alpha_s=1 - move_chance_s, alpha_t=1 - move_chance_t) else: raise NotImplementedError( f"Diffusion type {self.diffusion} not implemented.") # Sample from posterior xs = _sample_categorical(q_xs) if self.diffusion == 'absorbing_state': copy_flag = (xt != self.mask_index).to(torch.bool) q_xs[copy_flag] = 0.0 q_xs[copy_flag, xt[copy_flag]] = 1.0 xs = torch.where(copy_flag, xt, xs) return xs, q_xs, {'log_x_theta': log_x_theta} def _cfg_denoise( self, cond: torch.tensor, gamma: float, xt: torch.tensor, time_conditioning: torch.tensor, move_chance_t: torch.tensor, move_chance_s: torch.tensor, cache: typing.Optional[typing.Dict[str, torch.Tensor]] = None, ) -> typing.Tuple[torch.tensor, torch.tensor, typing.Dict[str, torch.tensor]]: # Compute log_x_theta if cache is not None: log_x_theta_uncond = cache['log_x_theta_uncond'] log_x_theta_cond = cache['log_x_theta_cond'] else: if gamma == 0.0: # Sample unconditionally mask_cond = (torch.ones_like(cond) * self.config.data.num_classes) log_x_theta_uncond = self.forward( xt, time_conditioning, cond=mask_cond) log_x_theta_cond = None elif gamma == 1.0: # Sample conditionally log_x_theta_cond = self.forward(xt, time_conditioning, cond=cond) log_x_theta_uncond = None else: # Sample from tempered distribution log_x_theta_cond = self.forward(xt, time_conditioning, cond=cond) mask_cond = (torch.ones_like(cond) * self.config.data.num_classes) log_x_theta_uncond = self.forward(xt, time_conditioning, cond=mask_cond) # Compute (weighted) posterior if (log_x_theta_cond is None # gamma == 0 or log_x_theta_uncond is None): # or gamma == 1 log_x_theta = log_x_theta_uncond if log_x_theta_uncond is not None else log_x_theta_cond x_theta = log_x_theta.exp() if self.diffusion == 'absorbing_state': q_xs = x_theta * (move_chance_t - move_chance_s) q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0] q_xs /= move_chance_t elif self.diffusion == 'uniform': q_xs = self._compute_posterior( x=x_theta, xt=xt, alpha_s=1 - move_chance_s, alpha_t=1 - move_chance_t) else: raise NotImplementedError( f"Diffusion type {self.diffusion} not implemented.") else: # gamma != 0 and gamma != 1 if self.diffusion == 'absorbing_state': log_x_theta = (gamma * log_x_theta_cond + (1 - gamma) * log_x_theta_uncond) x_theta = log_x_theta.softmax(dim=-1) q_xs = x_theta * (move_chance_t - move_chance_s) q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0] q_xs /= move_chance_t elif (self.diffusion == 'uniform' or self.diffusion == 'uniform_data_marginals'): log_q_xs_uncond = self._compute_posterior( x=log_x_theta_uncond.exp(), xt=xt, alpha_s=1 - move_chance_s, alpha_t=1 - move_chance_t).log() log_q_xs_cond = self._compute_posterior( x=log_x_theta_cond.exp(), xt=xt, alpha_s=1 - move_chance_s, alpha_t=1 - move_chance_t).log() log_q_xs = (gamma * log_q_xs_cond + (1 - gamma) * log_q_xs_uncond) q_xs = log_q_xs.softmax(dim=-1) else: raise NotImplementedError( f"Diffusion type {self.diffusion} not implemented.") # Sample from posterior xs = _sample_categorical(q_xs) if self.diffusion == 'absorbing_state': copy_flag = (xt != self.mask_index).to(torch.bool) q_xs[copy_flag] = 0.0 q_xs[copy_flag, xt[copy_flag]] = 1.0 xs = torch.where(copy_flag, xt, xs) return xs, q_xs, {'log_x_theta_uncond': log_x_theta_uncond, 'log_x_theta_cond': log_x_theta_cond} def _cbg_denoise( self, conditioning_class: int, gamma: float, classifier_model: classifier.Classifier, xt: torch.tensor, time_conditioning: torch.tensor, move_chance_t: torch.tensor, move_chance_s: torch.tensor, wt_embed: torch.tensor = None, mut_embed: torch.tensor = None, use_approx: bool = False, # whether to use first-order approximation cache: typing.Optional[typing.Dict[str, torch.Tensor]] = None, ) -> typing.Tuple[torch.tensor, torch.tensor, typing.Dict[str, torch.tensor]]: if cache is not None: log_x_theta = cache['log_x_theta'] classifier_log_prob = cache['classifier_log_prob'] else: # Diffusion model log_x_theta = self.forward(xt, time_conditioning, cond=None) # Classifier model if use_approx: xt_one_hot = torch.nn.functional.one_hot( xt, self.vocab_size).to(torch.float) with torch.enable_grad(): xt_one_hot.requires_grad_(True) classifier_log_prob_xt = classifier_model.get_log_probs( xt_one_hot, time_conditioning) classifier_log_prob_xt[..., conditioning_class].sum().backward() grad_log_prob_xt = xt_one_hot.grad classifier_log_prob_ratio = ( grad_log_prob_xt - (xt_one_hot * grad_log_prob_xt).sum(dim=-1, keepdim=True) ).detach().requires_grad_(False) classifier_log_prob = ( classifier_log_prob_ratio + classifier_log_prob_xt[..., conditioning_class][..., None, None] ).detach().requires_grad_(False) else: # Copied from https://github.com/hnisonoff/discrete_guidance/blob/main/src/fm_utils.py#L441 bsz, seq_len = xt.shape # Create bsz*seq_len*N copies of input sequences # Shape: (bsz, 1, seq_len) -> (bsz, seq_len*N, seq_len) # (where N = vocab_size). xt_expand = xt.unsqueeze(1).repeat(1, seq_len * self.vocab_size, 1) # Flatten batch and transition dimensions # Shape: (bsz, seq_len*N, seq_len) -> (bsz*seq_len*N, seq_len) xt_expand = xt_expand.view(-1, seq_len) # Create indices for all possible transitions # Shape: (seq_len*N,) -> (bsz, seq_len*N) -> (bsz*seq_len*N,) jump_idx = torch.arange(seq_len * self.vocab_size).to(xt.device) jump_idx = jump_idx.repeat(bsz, 1).flatten() # Create tensor for states after one transition xt_jumps = xt_expand.clone() # Calculate which dimension changes for each transition # Shape: (bsz*seq_len*N,) jump_dims = jump_idx // self.vocab_size # Calculate new value for changed dimension # Shape: (bsz*seq_len*N,) jump_states = jump_idx % self.vocab_size # Apply transitions by assigning new values at transition dimensions # Shape: (bsz*seq_len*N, seq_len) -> N * (bsz*seq_len, seq_len) xt_jumps[ torch.arange(jump_idx.size(0), device=xt.device), jump_dims, # Index the transitioned dimension ] = jump_states # Assign the new state # classifier_log_prob = (classifier_model.get_log_probs( # xt_jumps, time_conditioning.repeat(seq_len * self.vocab_size) # ))[..., conditioning_class].reshape(bsz, seq_len, self.vocab_size) # xt_jumps shape (bsz*seq_len*N, seq_len) a batch of sequences classifier_log_prob = classifier_model.get_log_probs( xt_jumps, wt_embed.repeat(xt_jumps.shape[0], 1, 1), mut_embed.repeat(xt_jumps.shape[0], 1, 1) )[:, 1].reshape(bsz, seq_len, self.vocab_size) # pdb.set_trace() # Compute unguided posterior if self.diffusion == 'absorbing_state': diffusion_log_probs = log_x_theta + torch.log( 1. - (move_chance_s / move_chance_t)) diffusion_log_probs[..., self.mask_index] = torch.log( move_chance_s / move_chance_t)[:, :, 0] diffusion_log_probs.detach() elif self.diffusion == 'uniform': diffusion_log_probs = self._compute_posterior( x=log_x_theta.exp(), xt=xt, alpha_s=1 - move_chance_s, alpha_t=1 - move_chance_t).log() else: raise NotImplementedError( f"Diffusion type {self.diffusion} not implemented.") # Apply guidance with torch.no_grad(): if self.diffusion == 'absorbing_state': guided_log_probs = (gamma * classifier_log_prob) + diffusion_log_probs copy_flag = (xt != self.mask_index) guided_log_probs[copy_flag] = self.neg_infinity guided_log_probs[copy_flag, xt[copy_flag]] = 0.0 elif self.diffusion == 'uniform': # pdb.set_trace() guided_log_probs = (gamma * classifier_log_prob) + diffusion_log_probs else: raise NotImplementedError( f"Diffusion type {self.diffusion} not implemented.") guided_probs = guided_log_probs.softmax(dim=-1) # Sample from guided posterior xs = _sample_categorical(guided_probs) if self.diffusion == 'absorbing_state': xs = torch.where(copy_flag.to(bool), xt, xs) return xs, guided_probs, {'log_x_theta': log_x_theta, 'classifier_log_prob': classifier_log_prob} def _nos_denoise( self, classifier_model: classifier.Classifier, num_nos_steps: int, nos_step_size: float, nos_stability_coef: float, conditioning_class: int, xt: torch.Tensor, time_conditioning: torch.tensor, move_chance_t: torch.tensor, move_chance_s: torch.tensor, ) -> typing.Tuple[torch.tensor, torch.tensor, None]: # Compute original diffusion_log_probs and hidden states copy_flag = (xt != self.mask_index).to(torch.bool) with torch.no_grad(): time_conditioning = self._process_sigma(time_conditioning) with torch.cuda.amp.autocast(dtype=torch.float32): logits, hidden_states = self.backbone( xt, time_conditioning, cond=None, return_hidden_states=True) if self.parameterization == 'subs': log_x_theta = self._subs_parameterization( logits=logits, xt=xt) elif self.parameterization == 'd3pm': # returns log_probs if self.subs_masking: # Can use "zero masking prob" logits[:, :, self.mask_index] += self.neg_infinity log_x_theta = logits.log_softmax(dim=-1) else: raise NotImplementedError( f"Parameterization {self.parameterization} not implemented for NOS guidance.") if self.diffusion == 'absorbing_state': diffusion_log_probs = log_x_theta + torch.log( 1. - (move_chance_s / move_chance_t)) diffusion_log_probs[..., self.mask_index] = torch.log( move_chance_s / move_chance_t)[:, :, 0] diffusion_log_probs[copy_flag] = self.neg_infinity diffusion_log_probs[copy_flag, xt[copy_flag]] = 0.0 elif self.diffusion == 'uniform': diffusion_log_probs = self._compute_posterior( x=log_x_theta.exp(), xt=xt, alpha_s=1 - move_chance_s, alpha_t=1 - move_chance_t).log() # Perform NOS steps kl_loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=True) delta = torch.nn.Parameter( torch.zeros_like(hidden_states[-1]), requires_grad=True) optimizer = torch.optim.Adagrad([delta], lr=nos_step_size) with torch.enable_grad(): for _ in tqdm(range(num_nos_steps), desc='NOS', leave=False): h_current = hidden_states[-1] + delta target_loss = classifier_model.get_log_probs( xt, time_conditioning, x_emb=h_current)[..., conditioning_class].sum() with torch.cuda.amp.autocast(dtype=torch.float32): new_logits = self.forward(xt, time_conditioning, cond=None, x_emb=h_current) if self.diffusion == 'absorbing_state': adjusted_log_probs = new_logits + torch.log( 1. - (move_chance_s / move_chance_t)) adjusted_log_probs[ ..., self.mask_index] = torch.log( move_chance_s / move_chance_t)[:, :, 0] adjusted_log_probs[ copy_flag] = self.neg_infinity adjusted_log_probs[copy_flag, xt[copy_flag]] = 0.0 elif self.diffusion == 'uniform': adjusted_log_probs = self._compute_posterior( x=new_logits.exp(), xt=xt, alpha_s=1 - move_chance_s, alpha_t=1 - move_chance_t).log() kl = kl_loss(adjusted_log_probs, diffusion_log_probs) loss = -target_loss + nos_stability_coef * kl optimizer.zero_grad() loss.backward() optimizer.step() with torch.cuda.amp.autocast(dtype=torch.float32): guided_logits = self.forward( xt, time_conditioning, cond=None, x_emb=hidden_states[-1] + delta.data) if self.diffusion == 'absorbing_state': diffusion_log_probs = guided_logits + torch.log( 1. - (move_chance_s / move_chance_t)) diffusion_log_probs[ ..., self.mask_index] = torch.log( move_chance_s / move_chance_t)[:, :, 0] diffusion_log_probs.detach() guided_probs = diffusion_log_probs.exp() elif self.diffusion == 'uniform': guided_probs = self._compute_posterior( x=guided_logits.exp(), xt=xt, alpha_s=1 - move_chance_s, alpha_t=1 - move_chance_t).detach() else: raise NotImplementedError( f"Diffusion type {self.diffusion} not implemented.") xs = _sample_categorical(guided_probs) if self.diffusion == 'absorbing_state': xs = torch.where(copy_flag, xt, xs) return xs, guided_probs, None