File size: 2,794 Bytes
ed920f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import itertools
import math
import torch
import torch.nn.functional as F
import pytorch_lightning as L
import torchmetrics
from dataclasses import dataclass
from models import dit, ema
import noise_schedule  # Assuming this is part of the MDLM repository

LOG2 = math.log(2)

@dataclass
class Loss:
    loss: torch.FloatTensor
    nlls: torch.FloatTensor
    token_mask: torch.FloatTensor

class NLL(torchmetrics.MeanMetric):
    pass

class BPD(NLL):
    def compute(self) -> torch.Tensor:
        """Computes the bits per dimension.

        Returns:
          bpd
        """
        return self.mean_value / self.weight / LOG2

class Perplexity(NLL):
    def compute(self) -> torch.Tensor:
        """Computes the Perplexity.

        Returns:
         Perplexity
        """
        return torch.exp(self.mean_value / self.weight)

class Diffusion(L.LightningModule):
    def __init__(self, config, latent_dim):
        super().__init__()
        self.config = config
        self.latent_dim = latent_dim

        self.backbone = dit.DIT(config, vocab_size=self.latent_dim)
        self.T = self.config.T
        self.subs_masking = self.config.subs_masking

        self.softplus = torch.nn.Softplus()
        metrics = torchmetrics.MetricCollection({
            'nll': NLL(),
            'bpd': BPD(),
            'ppl': Perplexity(),
        })
        metrics.set_dtype(torch.float64)
        self.train_metrics = metrics.clone(prefix='train/')
        self.valid_metrics = metrics.clone(prefix='val/')
        self.test_metrics = metrics.clone(prefix='test/')

        self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype)
        self.lr = self.config.optim["lr"]
        self.sampling_eps = self.config.training.get("sampling_eps", 1e-5)
        self.time_conditioning = self.config.get("time_conditioning", True)
        self.neg_infinity = -1000000.0

    def forward(self, latents, sigma):
        """Forward diffusion process, adds noise to the latents."""
        noise = sigma * torch.randn_like(latents)
        noisy_latents = latents + noise
        return noisy_latents

    def reverse_diffusion(self, noisy_latents, sigma):
        """Reverse diffusion process, denoises the latents."""
        denoised_latents = self.backbone(noisy_latents, sigma)
        return denoised_latents

    def training_step(self, batch, batch_idx):
        sigma = torch.rand(batch.size(0), device=self.device)
        noisy_latents = self.forward(batch, sigma)
        denoised_latents = self.reverse_diffusion(noisy_latents, sigma)
        loss = F.mse_loss(denoised_latents, batch)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer