import itertools import math import torch import torch.nn.functional as F import pytorch_lightning as L import torchmetrics from dataclasses import dataclass from models import dit, ema import noise_schedule # Assuming this is part of the MDLM repository LOG2 = math.log(2) @dataclass class Loss: loss: torch.FloatTensor nlls: torch.FloatTensor token_mask: torch.FloatTensor class NLL(torchmetrics.MeanMetric): pass class BPD(NLL): def compute(self) -> torch.Tensor: """Computes the bits per dimension. Returns: bpd """ return self.mean_value / self.weight / LOG2 class Perplexity(NLL): def compute(self) -> torch.Tensor: """Computes the Perplexity. Returns: Perplexity """ return torch.exp(self.mean_value / self.weight) class Diffusion(L.LightningModule): def __init__(self, config, latent_dim): super().__init__() self.config = config self.latent_dim = latent_dim self.backbone = dit.DIT(config, vocab_size=self.latent_dim) self.T = self.config.T self.subs_masking = self.config.subs_masking self.softplus = torch.nn.Softplus() metrics = torchmetrics.MetricCollection({ 'nll': NLL(), 'bpd': BPD(), 'ppl': Perplexity(), }) metrics.set_dtype(torch.float64) self.train_metrics = metrics.clone(prefix='train/') self.valid_metrics = metrics.clone(prefix='val/') self.test_metrics = metrics.clone(prefix='test/') self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype) self.lr = self.config.optim["lr"] self.sampling_eps = self.config.training.get("sampling_eps", 1e-5) self.time_conditioning = self.config.get("time_conditioning", True) self.neg_infinity = -1000000.0 def forward(self, latents, sigma): """Forward diffusion process, adds noise to the latents.""" noise = sigma * torch.randn_like(latents) noisy_latents = latents + noise return noisy_latents def reverse_diffusion(self, noisy_latents, sigma): """Reverse diffusion process, denoises the latents.""" denoised_latents = self.backbone(noisy_latents, sigma) return denoised_latents def training_step(self, batch, batch_idx): sigma = torch.rand(batch.size(0), device=self.device) noisy_latents = self.forward(batch, sigma) denoised_latents = self.reverse_diffusion(noisy_latents, sigma) loss = F.mse_loss(denoised_latents, batch) self.log("train_loss", loss) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) return optimizer