sgoel30 commited on
Commit
a43e004
·
verified ·
1 Parent(s): 62cf524

Dependencies for MDLM

Browse files
Files changed (2) hide show
  1. models/dit.py +370 -0
  2. models/ema.py +97 -0
models/dit.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing
3
+
4
+ import flash_attn
5
+ import flash_attn.layers.rotary
6
+ import huggingface_hub
7
+ import omegaconf
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+
13
+ # Flags required to enable jit fusion kernels
14
+ torch._C._jit_set_profiling_mode(False)
15
+ torch._C._jit_set_profiling_executor(False)
16
+ torch._C._jit_override_can_fuse_on_cpu(True)
17
+ torch._C._jit_override_can_fuse_on_gpu(True)
18
+
19
+
20
+ def bias_dropout_add_scale(
21
+ x: torch.Tensor,
22
+ bias: typing.Optional[torch.Tensor],
23
+ scale: torch.Tensor,
24
+ residual: typing.Optional[torch.Tensor],
25
+ prob: float,
26
+ training: bool) -> torch.Tensor:
27
+ if bias is not None:
28
+ out = scale * F.dropout(x + bias, p=prob, training=training)
29
+ else:
30
+ out = scale * F.dropout(x, p=prob, training=training)
31
+
32
+ if residual is not None:
33
+ out = residual + out
34
+ return out
35
+
36
+
37
+ def get_bias_dropout_add_scale(training):
38
+ def _bias_dropout_add(x, bias, scale, residual, prob):
39
+ return bias_dropout_add_scale(
40
+ x, bias, scale, residual, prob, training)
41
+
42
+ return _bias_dropout_add
43
+
44
+
45
+ # function overload
46
+ def modulate(x: torch.Tensor,
47
+ shift: torch.Tensor,
48
+ scale: torch.Tensor) -> torch.Tensor:
49
+ return x * (1 + scale) + shift
50
+
51
+
52
+ @torch.jit.script
53
+ def bias_dropout_add_scale_fused_train(
54
+ x: torch.Tensor,
55
+ bias: typing.Optional[torch.Tensor],
56
+ scale: torch.Tensor,
57
+ residual: typing.Optional[torch.Tensor],
58
+ prob: float) -> torch.Tensor:
59
+ return bias_dropout_add_scale(
60
+ x, bias, scale, residual, prob, True)
61
+
62
+
63
+ @torch.jit.script
64
+ def bias_dropout_add_scale_fused_inference(
65
+ x: torch.Tensor,
66
+ bias: typing.Optional[torch.Tensor],
67
+ scale: torch.Tensor,
68
+ residual: typing.Optional[torch.Tensor],
69
+ prob: float) -> torch.Tensor:
70
+ return bias_dropout_add_scale(
71
+ x, bias, scale, residual, prob, False)
72
+
73
+
74
+ @torch.jit.script
75
+ def modulate_fused(x: torch.Tensor,
76
+ shift: torch.Tensor,
77
+ scale: torch.Tensor) -> torch.Tensor:
78
+ return modulate(x, shift, scale)
79
+
80
+
81
+ class Rotary(torch.nn.Module):
82
+ def __init__(self, dim, base=10_000):
83
+ super().__init__()
84
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
85
+ self.register_buffer('inv_freq', inv_freq)
86
+ self.seq_len_cached = None
87
+ self.cos_cached = None
88
+ self.sin_cached = None
89
+
90
+ def forward(self, x, seq_dim=1):
91
+ seq_len = x.shape[seq_dim]
92
+ if seq_len != self.seq_len_cached:
93
+ self.seq_len_cached = seq_len
94
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
95
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
96
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
97
+ # dims are: batch, seq_len, qkv, head, dim
98
+ self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)
99
+ self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)
100
+ # This makes the transformation on v an identity.
101
+ self.cos_cached[:,:,2,:,:].fill_(1.)
102
+ self.sin_cached[:,:,2,:,:].fill_(0.)
103
+
104
+ return self.cos_cached, self.sin_cached
105
+
106
+
107
+ def rotate_half(x):
108
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
109
+ return torch.cat((-x2, x1), dim=-1)
110
+
111
+
112
+ def apply_rotary_pos_emb(qkv, cos, sin):
113
+ cos = cos[0,:,0,0,:cos.shape[-1]//2]
114
+ sin = sin[0,:,0,0,:sin.shape[-1]//2]
115
+ return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
116
+
117
+
118
+ # function overload
119
+ def modulate(x, shift, scale):
120
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
121
+
122
+
123
+ #################################################################################
124
+ # Layers #
125
+ #################################################################################
126
+ class LayerNorm(nn.Module):
127
+ def __init__(self, dim):
128
+ super().__init__()
129
+ self.weight = nn.Parameter(torch.ones([dim]))
130
+ self.dim = dim
131
+ def forward(self, x):
132
+ with torch.cuda.amp.autocast(enabled=False):
133
+ x = F.layer_norm(x.float(), [self.dim])
134
+ return x * self.weight[None,None,:]
135
+
136
+
137
+ def residual_linear(x, W, x_skip, residual_scale):
138
+ """x_skip + residual_scale * W @ x"""
139
+ dim_out, dim_in = W.shape[0], W.shape[1]
140
+ return torch.addmm(
141
+ x_skip.view(-1, dim_out),
142
+ x.view(-1, dim_in),
143
+ W.T,
144
+ alpha=residual_scale).view(*x.shape[:-1], dim_out)
145
+
146
+
147
+ #################################################################################
148
+ # Embedding Layers for Timesteps and Class Labels #
149
+ #################################################################################
150
+ class TimestepEmbedder(nn.Module):
151
+ """
152
+ Embeds scalar timesteps into vector representations.
153
+ """
154
+ def __init__(self, hidden_size, frequency_embedding_size=256):
155
+ super().__init__()
156
+ self.mlp = nn.Sequential(
157
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
158
+ nn.SiLU(),
159
+ nn.Linear(hidden_size, hidden_size, bias=True))
160
+ self.frequency_embedding_size = frequency_embedding_size
161
+
162
+ @staticmethod
163
+ def timestep_embedding(t, dim, max_period=10000):
164
+ """
165
+ Create sinusoidal timestep embeddings.
166
+ :param t: a 1-D Tensor of N indices, one per batch element.
167
+ These may be fractional.
168
+ :param dim: the dimension of the output.
169
+ :param max_period: controls the minimum frequency of the embeddings.
170
+ :return: an (N, D) Tensor of positional embeddings.
171
+ """
172
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
173
+ half = dim // 2
174
+ freqs = torch.exp(
175
+ - math.log(max_period)
176
+ * torch.arange(start=0, end=half, dtype=torch.float32)
177
+ / half).to(device=t.device)
178
+ args = t[:, None].float() * freqs[None]
179
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
180
+ if dim % 2:
181
+ embedding = torch.cat(
182
+ [embedding,
183
+ torch.zeros_like(embedding[:, :1])], dim=-1)
184
+ return embedding
185
+
186
+ def forward(self, t):
187
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
188
+ t_emb = self.mlp(t_freq)
189
+ return t_emb
190
+
191
+
192
+ class LabelEmbedder(nn.Module):
193
+ """Embeds class labels into vector representations.
194
+
195
+ Also handles label dropout for classifier-free guidance.
196
+ """
197
+ def __init__(self, num_classes, cond_size):
198
+ super().__init__()
199
+ self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
200
+ self.num_classes = num_classes
201
+
202
+ # TODO think of initializing with 0.02 std deviation like in original DiT paper
203
+
204
+ def forward(self, labels):
205
+ embeddings = self.embedding_table(labels)
206
+ return embeddings
207
+
208
+
209
+ #################################################################################
210
+ # Core Model #
211
+ #################################################################################
212
+
213
+
214
+ class DDiTBlock(nn.Module):
215
+ def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
216
+ super().__init__()
217
+ self.n_heads = n_heads
218
+
219
+ self.norm1 = LayerNorm(dim)
220
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
221
+ self.attn_out = nn.Linear(dim, dim, bias=False)
222
+ self.dropout1 = nn.Dropout(dropout)
223
+
224
+ self.norm2 = LayerNorm(dim)
225
+ self.mlp = nn.Sequential(
226
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
227
+ nn.GELU(approximate='tanh'),
228
+ nn.Linear(mlp_ratio * dim, dim, bias=True))
229
+ self.dropout2 = nn.Dropout(dropout)
230
+ self.dropout = dropout
231
+
232
+ self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
233
+ self.adaLN_modulation.weight.data.zero_()
234
+ self.adaLN_modulation.bias.data.zero_()
235
+
236
+
237
+ def _get_bias_dropout_scale(self):
238
+ if self.training:
239
+ return bias_dropout_add_scale_fused_train
240
+ else:
241
+ return bias_dropout_add_scale_fused_inference
242
+
243
+
244
+ def forward(self, x, rotary_cos_sin, c, seqlens=None):
245
+ batch_size, seq_len = x.shape[0], x.shape[1]
246
+
247
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
248
+
249
+ (shift_msa, scale_msa, gate_msa, shift_mlp,
250
+ scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
251
+
252
+ # attention operation
253
+ x_skip = x
254
+ x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
255
+
256
+ qkv = self.attn_qkv(x)
257
+ qkv = rearrange(qkv,
258
+ 'b s (three h d) -> b s three h d',
259
+ three=3,
260
+ h=self.n_heads)
261
+ with torch.cuda.amp.autocast(enabled=False):
262
+ cos, sin = rotary_cos_sin
263
+ qkv = apply_rotary_pos_emb(
264
+ qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
265
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
266
+ if seqlens is None:
267
+ cu_seqlens = torch.arange(
268
+ 0, (batch_size + 1) * seq_len, step=seq_len,
269
+ dtype=torch.int32, device=qkv.device)
270
+ else:
271
+ cu_seqlens = seqlens.cumsum(-1)
272
+ x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
273
+ qkv, cu_seqlens, seq_len, 0., causal=False)
274
+
275
+ x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size)
276
+
277
+ x = bias_dropout_scale_fn(self.attn_out(x),
278
+ None,
279
+ gate_msa,
280
+ x_skip,
281
+ self.dropout)
282
+
283
+ # mlp operation
284
+ x = bias_dropout_scale_fn(
285
+ self.mlp(modulate_fused(
286
+ self.norm2(x), shift_mlp, scale_mlp)),
287
+ None, gate_mlp, x, self.dropout)
288
+ return x
289
+
290
+
291
+
292
+ class EmbeddingLayer(nn.Module):
293
+ def __init__(self, dim, vocab_dim):
294
+ super().__init__()
295
+ self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
296
+ torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
297
+
298
+ def forward(self, x):
299
+ return self.embedding[x]
300
+
301
+
302
+ class DDitFinalLayer(nn.Module):
303
+ def __init__(self, hidden_size, out_channels, cond_dim):
304
+ super().__init__()
305
+ self.norm_final = LayerNorm(hidden_size)
306
+ self.linear = nn.Linear(hidden_size, out_channels)
307
+ self.linear.weight.data.zero_()
308
+ self.linear.bias.data.zero_()
309
+
310
+ self.adaLN_modulation = nn.Linear(cond_dim,
311
+ 2 * hidden_size,
312
+ bias=True)
313
+ self.adaLN_modulation.weight.data.zero_()
314
+ self.adaLN_modulation.bias.data.zero_()
315
+
316
+
317
+ def forward(self, x, c):
318
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
319
+ x = modulate_fused(self.norm_final(x), shift, scale)
320
+ x = self.linear(x)
321
+ return x
322
+
323
+
324
+ class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
325
+ def __init__(self, config, vocab_size: int):
326
+ super().__init__()
327
+ if type(config) == dict:
328
+ config = omegaconf.OmegaConf.create(config)
329
+
330
+ self.config = config
331
+ self.vocab_size = vocab_size
332
+
333
+ self.vocab_embed = EmbeddingLayer(config.model.hidden_size,
334
+ vocab_size)
335
+ self.sigma_map = TimestepEmbedder(config.model.cond_dim)
336
+ self.rotary_emb = Rotary(
337
+ config.model.hidden_size // config.model.n_heads)
338
+
339
+ blocks = []
340
+ for _ in range(config.model.n_blocks):
341
+ blocks.append(DDiTBlock(config.model.hidden_size,
342
+ config.model.n_heads,
343
+ config.model.cond_dim,
344
+ dropout=config.model.dropout))
345
+ self.blocks = nn.ModuleList(blocks)
346
+
347
+ self.output_layer = DDitFinalLayer(
348
+ config.model.hidden_size,
349
+ vocab_size,
350
+ config.model.cond_dim)
351
+ self.scale_by_sigma = config.model.scale_by_sigma
352
+
353
+ def _get_bias_dropout_scale(self):
354
+ if self.training:
355
+ return bias_dropout_add_scale_fused_train
356
+ else:
357
+ return bias_dropout_add_scale_fused_inference
358
+
359
+ def forward(self, indices, sigma):
360
+ x = self.vocab_embed(indices)
361
+ c = F.silu(self.sigma_map(sigma))
362
+
363
+ rotary_cos_sin = self.rotary_emb(x)
364
+
365
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
366
+ for i in range(len(self.blocks)):
367
+ x = self.blocks[i](x, rotary_cos_sin, c, seqlens=None)
368
+ x = self.output_layer(x, c)
369
+
370
+ return x
models/ema.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ExponentialMovingAverage:
5
+ """
6
+ Maintains (exponential) moving average of a set of parameters.
7
+ """
8
+
9
+ def __init__(self, parameters, decay, use_num_updates=True):
10
+ """
11
+ Args:
12
+ parameters: Iterable of `torch.nn.Parameter`; usually the result of
13
+ `model.parameters()`.
14
+ decay: The exponential decay.
15
+ use_num_updates: Whether to use number of updates when computing
16
+ averages.
17
+ """
18
+ if decay < 0.0 or decay > 1.0:
19
+ raise ValueError('Decay must be between 0 and 1')
20
+ self.decay = decay
21
+ self.num_updates = 0 if use_num_updates else None
22
+ self.shadow_params = [p.clone().detach()
23
+ for p in parameters if p.requires_grad]
24
+ self.collected_params = []
25
+
26
+ def move_shadow_params_to_device(self, device):
27
+ self.shadow_params = [i.to(device) for i in self.shadow_params]
28
+
29
+ def update(self, parameters):
30
+ """
31
+ Update currently maintained parameters.
32
+
33
+ Call this every time the parameters are updated, such as the result of
34
+ the `optimizer.step()` call.
35
+
36
+ Args:
37
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
38
+ parameters used to initialize this object.
39
+ """
40
+ decay = self.decay
41
+ if self.num_updates is not None:
42
+ self.num_updates += 1
43
+ decay = min(decay, (1 + self.num_updates) /
44
+ (10 + self.num_updates))
45
+ one_minus_decay = 1.0 - decay
46
+ with torch.no_grad():
47
+ parameters = [p for p in parameters if p.requires_grad]
48
+ for s_param, param in zip(self.shadow_params, parameters):
49
+ s_param.sub_(one_minus_decay * (s_param - param))
50
+
51
+ def copy_to(self, parameters):
52
+ """
53
+ Copy current parameters into given collection of parameters.
54
+
55
+ Args:
56
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
57
+ updated with the stored moving averages.
58
+ """
59
+ parameters = [p for p in parameters if p.requires_grad]
60
+ for s_param, param in zip(self.shadow_params, parameters):
61
+ if param.requires_grad:
62
+ param.data.copy_(s_param.data)
63
+
64
+ def store(self, parameters):
65
+ """
66
+ Save the current parameters for restoring later.
67
+
68
+ Args:
69
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70
+ temporarily stored.
71
+ """
72
+ self.collected_params = [param.clone() for param in parameters]
73
+
74
+ def restore(self, parameters):
75
+ """
76
+ Restore the parameters stored with the `store` method.
77
+ Useful to validate the model with EMA parameters without affecting the
78
+ original optimization process. Store the parameters before the
79
+ `copy_to` method. After validation (or model saving), use this to
80
+ restore the former parameters.
81
+
82
+ Args:
83
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
84
+ updated with the stored parameters.
85
+ """
86
+ for c_param, param in zip(self.collected_params, parameters):
87
+ param.data.copy_(c_param.data)
88
+
89
+ def state_dict(self):
90
+ return dict(decay=self.decay,
91
+ num_updates=self.num_updates,
92
+ shadow_params=self.shadow_params)
93
+
94
+ def load_state_dict(self, state_dict):
95
+ self.decay = state_dict['decay']
96
+ self.num_updates = state_dict['num_updates']
97
+ self.shadow_params = state_dict['shadow_params']