dinhdat1110 commited on
Commit
dabac1b
·
1 Parent(s): 7d078ca
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitignore CHANGED
@@ -161,8 +161,8 @@ cython_debug/
161
  *.jpeg
162
  *.gz
163
  cifar-10-batches-py
164
- checkpoints
165
  MNIST
166
  *.ipynb
167
  data
168
  wandb
 
 
161
  *.jpeg
162
  *.gz
163
  cifar-10-batches-py
 
164
  MNIST
165
  *.ipynb
166
  data
167
  wandb
168
+ /checkpoints/lightning_logs
checkpoints/{cifar.ckpt → model/celebahq.ckpt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d9efb3494902fa10ab74f65d670a114a5470cfb879fa60bdd0292956895de587
3
- size 278317592
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfc9fa8cb71bc57bc4d1f54da56e71060609828bc6903cec3ae46418c18bf3a1
3
+ size 99080226
checkpoints/model/cifar10.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:509e43eb3be202b3d71ef37ca5b66de501fccda8806e4a193a466bcbfcb71b83
3
+ size 99090784
checkpoints/{mnist.ckpt → model/mnist.ckpt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:584b8c9097fc8022291ad2760512bf5cef8bee1b1f4fc52b211e112dca44c643
3
- size 278302296
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2cde8ee89e68b413c32685145b4ad1fea10b7a6617c0f14a12a9af8afac9712
3
+ size 99081632
diffusion/dataset/celeba.py CHANGED
@@ -11,12 +11,13 @@ class CelebADataset(Dataset):
11
  def __init__(
12
  self,
13
  data_dir: str,
 
14
  ):
15
  self.list_path = os.listdir(data_dir)
16
  self.data_dir = data_dir
17
  self.transform = transforms.Compose(
18
  [
19
- transforms.Resize((64, 64)),
20
  transforms.ToTensor(),
21
  transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
22
  ]
@@ -37,13 +38,15 @@ class CelebADataModule(pl.LightningDataModule):
37
  batch_size: int = 32,
38
  num_workers: int = 0,
39
  seed: int = 42,
40
- train_ratio: float = 0.99
 
41
  ):
42
  super().__init__()
43
  self.data_dir = data_dir
44
  self.batch_size = batch_size
45
  self.num_workers = num_workers
46
  self.train_ratio = min(train_ratio, 0.99)
 
47
  self.seed = seed
48
 
49
  self.loader = partial(
@@ -56,7 +59,7 @@ class CelebADataModule(pl.LightningDataModule):
56
 
57
  def setup(self, stage: str):
58
  if stage == "fit":
59
- dataset = CelebADataset(self.data_dir)
60
  self.CelebA_train, self.CelebA_val, _ = random_split(
61
  dataset=dataset,
62
  lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio],
 
11
  def __init__(
12
  self,
13
  data_dir: str,
14
+ img_dim: int = 64
15
  ):
16
  self.list_path = os.listdir(data_dir)
17
  self.data_dir = data_dir
18
  self.transform = transforms.Compose(
19
  [
20
+ transforms.Resize((img_dim, img_dim)),
21
  transforms.ToTensor(),
22
  transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
23
  ]
 
38
  batch_size: int = 32,
39
  num_workers: int = 0,
40
  seed: int = 42,
41
+ train_ratio: float = 0.99,
42
+ img_dim: int = 64
43
  ):
44
  super().__init__()
45
  self.data_dir = data_dir
46
  self.batch_size = batch_size
47
  self.num_workers = num_workers
48
  self.train_ratio = min(train_ratio, 0.99)
49
+ self.img_dim = img_dim
50
  self.seed = seed
51
 
52
  self.loader = partial(
 
59
 
60
  def setup(self, stage: str):
61
  if stage == "fit":
62
+ dataset = CelebADataset(self.data_dir, self.img_dim)
63
  self.CelebA_train, self.CelebA_val, _ = random_split(
64
  dataset=dataset,
65
  lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio],
diffusion/dataset/cifar10.py CHANGED
@@ -13,7 +13,8 @@ class CIFAR10DataModule(pl.LightningDataModule):
13
  batch_size: int = 32,
14
  num_workers: int = 0,
15
  seed: int = 42,
16
- train_ratio: float = 0.99
 
17
  ):
18
  super().__init__()
19
  self.data_dir = data_dir
@@ -23,7 +24,7 @@ class CIFAR10DataModule(pl.LightningDataModule):
23
  self.train_ratio = min(train_ratio, 0.99)
24
  self.transform = transforms.Compose(
25
  [
26
- transforms.Resize((32, 32)),
27
  transforms.ToTensor(),
28
  transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
29
  ]
 
13
  batch_size: int = 32,
14
  num_workers: int = 0,
15
  seed: int = 42,
16
+ train_ratio: float = 0.99,
17
+ img_dim: int = 32
18
  ):
19
  super().__init__()
20
  self.data_dir = data_dir
 
24
  self.train_ratio = min(train_ratio, 0.99)
25
  self.transform = transforms.Compose(
26
  [
27
+ transforms.Resize((img_dim, img_dim)),
28
  transforms.ToTensor(),
29
  transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
30
  ]
diffusion/dataset/mnist.py CHANGED
@@ -13,7 +13,8 @@ class MNISTDataModule(pl.LightningDataModule):
13
  batch_size: int = 32,
14
  num_workers: int = 0,
15
  seed: int = 42,
16
- train_ratio: float = 0.99
 
17
  ):
18
  super().__init__()
19
  self.data_dir = data_dir
@@ -23,7 +24,7 @@ class MNISTDataModule(pl.LightningDataModule):
23
  self.seed = seed
24
  self.transform = transforms.Compose(
25
  [
26
- transforms.Resize((32, 32)),
27
  transforms.ToTensor(),
28
  transforms.Normalize(mean=(0.5), std=(0.5))
29
  ]
 
13
  batch_size: int = 32,
14
  num_workers: int = 0,
15
  seed: int = 42,
16
+ train_ratio: float = 0.99,
17
+ img_dim: int = 32
18
  ):
19
  super().__init__()
20
  self.data_dir = data_dir
 
24
  self.seed = seed
25
  self.transform = transforms.Compose(
26
  [
27
+ transforms.Resize((img_dim, img_dim)),
28
  transforms.ToTensor(),
29
  transforms.Normalize(mean=(0.5), std=(0.5))
30
  ]
diffusion/model/diffusion/__init__.py CHANGED
@@ -1,4 +1,3 @@
1
  from .unet import *
2
  from .model import *
3
- from .sampling import *
4
  from .scheduler import *
 
1
  from .unet import *
2
  from .model import *
 
3
  from .scheduler import *
diffusion/model/diffusion/model.py CHANGED
@@ -4,8 +4,10 @@ import numpy as np
4
  import pytorch_lightning as pl
5
  import diffusion
6
  import wandb
 
7
  from torchvision.utils import make_grid
8
  from torch.optim.lr_scheduler import OneCycleLR
 
9
 
10
 
11
  class DiffusionModel(pl.LightningModule):
@@ -16,6 +18,7 @@ class DiffusionModel(pl.LightningModule):
16
  beta_1: float = 0.0001,
17
  beta_2: float = 0.02,
18
  in_channels: int = 3,
 
19
  dim: int = 32,
20
  num_classes: int | None = 10,
21
  sample_per_epochs: int = 50,
@@ -33,11 +36,17 @@ class DiffusionModel(pl.LightningModule):
33
  self.max_timesteps = max_timesteps
34
  self.in_channels = in_channels
35
  self.dim = dim
 
36
  self.num_classes = num_classes
37
 
38
- self.scheduler = diffusion.LinearScheduler(
39
- max_timesteps, beta_1, beta_2
40
- )
 
 
 
 
 
41
 
42
  self.criterion = nn.MSELoss()
43
 
@@ -49,8 +58,6 @@ class DiffusionModel(pl.LightningModule):
49
 
50
  self.sampling_kwargs = {
51
  'model': self.model,
52
- 'scheduler': self.scheduler,
53
- 'max_timesteps': self.max_timesteps,
54
  'in_channels': self.in_channels,
55
  'dim': self.dim,
56
  }
@@ -75,29 +82,37 @@ class DiffusionModel(pl.LightningModule):
75
  x_0: torch.Tensor,
76
  t: torch.Tensor
77
  ):
78
- noise = torch.randn_like(x_0, device=x_0.device)
79
- new_x = self.scheduler.get('sqrt_alpha_hat', t) * x_0
80
- new_noise = self.scheduler.get('sqrt_one_minus_alpha_hat', t) * noise
81
- return new_x + new_noise, noise
82
 
83
- def sampling(self, labels=None, n_samples: int = 16):
84
- return diffusion.ddpm_sampling(
85
- n_samples=n_samples,
86
- labels=labels,
87
- **self.sampling_kwargs
88
- )
89
-
90
- def sampling_demo(self, labels=None, n_samples: int = 16):
91
- return diffusion.ddpm_sampling_demo(
92
- n_samples=n_samples,
93
- labels=labels,
94
- **self.sampling_kwargs
95
- )
 
 
 
 
 
 
 
 
 
96
 
97
  def forward(self, x_0, labels):
 
98
  t = torch.randint(
99
- low=0, high=self.max_timesteps, size=(x_0.shape[0],), device=x_0.device
100
  )
 
101
  x_noise, noise = self.noising(x_0, t)
102
  noise_pred = self.model(x_noise, t, labels)
103
  return noise, noise_pred
@@ -108,8 +123,8 @@ class DiffusionModel(pl.LightningModule):
108
  labels = None
109
  else:
110
  x_0, labels = batch
111
- if np.random.random() < 0.1:
112
- labels = None
113
  noise, noise_pred = self(x_0, labels)
114
  loss = self.criterion(noise, noise_pred)
115
  self.train_loss.append(loss)
@@ -135,19 +150,20 @@ class DiffusionModel(pl.LightningModule):
135
  )
136
  self.train_loss.clear()
137
 
138
- if self.epoch_count % self.spe == 0:
139
- wandblog = self.logger.experiment
140
- x_t = self.sampling(n_samples=self.n_samples)
141
- img_array = [x_t[i] for i in range(x_t.shape[0])]
142
-
143
- wandblog.log(
144
- {
145
- "sampling": wandb.Image(
146
- make_grid(img_array, nrow=4).permute(1, 2, 0).cpu().numpy(),
147
- caption="Sampled Image!"
148
- )
149
- }
150
- )
 
151
 
152
  self.epoch_count += 1
153
 
@@ -173,10 +189,34 @@ class DiffusionModel(pl.LightningModule):
173
  total_steps=self.trainer.estimated_stepping_batches,
174
 
175
  )
176
- return {
177
- 'optimizer': optimizer,
178
- 'lr_scheduler': scheduler
179
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
  if __name__ == "__main__":
 
4
  import pytorch_lightning as pl
5
  import diffusion
6
  import wandb
7
+ import matplotlib.pyplot as plt
8
  from torchvision.utils import make_grid
9
  from torch.optim.lr_scheduler import OneCycleLR
10
+ from IPython.display import clear_output
11
 
12
 
13
  class DiffusionModel(pl.LightningModule):
 
18
  beta_1: float = 0.0001,
19
  beta_2: float = 0.02,
20
  in_channels: int = 3,
21
+ mode: str = "ddpm",
22
  dim: int = 32,
23
  num_classes: int | None = 10,
24
  sample_per_epochs: int = 50,
 
36
  self.max_timesteps = max_timesteps
37
  self.in_channels = in_channels
38
  self.dim = dim
39
+ self.mode = mode
40
  self.num_classes = num_classes
41
 
42
+ if mode == "ddpm":
43
+ self.scheduler = diffusion.DDPMScheduler(
44
+ max_timesteps, beta_1, beta_2
45
+ )
46
+ elif mode == "ddim":
47
+ self.scheduler = diffusion.DDIMScheduler(
48
+ max_timesteps, beta_1, beta_2
49
+ )
50
 
51
  self.criterion = nn.MSELoss()
52
 
 
58
 
59
  self.sampling_kwargs = {
60
  'model': self.model,
 
 
61
  'in_channels': self.in_channels,
62
  'dim': self.dim,
63
  }
 
82
  x_0: torch.Tensor,
83
  t: torch.Tensor
84
  ):
85
+ return self.scheduler.noising(x_0, t)
 
 
 
86
 
87
+ def sampling(
88
+ self,
89
+ labels=None,
90
+ mode: int = "ddpm",
91
+ demo: bool = True,
92
+ n_samples: int = 16,
93
+ timesteps: int = 1000,
94
+ ):
95
+ if mode == "ddpm":
96
+ self.test_scheduler = diffusion.DDPMScheduler(self.max_timesteps)
97
+ elif mode == "ddim":
98
+ self.test_scheduler = diffusion.DDIMScheduler(self.max_timesteps)
99
+
100
+ kwargs = {
101
+ "n_samples": n_samples,
102
+ "labels": labels,
103
+ "timesteps": timesteps,
104
+ } | self.sampling_kwargs
105
+ if demo:
106
+ return self.test_scheduler.sampling_demo(**kwargs)
107
+ else:
108
+ return self.test_scheduler.sampling(**kwargs)
109
 
110
  def forward(self, x_0, labels):
111
+ n = x_0.shape[0]
112
  t = torch.randint(
113
+ low=0, high=self.max_timesteps, size=(n//2+1,), device=x_0.device
114
  )
115
+ t = torch.cat([t, self.max_timesteps - t - 1], dim=0)[:n]
116
  x_noise, noise = self.noising(x_0, t)
117
  noise_pred = self.model(x_noise, t, labels)
118
  return noise, noise_pred
 
123
  labels = None
124
  else:
125
  x_0, labels = batch
126
+ if np.random.random() < 0.1:
127
+ labels = None
128
  noise, noise_pred = self(x_0, labels)
129
  loss = self.criterion(noise, noise_pred)
130
  self.train_loss.append(loss)
 
150
  )
151
  self.train_loss.clear()
152
 
153
+ if self.spe > 0:
154
+ if self.epoch_count % self.spe == 0:
155
+ wandblog = self.logger.experiment
156
+ x_t = self.sampling(n_samples=self.n_samples, timesteps=100, demo=False)
157
+ img_array = [x_t[i] for i in range(x_t.shape[0])]
158
+
159
+ wandblog.log(
160
+ {
161
+ "sampling": wandb.Image(
162
+ make_grid(img_array, nrow=4).permute(1, 2, 0).cpu().numpy(),
163
+ caption="Sampled Image!"
164
+ )
165
+ }
166
+ )
167
 
168
  self.epoch_count += 1
169
 
 
189
  total_steps=self.trainer.estimated_stepping_batches,
190
 
191
  )
192
+ return [optimizer], [scheduler]
193
+
194
+ def draw(
195
+ self,
196
+ labels=None,
197
+ mode: int = "ddpm",
198
+ n_samples: int = 1,
199
+ timesteps: int = 1000,
200
+ ):
201
+ demo = self.sampling(
202
+ labels=labels,
203
+ mode=mode,
204
+ n_samples=n_samples,
205
+ timesteps=timesteps,
206
+ demo=True
207
+ )
208
+ idx = 0
209
+ length = labels.shape[0] if labels is not None else n_samples
210
+ for img in demo:
211
+ for i in range(length):
212
+ plt.subplot(1, length, i+1)
213
+ plt.imshow(img[i].permute(1, 2, 0))
214
+ plt.axis('off')
215
+ plt.title(f"{idx+1}/{timesteps}")
216
+ idx += 1
217
+ plt.show()
218
+ if idx < timesteps:
219
+ clear_output(wait=True)
220
 
221
 
222
  if __name__ == "__main__":
diffusion/model/diffusion/sampling.py DELETED
@@ -1,82 +0,0 @@
1
- import torch
2
-
3
-
4
- def ddpm_sampling_timestep(
5
- x_t,
6
- model,
7
- scheduler,
8
- labels,
9
- t,
10
- n_samples: int = 16,
11
- cfg_scale: int = 3,
12
- ):
13
- time = torch.full((n_samples,), fill_value=t, device=model.device)
14
- pred_noise = model(x_t, time, labels)
15
- if cfg_scale > 0:
16
- uncond_pred_noise = model(x_t, time, None)
17
- pred_noise = torch.lerp(uncond_pred_noise, pred_noise, cfg_scale)
18
- alpha = scheduler.get('alpha', time)
19
- sqrt_alpha = scheduler.get('sqrt_alpha', time)
20
- somah = scheduler.get('sqrt_one_minus_alpha_hat', time)
21
- sqrt_beta = scheduler.get('sqrt_beta', time)
22
- if t > 0:
23
- noise = torch.randn_like(x_t, device=model.device)
24
- else:
25
- noise = torch.zeros_like(x_t, device=model.device)
26
-
27
- x_t_new = 1 / sqrt_alpha * (x_t - (1-alpha) / somah * pred_noise) + sqrt_beta * noise
28
- return x_t_new.clamp(-1, 1)
29
-
30
-
31
- @torch.no_grad()
32
- def ddpm_sampling(
33
- model,
34
- scheduler,
35
- n_samples: int = 16,
36
- max_timesteps: int = 1000,
37
- in_channels: int = 3,
38
- dim: int = 32,
39
- cfg_scale: int = 3,
40
-
41
- labels=None
42
- ):
43
- if labels is not None:
44
- n_samples = labels.shape[0]
45
-
46
- x_t = torch.randn(
47
- n_samples, in_channels, dim, dim, device=model.device
48
- )
49
- model.eval()
50
- for t in range(max_timesteps-1, -1, -1):
51
- x_t = ddpm_sampling_timestep(x_t=x_t, model=model, scheduler=scheduler,
52
- labels=labels, t=t, n_samples=n_samples,
53
- cfg_scale=cfg_scale)
54
-
55
- model.train()
56
- x_t = (x_t + 1) / 2 * 255. # range [0,255]
57
- return x_t.type(torch.uint8)
58
-
59
-
60
- @torch.no_grad()
61
- def ddpm_sampling_demo(
62
- model,
63
- scheduler,
64
- n_samples: int = 16,
65
- max_timesteps: int = 1000,
66
- in_channels: int = 3,
67
- dim: int = 32,
68
- cfg_scale: int = 3,
69
- labels=None
70
- ):
71
- if labels is not None:
72
- n_samples = labels.shape[0]
73
-
74
- x_t = torch.randn(
75
- n_samples, in_channels, dim, dim, device=model.device
76
- )
77
- model.eval()
78
- for t in range(max_timesteps-1, -1, -1):
79
- x_t = ddpm_sampling_timestep(x_t=x_t, model=model, scheduler=scheduler,
80
- labels=labels, t=t, n_samples=n_samples,
81
- cfg_scale=cfg_scale)
82
- yield ((x_t + 1) / 2 * 255).type(torch.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusion/model/diffusion/scheduler.py CHANGED
@@ -1,20 +1,184 @@
1
  import torch
2
 
3
 
4
- class LinearScheduler:
5
  def __init__(
6
  self,
7
  max_timesteps: int = 1000,
8
  beta_1: int = 0.0001,
9
  beta_2: int = 0.02
10
  ) -> None:
11
- self.beta = torch.linspace(beta_1, beta_2, max_timesteps)
12
- self.sqrt_beta = torch.sqrt(self.beta)[:, None, None, None]
13
- self.alpha = (1 - self.beta)[:, None, None, None]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  self.sqrt_alpha = torch.sqrt(self.alpha)
15
- self.alpha_hat = torch.cumprod(1 - self.beta, dim=0)[:, None, None, None]
16
  self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat)
 
17
  self.sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- def get(self, key: str, t: torch.Tensor):
20
- return self.__dict__[key].to(t.device)[t]
 
 
 
 
 
 
1
  import torch
2
 
3
 
4
+ class DDPMScheduler:
5
  def __init__(
6
  self,
7
  max_timesteps: int = 1000,
8
  beta_1: int = 0.0001,
9
  beta_2: int = 0.02
10
  ) -> None:
11
+ self.beta_1 = beta_1
12
+ self.beta_2 = beta_2
13
+ self.max_timesteps = max_timesteps
14
+ self._init_params()
15
+
16
+ def _init_params(self, timesteps: int | None = None):
17
+ self.beta = torch.linspace(self.beta_1, self.beta_2, timesteps or self.max_timesteps)
18
+ self.sqrt_beta = torch.sqrt(self.beta)
19
+ self.alpha = (1 - self.beta)
20
+ self.sqrt_alpha = torch.sqrt(self.alpha)
21
+ self.alpha_hat = torch.cumprod(1 - self.beta, dim=0)
22
+ self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat)
23
+ self.sqrt_one_minus_alpha = torch.sqrt(1 - self.alpha)
24
+ self.sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat)
25
+
26
+ def noising(
27
+ self,
28
+ x_0: torch.Tensor,
29
+ t: torch.Tensor
30
+ ):
31
+ if t.device != x_0.device:
32
+ t = t.to(x_0.device)
33
+ noise = torch.randn_like(x_0, device=x_0.device)
34
+ new_x = self.sqrt_alpha_hat.to(x_0.device)[t][:, None, None, None] * x_0
35
+ new_noise = self.sqrt_one_minus_alpha_hat.to(x_0.device)[t][:, None, None, None] * noise
36
+ return new_x + new_noise, noise
37
+
38
+ @torch.no_grad()
39
+ def sampling_t(
40
+ self,
41
+ x_t: torch.Tensor,
42
+ model,
43
+ labels: torch.Tensor,
44
+ timesteps: int,
45
+ t: int,
46
+ n_samples: int = 16,
47
+ cfg_scale: int = 3,
48
+ ):
49
+ time = torch.full((n_samples,), fill_value=t, device=model.device)
50
+ pred_noise = model(x_t, time, labels)
51
+ if cfg_scale > 0 and labels is not None:
52
+ uncond_pred_noise = model(x_t, time, None)
53
+ pred_noise = torch.lerp(uncond_pred_noise, pred_noise, cfg_scale)
54
+ alpha = self.alpha.to(model.device)[time][:, None, None, None]
55
+ sqrt_alpha = self.sqrt_alpha.to(model.device)[time][:, None, None, None]
56
+ somah = self.sqrt_one_minus_alpha_hat.to(model.device)[time][:, None, None, None]
57
+ sqrt_beta = self.sqrt_beta.to(model.device)[time][:, None, None, None]
58
+ if t > 1:
59
+ noise = torch.randn_like(x_t, device=model.device)
60
+ else:
61
+ noise = torch.zeros_like(x_t, device=model.device)
62
+
63
+ x_t_new = 1 / sqrt_alpha * (x_t - (1-alpha) / somah * pred_noise) + sqrt_beta * noise
64
+ return x_t_new.clamp(-1, 1)
65
+
66
+ @torch.no_grad()
67
+ def sampling(
68
+ self,
69
+ model,
70
+ n_samples: int = 16,
71
+ in_channels: int = 3,
72
+ dim: int = 32,
73
+ timesteps: int = 1000,
74
+ cfg_scale: int = 3,
75
+ labels=None,
76
+ *args, **kwargs
77
+ ):
78
+ if labels is not None:
79
+ n_samples = labels.shape[0]
80
+ model.eval()
81
+ x_t = torch.randn(
82
+ n_samples, in_channels, dim, dim, device=model.device
83
+ )
84
+ step_ratios = self.max_timesteps // timesteps
85
+ all_timesteps = torch.flip(torch.arange(0, timesteps) * step_ratios, dims=(0,))
86
+ for t in all_timesteps:
87
+ x_t = self.sampling_t(x_t=x_t, model=model, labels=labels, t=t, timesteps=timesteps,
88
+ n_samples=n_samples, cfg_scale=cfg_scale)
89
+ model.train()
90
+ x_t = (x_t.clamp(-1, 1) + 1) / 2 * 255. # range [0,255]
91
+ return x_t.type(torch.uint8)
92
+
93
+ @torch.no_grad()
94
+ def sampling_demo(
95
+ self,
96
+ model,
97
+ n_samples: int = 16,
98
+ in_channels: int = 3,
99
+ dim: int = 32,
100
+ timesteps: int = 1000,
101
+ cfg_scale: int = 3,
102
+ labels=None,
103
+ *args, **kwargs
104
+ ):
105
+ if labels is not None:
106
+ n_samples = labels.shape[0]
107
+
108
+ x_t = torch.randn(
109
+ n_samples, in_channels, dim, dim, device=model.device
110
+ )
111
+ model.eval()
112
+ step_ratios = self.max_timesteps // timesteps
113
+ all_timesteps = torch.flip(torch.arange(0, timesteps) * step_ratios, dims=(0,))
114
+ for t in all_timesteps:
115
+ x_t = self.sampling_t(x_t=x_t, model=model, labels=labels, t=t, timesteps=timesteps,
116
+ n_samples=n_samples, cfg_scale=cfg_scale)
117
+ yield ((x_t.clamp(-1, 1) + 1) / 2 * 255).type(torch.uint8)
118
+
119
+
120
+ class DDIMScheduler(DDPMScheduler):
121
+ def __init__(
122
+ self,
123
+ max_timesteps: int = 1000,
124
+ beta_1: int = 0.0001,
125
+ beta_2: int = 0.02
126
+ ) -> None:
127
+ super().__init__(beta_1=beta_1, beta_2=beta_2, max_timesteps=max_timesteps)
128
+ self._init_params()
129
+
130
+ def _init_params(self, timesteps: int | None = None):
131
+ self.beta = torch.linspace(self.beta_1, self.beta_2, timesteps or self.max_timesteps)
132
+ self.sqrt_beta = torch.sqrt(self.beta)
133
+ self.alpha = (1 - self.beta)
134
  self.sqrt_alpha = torch.sqrt(self.alpha)
135
+ self.alpha_hat = torch.cumprod(1 - self.beta, dim=0)
136
  self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat)
137
+ self.sqrt_one_minus_alpha = torch.sqrt(1 - self.alpha)
138
  self.sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat)
139
+ self.alpha_hat_prev = torch.cat([torch.tensor([1.]), self.alpha_hat], dim=0)[:-1]
140
+ self.variance = (1 - self.alpha_hat_prev) / (1 - self.alpha_hat) * \
141
+ (1 - self.alpha_hat / self.alpha_hat_prev)
142
+
143
+ @torch.no_grad()
144
+ def sampling_t(
145
+ self,
146
+ x_t: torch.Tensor, model, t: int,
147
+ timesteps: int,
148
+ labels: torch.Tensor | None = None,
149
+ n_samples: int = 16,
150
+ eta: float = 0.0,
151
+ *args, **kwargs
152
+ ):
153
+ time = torch.full((n_samples,), fill_value=t, device=model.device)
154
+ time_prev = time - self.max_timesteps // timesteps
155
+ pred_noise = model(x_t, time, labels)
156
+
157
+ sqrt_one_minus_alpha_hat = self.sqrt_one_minus_alpha_hat.to(model.device)[time][:, None, None, None]
158
+ sqrt_alpha_hat = self.sqrt_alpha_hat.to(model.device)[time][:, None, None, None]
159
+ alpha_hat_prev = self.alpha_hat[time_prev] if time_prev[0] >= 0 else torch.ones_like(time_prev)
160
+ alpha_hat_prev = alpha_hat_prev.to(model.device)[:, None, None, None]
161
+ sqrt_alpha_hat_prev = torch.sqrt(alpha_hat_prev)
162
+ posterior_std = torch.sqrt(self.variance)[time][:, None, None, None] * eta
163
+
164
+ if t > 0:
165
+ noise = torch.randn_like(x_t, device=model.device)
166
+ else:
167
+ noise = torch.zeros_like(x_t, device=model.device)
168
+
169
+ x_0_pred = (x_t - sqrt_one_minus_alpha_hat * pred_noise) / sqrt_alpha_hat
170
+ x_0_pred = x_0_pred.clamp(-1, 1)
171
+ x_t_direction = torch.sqrt(1. - alpha_hat_prev - posterior_std**2) * pred_noise
172
+ random_noise = posterior_std * noise
173
+ x_t_1 = sqrt_alpha_hat_prev * x_0_pred + x_t_direction + random_noise
174
+
175
+ return x_t_1
176
+
177
 
178
+ if __name__ == "__main__":
179
+ dct = DDIMScheduler().__dict__
180
+ for k in dct.keys():
181
+ if isinstance(dct[k], torch.Tensor):
182
+ print(k, dct[k].shape)
183
+ else:
184
+ print(k, dct[k])
diffusion/model/diffusion/unet.py CHANGED
@@ -45,10 +45,10 @@ class DoubleConv(nn.Module):
45
  mid_channels = out_channels
46
  self.double_conv = nn.Sequential(
47
  nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
48
- nn.GroupNorm(8, mid_channels),
49
  nn.GELU(),
50
  nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
51
- nn.GroupNorm(8, out_channels),
52
  )
53
 
54
  def forward(self, x):
@@ -137,15 +137,20 @@ class UNet(pl.LightningModule):
137
  self.sa3 = SelfAttention(channels=256)
138
 
139
  self.mid1 = DoubleConv(in_channels=256, out_channels=512)
 
140
  self.mid2 = DoubleConv(in_channels=512, out_channels=512)
141
 
142
  self.up1 = UpSample(in_channels=512, out_channels=256)
143
- self.sa4 = SelfAttention(channels=256)
144
  self.up2 = UpSample(in_channels=256, out_channels=128)
145
- self.sa5 = SelfAttention(channels=128)
146
  self.up3 = UpSample(in_channels=128, out_channels=64)
147
- self.sa6 = SelfAttention(channels=64)
148
- self.outc = nn.Conv2d(64, c_out, kernel_size=1)
 
 
 
 
149
 
150
  def pos_encoding(self, t, channels):
151
  inv_freq = 1.0 / (
@@ -168,14 +173,15 @@ class UNet(pl.LightningModule):
168
  x4 = self.sa3(x4)
169
 
170
  x4 = self.mid1(x4)
 
171
  x4 = self.mid2(x4)
172
 
173
  x = self.up1(x4, x3, t)
174
- x = self.sa4(x)
175
- x = self.up2(x, x2, t)
176
  x = self.sa5(x)
177
- x = self.up3(x, x1, t)
178
  x = self.sa6(x)
 
 
179
  output = self.outc(x)
180
  return output
181
 
 
45
  mid_channels = out_channels
46
  self.double_conv = nn.Sequential(
47
  nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
48
+ nn.GroupNorm(32, mid_channels, eps=1e-6, affine=True),
49
  nn.GELU(),
50
  nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
51
+ nn.GroupNorm(32, out_channels, eps=1e-6, affine=True),
52
  )
53
 
54
  def forward(self, x):
 
137
  self.sa3 = SelfAttention(channels=256)
138
 
139
  self.mid1 = DoubleConv(in_channels=256, out_channels=512)
140
+ self.sa4 = SelfAttention(channels=512)
141
  self.mid2 = DoubleConv(in_channels=512, out_channels=512)
142
 
143
  self.up1 = UpSample(in_channels=512, out_channels=256)
144
+ self.sa5 = SelfAttention(channels=256)
145
  self.up2 = UpSample(in_channels=256, out_channels=128)
146
+ self.sa6 = SelfAttention(channels=128)
147
  self.up3 = UpSample(in_channels=128, out_channels=64)
148
+ self.sa7 = SelfAttention(channels=64)
149
+ self.outc = nn.Sequential(
150
+ nn.GroupNorm(32, 64, eps=1e-6, affine=True),
151
+ nn.SiLU(),
152
+ nn.Conv2d(64, c_out, kernel_size=3, padding=1)
153
+ )
154
 
155
  def pos_encoding(self, t, channels):
156
  inv_freq = 1.0 / (
 
173
  x4 = self.sa3(x4)
174
 
175
  x4 = self.mid1(x4)
176
+ x4 = self.sa4(x4)
177
  x4 = self.mid2(x4)
178
 
179
  x = self.up1(x4, x3, t)
 
 
180
  x = self.sa5(x)
181
+ x = self.up2(x, x2, t)
182
  x = self.sa6(x)
183
+ x = self.up3(x, x1, t)
184
+ x = self.sa7(x)
185
  output = self.outc(x)
186
  return output
187
 
diffusion/model/ldm/model.py CHANGED
@@ -2,4 +2,5 @@ import torch
2
  import pytorch_lightning as pl
3
 
4
  class LatentDiffusionModel(pl.LightningModule):
 
5
  pass
 
2
  import pytorch_lightning as pl
3
 
4
  class LatentDiffusionModel(pl.LightningModule):
5
+ # TODO
6
  pass
diffusion/model/ldm/tests/__init__.py DELETED
File without changes
diffusion/tests/__init__.py DELETED
File without changes
diffusion/train/__main__.py CHANGED
@@ -20,6 +20,10 @@ def main():
20
  '--data_dir', '-dd', type=str, default='./data/',
21
  help='model name'
22
  )
 
 
 
 
23
  parser.add_argument(
24
  '--max_epochs', '-me', type=int, default=200,
25
  help='max epoch'
@@ -117,7 +121,8 @@ def main():
117
  batch_size=args.batch_size,
118
  num_workers=args.num_workers,
119
  seed=args.seed,
120
- train_ratio=args.train_ratio
 
121
  )
122
 
123
  # MODEL
@@ -129,7 +134,8 @@ def main():
129
  max_timesteps=args.timesteps,
130
  dim=img_dim,
131
  num_classes=num_classes,
132
- n_samples=args.n_samples
 
133
  )
134
 
135
  # CALLBACK
 
20
  '--data_dir', '-dd', type=str, default='./data/',
21
  help='model name'
22
  )
23
+ parser.add_argument(
24
+ '--mode', type=str, default='ddim',
25
+ help='sampling mode'
26
+ )
27
  parser.add_argument(
28
  '--max_epochs', '-me', type=int, default=200,
29
  help='max epoch'
 
121
  batch_size=args.batch_size,
122
  num_workers=args.num_workers,
123
  seed=args.seed,
124
+ train_ratio=args.train_ratio,
125
+ img_dim=img_dim
126
  )
127
 
128
  # MODEL
 
134
  max_timesteps=args.timesteps,
135
  dim=img_dim,
136
  num_classes=num_classes,
137
+ n_samples=args.n_samples,
138
+ mode=args.mode
139
  )
140
 
141
  # CALLBACK