MeMDLM / models /diffusion.py
pranamanam's picture
Upload 15 files
ed920f9 verified
raw
history blame
2.79 kB
import itertools
import math
import torch
import torch.nn.functional as F
import pytorch_lightning as L
import torchmetrics
from dataclasses import dataclass
from models import dit, ema
import noise_schedule # Assuming this is part of the MDLM repository
LOG2 = math.log(2)
@dataclass
class Loss:
loss: torch.FloatTensor
nlls: torch.FloatTensor
token_mask: torch.FloatTensor
class NLL(torchmetrics.MeanMetric):
pass
class BPD(NLL):
def compute(self) -> torch.Tensor:
"""Computes the bits per dimension.
Returns:
bpd
"""
return self.mean_value / self.weight / LOG2
class Perplexity(NLL):
def compute(self) -> torch.Tensor:
"""Computes the Perplexity.
Returns:
Perplexity
"""
return torch.exp(self.mean_value / self.weight)
class Diffusion(L.LightningModule):
def __init__(self, config, latent_dim):
super().__init__()
self.config = config
self.latent_dim = latent_dim
self.backbone = dit.DIT(config, vocab_size=self.latent_dim)
self.T = self.config.T
self.subs_masking = self.config.subs_masking
self.softplus = torch.nn.Softplus()
metrics = torchmetrics.MetricCollection({
'nll': NLL(),
'bpd': BPD(),
'ppl': Perplexity(),
})
metrics.set_dtype(torch.float64)
self.train_metrics = metrics.clone(prefix='train/')
self.valid_metrics = metrics.clone(prefix='val/')
self.test_metrics = metrics.clone(prefix='test/')
self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype)
self.lr = self.config.optim["lr"]
self.sampling_eps = self.config.training.get("sampling_eps", 1e-5)
self.time_conditioning = self.config.get("time_conditioning", True)
self.neg_infinity = -1000000.0
def forward(self, latents, sigma):
"""Forward diffusion process, adds noise to the latents."""
noise = sigma * torch.randn_like(latents)
noisy_latents = latents + noise
return noisy_latents
def reverse_diffusion(self, noisy_latents, sigma):
"""Reverse diffusion process, denoises the latents."""
denoised_latents = self.backbone(noisy_latents, sigma)
return denoised_latents
def training_step(self, batch, batch_idx):
sigma = torch.rand(batch.size(0), device=self.device)
noisy_latents = self.forward(batch, sigma)
denoised_latents = self.reverse_diffusion(noisy_latents, sigma)
loss = F.mse_loss(denoised_latents, batch)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer