Antoni Bigata commited on
Commit
b5ce381
Β·
1 Parent(s): 17d618b

first commit

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. WavLM.py +854 -0
  2. WavLM_modules.py +765 -0
  3. __pycache__/WavLM.cpython-311.pyc +0 -0
  4. __pycache__/WavLM_modules.cpython-311.pyc +0 -0
  5. __pycache__/data_utils.cpython-311.pyc +0 -0
  6. __pycache__/dino_game.cpython-311.pyc +0 -0
  7. __pycache__/inference_functions.cpython-311.pyc +0 -0
  8. __pycache__/landmarks_extractor.cpython-311.pyc +0 -0
  9. __pycache__/utils.cpython-311.pyc +0 -0
  10. __pycache__/vae_wrapper.cpython-311.pyc +0 -0
  11. __pycache__/wordle_game.cpython-311.pyc +0 -0
  12. app.py +978 -0
  13. data_utils.py +635 -0
  14. inference_functions.py +493 -0
  15. landmarks_extractor.py +35 -0
  16. sgm/__init__.py +4 -0
  17. sgm/__pycache__/__init__.cpython-311.pyc +0 -0
  18. sgm/__pycache__/lr_scheduler.cpython-311.pyc +0 -0
  19. sgm/__pycache__/util.cpython-311.pyc +0 -0
  20. sgm/callbacks/__pycache__/video_logger.cpython-311.pyc +0 -0
  21. sgm/callbacks/custom_ddp.py +10 -0
  22. sgm/callbacks/image_logger.py +193 -0
  23. sgm/callbacks/setup_callback.py +86 -0
  24. sgm/callbacks/video_logger.py +294 -0
  25. sgm/data/__init__.py +1 -0
  26. sgm/data/__pycache__/__init__.cpython-311.pyc +0 -0
  27. sgm/data/__pycache__/data_utils.cpython-311.pyc +0 -0
  28. sgm/data/__pycache__/mask.cpython-311.pyc +0 -0
  29. sgm/data/__pycache__/video_datamodule_latent.cpython-311.pyc +0 -0
  30. sgm/data/__pycache__/video_dataset_latent.cpython-311.pyc +0 -0
  31. sgm/data/data_utils.py +561 -0
  32. sgm/data/dataset.py +80 -0
  33. sgm/data/mask.py +525 -0
  34. sgm/data/video_datamodule_latent.py +138 -0
  35. sgm/data/video_dataset_latent.py +780 -0
  36. sgm/inference/api.py +385 -0
  37. sgm/inference/helpers.py +305 -0
  38. sgm/lr_scheduler.py +135 -0
  39. sgm/models/__init__.py +2 -0
  40. sgm/models/__pycache__/__init__.cpython-311.pyc +0 -0
  41. sgm/models/__pycache__/autoencoder.cpython-311.pyc +0 -0
  42. sgm/models/__pycache__/diffusion.cpython-311.pyc +0 -0
  43. sgm/models/autoencoder.py +615 -0
  44. sgm/models/diffusion.py +747 -0
  45. sgm/modules/__init__.py +6 -0
  46. sgm/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  47. sgm/modules/__pycache__/attention.cpython-311.pyc +0 -0
  48. sgm/modules/__pycache__/ema.cpython-311.pyc +0 -0
  49. sgm/modules/__pycache__/video_attention.cpython-311.pyc +0 -0
  50. sgm/modules/attention.py +889 -0
WavLM.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import logging
12
+ from typing import List, Optional, Tuple
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.nn import LayerNorm
20
+ from einops import rearrange
21
+ import requests
22
+ from clint.textui import progress
23
+ import os
24
+ from WavLM_modules import (
25
+ Fp32GroupNorm,
26
+ Fp32LayerNorm,
27
+ GradMultiply,
28
+ MultiheadAttention,
29
+ SamePad,
30
+ init_bert_params,
31
+ get_activation_fn,
32
+ TransposeLast,
33
+ GLU_Linear,
34
+ )
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class WavLM_wrapper(nn.Module):
40
+ def __init__(
41
+ self, model_size="Base+", feed_as_frames=True, merge_type="cat", model_path=None
42
+ ):
43
+ super().__init__()
44
+ assert model_size in ["Base+", "Large"]
45
+ if model_path is None:
46
+ model_path = os.path.join(
47
+ os.path.dirname(__file__), f"WavLM-{model_size}.pt"
48
+ )
49
+ if not os.path.exists(model_path):
50
+ self.download_model(model_path, model_size)
51
+ checkpoint = torch.load(model_path)
52
+ cfg = WavLMConfig(checkpoint["cfg"])
53
+ self.cfg = cfg
54
+ self.model = WavLM(cfg)
55
+ self.model.load_state_dict(checkpoint["model"])
56
+ self.model.eval()
57
+ for param in self.model.parameters():
58
+ param.requires_grad = False
59
+ self.code_size = 768 * 2 if merge_type == "cat" else 768
60
+ self.merge_type = merge_type
61
+ self.feed_as_frames = feed_as_frames
62
+
63
+ def download_model(self, out_path, size: str = "Base+"):
64
+ print("Downloading model...")
65
+ if size == "Base+":
66
+ url = "https://valle.blob.core.windows.net/share/wavlm/WavLM-Base+.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D"
67
+ else:
68
+ url = "https://valle.blob.core.windows.net/share/wavlm/WavLM-Large.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D"
69
+ r = requests.get(url, allow_redirects=True, stream=True)
70
+ with open(out_path, "wb") as f:
71
+ total_length = int(r.headers.get("content-length"))
72
+ for chunk in progress.bar(
73
+ r.iter_content(chunk_size=1024), expected_size=(total_length / 1024) + 1
74
+ ):
75
+ if chunk:
76
+ f.write(chunk)
77
+ f.flush()
78
+ print("Model downloaded to %s" % out_path)
79
+
80
+ def forward(self, x):
81
+ """
82
+ Args:
83
+ x: (batch, n_frames, audio_features)
84
+ """
85
+ T = x.shape[1]
86
+
87
+ if self.feed_as_frames:
88
+ x = rearrange(x, "b f d -> (b f) d")
89
+ else:
90
+ x = rearrange(x, "b ... -> b (...)")
91
+
92
+ if self.cfg.normalize:
93
+ x = torch.nn.functional.layer_norm(x, x.shape)
94
+
95
+ x = self.model.extract_features(x)[0] # B, new_features, C
96
+ if self.feed_as_frames:
97
+ x = rearrange(x, "(b f) d c -> b f d c", f=T)
98
+ else:
99
+ x = torch.nn.functional.interpolate(
100
+ x.permute(0, 2, 1), T * 2, mode="nearest"
101
+ )
102
+ x = rearrange(x, "b c (f d) -> b f d c", d=2)
103
+
104
+ if self.merge_type == "cat":
105
+ if x.dim() == 3:
106
+ return rearrange(x, "b d c -> b (d c)")
107
+ return rearrange(x, "b f d c -> b f (d c)")
108
+ elif self.merge_type == "sum":
109
+ return x.sum(dim=-2)
110
+ elif self.merge_type == "mean":
111
+ return x.mean(dim=-2)
112
+ elif self.merge_type == "None":
113
+ return x
114
+ else:
115
+ raise NotImplementedError
116
+
117
+
118
+ def compute_mask_indices(
119
+ shape: Tuple[int, int],
120
+ padding_mask: Optional[torch.Tensor],
121
+ mask_prob: float,
122
+ mask_length: int,
123
+ mask_type: str = "static",
124
+ mask_other: float = 0.0,
125
+ min_masks: int = 0,
126
+ no_overlap: bool = False,
127
+ min_space: int = 0,
128
+ ) -> np.ndarray:
129
+ """
130
+ Computes random mask spans for a given shape
131
+ Args:
132
+ shape: the the shape for which to compute masks.
133
+ should be of size 2 where first element is batch size and 2nd is timesteps
134
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
135
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
136
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
137
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
138
+ mask_type: how to compute mask lengths
139
+ static = fixed size
140
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
141
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
142
+ poisson = sample from possion distribution with lambda = mask length
143
+ min_masks: minimum number of masked spans
144
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
145
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
146
+ """
147
+
148
+ bsz, all_sz = shape
149
+ mask = np.full((bsz, all_sz), False)
150
+
151
+ all_num_mask = int(
152
+ # add a random number for probabilistic rounding
153
+ mask_prob * all_sz / float(mask_length) + np.random.rand()
154
+ )
155
+
156
+ all_num_mask = max(min_masks, all_num_mask)
157
+
158
+ mask_idcs = []
159
+ for i in range(bsz):
160
+ if padding_mask is not None:
161
+ sz = all_sz - padding_mask[i].long().sum().item()
162
+ num_mask = int(
163
+ # add a random number for probabilistic rounding
164
+ mask_prob * sz / float(mask_length) + np.random.rand()
165
+ )
166
+ num_mask = max(min_masks, num_mask)
167
+ else:
168
+ sz = all_sz
169
+ num_mask = all_num_mask
170
+
171
+ if mask_type == "static":
172
+ lengths = np.full(num_mask, mask_length)
173
+ elif mask_type == "uniform":
174
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
175
+ elif mask_type == "normal":
176
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
177
+ lengths = [max(1, int(round(x))) for x in lengths]
178
+ elif mask_type == "poisson":
179
+ lengths = np.random.poisson(mask_length, size=num_mask)
180
+ lengths = [int(round(x)) for x in lengths]
181
+ else:
182
+ raise Exception("unknown mask selection " + mask_type)
183
+
184
+ if sum(lengths) == 0:
185
+ lengths[0] = min(mask_length, sz - 1)
186
+
187
+ if no_overlap:
188
+ mask_idc = []
189
+
190
+ def arrange(s, e, length, keep_length):
191
+ span_start = np.random.randint(s, e - length)
192
+ mask_idc.extend(span_start + i for i in range(length))
193
+
194
+ new_parts = []
195
+ if span_start - s - min_space >= keep_length:
196
+ new_parts.append((s, span_start - min_space + 1))
197
+ if e - span_start - keep_length - min_space > keep_length:
198
+ new_parts.append((span_start + length + min_space, e))
199
+ return new_parts
200
+
201
+ parts = [(0, sz)]
202
+ min_length = min(lengths)
203
+ for length in sorted(lengths, reverse=True):
204
+ lens = np.fromiter(
205
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
206
+ np.int,
207
+ )
208
+ l_sum = np.sum(lens)
209
+ if l_sum == 0:
210
+ break
211
+ probs = lens / np.sum(lens)
212
+ c = np.random.choice(len(parts), p=probs)
213
+ s, e = parts.pop(c)
214
+ parts.extend(arrange(s, e, length, min_length))
215
+ mask_idc = np.asarray(mask_idc)
216
+ else:
217
+ min_len = min(lengths)
218
+ if sz - min_len <= num_mask:
219
+ min_len = sz - num_mask - 1
220
+
221
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
222
+
223
+ mask_idc = np.asarray(
224
+ [
225
+ mask_idc[j] + offset
226
+ for j in range(len(mask_idc))
227
+ for offset in range(lengths[j])
228
+ ]
229
+ )
230
+
231
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
232
+
233
+ min_len = min([len(m) for m in mask_idcs])
234
+ for i, mask_idc in enumerate(mask_idcs):
235
+ if len(mask_idc) > min_len:
236
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
237
+ mask[i, mask_idc] = True
238
+
239
+ return mask
240
+
241
+
242
+ class WavLMConfig:
243
+ def __init__(self, cfg=None):
244
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
245
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
246
+
247
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
248
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
249
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
250
+ self.activation_fn: str = "gelu" # activation function to use
251
+
252
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
253
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
254
+ self.conv_bias: bool = False # include bias in conv encoder
255
+ self.feature_grad_mult: float = (
256
+ 1.0 # multiply feature extractor var grads by this
257
+ )
258
+
259
+ self.normalize: bool = (
260
+ False # normalize input to have 0 mean and unit variance during training
261
+ )
262
+
263
+ # dropouts
264
+ self.dropout: float = 0.1 # dropout probability for the transformer
265
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
266
+ self.activation_dropout: float = (
267
+ 0.0 # dropout probability after activation in FFN
268
+ )
269
+ self.encoder_layerdrop: float = (
270
+ 0.0 # probability of dropping a tarnsformer layer
271
+ )
272
+ self.dropout_input: float = (
273
+ 0.0 # dropout to apply to the input (after feat extr)
274
+ )
275
+ self.dropout_features: float = (
276
+ 0.0 # dropout to apply to the features (after feat extr)
277
+ )
278
+
279
+ # masking
280
+ self.mask_length: int = 10 # mask length
281
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
282
+ self.mask_selection: str = "static" # how to choose mask length
283
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
284
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
285
+ self.mask_min_space: int = (
286
+ 1 # min space between spans (if no overlap is enabled)
287
+ )
288
+
289
+ # channel masking
290
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
291
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
292
+ self.mask_channel_selection: str = (
293
+ "static" # how to choose mask length for channel masking
294
+ )
295
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
296
+ self.no_mask_channel_overlap: bool = (
297
+ False # whether to allow channel masks to overlap
298
+ )
299
+ self.mask_channel_min_space: int = (
300
+ 1 # min space between spans (if no overlap is enabled)
301
+ )
302
+
303
+ # positional embeddings
304
+ self.conv_pos: int = (
305
+ 128 # number of filters for convolutional positional embeddings
306
+ )
307
+ self.conv_pos_groups: int = (
308
+ 16 # number of groups for convolutional positional embedding
309
+ )
310
+
311
+ # relative position embedding
312
+ self.relative_position_embedding: bool = (
313
+ False # apply relative position embedding
314
+ )
315
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
316
+ self.max_distance: int = (
317
+ 1280 # maximum distance for relative position embedding
318
+ )
319
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
320
+
321
+ if cfg is not None:
322
+ self.update(cfg)
323
+
324
+ def update(self, cfg: dict):
325
+ self.__dict__.update(cfg)
326
+
327
+
328
+ class WavLM(nn.Module):
329
+ def __init__(
330
+ self,
331
+ cfg: WavLMConfig,
332
+ ) -> None:
333
+ super().__init__()
334
+ logger.info(f"WavLM Config: {cfg.__dict__}")
335
+
336
+ self.cfg = cfg
337
+ feature_enc_layers = eval(cfg.conv_feature_layers)
338
+ self.embed = feature_enc_layers[-1][0]
339
+
340
+ self.feature_extractor = ConvFeatureExtractionModel(
341
+ conv_layers=feature_enc_layers,
342
+ dropout=0.0,
343
+ mode=cfg.extractor_mode,
344
+ conv_bias=cfg.conv_bias,
345
+ )
346
+
347
+ self.post_extract_proj = (
348
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
349
+ if self.embed != cfg.encoder_embed_dim
350
+ else None
351
+ )
352
+
353
+ self.mask_prob = cfg.mask_prob
354
+ self.mask_selection = cfg.mask_selection
355
+ self.mask_other = cfg.mask_other
356
+ self.mask_length = cfg.mask_length
357
+ self.no_mask_overlap = cfg.no_mask_overlap
358
+ self.mask_min_space = cfg.mask_min_space
359
+
360
+ self.mask_channel_prob = cfg.mask_channel_prob
361
+ self.mask_channel_selection = cfg.mask_channel_selection
362
+ self.mask_channel_other = cfg.mask_channel_other
363
+ self.mask_channel_length = cfg.mask_channel_length
364
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
365
+ self.mask_channel_min_space = cfg.mask_channel_min_space
366
+
367
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
368
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
369
+
370
+ self.feature_grad_mult = cfg.feature_grad_mult
371
+
372
+ self.mask_emb = nn.Parameter(
373
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
374
+ )
375
+
376
+ self.encoder = TransformerEncoder(cfg)
377
+ self.layer_norm = LayerNorm(self.embed)
378
+
379
+ def apply_mask(self, x, padding_mask):
380
+ B, T, C = x.shape
381
+ if self.mask_prob > 0:
382
+ mask_indices = compute_mask_indices(
383
+ (B, T),
384
+ padding_mask,
385
+ self.mask_prob,
386
+ self.mask_length,
387
+ self.mask_selection,
388
+ self.mask_other,
389
+ min_masks=2,
390
+ no_overlap=self.no_mask_overlap,
391
+ min_space=self.mask_min_space,
392
+ )
393
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
394
+ x[mask_indices] = self.mask_emb
395
+ else:
396
+ mask_indices = None
397
+
398
+ if self.mask_channel_prob > 0:
399
+ mask_channel_indices = compute_mask_indices(
400
+ (B, C),
401
+ None,
402
+ self.mask_channel_prob,
403
+ self.mask_channel_length,
404
+ self.mask_channel_selection,
405
+ self.mask_channel_other,
406
+ no_overlap=self.no_mask_channel_overlap,
407
+ min_space=self.mask_channel_min_space,
408
+ )
409
+ mask_channel_indices = (
410
+ torch.from_numpy(mask_channel_indices)
411
+ .to(x.device)
412
+ .unsqueeze(1)
413
+ .expand(-1, T, -1)
414
+ )
415
+ x[mask_channel_indices] = 0
416
+
417
+ return x, mask_indices
418
+
419
+ def forward_padding_mask(
420
+ self,
421
+ features: torch.Tensor,
422
+ padding_mask: torch.Tensor,
423
+ ) -> torch.Tensor:
424
+ extra = padding_mask.size(1) % features.size(1)
425
+ if extra > 0:
426
+ padding_mask = padding_mask[:, :-extra]
427
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
428
+ padding_mask = padding_mask.all(-1)
429
+ return padding_mask
430
+
431
+ def extract_features(
432
+ self,
433
+ source: torch.Tensor,
434
+ padding_mask: Optional[torch.Tensor] = None,
435
+ mask: bool = False,
436
+ ret_conv: bool = False,
437
+ output_layer: Optional[int] = None,
438
+ ret_layer_results: bool = False,
439
+ ):
440
+ if self.feature_grad_mult > 0:
441
+ features = self.feature_extractor(source)
442
+ if self.feature_grad_mult != 1.0:
443
+ features = GradMultiply.apply(features, self.feature_grad_mult)
444
+ else:
445
+ with torch.no_grad():
446
+ features = self.feature_extractor(source)
447
+
448
+ features = features.transpose(1, 2)
449
+ features = self.layer_norm(features)
450
+
451
+ if padding_mask is not None:
452
+ padding_mask = self.forward_padding_mask(features, padding_mask)
453
+
454
+ if self.post_extract_proj is not None:
455
+ features = self.post_extract_proj(features)
456
+
457
+ features = self.dropout_input(features)
458
+
459
+ if mask:
460
+ x, mask_indices = self.apply_mask(features, padding_mask)
461
+ else:
462
+ x = features
463
+
464
+ # feature: (B, T, D), float
465
+ # target: (B, T), long
466
+ # x: (B, T, D), float
467
+ # padding_mask: (B, T), bool
468
+ # mask_indices: (B, T), bool
469
+ x, layer_results = self.encoder(
470
+ x,
471
+ padding_mask=padding_mask,
472
+ layer=None if output_layer is None else output_layer - 1,
473
+ )
474
+
475
+ res = {
476
+ "x": x,
477
+ "padding_mask": padding_mask,
478
+ "features": features,
479
+ "layer_results": layer_results,
480
+ }
481
+
482
+ feature = res["features"] if ret_conv else res["x"]
483
+ if ret_layer_results:
484
+ feature = (feature, res["layer_results"])
485
+ return feature, res["padding_mask"]
486
+
487
+
488
+ class ConvFeatureExtractionModel(nn.Module):
489
+ def __init__(
490
+ self,
491
+ conv_layers: List[Tuple[int, int, int]],
492
+ dropout: float = 0.0,
493
+ mode: str = "default",
494
+ conv_bias: bool = False,
495
+ conv_type: str = "default",
496
+ ):
497
+ super().__init__()
498
+
499
+ assert mode in {"default", "layer_norm"}
500
+
501
+ def block(
502
+ n_in,
503
+ n_out,
504
+ k,
505
+ stride,
506
+ is_layer_norm=False,
507
+ is_group_norm=False,
508
+ conv_bias=False,
509
+ ):
510
+ def make_conv():
511
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
512
+ nn.init.kaiming_normal_(conv.weight)
513
+ return conv
514
+
515
+ assert not (is_layer_norm and is_group_norm), (
516
+ "layer norm and group norm are exclusive"
517
+ )
518
+
519
+ if is_layer_norm:
520
+ return nn.Sequential(
521
+ make_conv(),
522
+ nn.Dropout(p=dropout),
523
+ nn.Sequential(
524
+ TransposeLast(),
525
+ Fp32LayerNorm(dim, elementwise_affine=True),
526
+ TransposeLast(),
527
+ ),
528
+ nn.GELU(),
529
+ )
530
+ elif is_group_norm:
531
+ return nn.Sequential(
532
+ make_conv(),
533
+ nn.Dropout(p=dropout),
534
+ Fp32GroupNorm(dim, dim, affine=True),
535
+ nn.GELU(),
536
+ )
537
+ else:
538
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
539
+
540
+ self.conv_type = conv_type
541
+ if self.conv_type == "default":
542
+ in_d = 1
543
+ self.conv_layers = nn.ModuleList()
544
+ for i, cl in enumerate(conv_layers):
545
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
546
+ (dim, k, stride) = cl
547
+
548
+ self.conv_layers.append(
549
+ block(
550
+ in_d,
551
+ dim,
552
+ k,
553
+ stride,
554
+ is_layer_norm=mode == "layer_norm",
555
+ is_group_norm=mode == "default" and i == 0,
556
+ conv_bias=conv_bias,
557
+ )
558
+ )
559
+ in_d = dim
560
+ elif self.conv_type == "conv2d":
561
+ in_d = 1
562
+ self.conv_layers = nn.ModuleList()
563
+ for i, cl in enumerate(conv_layers):
564
+ assert len(cl) == 3
565
+ (dim, k, stride) = cl
566
+
567
+ self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride))
568
+ self.conv_layers.append(torch.nn.ReLU())
569
+ in_d = dim
570
+ elif self.conv_type == "custom":
571
+ in_d = 1
572
+ idim = 80
573
+ self.conv_layers = nn.ModuleList()
574
+ for i, cl in enumerate(conv_layers):
575
+ assert len(cl) == 3
576
+ (dim, k, stride) = cl
577
+ self.conv_layers.append(
578
+ torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
579
+ )
580
+ self.conv_layers.append(torch.nn.LayerNorm([dim, idim]))
581
+ self.conv_layers.append(torch.nn.ReLU())
582
+ in_d = dim
583
+ if (i + 1) % 2 == 0:
584
+ self.conv_layers.append(
585
+ torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
586
+ )
587
+ idim = int(math.ceil(idim / 2))
588
+ else:
589
+ pass
590
+
591
+ def forward(self, x, mask=None):
592
+ # BxT -> BxCxT
593
+ x = x.unsqueeze(1)
594
+ if self.conv_type == "custom":
595
+ for conv in self.conv_layers:
596
+ if isinstance(conv, nn.LayerNorm):
597
+ x = x.transpose(1, 2)
598
+ x = conv(x).transpose(1, 2)
599
+ else:
600
+ x = conv(x)
601
+ x = x.transpose(2, 3).contiguous()
602
+ x = x.view(x.size(0), -1, x.size(-1))
603
+ else:
604
+ for conv in self.conv_layers:
605
+ x = conv(x)
606
+ if self.conv_type == "conv2d":
607
+ b, c, t, f = x.size()
608
+ x = x.transpose(2, 3).contiguous().view(b, c * f, t)
609
+ return x
610
+
611
+
612
+ class TransformerEncoder(nn.Module):
613
+ def __init__(self, args):
614
+ super().__init__()
615
+
616
+ self.dropout = args.dropout
617
+ self.embedding_dim = args.encoder_embed_dim
618
+
619
+ self.pos_conv = nn.Conv1d(
620
+ self.embedding_dim,
621
+ self.embedding_dim,
622
+ kernel_size=args.conv_pos,
623
+ padding=args.conv_pos // 2,
624
+ groups=args.conv_pos_groups,
625
+ )
626
+ dropout = 0
627
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
628
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
629
+ nn.init.constant_(self.pos_conv.bias, 0)
630
+
631
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
632
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
633
+
634
+ if hasattr(args, "relative_position_embedding"):
635
+ self.relative_position_embedding = args.relative_position_embedding
636
+ self.num_buckets = args.num_buckets
637
+ self.max_distance = args.max_distance
638
+ else:
639
+ self.relative_position_embedding = False
640
+ self.num_buckets = 0
641
+ self.max_distance = 0
642
+
643
+ self.layers = nn.ModuleList(
644
+ [
645
+ TransformerSentenceEncoderLayer(
646
+ embedding_dim=self.embedding_dim,
647
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
648
+ num_attention_heads=args.encoder_attention_heads,
649
+ dropout=self.dropout,
650
+ attention_dropout=args.attention_dropout,
651
+ activation_dropout=args.activation_dropout,
652
+ activation_fn=args.activation_fn,
653
+ layer_norm_first=args.layer_norm_first,
654
+ has_relative_attention_bias=(
655
+ self.relative_position_embedding and i == 0
656
+ ),
657
+ num_buckets=self.num_buckets,
658
+ max_distance=self.max_distance,
659
+ gru_rel_pos=args.gru_rel_pos,
660
+ )
661
+ for i in range(args.encoder_layers)
662
+ ]
663
+ )
664
+
665
+ self.layer_norm_first = args.layer_norm_first
666
+ self.layer_norm = LayerNorm(self.embedding_dim)
667
+ self.layerdrop = args.encoder_layerdrop
668
+
669
+ self.apply(init_bert_params)
670
+
671
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
672
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
673
+
674
+ if self.layer_norm_first and layer is None:
675
+ x = self.layer_norm(x)
676
+
677
+ return x, layer_results
678
+
679
+ def extract_features(
680
+ self, x, padding_mask=None, streaming_mask=None, tgt_layer=None
681
+ ):
682
+ if padding_mask is not None:
683
+ x[padding_mask] = 0
684
+
685
+ x_conv = self.pos_conv(x.transpose(1, 2))
686
+ x_conv = x_conv.transpose(1, 2)
687
+ x += x_conv
688
+
689
+ if not self.layer_norm_first:
690
+ x = self.layer_norm(x)
691
+
692
+ x = F.dropout(x, p=self.dropout, training=self.training)
693
+
694
+ # B x T x C -> T x B x C
695
+ x = x.transpose(0, 1)
696
+
697
+ layer_results = []
698
+ z = None
699
+ if tgt_layer is not None:
700
+ layer_results.append((x, z))
701
+ r = None
702
+ pos_bias = None
703
+ for i, layer in enumerate(self.layers):
704
+ dropout_probability = np.random.random()
705
+ if not self.training or (dropout_probability > self.layerdrop):
706
+ x, z, pos_bias = layer(
707
+ x,
708
+ self_attn_padding_mask=padding_mask,
709
+ need_weights=False,
710
+ self_attn_mask=streaming_mask,
711
+ pos_bias=pos_bias,
712
+ )
713
+ if tgt_layer is not None:
714
+ layer_results.append((x, z))
715
+ if i == tgt_layer:
716
+ r = x
717
+ break
718
+
719
+ if r is not None:
720
+ x = r
721
+
722
+ # T x B x C -> B x T x C
723
+ x = x.transpose(0, 1)
724
+
725
+ return x, layer_results
726
+
727
+
728
+ class TransformerSentenceEncoderLayer(nn.Module):
729
+ """
730
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
731
+ models.
732
+ """
733
+
734
+ def __init__(
735
+ self,
736
+ embedding_dim: float = 768,
737
+ ffn_embedding_dim: float = 3072,
738
+ num_attention_heads: float = 8,
739
+ dropout: float = 0.1,
740
+ attention_dropout: float = 0.1,
741
+ activation_dropout: float = 0.1,
742
+ activation_fn: str = "relu",
743
+ layer_norm_first: bool = False,
744
+ has_relative_attention_bias: bool = False,
745
+ num_buckets: int = 0,
746
+ max_distance: int = 0,
747
+ rescale_init: bool = False,
748
+ gru_rel_pos: bool = False,
749
+ ) -> None:
750
+ super().__init__()
751
+ # Initialize parameters
752
+ self.embedding_dim = embedding_dim
753
+ self.dropout = dropout
754
+ self.activation_dropout = activation_dropout
755
+
756
+ # Initialize blocks
757
+ self.activation_name = activation_fn
758
+ self.activation_fn = get_activation_fn(activation_fn)
759
+ self.self_attn = MultiheadAttention(
760
+ self.embedding_dim,
761
+ num_attention_heads,
762
+ dropout=attention_dropout,
763
+ self_attention=True,
764
+ has_relative_attention_bias=has_relative_attention_bias,
765
+ num_buckets=num_buckets,
766
+ max_distance=max_distance,
767
+ rescale_init=rescale_init,
768
+ gru_rel_pos=gru_rel_pos,
769
+ )
770
+
771
+ self.dropout1 = nn.Dropout(dropout)
772
+ self.dropout2 = nn.Dropout(self.activation_dropout)
773
+ self.dropout3 = nn.Dropout(dropout)
774
+
775
+ self.layer_norm_first = layer_norm_first
776
+
777
+ # layer norm associated with the self attention layer
778
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
779
+
780
+ if self.activation_name == "glu":
781
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
782
+ else:
783
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
784
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
785
+
786
+ # layer norm associated with the position wise feed-forward NN
787
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
788
+
789
+ def forward(
790
+ self,
791
+ x: torch.Tensor,
792
+ self_attn_mask: torch.Tensor = None,
793
+ self_attn_padding_mask: torch.Tensor = None,
794
+ need_weights: bool = False,
795
+ pos_bias=None,
796
+ ):
797
+ """
798
+ LayerNorm is applied either before or after the self-attention/ffn
799
+ modules similar to the original Transformer imlementation.
800
+ """
801
+ residual = x
802
+
803
+ if self.layer_norm_first:
804
+ x = self.self_attn_layer_norm(x)
805
+ x, attn, pos_bias = self.self_attn(
806
+ query=x,
807
+ key=x,
808
+ value=x,
809
+ key_padding_mask=self_attn_padding_mask,
810
+ need_weights=False,
811
+ attn_mask=self_attn_mask,
812
+ position_bias=pos_bias,
813
+ )
814
+ x = self.dropout1(x)
815
+ x = residual + x
816
+
817
+ residual = x
818
+ x = self.final_layer_norm(x)
819
+ if self.activation_name == "glu":
820
+ x = self.fc1(x)
821
+ else:
822
+ x = self.activation_fn(self.fc1(x))
823
+ x = self.dropout2(x)
824
+ x = self.fc2(x)
825
+ x = self.dropout3(x)
826
+ x = residual + x
827
+ else:
828
+ x, attn, pos_bias = self.self_attn(
829
+ query=x,
830
+ key=x,
831
+ value=x,
832
+ key_padding_mask=self_attn_padding_mask,
833
+ need_weights=need_weights,
834
+ attn_mask=self_attn_mask,
835
+ position_bias=pos_bias,
836
+ )
837
+
838
+ x = self.dropout1(x)
839
+ x = residual + x
840
+
841
+ x = self.self_attn_layer_norm(x)
842
+
843
+ residual = x
844
+ if self.activation_name == "glu":
845
+ x = self.fc1(x)
846
+ else:
847
+ x = self.activation_fn(self.fc1(x))
848
+ x = self.dropout2(x)
849
+ x = self.fc2(x)
850
+ x = self.dropout3(x)
851
+ x = residual + x
852
+ x = self.final_layer_norm(x)
853
+
854
+ return x, attn, pos_bias
WavLM_modules.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ from torch.nn import Parameter
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class TransposeLast(nn.Module):
20
+ def __init__(self, deconstruct_idx=None):
21
+ super().__init__()
22
+ self.deconstruct_idx = deconstruct_idx
23
+
24
+ def forward(self, x):
25
+ if self.deconstruct_idx is not None:
26
+ x = x[self.deconstruct_idx]
27
+ return x.transpose(-2, -1)
28
+
29
+
30
+ class Fp32LayerNorm(nn.LayerNorm):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+
34
+ def forward(self, input):
35
+ output = F.layer_norm(
36
+ input.float(),
37
+ self.normalized_shape,
38
+ self.weight.float() if self.weight is not None else None,
39
+ self.bias.float() if self.bias is not None else None,
40
+ self.eps,
41
+ )
42
+ return output.type_as(input)
43
+
44
+
45
+ class Fp32GroupNorm(nn.GroupNorm):
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+
49
+ def forward(self, input):
50
+ output = F.group_norm(
51
+ input.float(),
52
+ self.num_groups,
53
+ self.weight.float() if self.weight is not None else None,
54
+ self.bias.float() if self.bias is not None else None,
55
+ self.eps,
56
+ )
57
+ return output.type_as(input)
58
+
59
+
60
+ class GradMultiply(torch.autograd.Function):
61
+ @staticmethod
62
+ def forward(ctx, x, scale):
63
+ ctx.scale = scale
64
+ res = x.new(x)
65
+ return res
66
+
67
+ @staticmethod
68
+ def backward(ctx, grad):
69
+ return grad * ctx.scale, None
70
+
71
+
72
+ class SamePad(nn.Module):
73
+ def __init__(self, kernel_size, causal=False):
74
+ super().__init__()
75
+ if causal:
76
+ self.remove = kernel_size - 1
77
+ else:
78
+ self.remove = 1 if kernel_size % 2 == 0 else 0
79
+
80
+ def forward(self, x):
81
+ if self.remove > 0:
82
+ x = x[:, :, : -self.remove]
83
+ return x
84
+
85
+
86
+ class Swish(nn.Module):
87
+ """Swish function"""
88
+
89
+ def __init__(self):
90
+ """Construct an MultiHeadedAttention object."""
91
+ super(Swish, self).__init__()
92
+ self.act = torch.nn.Sigmoid()
93
+
94
+ def forward(self, x):
95
+ return x * self.act(x)
96
+
97
+
98
+ class GLU_Linear(nn.Module):
99
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
100
+ super(GLU_Linear, self).__init__()
101
+
102
+ self.glu_type = glu_type
103
+ self.output_dim = output_dim
104
+
105
+ if glu_type == "sigmoid":
106
+ self.glu_act = torch.nn.Sigmoid()
107
+ elif glu_type == "swish":
108
+ self.glu_act = Swish()
109
+ elif glu_type == "relu":
110
+ self.glu_act = torch.nn.ReLU()
111
+ elif glu_type == "gelu":
112
+ self.glu_act = torch.nn.GELU()
113
+
114
+ if bias_in_glu:
115
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
116
+ else:
117
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
118
+
119
+ def forward(self, x):
120
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
121
+ x = self.linear(x)
122
+
123
+ if self.glu_type == "bilinear":
124
+ x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2]
125
+ else:
126
+ x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2])
127
+
128
+ return x
129
+
130
+
131
+ def gelu_accurate(x):
132
+ if not hasattr(gelu_accurate, "_a"):
133
+ gelu_accurate._a = math.sqrt(2 / math.pi)
134
+ return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
135
+
136
+
137
+ def gelu(x: torch.Tensor) -> torch.Tensor:
138
+ return torch.nn.functional.gelu(x.float()).type_as(x)
139
+
140
+
141
+ def get_activation_fn(activation: str):
142
+ """Returns the activation function corresponding to `activation`"""
143
+
144
+ if activation == "relu":
145
+ return F.relu
146
+ elif activation == "gelu":
147
+ return gelu
148
+ elif activation == "gelu_fast":
149
+ warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate")
150
+ return gelu_accurate
151
+ elif activation == "gelu_accurate":
152
+ return gelu_accurate
153
+ elif activation == "tanh":
154
+ return torch.tanh
155
+ elif activation == "linear":
156
+ return lambda x: x
157
+ elif activation == "glu":
158
+ return lambda x: x
159
+ else:
160
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
161
+
162
+
163
+ def init_bert_params(module):
164
+ """
165
+ Initialize the weights specific to the BERT Model.
166
+ This overrides the default initializations depending on the specified arguments.
167
+ 1. If normal_init_linear_weights is set then weights of linear
168
+ layer will be initialized using the normal distribution and
169
+ bais will be set to the specified value.
170
+ 2. If normal_init_embed_weights is set then weights of embedding
171
+ layer will be initialized using the normal distribution.
172
+ 3. If normal_init_proj_weights is set then weights of
173
+ in_project_weight for MultiHeadAttention initialized using
174
+ the normal distribution (to be validated).
175
+ """
176
+
177
+ def normal_(data):
178
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
179
+ # so that the RNG is consistent with and without FSDP
180
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
181
+
182
+ if isinstance(module, nn.Linear):
183
+ normal_(module.weight.data)
184
+ if module.bias is not None:
185
+ module.bias.data.zero_()
186
+ if isinstance(module, nn.Embedding):
187
+ normal_(module.weight.data)
188
+ if module.padding_idx is not None:
189
+ module.weight.data[module.padding_idx].zero_()
190
+ if isinstance(module, MultiheadAttention):
191
+ normal_(module.q_proj.weight.data)
192
+ normal_(module.k_proj.weight.data)
193
+ normal_(module.v_proj.weight.data)
194
+
195
+
196
+ def quant_noise(module, p, block_size):
197
+ """
198
+ Wraps modules and applies quantization noise to the weights for
199
+ subsequent quantization with Iterative Product Quantization as
200
+ described in "Training with Quantization Noise for Extreme Model Compression"
201
+ Args:
202
+ - module: nn.Module
203
+ - p: amount of Quantization Noise
204
+ - block_size: size of the blocks for subsequent quantization with iPQ
205
+ Remarks:
206
+ - Module weights must have the right sizes wrt the block size
207
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
208
+ - For more detail on how to quantize by blocks with convolutional weights,
209
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
210
+ - We implement the simplest form of noise here as stated in the paper
211
+ which consists in randomly dropping blocks
212
+ """
213
+
214
+ # if no quantization noise, don't register hook
215
+ if p <= 0:
216
+ return module
217
+
218
+ # supported modules
219
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
220
+
221
+ # test whether module.weight has the right sizes wrt block_size
222
+ is_conv = module.weight.ndim == 4
223
+
224
+ # 2D matrix
225
+ if not is_conv:
226
+ assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
227
+
228
+ # 4D matrix
229
+ else:
230
+ # 1x1 convolutions
231
+ if module.kernel_size == (1, 1):
232
+ assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
233
+ # regular convolutions
234
+ else:
235
+ k = module.kernel_size[0] * module.kernel_size[1]
236
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
237
+
238
+ def _forward_pre_hook(mod, input):
239
+ # no noise for evaluation
240
+ if mod.training:
241
+ if not is_conv:
242
+ # gather weight and sizes
243
+ weight = mod.weight
244
+ in_features = weight.size(1)
245
+ out_features = weight.size(0)
246
+
247
+ # split weight matrix into blocks and randomly drop selected blocks
248
+ mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
249
+ mask.bernoulli_(p)
250
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
251
+
252
+ else:
253
+ # gather weight and sizes
254
+ weight = mod.weight
255
+ in_channels = mod.in_channels
256
+ out_channels = mod.out_channels
257
+
258
+ # split weight matrix into blocks and randomly drop selected blocks
259
+ if mod.kernel_size == (1, 1):
260
+ mask = torch.zeros(
261
+ int(in_channels // block_size * out_channels),
262
+ device=weight.device,
263
+ )
264
+ mask.bernoulli_(p)
265
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
266
+ else:
267
+ mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
268
+ mask.bernoulli_(p)
269
+ mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
270
+
271
+ # scale weights and apply mask
272
+ mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
273
+ s = 1 / (1 - p)
274
+ mod.weight.data = s * weight.masked_fill(mask, 0)
275
+
276
+ module.register_forward_pre_hook(_forward_pre_hook)
277
+ return module
278
+
279
+
280
+ class MultiheadAttention(nn.Module):
281
+ """Multi-headed attention.
282
+ See "Attention Is All You Need" for more details.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ embed_dim,
288
+ num_heads,
289
+ kdim=None,
290
+ vdim=None,
291
+ dropout=0.0,
292
+ bias=True,
293
+ add_bias_kv=False,
294
+ add_zero_attn=False,
295
+ self_attention=False,
296
+ encoder_decoder_attention=False,
297
+ q_noise=0.0,
298
+ qn_block_size=8,
299
+ has_relative_attention_bias=False,
300
+ num_buckets=32,
301
+ max_distance=128,
302
+ gru_rel_pos=False,
303
+ rescale_init=False,
304
+ ):
305
+ super().__init__()
306
+ self.embed_dim = embed_dim
307
+ self.kdim = kdim if kdim is not None else embed_dim
308
+ self.vdim = vdim if vdim is not None else embed_dim
309
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
310
+
311
+ self.num_heads = num_heads
312
+ self.dropout_module = nn.Dropout(dropout)
313
+
314
+ self.has_relative_attention_bias = has_relative_attention_bias
315
+ self.num_buckets = num_buckets
316
+ self.max_distance = max_distance
317
+ if self.has_relative_attention_bias:
318
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
319
+
320
+ self.head_dim = embed_dim // num_heads
321
+ self.q_head_dim = self.head_dim
322
+ self.k_head_dim = self.head_dim
323
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
324
+ self.scaling = self.head_dim**-0.5
325
+
326
+ self.self_attention = self_attention
327
+ self.encoder_decoder_attention = encoder_decoder_attention
328
+
329
+ assert not self.self_attention or self.qkv_same_dim, (
330
+ "Self-attention requires query, key and " "value to be of the same size"
331
+ )
332
+
333
+ k_bias = True
334
+ if rescale_init:
335
+ k_bias = False
336
+
337
+ k_embed_dim = embed_dim
338
+ q_embed_dim = embed_dim
339
+
340
+ self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size)
341
+ self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
342
+ self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size)
343
+
344
+ self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
345
+
346
+ if add_bias_kv:
347
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
348
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
349
+ else:
350
+ self.bias_k = self.bias_v = None
351
+
352
+ self.add_zero_attn = add_zero_attn
353
+
354
+ self.gru_rel_pos = gru_rel_pos
355
+ if self.gru_rel_pos:
356
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
357
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
358
+
359
+ self.reset_parameters()
360
+
361
+ def reset_parameters(self):
362
+ if self.qkv_same_dim:
363
+ # Empirically observed the convergence to be much better with
364
+ # the scaled initialization
365
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
366
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
367
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
368
+ else:
369
+ nn.init.xavier_uniform_(self.k_proj.weight)
370
+ nn.init.xavier_uniform_(self.v_proj.weight)
371
+ nn.init.xavier_uniform_(self.q_proj.weight)
372
+
373
+ nn.init.xavier_uniform_(self.out_proj.weight)
374
+ if self.out_proj.bias is not None:
375
+ nn.init.constant_(self.out_proj.bias, 0.0)
376
+ if self.bias_k is not None:
377
+ nn.init.xavier_normal_(self.bias_k)
378
+ if self.bias_v is not None:
379
+ nn.init.xavier_normal_(self.bias_v)
380
+ if self.has_relative_attention_bias:
381
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
382
+
383
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
384
+ num_buckets = self.num_buckets
385
+ max_distance = self.max_distance
386
+ relative_buckets = 0
387
+
388
+ if bidirectional:
389
+ num_buckets = num_buckets // 2
390
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
391
+ relative_positions = torch.abs(relative_positions)
392
+ else:
393
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
394
+
395
+ max_exact = num_buckets // 2
396
+ is_small = relative_positions < max_exact
397
+
398
+ relative_postion_if_large = max_exact + (
399
+ torch.log(relative_positions.float() / max_exact)
400
+ / math.log(max_distance / max_exact)
401
+ * (num_buckets - max_exact)
402
+ ).to(torch.long)
403
+ relative_postion_if_large = torch.min(
404
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
405
+ )
406
+
407
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
408
+ return relative_buckets
409
+
410
+ def compute_bias(self, query_length, key_length):
411
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
412
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
413
+ relative_position = memory_position - context_position
414
+ relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
415
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
416
+ values = self.relative_attention_bias(relative_position_bucket)
417
+ values = values.permute([2, 0, 1])
418
+ return values
419
+
420
+ def forward(
421
+ self,
422
+ query,
423
+ key: Optional[Tensor],
424
+ value: Optional[Tensor],
425
+ key_padding_mask: Optional[Tensor] = None,
426
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
427
+ need_weights: bool = True,
428
+ static_kv: bool = False,
429
+ attn_mask: Optional[Tensor] = None,
430
+ before_softmax: bool = False,
431
+ need_head_weights: bool = False,
432
+ position_bias: Optional[Tensor] = None,
433
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
434
+ """Input shape: Time x Batch x Channel
435
+ Args:
436
+ key_padding_mask (ByteTensor, optional): mask to exclude
437
+ keys that are pads, of shape `(batch, src_len)`, where
438
+ padding elements are indicated by 1s.
439
+ need_weights (bool, optional): return the attention weights,
440
+ averaged over heads (default: False).
441
+ attn_mask (ByteTensor, optional): typically used to
442
+ implement causal attention, where the mask prevents the
443
+ attention from looking forward in time (default: None).
444
+ before_softmax (bool, optional): return the raw attention
445
+ weights and values before the attention softmax.
446
+ need_head_weights (bool, optional): return the attention
447
+ weights for each head. Implies *need_weights*. Default:
448
+ return the average attention weights over all heads.
449
+ """
450
+ if need_head_weights:
451
+ need_weights = True
452
+
453
+ is_tpu = query.device.type == "xla"
454
+
455
+ tgt_len, bsz, embed_dim = query.size()
456
+ src_len = tgt_len
457
+ assert embed_dim == self.embed_dim
458
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
459
+ if key is not None:
460
+ src_len, key_bsz, _ = key.size()
461
+ if not torch.jit.is_scripting():
462
+ assert key_bsz == bsz
463
+ assert value is not None
464
+ assert src_len, bsz == value.shape[:2]
465
+
466
+ if self.has_relative_attention_bias and position_bias is None:
467
+ position_bias = self.compute_bias(tgt_len, src_len)
468
+ position_bias = (
469
+ position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
470
+ )
471
+
472
+ if (
473
+ not is_tpu # don't use PyTorch version on TPUs
474
+ and incremental_state is None
475
+ and not static_kv
476
+ # A workaround for quantization to work. Otherwise JIT compilation
477
+ # treats bias in linear module as method.
478
+ and not torch.jit.is_scripting()
479
+ and self.q_head_dim == self.head_dim
480
+ ):
481
+ assert key is not None and value is not None
482
+ assert attn_mask is None
483
+
484
+ attn_mask_rel_pos = None
485
+ if position_bias is not None:
486
+ attn_mask_rel_pos = position_bias
487
+ if self.gru_rel_pos:
488
+ query_layer = query.transpose(0, 1)
489
+ new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
490
+ query_layer = query_layer.view(*new_x_shape)
491
+ query_layer = query_layer.permute(0, 2, 1, 3)
492
+ _B, _H, _L, __ = query_layer.size()
493
+
494
+ gate_a, gate_b = torch.sigmoid(
495
+ self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
496
+ ).chunk(2, dim=-1)
497
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
498
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
499
+
500
+ attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
501
+ k_proj_bias = self.k_proj.bias
502
+ if k_proj_bias is None:
503
+ k_proj_bias = torch.zeros_like(self.q_proj.bias)
504
+
505
+ x, attn = F.multi_head_attention_forward(
506
+ query,
507
+ key,
508
+ value,
509
+ self.embed_dim,
510
+ self.num_heads,
511
+ torch.empty([0]),
512
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
513
+ self.bias_k,
514
+ self.bias_v,
515
+ self.add_zero_attn,
516
+ self.dropout_module.p,
517
+ self.out_proj.weight,
518
+ self.out_proj.bias,
519
+ self.training,
520
+ # self.training or self.dropout_module.apply_during_inference,
521
+ key_padding_mask,
522
+ need_weights,
523
+ attn_mask_rel_pos,
524
+ use_separate_proj_weight=True,
525
+ q_proj_weight=self.q_proj.weight,
526
+ k_proj_weight=self.k_proj.weight,
527
+ v_proj_weight=self.v_proj.weight,
528
+ )
529
+ return x, attn, position_bias
530
+
531
+ if incremental_state is not None:
532
+ saved_state = self._get_input_buffer(incremental_state)
533
+ if saved_state is not None and "prev_key" in saved_state:
534
+ # previous time steps are cached - no need to recompute
535
+ # key and value if they are static
536
+ if static_kv:
537
+ assert self.encoder_decoder_attention and not self.self_attention
538
+ key = value = None
539
+ else:
540
+ saved_state = None
541
+
542
+ if self.self_attention:
543
+ q = self.q_proj(query)
544
+ k = self.k_proj(query)
545
+ v = self.v_proj(query)
546
+ elif self.encoder_decoder_attention:
547
+ # encoder-decoder attention
548
+ q = self.q_proj(query)
549
+ if key is None:
550
+ assert value is None
551
+ k = v = None
552
+ else:
553
+ k = self.k_proj(key)
554
+ v = self.v_proj(key)
555
+
556
+ else:
557
+ assert key is not None and value is not None
558
+ q = self.q_proj(query)
559
+ k = self.k_proj(key)
560
+ v = self.v_proj(value)
561
+ q *= self.scaling
562
+
563
+ if self.bias_k is not None:
564
+ assert self.bias_v is not None
565
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
566
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
567
+ if attn_mask is not None:
568
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
569
+ if key_padding_mask is not None:
570
+ key_padding_mask = torch.cat(
571
+ [
572
+ key_padding_mask,
573
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
574
+ ],
575
+ dim=1,
576
+ )
577
+
578
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1)
579
+ if k is not None:
580
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1)
581
+ if v is not None:
582
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
583
+
584
+ if saved_state is not None:
585
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
586
+ if "prev_key" in saved_state:
587
+ _prev_key = saved_state["prev_key"]
588
+ assert _prev_key is not None
589
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
590
+ if static_kv:
591
+ k = prev_key
592
+ else:
593
+ assert k is not None
594
+ k = torch.cat([prev_key, k], dim=1)
595
+ src_len = k.size(1)
596
+ if "prev_value" in saved_state:
597
+ _prev_value = saved_state["prev_value"]
598
+ assert _prev_value is not None
599
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
600
+ if static_kv:
601
+ v = prev_value
602
+ else:
603
+ assert v is not None
604
+ v = torch.cat([prev_value, v], dim=1)
605
+ prev_key_padding_mask: Optional[Tensor] = None
606
+ if "prev_key_padding_mask" in saved_state:
607
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
608
+ assert k is not None and v is not None
609
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
610
+ key_padding_mask=key_padding_mask,
611
+ prev_key_padding_mask=prev_key_padding_mask,
612
+ batch_size=bsz,
613
+ src_len=k.size(1),
614
+ static_kv=static_kv,
615
+ )
616
+
617
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
618
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
619
+ saved_state["prev_key_padding_mask"] = key_padding_mask
620
+ # In this branch incremental_state is never None
621
+ assert incremental_state is not None
622
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
623
+ assert k is not None
624
+ assert k.size(1) == src_len
625
+
626
+ # This is part of a workaround to get around fork/join parallelism
627
+ # not supporting Optional types.
628
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
629
+ key_padding_mask = None
630
+
631
+ if key_padding_mask is not None:
632
+ assert key_padding_mask.size(0) == bsz
633
+ assert key_padding_mask.size(1) == src_len
634
+
635
+ if self.add_zero_attn:
636
+ assert v is not None
637
+ src_len += 1
638
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
639
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
640
+ if attn_mask is not None:
641
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
642
+ if key_padding_mask is not None:
643
+ key_padding_mask = torch.cat(
644
+ [
645
+ key_padding_mask,
646
+ torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
647
+ ],
648
+ dim=1,
649
+ )
650
+
651
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
652
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
653
+
654
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
655
+
656
+ if attn_mask is not None:
657
+ attn_mask = attn_mask.unsqueeze(0)
658
+ attn_weights += attn_mask
659
+
660
+ if key_padding_mask is not None:
661
+ # don't attend to padding symbols
662
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
663
+ if not is_tpu:
664
+ attn_weights = attn_weights.masked_fill(
665
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
666
+ float("-inf"),
667
+ )
668
+ else:
669
+ attn_weights = attn_weights.transpose(0, 2)
670
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
671
+ attn_weights = attn_weights.transpose(0, 2)
672
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
673
+
674
+ if before_softmax:
675
+ return attn_weights, v, position_bias
676
+
677
+ if position_bias is not None:
678
+ if self.gru_rel_pos == 1:
679
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
680
+ _B, _H, _L, __ = query_layer.size()
681
+ gate_a, gate_b = torch.sigmoid(
682
+ self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
683
+ ).chunk(2, dim=-1)
684
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
685
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
686
+
687
+ position_bias = position_bias.view(attn_weights.size())
688
+
689
+ attn_weights = attn_weights + position_bias
690
+
691
+ attn_weights_float = F.softmax(attn_weights, dim=-1)
692
+ attn_weights = attn_weights_float.type_as(attn_weights)
693
+ attn_probs = self.dropout_module(attn_weights)
694
+
695
+ assert v is not None
696
+ attn = torch.bmm(attn_probs, v)
697
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
698
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
699
+ attn = self.out_proj(attn)
700
+ attn_weights: Optional[Tensor] = None
701
+ if need_weights:
702
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
703
+ if not need_head_weights:
704
+ # average attention weights over heads
705
+ attn_weights = attn_weights.mean(dim=0)
706
+
707
+ return attn, attn_weights, position_bias
708
+
709
+ @staticmethod
710
+ def _append_prev_key_padding_mask(
711
+ key_padding_mask: Optional[Tensor],
712
+ prev_key_padding_mask: Optional[Tensor],
713
+ batch_size: int,
714
+ src_len: int,
715
+ static_kv: bool,
716
+ ) -> Optional[Tensor]:
717
+ # saved key padding masks have shape (bsz, seq_len)
718
+ if prev_key_padding_mask is not None and static_kv:
719
+ new_key_padding_mask = prev_key_padding_mask
720
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
721
+ new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
722
+ # During incremental decoding, as the padding token enters and
723
+ # leaves the frame, there will be a time when prev or current
724
+ # is None
725
+ elif prev_key_padding_mask is not None:
726
+ if src_len > prev_key_padding_mask.size(1):
727
+ filler = torch.zeros(
728
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
729
+ device=prev_key_padding_mask.device,
730
+ )
731
+ new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
732
+ else:
733
+ new_key_padding_mask = prev_key_padding_mask.float()
734
+ elif key_padding_mask is not None:
735
+ if src_len > key_padding_mask.size(1):
736
+ filler = torch.zeros(
737
+ (batch_size, src_len - key_padding_mask.size(1)),
738
+ device=key_padding_mask.device,
739
+ )
740
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
741
+ else:
742
+ new_key_padding_mask = key_padding_mask.float()
743
+ else:
744
+ new_key_padding_mask = prev_key_padding_mask
745
+ return new_key_padding_mask
746
+
747
+ def _get_input_buffer(
748
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
749
+ ) -> Dict[str, Optional[Tensor]]:
750
+ result = self.get_incremental_state(incremental_state, "attn_state")
751
+ if result is not None:
752
+ return result
753
+ else:
754
+ empty_result: Dict[str, Optional[Tensor]] = {}
755
+ return empty_result
756
+
757
+ def _set_input_buffer(
758
+ self,
759
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
760
+ buffer: Dict[str, Optional[Tensor]],
761
+ ):
762
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
763
+
764
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
765
+ return attn_weights
__pycache__/WavLM.cpython-311.pyc ADDED
Binary file (38.8 kB). View file
 
__pycache__/WavLM_modules.cpython-311.pyc ADDED
Binary file (39.9 kB). View file
 
__pycache__/data_utils.cpython-311.pyc ADDED
Binary file (28.5 kB). View file
 
__pycache__/dino_game.cpython-311.pyc ADDED
Binary file (5.35 kB). View file
 
__pycache__/inference_functions.cpython-311.pyc ADDED
Binary file (20.2 kB). View file
 
__pycache__/landmarks_extractor.cpython-311.pyc ADDED
Binary file (1.93 kB). View file
 
__pycache__/utils.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
__pycache__/vae_wrapper.cpython-311.pyc ADDED
Binary file (8.83 kB). View file
 
__pycache__/wordle_game.cpython-311.pyc ADDED
Binary file (6.88 kB). View file
 
app.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import tempfile
4
+ import os
5
+ from vae_wrapper import VaeWrapper, encode_video_chunk
6
+ from landmarks_extractor import LandmarksExtractor
7
+ import decord
8
+ from utils import (
9
+ get_raw_audio,
10
+ save_audio_video,
11
+ calculate_splits,
12
+ instantiate_from_config,
13
+ create_pipeline_inputs,
14
+ )
15
+ from transformers import HubertModel
16
+ from einops import rearrange
17
+ import numpy as np
18
+ from WavLM import WavLM_wrapper
19
+ from omegaconf import OmegaConf
20
+ from inference_functions import (
21
+ sample_keyframes,
22
+ sample_interpolation,
23
+ )
24
+ from wordle_game import WordleGame
25
+ import torch.cuda.amp as amp # Import amp for mixed precision
26
+
27
+
28
+ # Set default tensor type to float16 for faster computation
29
+ if torch.cuda.is_available():
30
+ # torch.set_default_tensor_type(torch.cuda.FloatTensor)
31
+ # Enable TF32 precision for better performance on Ampere+ GPUs
32
+ torch.backends.cuda.matmul.allow_tf32 = True
33
+ torch.backends.cudnn.allow_tf32 = True
34
+
35
+ # Cache for video and audio processing
36
+ cache = {
37
+ "video": {
38
+ "path": None,
39
+ "embedding": None,
40
+ "frames": None,
41
+ "landmarks": None,
42
+ },
43
+ "audio": {
44
+ "path": None,
45
+ "raw_audio": None,
46
+ "hubert_embedding": None,
47
+ "wavlm_embedding": None,
48
+ },
49
+ }
50
+
51
+ # Create mixed precision scaler
52
+ scaler = amp.GradScaler()
53
+
54
+
55
+ def load_model(
56
+ config: str,
57
+ device: str = "cuda",
58
+ ckpt: str = None,
59
+ ):
60
+ """
61
+ Load a model from configuration.
62
+
63
+ Args:
64
+ config: Path to model configuration file
65
+ device: Device to load the model on
66
+ num_frames: Number of frames to process
67
+ input_key: Input key for the model
68
+ ckpt: Optional checkpoint path
69
+
70
+ Returns:
71
+ Tuple of (model, filter, batch size)
72
+ """
73
+ config = OmegaConf.load(config)
74
+
75
+ config["model"]["params"]["input_key"] = "latents"
76
+
77
+ if ckpt is not None:
78
+ config.model.params.ckpt_path = ckpt
79
+
80
+ with torch.device(device):
81
+ model = instantiate_from_config(config.model).to(device).eval()
82
+ # Convert model to half precision
83
+ if torch.cuda.is_available():
84
+ model = model.half()
85
+ model.first_stage_model = model.first_stage_model.float()
86
+ print("Converted model to FP16 precision")
87
+
88
+ # Compile model for faster inference
89
+ if torch.cuda.is_available():
90
+ try:
91
+ model = torch.compile(model)
92
+ print(f"Successfully compiled model with torch.compile()")
93
+ except Exception as e:
94
+ print(f"Warning: Failed to compile model: {e}")
95
+
96
+ return model
97
+
98
+
99
+ # keyframe_model = KeyframeModel(device=device)
100
+ # interpolation_model = InterpolationModel(device=device)
101
+ vae_model = VaeWrapper("video")
102
+ if torch.cuda.is_available():
103
+ vae_model = vae_model.half() # Convert to half precision
104
+ try:
105
+ vae_model = torch.compile(vae_model)
106
+ print("Successfully compiled vae_model in FP16")
107
+ except Exception as e:
108
+ print(f"Warning: Failed to compile vae_model: {e}")
109
+
110
+ hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
111
+ if torch.cuda.is_available():
112
+ hubert_model = hubert_model.half() # Convert to half precision
113
+ try:
114
+ hubert_model = torch.compile(hubert_model)
115
+ print("Successfully compiled hubert_model in FP16")
116
+ except Exception as e:
117
+ print(f"Warning: Failed to compile hubert_model: {e}")
118
+
119
+ wavlm_model = WavLM_wrapper(
120
+ model_size="Base+",
121
+ feed_as_frames=False,
122
+ merge_type="None",
123
+ model_path="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/WavLM-Base+.pt",
124
+ ).cuda()
125
+ if torch.cuda.is_available():
126
+ wavlm_model = wavlm_model.half() # Convert to half precision
127
+ try:
128
+ wavlm_model = torch.compile(wavlm_model)
129
+ print("Successfully compiled wavlm_model in FP16")
130
+ except Exception as e:
131
+ print(f"Warning: Failed to compile wavlm_model: {e}")
132
+
133
+ landmarks_extractor = LandmarksExtractor()
134
+ keyframe_model = load_model(
135
+ config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/keyframe.yaml",
136
+ ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/keyframe_dub.pt",
137
+ )
138
+ interpolation_model = load_model(
139
+ config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/interpolation.yaml",
140
+ ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/interpolation_dub.pt",
141
+ )
142
+ keyframe_model.en_and_decode_n_samples_a_time = 2
143
+ interpolation_model.en_and_decode_n_samples_a_time = 2
144
+
145
+ # Default media paths
146
+ DEFAULT_VIDEO_PATH = os.path.join(
147
+ os.path.dirname(__file__), "assets", "sample_video.mp4"
148
+ )
149
+ DEFAULT_AUDIO_PATH = os.path.join(
150
+ os.path.dirname(__file__), "assets", "sample_audio.wav"
151
+ )
152
+
153
+
154
+ @torch.no_grad()
155
+ def compute_video_embedding(video_reader, min_len):
156
+ """Compute embeddings from video"""
157
+
158
+ total_frames = min_len
159
+
160
+ encoded = []
161
+ video_frames = []
162
+ chunk_size = 16
163
+ resolution = 512
164
+
165
+ # # Create a progress bar for Gradio
166
+ progress = gr.Progress()
167
+
168
+ # Calculate total chunks for progress tracking
169
+ total_chunks = (total_frames + chunk_size - 1) // chunk_size
170
+
171
+ for i, start_idx in enumerate(range(0, total_frames, chunk_size)):
172
+ # Update progress bar
173
+ progress(i / total_chunks, desc="Processing video chunks")
174
+
175
+ end_idx = min(start_idx + chunk_size, total_frames)
176
+ video_chunk = video_reader.get_batch(range(start_idx, end_idx))
177
+ # Interpolate video chunk to the target resolution
178
+ video_chunk = rearrange(video_chunk, "f h w c -> f c h w")
179
+ video_chunk = torch.nn.functional.interpolate(
180
+ video_chunk,
181
+ size=(resolution, resolution),
182
+ mode="bilinear",
183
+ align_corners=False,
184
+ )
185
+ video_chunk = rearrange(video_chunk, "f c h w -> f h w c")
186
+ video_frames.append(video_chunk)
187
+
188
+ # Convert chunk to FP16 if using CUDA
189
+ if torch.cuda.is_available():
190
+ video_chunk = video_chunk.half()
191
+
192
+ # Always use autocast for FP16 computation
193
+ with amp.autocast(enabled=True):
194
+ encoded.append(encode_video_chunk(vae_model, video_chunk, resolution))
195
+
196
+ encoded = torch.cat(encoded, dim=0)
197
+ video_frames = torch.cat(video_frames, dim=0)
198
+ video_frames = rearrange(video_frames, "f h w c -> f c h w")
199
+ torch.cuda.empty_cache()
200
+ return encoded, video_frames
201
+
202
+
203
+ @torch.no_grad()
204
+ def compute_hubert_embedding(raw_audio):
205
+ """Compute embeddings from audio"""
206
+ print(f"Computing audio embedding from {raw_audio.shape}")
207
+
208
+ audio = (
209
+ (raw_audio - raw_audio.mean()) / torch.sqrt(raw_audio.var() + 1e-7)
210
+ ).unsqueeze(0)
211
+ chunks = 16000 * 20
212
+
213
+ # Create a progress bar for Gradio
214
+ progress = gr.Progress()
215
+
216
+ # Get audio embeddings
217
+ audio_embeddings = []
218
+ splits = list(calculate_splits(audio, chunks))
219
+ total_splits = len(splits)
220
+
221
+ for i, chunk in enumerate(splits):
222
+ # Update progress bar
223
+ progress(i / total_splits, desc="Processing audio chunks")
224
+
225
+ # Convert audio chunk to half precision
226
+ if torch.cuda.is_available():
227
+ chunk_cuda = chunk.cuda().half()
228
+ else:
229
+ chunk_cuda = chunk.cuda()
230
+
231
+ # Always use autocast for FP16 computation
232
+ with amp.autocast(enabled=True):
233
+ hidden_states = hubert_model(chunk_cuda)[0]
234
+
235
+ audio_embeddings.append(hidden_states)
236
+ audio_embeddings = torch.cat(audio_embeddings, dim=1)
237
+
238
+ # audio_embeddings = self.model.wav2vec2(rearrange(audio_frames, "f s -> () (f s)"))[0]
239
+ if audio_embeddings.shape[1] % 2 != 0:
240
+ audio_embeddings = torch.cat(
241
+ [audio_embeddings, torch.zeros_like(audio_embeddings[:, :1])], dim=1
242
+ )
243
+ audio_embeddings = rearrange(audio_embeddings, "() (f d) c -> f d c", d=2)
244
+ torch.cuda.empty_cache()
245
+
246
+ return audio_embeddings
247
+
248
+
249
+ @torch.no_grad()
250
+ def compute_wavlm_embedding(raw_audio):
251
+ """Compute embeddings from audio"""
252
+ audio = rearrange(raw_audio, "(f s) -> f s", s=640)
253
+
254
+ if audio.shape[0] % 2 != 0:
255
+ audio = torch.cat([audio, torch.zeros(1, 640)], dim=0)
256
+ chunks = 500
257
+
258
+ # Create a progress bar for Gradio
259
+ progress = gr.Progress()
260
+
261
+ # Get audio embeddings
262
+ audio_embeddings = []
263
+ splits = list(calculate_splits(audio, chunks))
264
+ total_splits = len(splits)
265
+
266
+ for i, chunk in enumerate(splits):
267
+ # Update progress bar
268
+ progress(i / total_splits, desc="Processing audio chunks")
269
+
270
+ # Convert chunk to half precision
271
+ if torch.cuda.is_available():
272
+ chunk_cuda = chunk.unsqueeze(0).cuda().half()
273
+ else:
274
+ chunk_cuda = chunk.unsqueeze(0).cuda()
275
+
276
+ # Always use autocast for FP16 computation
277
+ with amp.autocast(enabled=True):
278
+ wavlm_hidden_states = wavlm_model(chunk_cuda).squeeze(0)
279
+
280
+ audio_embeddings.append(wavlm_hidden_states)
281
+ audio_embeddings = torch.cat(audio_embeddings, dim=0)
282
+
283
+ torch.cuda.empty_cache()
284
+
285
+ return audio_embeddings
286
+
287
+
288
+ @torch.no_grad()
289
+ def extract_video_landmarks(video_frames):
290
+ """Extract landmarks from video frames"""
291
+
292
+ # Create a progress bar for Gradio
293
+ progress = gr.Progress()
294
+
295
+ landmarks = []
296
+ batch_size = 10
297
+
298
+ for i in range(0, len(video_frames), batch_size):
299
+ # Update progress bar
300
+ progress(i / len(video_frames), desc="Extracting facial landmarks")
301
+
302
+ batch = video_frames[i : i + batch_size].cpu().float()
303
+ batch_landmarks = landmarks_extractor.extract_landmarks(batch)
304
+ landmarks.extend(batch_landmarks)
305
+
306
+ torch.cuda.empty_cache()
307
+
308
+ # Convert landmarks to a list of numpy arrays with consistent shape
309
+ processed_landmarks = []
310
+
311
+ expected_shape = (68, 2) # Common shape for facial landmarks
312
+
313
+ # Process each landmark to ensure consistent shape
314
+ last_valid_landmark = None
315
+ for i, lm in enumerate(landmarks):
316
+ if lm is not None and isinstance(lm, np.ndarray) and lm.shape == expected_shape:
317
+ processed_landmarks.append(lm)
318
+ last_valid_landmark = lm
319
+ else:
320
+ # Print information about inconsistent landmarks
321
+ if lm is None:
322
+ print(f"Warning: Landmark at index {i} is None")
323
+ elif not isinstance(lm, np.ndarray):
324
+ print(
325
+ f"Warning: Landmark at index {i} is not a numpy array, type: {type(lm)}"
326
+ )
327
+ elif lm.shape != expected_shape:
328
+ print(
329
+ f"Warning: Landmark at index {i} has shape {lm.shape}, expected {expected_shape}"
330
+ )
331
+
332
+ # Replace invalid landmarks with the closest valid landmark if available
333
+ if last_valid_landmark is not None:
334
+ processed_landmarks.append(last_valid_landmark.copy())
335
+ else:
336
+ # If no valid landmark has been seen yet, look ahead for a valid one
337
+ found_future_valid = False
338
+ for future_lm in landmarks[i + 1 :]:
339
+ if (
340
+ future_lm is not None
341
+ and isinstance(future_lm, np.ndarray)
342
+ and future_lm.shape == expected_shape
343
+ ):
344
+ processed_landmarks.append(future_lm.copy())
345
+ found_future_valid = True
346
+ break
347
+
348
+ # If no valid landmark found in the future, use zeros
349
+ if not found_future_valid:
350
+ processed_landmarks.append(np.zeros(expected_shape))
351
+
352
+ return np.array(processed_landmarks)
353
+
354
+
355
+ @torch.no_grad()
356
+ def sample(
357
+ audio_list,
358
+ gt_keyframes,
359
+ masks_keyframes,
360
+ to_remove,
361
+ test_keyframes_list,
362
+ num_frames,
363
+ device,
364
+ emb,
365
+ force_uc_zero_embeddings,
366
+ n_batch_keyframes,
367
+ n_batch,
368
+ test_interpolation_list,
369
+ audio_interpolation_list,
370
+ masks_interpolation,
371
+ gt_interpolation,
372
+ model_keyframes,
373
+ model,
374
+ ):
375
+ # Create a progress bar for Gradio
376
+ progress = gr.Progress()
377
+
378
+ condition = torch.zeros(1, 3, 512, 512).to(device)
379
+ if torch.cuda.is_available():
380
+ condition = condition.half()
381
+
382
+ audio_list = rearrange(audio_list, "(b t) c d -> b t c d", t=num_frames)
383
+ gt_keyframes = rearrange(gt_keyframes, "(b t) c h w -> b t c h w", t=num_frames)
384
+ # Rearrange masks_keyframes and save locally
385
+ masks_keyframes = rearrange(
386
+ masks_keyframes, "(b t) c h w -> b t c h w", t=num_frames
387
+ )
388
+
389
+ # Convert to_remove into chunks of num_frames
390
+ to_remove_chunks = [
391
+ to_remove[i : i + num_frames] for i in range(0, len(to_remove), num_frames)
392
+ ]
393
+ test_keyframes_list = [
394
+ test_keyframes_list[i : i + num_frames]
395
+ for i in range(0, len(test_keyframes_list), num_frames)
396
+ ]
397
+
398
+ audio_cond = audio_list
399
+ if emb is not None:
400
+ embbedings = emb.unsqueeze(0).to(device)
401
+ if torch.cuda.is_available():
402
+ embbedings = embbedings.half()
403
+ else:
404
+ embbedings = None
405
+
406
+ # One batch of keframes is approximately 7 seconds
407
+ chunk_size = 2
408
+ complete_video = []
409
+ start_idx = 0
410
+ last_frame_z = None
411
+ last_frame_x = None
412
+ last_keyframe_idx = None
413
+ last_to_remove = None
414
+
415
+ total_chunks = (len(audio_cond) + chunk_size - 1) // chunk_size
416
+
417
+ for chunk_idx, chunk_start in enumerate(range(0, len(audio_cond), chunk_size)):
418
+ # Update progress bar
419
+ progress(chunk_idx / total_chunks, desc="Generating video")
420
+
421
+ # Clear GPU cache between chunks
422
+ torch.cuda.empty_cache()
423
+
424
+ chunk_end = min(chunk_start + chunk_size, len(audio_cond))
425
+
426
+ chunk_audio_cond = audio_cond[chunk_start:chunk_end].cuda()
427
+ if torch.cuda.is_available():
428
+ chunk_audio_cond = chunk_audio_cond.half()
429
+
430
+ chunk_gt_keyframes = gt_keyframes[chunk_start:chunk_end].cuda()
431
+ chunk_masks = masks_keyframes[chunk_start:chunk_end].cuda()
432
+
433
+ if torch.cuda.is_available():
434
+ chunk_gt_keyframes = chunk_gt_keyframes.half()
435
+ chunk_masks = chunk_masks.half()
436
+
437
+ test_keyframes_list_unwrapped = [
438
+ elem
439
+ for sublist in test_keyframes_list[chunk_start:chunk_end]
440
+ for elem in sublist
441
+ ]
442
+ to_remove_chunks_unwrapped = [
443
+ elem
444
+ for sublist in to_remove_chunks[chunk_start:chunk_end]
445
+ for elem in sublist
446
+ ]
447
+
448
+ if last_keyframe_idx is not None:
449
+ test_keyframes_list_unwrapped = [
450
+ last_keyframe_idx
451
+ ] + test_keyframes_list_unwrapped
452
+ to_remove_chunks_unwrapped = [last_to_remove] + to_remove_chunks_unwrapped
453
+
454
+ last_keyframe_idx = test_keyframes_list_unwrapped[-1]
455
+ last_to_remove = to_remove_chunks_unwrapped[-1]
456
+ # Find the first non-None keyframe in the chunk
457
+ first_keyframe = next(
458
+ (kf for kf in test_keyframes_list_unwrapped if kf is not None), None
459
+ )
460
+
461
+ # Find the last non-None keyframe in the chunk
462
+ last_keyframe = next(
463
+ (kf for kf in reversed(test_keyframes_list_unwrapped) if kf is not None),
464
+ None,
465
+ )
466
+
467
+ start_idx = next(
468
+ (
469
+ idx
470
+ for idx, comb in enumerate(test_interpolation_list)
471
+ if comb[0] == first_keyframe
472
+ ),
473
+ None,
474
+ )
475
+ end_idx = next(
476
+ (
477
+ idx
478
+ for idx, comb in enumerate(reversed(test_interpolation_list))
479
+ if comb[1] == last_keyframe
480
+ ),
481
+ None,
482
+ )
483
+
484
+ if start_idx is not None and end_idx is not None:
485
+ end_idx = (
486
+ len(test_interpolation_list) - 1 - end_idx
487
+ ) # Adjust for reversed enumeration
488
+ end_idx += 1
489
+ if start_idx is None:
490
+ break
491
+ if end_idx < start_idx:
492
+ end_idx = len(audio_interpolation_list)
493
+
494
+ audio_interpolation_list_chunk = audio_interpolation_list[start_idx:end_idx]
495
+ chunk_masks_interpolation = masks_interpolation[start_idx:end_idx]
496
+ gt_interpolation_chunks = gt_interpolation[start_idx:end_idx]
497
+
498
+ if torch.cuda.is_available():
499
+ audio_interpolation_list_chunk = [
500
+ chunk.half() for chunk in audio_interpolation_list_chunk
501
+ ]
502
+ chunk_masks_interpolation = [
503
+ chunk.half() for chunk in chunk_masks_interpolation
504
+ ]
505
+ gt_interpolation_chunks = [
506
+ chunk.half() for chunk in gt_interpolation_chunks
507
+ ]
508
+
509
+ progress(chunk_idx / total_chunks, desc="Generating keyframes")
510
+
511
+ # Always use autocast for FP16 computation
512
+ with amp.autocast(enabled=True):
513
+ samples_z = sample_keyframes(
514
+ model_keyframes,
515
+ chunk_audio_cond,
516
+ chunk_gt_keyframes,
517
+ chunk_masks,
518
+ condition.cuda(),
519
+ num_frames,
520
+ 24,
521
+ 0.0,
522
+ device,
523
+ embbedings.cuda() if embbedings is not None else None,
524
+ force_uc_zero_embeddings,
525
+ n_batch_keyframes,
526
+ 0,
527
+ 1.0,
528
+ None,
529
+ gt_as_cond=False,
530
+ )
531
+
532
+ if last_frame_x is not None:
533
+ # samples_x = torch.cat([last_frame_x.unsqueeze(0), samples_x], axis=0)
534
+ samples_z = torch.cat([last_frame_z.unsqueeze(0), samples_z], axis=0)
535
+
536
+ # last_frame_x = samples_x[-1]
537
+ last_frame_z = samples_z[-1]
538
+
539
+ progress(chunk_idx / total_chunks, desc="Interpolating frames")
540
+
541
+ # Always use autocast for FP16 computation
542
+ with amp.autocast(enabled=True):
543
+ vid = sample_interpolation(
544
+ model,
545
+ samples_z,
546
+ # samples_x,
547
+ audio_interpolation_list_chunk,
548
+ gt_interpolation_chunks,
549
+ chunk_masks_interpolation,
550
+ condition.cuda(),
551
+ num_frames,
552
+ device,
553
+ 1,
554
+ 24,
555
+ 0.0,
556
+ force_uc_zero_embeddings,
557
+ n_batch,
558
+ chunk_size,
559
+ 1.0,
560
+ None,
561
+ cut_audio=False,
562
+ to_remove=to_remove_chunks_unwrapped,
563
+ )
564
+
565
+ if chunk_start == 0:
566
+ complete_video = vid
567
+ else:
568
+ complete_video = np.concatenate([complete_video[:-1], vid], axis=0)
569
+
570
+ return complete_video
571
+
572
+
573
+ def process_video(video_input, audio_input, max_num_seconds):
574
+ """Main processing function to generate synchronized video"""
575
+
576
+ # Display a message to the user about the processing time
577
+ gr.Info("Processing video. This may take a while...", duration=10)
578
+ gr.Info(
579
+ "If you're tired of waiting, try playing the Wordle game in the other tab!",
580
+ duration=10,
581
+ )
582
+
583
+ # Use default media if none provided
584
+ if video_input is None:
585
+ video_input = DEFAULT_VIDEO_PATH
586
+ print(f"Using default video: {DEFAULT_VIDEO_PATH}")
587
+
588
+ if audio_input is None:
589
+ audio_input = DEFAULT_AUDIO_PATH
590
+ print(f"Using default audio: {DEFAULT_AUDIO_PATH}")
591
+
592
+ try:
593
+ # Calculate hashes for cache keys
594
+ video_path_hash = video_input
595
+ audio_path_hash = audio_input
596
+
597
+ # Check if we need to recompute video embeddings
598
+ video_cache_hit = cache["video"]["path"] == video_path_hash
599
+ audio_cache_hit = cache["audio"]["path"] == audio_path_hash
600
+
601
+ if video_cache_hit and audio_cache_hit:
602
+ print("Using cached video and audio computations")
603
+ # Make copies of cached data to avoid modifying cache
604
+ video_embedding = cache["video"]["embedding"].clone()
605
+ video_frames = cache["video"]["frames"].clone()
606
+ video_landmarks = cache["video"]["landmarks"].copy()
607
+ raw_audio = cache["audio"]["raw_audio"].clone()
608
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
609
+ hubert_embedding = cache["audio"]["hubert_embedding"].clone()
610
+ wavlm_embedding = cache["audio"]["wavlm_embedding"].clone()
611
+
612
+ # Ensure all data is truncated to the same length if needed
613
+ min_len = min(
614
+ len(video_frames),
615
+ len(raw_audio),
616
+ len(hubert_embedding),
617
+ len(wavlm_embedding),
618
+ )
619
+ video_frames = video_frames[:min_len]
620
+ video_embedding = video_embedding[:min_len]
621
+ video_landmarks = video_landmarks[:min_len]
622
+ raw_audio = raw_audio[:min_len]
623
+ hubert_embedding = hubert_embedding[:min_len]
624
+ wavlm_embedding = wavlm_embedding[:min_len]
625
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
626
+
627
+ else:
628
+ # Process video if needed
629
+ if not video_cache_hit:
630
+ print("Computing video embeddings and landmarks")
631
+ video_reader = decord.VideoReader(video_input)
632
+ decord.bridge.set_bridge("torch")
633
+
634
+ if not audio_cache_hit:
635
+ # Need to process audio to determine min_len
636
+ raw_audio = get_raw_audio(audio_input, 16000)
637
+ if len(raw_audio) == 0 or len(video_reader) == 0:
638
+ raise ValueError("Empty audio or video input")
639
+
640
+ min_len = min(len(raw_audio), len(video_reader))
641
+
642
+ # Store full audio in cache
643
+ cache["audio"]["path"] = audio_path_hash
644
+ cache["audio"]["raw_audio"] = raw_audio.clone()
645
+
646
+ # Create truncated copy for processing
647
+ raw_audio = raw_audio[:min_len]
648
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
649
+ else:
650
+ # Use cached audio - make a copy
651
+ if cache["audio"]["raw_audio"] is None:
652
+ raise ValueError("Cached audio is None")
653
+
654
+ raw_audio = cache["audio"]["raw_audio"].clone()
655
+ if len(raw_audio) == 0 or len(video_reader) == 0:
656
+ raise ValueError("Empty cached audio or video input")
657
+
658
+ min_len = min(len(raw_audio), len(video_reader))
659
+
660
+ # Create truncated copy for processing
661
+ raw_audio = raw_audio[:min_len]
662
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
663
+
664
+ # Compute video embeddings and landmarks - store full version in cache
665
+ video_embedding, video_frames = compute_video_embedding(
666
+ video_reader, len(video_reader)
667
+ )
668
+ video_landmarks = extract_video_landmarks(video_frames)
669
+
670
+ # Update video cache with full versions
671
+ cache["video"]["path"] = video_path_hash
672
+ cache["video"]["embedding"] = video_embedding
673
+ cache["video"]["frames"] = video_frames
674
+ cache["video"]["landmarks"] = video_landmarks
675
+
676
+ # Create truncated copies for processing
677
+ video_embedding = video_embedding[:min_len]
678
+ video_frames = video_frames[:min_len]
679
+ video_landmarks = video_landmarks[:min_len]
680
+
681
+ else:
682
+ # Use cached video data - make copies
683
+ print("Using cached video computations")
684
+
685
+ if (
686
+ cache["video"]["embedding"] is None
687
+ or cache["video"]["frames"] is None
688
+ or cache["video"]["landmarks"] is None
689
+ ):
690
+ raise ValueError("One or more video cache entries are None")
691
+
692
+ if not audio_cache_hit:
693
+ # New audio with cached video
694
+ raw_audio = get_raw_audio(audio_input, 16000)
695
+ if len(raw_audio) == 0:
696
+ raise ValueError("Empty audio input")
697
+
698
+ # Store full audio in cache
699
+ cache["audio"]["path"] = audio_path_hash
700
+ cache["audio"]["raw_audio"] = raw_audio.clone()
701
+
702
+ # Make copies of video data
703
+ video_embedding = cache["video"]["embedding"].clone()
704
+ video_frames = cache["video"]["frames"].clone()
705
+ video_landmarks = cache["video"]["landmarks"].copy()
706
+
707
+ # Determine truncation length and create truncated copies
708
+ min_len = min(len(raw_audio), len(video_frames))
709
+ raw_audio = raw_audio[:min_len]
710
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
711
+ video_frames = video_frames[:min_len]
712
+ video_embedding = video_embedding[:min_len]
713
+ video_landmarks = video_landmarks[:min_len]
714
+ else:
715
+ # Both video and audio are cached - should not reach here
716
+ # as it's handled in the first if statement
717
+ pass
718
+
719
+ # Process audio if needed
720
+ if not audio_cache_hit:
721
+ print("Computing audio embeddings")
722
+
723
+ # Compute audio embeddings with the truncated audio
724
+ hubert_embedding = compute_hubert_embedding(raw_audio_reshape)
725
+ wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape)
726
+
727
+ # Update audio cache with full embeddings
728
+ # Note: raw_audio was already cached above
729
+ cache["audio"]["hubert_embedding"] = hubert_embedding.clone()
730
+ cache["audio"]["wavlm_embedding"] = wavlm_embedding.clone()
731
+ else:
732
+ # Use cached audio data - make copies
733
+ if (
734
+ cache["audio"]["hubert_embedding"] is None
735
+ or cache["audio"]["wavlm_embedding"] is None
736
+ ):
737
+ raise ValueError(
738
+ "One or more audio embedding cache entries are None"
739
+ )
740
+
741
+ hubert_embedding = cache["audio"]["hubert_embedding"].clone()
742
+ wavlm_embedding = cache["audio"]["wavlm_embedding"].clone()
743
+
744
+ # Make sure embeddings match the truncated video length if needed
745
+ if "min_len" in locals() and (
746
+ min_len < len(hubert_embedding) or min_len < len(wavlm_embedding)
747
+ ):
748
+ hubert_embedding = hubert_embedding[:min_len]
749
+ wavlm_embedding = wavlm_embedding[:min_len]
750
+
751
+ # Apply max_num_seconds limit if specified
752
+ if max_num_seconds > 0:
753
+ # Convert seconds to frames (assuming 25 fps)
754
+ max_frames = int(max_num_seconds * 25)
755
+
756
+ # Truncate all data to max_frames
757
+ video_embedding = video_embedding[:max_frames]
758
+ video_frames = video_frames[:max_frames]
759
+ video_landmarks = video_landmarks[:max_frames]
760
+ hubert_embedding = hubert_embedding[:max_frames]
761
+ wavlm_embedding = wavlm_embedding[:max_frames]
762
+ raw_audio = raw_audio[:max_frames]
763
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
764
+
765
+ # Validate shapes before proceeding
766
+ assert video_embedding.shape[0] == hubert_embedding.shape[0], (
767
+ f"Video embedding length ({video_embedding.shape[0]}) doesn't match Hubert embedding length ({hubert_embedding.shape[0]})"
768
+ )
769
+ assert video_embedding.shape[0] == wavlm_embedding.shape[0], (
770
+ f"Video embedding length ({video_embedding.shape[0]}) doesn't match WavLM embedding length ({wavlm_embedding.shape[0]})"
771
+ )
772
+ assert video_embedding.shape[0] == video_landmarks.shape[0], (
773
+ f"Video embedding length ({video_embedding.shape[0]}) doesn't match landmarks length ({video_landmarks.shape[0]})"
774
+ )
775
+
776
+ print(f"Hubert embedding shape: {hubert_embedding.shape}")
777
+ print(f"WavLM embedding shape: {wavlm_embedding.shape}")
778
+ print(f"Video embedding shape: {video_embedding.shape}")
779
+ print(f"Video landmarks shape: {video_landmarks.shape}")
780
+
781
+ # Create pipeline inputs for models
782
+ (
783
+ interpolation_chunks,
784
+ keyframe_chunks,
785
+ audio_interpolation_chunks,
786
+ audio_keyframe_chunks,
787
+ emb_cond,
788
+ masks_keyframe_chunks,
789
+ masks_interpolation_chunks,
790
+ to_remove,
791
+ audio_interpolation_idx,
792
+ audio_keyframe_idx,
793
+ ) = create_pipeline_inputs(
794
+ hubert_embedding,
795
+ wavlm_embedding,
796
+ 14,
797
+ video_embedding,
798
+ video_landmarks,
799
+ overlap=1,
800
+ add_zero_flag=True,
801
+ mask_arms=None,
802
+ nose_index=28,
803
+ )
804
+
805
+ complete_video = sample(
806
+ audio_keyframe_chunks,
807
+ keyframe_chunks,
808
+ masks_keyframe_chunks,
809
+ to_remove,
810
+ audio_keyframe_idx,
811
+ 14,
812
+ "cuda",
813
+ emb_cond,
814
+ [],
815
+ 3,
816
+ 3,
817
+ audio_interpolation_idx,
818
+ audio_interpolation_chunks,
819
+ masks_interpolation_chunks,
820
+ interpolation_chunks,
821
+ keyframe_model,
822
+ interpolation_model,
823
+ )
824
+
825
+ complete_audio = rearrange(
826
+ raw_audio[: complete_video.shape[0]], "f s -> () (f s)"
827
+ )
828
+
829
+ # 4. Convert frames to video and combine with audio
830
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
831
+ output_path = temp_video.name
832
+
833
+ print("Saving video to", output_path)
834
+
835
+ save_audio_video(complete_video, audio=complete_audio, save_path=output_path)
836
+ torch.cuda.empty_cache()
837
+ return output_path
838
+
839
+ except Exception as e:
840
+ raise e
841
+ print(f"Error processing video: {str(e)}")
842
+ return None
843
+
844
+
845
+ def get_max_duration(video_input, audio_input):
846
+ """Get the maximum duration in seconds for the slider"""
847
+ try:
848
+ # Default to 60 seconds if files don't exist
849
+ if video_input is None or not os.path.exists(video_input):
850
+ video_input = DEFAULT_VIDEO_PATH
851
+
852
+ if audio_input is None or not os.path.exists(audio_input):
853
+ audio_input = DEFAULT_AUDIO_PATH
854
+
855
+ # Get video duration
856
+ video_reader = decord.VideoReader(video_input)
857
+ video_duration = len(video_reader) / video_reader.get_avg_fps()
858
+
859
+ # Get audio duration
860
+ raw_audio = get_raw_audio(audio_input, 16000)
861
+ audio_duration = len(raw_audio) / 25 # Assuming 25 fps
862
+
863
+ # Return the minimum of the two durations
864
+ return min(video_duration, audio_duration)
865
+ except Exception as e:
866
+ print(f"Error getting max duration: {str(e)}")
867
+ return 60 # Default to 60 seconds
868
+
869
+
870
+ def new_game_click(state):
871
+ """Handle the 'New Game' button click."""
872
+ message = state.new_game()
873
+ feedback_history = state.get_feedback_history()
874
+ return state, feedback_history, message
875
+
876
+
877
+ def submit_guess_click(guess, state):
878
+ """Handle the 'Submit Guess' button click."""
879
+ message = state.submit_guess(guess)
880
+ feedback_history = state.get_feedback_history()
881
+ return state, feedback_history, message
882
+
883
+
884
+ # Create Gradio interface
885
+ with gr.Blocks(title="Video Synchronization with Diffusion Models") as demo:
886
+ gr.Markdown("# Video Synchronization with Diffusion Models")
887
+ gr.Markdown(
888
+ "Upload a video and audio to create a synchronized video with the same visuals but synchronized to the new audio."
889
+ )
890
+
891
+ with gr.Tabs():
892
+ with gr.TabItem("Video Synchronization"):
893
+ with gr.Row():
894
+ with gr.Column():
895
+ video_input = gr.Video(
896
+ label="Input Video",
897
+ value=DEFAULT_VIDEO_PATH
898
+ if os.path.exists(DEFAULT_VIDEO_PATH)
899
+ else None,
900
+ width=512,
901
+ height=512,
902
+ )
903
+ audio_input = gr.Audio(
904
+ label="Input Audio",
905
+ type="filepath",
906
+ value=DEFAULT_AUDIO_PATH
907
+ if os.path.exists(DEFAULT_AUDIO_PATH)
908
+ else None,
909
+ )
910
+
911
+ max_duration = gr.State(value=60) # Default max duration
912
+
913
+ max_seconds_slider = gr.Slider(
914
+ minimum=0,
915
+ maximum=60, # Will be updated dynamically
916
+ value=0,
917
+ step=1,
918
+ label="Max Duration (seconds, 0 = full length)",
919
+ info="Limit the processing duration (0 means use full length)",
920
+ )
921
+
922
+ process_button = gr.Button("Generate Synchronized Video")
923
+
924
+ with gr.Column("Output Video"):
925
+ video_output = gr.Video(label="Output Video", width=512, height=512)
926
+
927
+ # Update slider max value when inputs change
928
+ def update_slider_max(video, audio):
929
+ max_dur = get_max_duration(video, audio)
930
+ return {"maximum": max_dur, "__type__": "update"}
931
+
932
+ video_input.change(
933
+ update_slider_max, [video_input, audio_input], [max_seconds_slider]
934
+ )
935
+ audio_input.change(
936
+ update_slider_max, [video_input, audio_input], [max_seconds_slider]
937
+ )
938
+
939
+ # Show Wordle message when processing starts and hide when complete
940
+ process_button.click(
941
+ fn=process_video,
942
+ inputs=[video_input, audio_input, max_seconds_slider],
943
+ outputs=video_output,
944
+ )
945
+
946
+ with gr.TabItem("Wordle Game"):
947
+ state = gr.State(WordleGame()) # Persist the WordleGame instance
948
+ guess_input = gr.Textbox(label="Your guess (5 letters)", max_length=5)
949
+ submit_btn = gr.Button("Submit Guess")
950
+ new_game_btn = gr.Button("New Game")
951
+ feedback_display = gr.HTML(label="Guesses")
952
+ message_display = gr.Textbox(
953
+ label="Message", interactive=False, value="Click 'New Game' to start."
954
+ )
955
+ # Connect the 'New Game' button
956
+ new_game_btn.click(
957
+ fn=new_game_click,
958
+ inputs=[state],
959
+ outputs=[state, feedback_display, message_display],
960
+ )
961
+ # Connect the 'Submit Guess' button
962
+ submit_btn.click(
963
+ fn=submit_guess_click,
964
+ inputs=[guess_input, state],
965
+ outputs=[state, feedback_display, message_display],
966
+ )
967
+
968
+ gr.Markdown("## How it works")
969
+ gr.Markdown("""
970
+ 1. The system extracts embeddings and landmarks from the input video
971
+ 2. Audio embeddings are computed from the input audio
972
+ 3. A keyframe model generates key visual frames
973
+ 4. An interpolation model creates a smooth video between keyframes
974
+ 5. The final video is rendered with the new audio
975
+ """)
976
+
977
+ if __name__ == "__main__":
978
+ demo.launch()
data_utils.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw
4
+ import cv2
5
+ from functools import partial
6
+ import math
7
+
8
+
9
+ def get_size(img):
10
+ if isinstance(img, (np.ndarray, torch.Tensor)):
11
+ return img.shape[1::-1]
12
+ else:
13
+ return img.size
14
+
15
+
16
+ def imresample(img, sz):
17
+ im_data = torch.nn.functional.interpolate(img, size=sz, mode="area")
18
+ return im_data
19
+
20
+
21
+ def crop_resize(img, box, image_size):
22
+ if isinstance(img, np.ndarray):
23
+ img = img[box[1] : box[3], box[0] : box[2]]
24
+ out = cv2.resize(
25
+ img, (image_size, image_size), interpolation=cv2.INTER_AREA
26
+ ).copy()
27
+ elif isinstance(img, torch.Tensor):
28
+ img = img[box[1] : box[3], box[0] : box[2]]
29
+ out = (
30
+ imresample(
31
+ img.permute(2, 0, 1).unsqueeze(0).float(), (image_size, image_size)
32
+ )
33
+ .byte()
34
+ .squeeze(0)
35
+ .permute(1, 2, 0)
36
+ )
37
+ else:
38
+ out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR)
39
+ return out
40
+
41
+
42
+ def fixed_image_standardization(image_tensor):
43
+ processed_tensor = (image_tensor - 127.5) / 128.0
44
+ return processed_tensor
45
+
46
+
47
+ def extract_face(img, landmarks, image_size=160, margin=0, postprocess=False):
48
+ """Extract face + margin from images given facial landmarks.
49
+
50
+ Arguments:
51
+ img {PIL.Image/torch.Tensor/np.ndarray} -- Input image(s) with shape (B, H, W, C)
52
+ landmarks {numpy.ndarray} -- Facial landmarks with shape (B, 68, 2)
53
+ image_size {int} -- Output image size in pixels. The image will be square.
54
+ margin {int} -- Margin to add to bounding box, in terms of pixels in the final image.
55
+ postprocess {bool} -- Whether to apply standardization
56
+
57
+ Returns:
58
+ torch.tensor -- tensor representing the extracted faces with shape (B, H, W, C)
59
+ """
60
+ # Calculate bounding boxes from landmarks for all faces in batch
61
+ x_min = np.min(landmarks, axis=1)[:, 0] # Shape: (B,)
62
+ y_min = np.min(landmarks, axis=1)[:, 1] # Shape: (B,)
63
+ x_max = np.max(landmarks, axis=1)[:, 0] # Shape: (B,)
64
+ y_max = np.max(landmarks, axis=1)[:, 1] # Shape: (B,)
65
+
66
+ # Calculate margin for top only
67
+ box_height = y_max - y_min
68
+ top_margin = margin * box_height / (image_size - margin)
69
+
70
+ # Create boxes for all faces
71
+ boxes = np.stack(
72
+ [
73
+ x_min,
74
+ np.maximum(y_min - top_margin, 0), # Only add margin to top
75
+ x_max,
76
+ y_max,
77
+ ],
78
+ axis=1,
79
+ ).astype(int) # Shape: (B, 4)
80
+
81
+ # Process each face in the batch
82
+ faces = []
83
+ for i in range(len(boxes)):
84
+ face = crop_resize(img[i], boxes[i], image_size)
85
+ faces.append(face)
86
+
87
+ faces = torch.stack(faces, dim=0)
88
+ faces = faces.float()
89
+
90
+ if postprocess:
91
+ faces = fixed_image_standardization(faces)
92
+
93
+ return faces
94
+
95
+
96
+ def crop_mouth_region(images, landmarks, crop_size=96):
97
+ """
98
+ Takes a fixed-size square crop centered on the mouth region.
99
+
100
+ Parameters:
101
+ - images: tensor/array of shape (num_frames, height, width, channels) or (height, width, channels)
102
+ - landmarks: numpy array of shape (num_frames, 68, 2) or (68, 2)
103
+ - crop_size: size of the square crop (both height and width)
104
+ - padding: percentage of padding around the mouth region (0.0 to 1.0)
105
+
106
+ Returns:
107
+ - List of fixed-size crops or single crop if input is single image
108
+ """
109
+ # Handle single image case
110
+ single_image = False
111
+ if len(images.shape) == 3:
112
+ images = images[None]
113
+ landmarks = landmarks[None]
114
+ single_image = True
115
+
116
+ num_frames = len(images)
117
+ crops = []
118
+
119
+ # Mouth landmarks indices (48-67 for mouth region)
120
+ mouth_indices = range(48, 68)
121
+
122
+ for i in range(num_frames):
123
+ # Get mouth landmarks for current frame
124
+ mouth_landmarks = landmarks[i][mouth_indices]
125
+
126
+ # Find center of mouth
127
+ center_x = int(np.mean(mouth_landmarks[:, 0]))
128
+ center_y = int(np.mean(mouth_landmarks[:, 1]))
129
+
130
+ # Calculate crop boundaries
131
+ half_size = crop_size // 2
132
+ left = max(0, center_x - half_size)
133
+ right = min(images.shape[2], center_x + half_size)
134
+ top = max(0, center_y - half_size)
135
+ bottom = min(images.shape[1], center_y + half_size)
136
+
137
+ # Adjust if crop would go out of bounds
138
+ if left == 0:
139
+ right = crop_size
140
+ if right == images.shape[2]:
141
+ left = images.shape[2] - crop_size
142
+ if top == 0:
143
+ bottom = crop_size
144
+ if bottom == images.shape[1]:
145
+ top = images.shape[1] - crop_size
146
+
147
+ # Take the crop
148
+ crop = images[i, top:bottom, left:right]
149
+ crops.append(crop)
150
+
151
+ return crops[0] if single_image else crops
152
+
153
+
154
+ def create_masks_from_landmarks_box(
155
+ landmark_list, img_shape, nose_index=28, dtype="uint8", box_expand=0.0
156
+ ):
157
+ height, width = img_shape[:2]
158
+ num_frames = landmark_list.shape[0]
159
+
160
+ # Initialize the masks array
161
+ masks = np.zeros((num_frames, height, width), dtype=dtype)
162
+
163
+ if 0 <= box_expand < 1:
164
+ box_expand = int(box_expand * width)
165
+
166
+ for i in range(num_frames):
167
+ # Get the landmarks for the current frame
168
+ landmarks = landmark_list[i]
169
+
170
+ # Get the y-coordinate of the nose landmark
171
+ nose_point_h = landmarks[nose_index, 1]
172
+ cut_h = nose_point_h
173
+
174
+ # Find the leftmost and rightmost landmarks
175
+ far_left_index = np.argmin(landmarks[:, 0])
176
+ far_right_index = np.argmax(landmarks[:, 0])
177
+
178
+ # Define the points for the mask contour
179
+ left_up_point = np.array(
180
+ [landmarks[far_left_index][0], cut_h - box_expand], dtype=np.int32
181
+ )
182
+ left_down_point = np.array(
183
+ [landmarks[far_left_index][0], height], dtype=np.int32
184
+ )
185
+ right_up_point = np.array(
186
+ [landmarks[far_right_index][0], cut_h - box_expand], dtype=np.int32
187
+ )
188
+ right_down_point = np.array(
189
+ [landmarks[far_right_index][0], height], dtype=np.int32
190
+ )
191
+
192
+ # Define the contour
193
+ contour = np.array(
194
+ [[left_up_point, left_down_point, right_down_point, right_up_point]]
195
+ )
196
+
197
+ # Draw the contour on the mask
198
+ cv2.drawContours(masks[i], [contour], -1, color=(1), thickness=cv2.FILLED)
199
+
200
+ return torch.from_numpy(masks)
201
+
202
+
203
+ def create_masks_from_landmarks_full_size(
204
+ landmarks_batch,
205
+ image_height,
206
+ image_width,
207
+ start_index=48,
208
+ end_index=68,
209
+ offset=0,
210
+ nose_index=33,
211
+ ):
212
+ """
213
+ Efficiently creates a batch of masks using vectorized operations where each mask has ones from the highest
214
+ landmark in the specified range (adjusted by an offset) to the bottom of the image, and zeros otherwise.
215
+
216
+ Parameters:
217
+ - landmarks_batch (np.array): An array of shape (B, 68, 2) containing facial landmarks for multiple samples.
218
+ - image_height (int): The height of the image for which masks are created.
219
+ - image_width (int): The width of the image for which masks are created.
220
+ - start_index (int): The starting index of the range to check (inclusive).
221
+ - end_index (int): The ending index of the range to check (inclusive).
222
+ - offset (int): An offset to add or subtract from the y-coordinate of the highest landmark.
223
+
224
+ Returns:
225
+ - np.array: An array of masks of shape (B, image_height, image_width) for each batch.
226
+ """
227
+ # Extract the y-coordinates for the specified range across all batches
228
+ y_coords = landmarks_batch[:, nose_index : nose_index + 1, 1]
229
+
230
+ # Find the index of the minimum y-coordinate in the specified range for each batch
231
+ min_y_indices = np.argmin(y_coords, axis=1)
232
+
233
+ # Gather the highest landmarks' y-coordinates using the indices found
234
+ highest_y_coords = y_coords[np.arange(len(y_coords)), min_y_indices]
235
+
236
+ if abs(offset) < 1 and abs(offset) > 0:
237
+ offset = int(offset * image_height)
238
+
239
+ # Apply the offset to the highest y-coordinate
240
+ adjusted_y_coords = highest_y_coords + offset
241
+
242
+ # Clip the coordinates to stay within image boundaries
243
+ adjusted_y_coords = np.clip(adjusted_y_coords, 0, image_height - 1)
244
+
245
+ # Use broadcasting to create a mask without loops
246
+ # Create a range of indices from 0 to image_height - 1
247
+ all_indices = np.arange(image_height)
248
+
249
+ # Compare each index in 'all_indices' to each 'adjusted_y_coord' in the batch
250
+ # 'all_indices' has shape (image_height,), we reshape to (1, image_height) to broadcast against (B, 1)
251
+ mask_2d = (all_indices >= adjusted_y_coords[:, None]).astype(int)
252
+
253
+ # Extend the 2D mask to a full 3D mask of size (B, image_height, image_width)
254
+ full_mask = np.tile(mask_2d[:, :, np.newaxis], (1, 1, image_width))
255
+
256
+ return torch.from_numpy(full_mask)
257
+
258
+
259
+ def expand_polygon(polygon, expand_size):
260
+ """
261
+ Expands the polygon outward by a specified number of pixels.
262
+
263
+ Parameters:
264
+ - polygon (list of tuples): The polygon points as (x, y).
265
+ - expand_size (int): The number of pixels to expand the polygon outward.
266
+
267
+ Returns:
268
+ - expanded_polygon (list of tuples): The expanded polygon points as (x, y).
269
+ """
270
+ if expand_size == 0:
271
+ return polygon
272
+
273
+ # Calculate centroid of the polygon
274
+ centroid_x = sum([point[0] for point in polygon]) / len(polygon)
275
+ centroid_y = sum([point[1] for point in polygon]) / len(polygon)
276
+
277
+ # Expand each point outward from the centroid
278
+ expanded_polygon = []
279
+ for x, y in polygon:
280
+ vector_x = x - centroid_x
281
+ vector_y = y - centroid_y
282
+ length = np.sqrt(vector_x**2 + vector_y**2)
283
+ if length == 0:
284
+ expanded_polygon.append((x, y))
285
+ else:
286
+ new_x = x + expand_size * (vector_x / length)
287
+ new_y = y + expand_size * (vector_y / length)
288
+ expanded_polygon.append((int(new_x), int(new_y)))
289
+
290
+ return expanded_polygon
291
+
292
+
293
+ def create_masks_from_landmarks_mouth(
294
+ landmark_list, img_shape, nose_index=33, dtype="uint8", box_expand=0.0
295
+ ):
296
+ height, width = img_shape[:2]
297
+ num_frames = landmark_list.shape[0]
298
+
299
+ # Initialize the masks array
300
+ masks = np.zeros((num_frames, height, width), dtype=dtype)
301
+
302
+ if 0 <= box_expand < 1:
303
+ box_expand = int(box_expand * width)
304
+
305
+ for i in range(num_frames):
306
+ # Get the landmarks for the current frame
307
+ landmarks = landmark_list[i]
308
+
309
+ # Get the y-coordinate of the nose landmark
310
+ nose_point_h = landmarks[nose_index, 1]
311
+ cut_h = nose_point_h
312
+
313
+ # Find the leftmost and rightmost landmarks
314
+ far_left_index = np.argmin(landmarks[:, 0])
315
+ far_right_index = np.argmax(landmarks[:, 0])
316
+
317
+ # Find lowest landmark y-coordinate
318
+ lowest_y = np.max(landmarks[:, 1])
319
+ # Add box_expand to the lowest point
320
+ lowest_y = min(height, lowest_y + box_expand)
321
+
322
+ # Define the points for the mask contour
323
+ left_up_point = np.array(
324
+ [landmarks[far_left_index][0], cut_h - box_expand], dtype=np.int32
325
+ )
326
+ left_down_point = np.array(
327
+ [landmarks[far_left_index][0], lowest_y], dtype=np.int32
328
+ )
329
+ right_up_point = np.array(
330
+ [landmarks[far_right_index][0], cut_h - box_expand], dtype=np.int32
331
+ )
332
+ right_down_point = np.array(
333
+ [landmarks[far_right_index][0], lowest_y], dtype=np.int32
334
+ )
335
+
336
+ # Define the contour
337
+ contour = np.array(
338
+ [[left_up_point, left_down_point, right_down_point, right_up_point]]
339
+ )
340
+
341
+ # Draw the contour on the mask
342
+ cv2.drawContours(masks[i], [contour], -1, color=(1), thickness=cv2.FILLED)
343
+
344
+ return torch.from_numpy(masks)
345
+
346
+
347
+ def create_face_mask_from_landmarks(
348
+ landmarks_batch, image_height, image_width, mask_expand=0
349
+ ):
350
+ """
351
+ Creates a batch of masks where each mask covers the face region using landmarks.
352
+
353
+ Parameters:
354
+ - landmarks_batch (np.array): An array of shape (B, 68, 2) containing facial landmarks for multiple samples.
355
+ - image_height (int): The height of the image for which masks are created.
356
+ - image_width (int): The width of the image for which masks are created.
357
+ - mask_expand (int): The number of pixels to expand the mask outward.
358
+
359
+ Returns:
360
+ - np.array: An array of masks of shape (B, image_height, image_width) for each batch.
361
+ """
362
+ # Initialize an array to hold all masks
363
+ masks = np.zeros(
364
+ (landmarks_batch.shape[0], image_height, image_width), dtype=np.uint8
365
+ )
366
+
367
+ if abs(mask_expand) < 1 and abs(mask_expand) > 0:
368
+ mask_expand = int(mask_expand * image_height)
369
+
370
+ for i, landmarks in enumerate(landmarks_batch):
371
+ # Create a blank image for each mask
372
+ mask = Image.new("L", (image_width, image_height), 0)
373
+ draw = ImageDraw.Draw(mask)
374
+
375
+ # Extract relevant landmarks for the face
376
+ jawline_landmarks = landmarks[2:15] # Jawline
377
+ # upper_face_landmarks = landmarks[17:27] # Eyebrows and top of nose bridge
378
+
379
+ # Combine landmarks to form a polygon around the face
380
+ # face_polygon = np.concatenate((jawline_landmarks, upper_face_landmarks[::-1]), axis=0)
381
+ face_polygon = jawline_landmarks
382
+
383
+ # Convert landmarks to a list of tuples
384
+ face_polygon = [(int(x), int(y)) for x, y in face_polygon]
385
+
386
+ # Expand the polygon if necessary
387
+ expanded_polygon = expand_polygon(face_polygon, mask_expand)
388
+
389
+ # Draw the polygon and fill it
390
+ draw.polygon(expanded_polygon, outline=1, fill=1)
391
+
392
+ # Convert mask to numpy array and add it to the batch of masks
393
+ masks[i] = np.array(mask)
394
+
395
+ return torch.from_numpy(masks)
396
+
397
+
398
+ ALL_FIXED_POINTS = (
399
+ [i for i in range(0, 4)]
400
+ + [i for i in range(13, 17)]
401
+ + [i for i in range(27, 36)]
402
+ + [36, 39, 42, 45]
403
+ )
404
+
405
+
406
+ def gaussian_kernel(sigma, width, height):
407
+ """Create a 2D Gaussian kernel."""
408
+ x = torch.arange(0, width, 1) - width // 2
409
+ y = torch.arange(0, height, 1) - height // 2
410
+ x = x.float()
411
+ y = y.float()
412
+ x2 = x**2
413
+ y2 = y[:, None] ** 2
414
+ g = torch.exp(-(x2 + y2) / (2 * sigma**2))
415
+ return g / g.sum()
416
+
417
+
418
+ def generate_hm(landmarks, height, width, n_points="all", sigma=3):
419
+ if n_points == "all":
420
+ Nlandmarks = range(len(landmarks))
421
+ elif n_points == "fixed":
422
+ Nlandmarks = ALL_FIXED_POINTS
423
+ elif n_points == "stable":
424
+ Nlandmarks = [33, 36, 39, 42, 45]
425
+
426
+ kernel = gaussian_kernel(sigma, width, height)
427
+ hm = torch.zeros((height, width))
428
+ for I in Nlandmarks:
429
+ x0, y0 = landmarks[I]
430
+ x0, y0 = int(x0), int(y0)
431
+ left, right = max(0, x0 - width // 2), min(width, x0 + width // 2)
432
+ top, bottom = max(0, y0 - height // 2), min(height, y0 + height // 2)
433
+ hm[top:bottom, left:right] += kernel[
434
+ max(0, -y0 + height // 2) : min(height, height - y0 + height // 2),
435
+ max(0, -x0 + width // 2) : min(width, width - x0 + width // 2),
436
+ ]
437
+ # Normalize the heatmap to have values between 0 and 1
438
+ max_val = hm.max()
439
+ if max_val > 0:
440
+ hm /= max_val
441
+ return hm
442
+
443
+
444
+ def get_heatmap(landmarks, image_size, or_im_size, n_points="stable", sigma=4):
445
+ stack = []
446
+ seq_length = landmarks.shape[0]
447
+ if or_im_size[0] != image_size[0] or or_im_size[1] != image_size[1]:
448
+ landmarks = scale_landmarks(landmarks, or_im_size, image_size)
449
+ gen_single_heatmap = partial(
450
+ generate_hm,
451
+ height=image_size[0],
452
+ width=image_size[1],
453
+ n_points=n_points,
454
+ sigma=sigma,
455
+ )
456
+ for i in range(seq_length):
457
+ stack.append(gen_single_heatmap(landmarks[i]))
458
+
459
+ return torch.stack(stack, axis=0).unsqueeze(0) # (1, seq_length, height, width)
460
+
461
+
462
+ def scale_landmarks(landmarks, original_size, target_size):
463
+ """
464
+ Scale landmarks from original size to target size.
465
+
466
+ Parameters:
467
+ - landmarks (np.array): An array of shape (N, 2) containing facial landmarks.
468
+ - original_size (tuple): The size (height, width) for which the landmarks are currently scaled.
469
+ - target_size (tuple): The size (height, width) to which landmarks should be scaled.
470
+
471
+ Returns:
472
+ - scaled_landmarks (np.array): Scaled landmarks.
473
+ """
474
+ scale_y = target_size[0] / original_size[0]
475
+ scale_x = target_size[1] / original_size[1]
476
+ scaled_landmarks = landmarks * np.array([scale_x, scale_y])
477
+ return scaled_landmarks.astype(int)
478
+
479
+
480
+ def draw_kps_image(
481
+ image_shape,
482
+ original_size,
483
+ landmarks,
484
+ color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255)],
485
+ rgb=True,
486
+ pts_width=4,
487
+ ):
488
+ stick_width = pts_width
489
+ limb_seq = np.array([[0, 2], [1, 2]])
490
+ kps = landmarks[[36, 45, 33], :]
491
+ kps = scale_landmarks(kps, original_size, image_shape)
492
+ if not rgb: # Grayscale image
493
+ canvas = np.zeros((image_shape[0], image_shape[1], 1))
494
+ color_mode = "grayscale"
495
+ else: # Color image
496
+ canvas = np.zeros((image_shape[0], image_shape[1], 3))
497
+ color_mode = "color"
498
+
499
+ polygon_cache = {}
500
+
501
+ for index in limb_seq:
502
+ color = color_list[index[0]]
503
+ if color_mode == "grayscale":
504
+ color = (
505
+ int(0.299 * color[2] + 0.587 * color[1] + 0.114 * color[0]),
506
+ ) # Convert to grayscale intensity
507
+
508
+ x = kps[index][:, 0]
509
+ y = kps[index][:, 1]
510
+ length = np.sqrt((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2)
511
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
512
+
513
+ cache_key = (
514
+ color,
515
+ int(np.mean(x)),
516
+ int(np.mean(y)),
517
+ int(length / 2),
518
+ int(angle),
519
+ )
520
+ if cache_key not in polygon_cache:
521
+ polygon_cache[cache_key] = cv2.ellipse2Poly(
522
+ (int(np.mean(x)), int(np.mean(y))),
523
+ (int(length / 2), stick_width),
524
+ int(angle),
525
+ 0,
526
+ 360,
527
+ 1,
528
+ )
529
+
530
+ polygon = polygon_cache[cache_key]
531
+ cv2.fillConvexPoly(canvas, polygon, [int(c * 0.6) for c in color])
532
+
533
+ for idx, kp in enumerate(kps):
534
+ if color_mode == "grayscale":
535
+ color = (
536
+ int(
537
+ 0.299 * color_list[idx][2]
538
+ + 0.587 * color_list[idx][1]
539
+ + 0.114 * color_list[idx][0]
540
+ ),
541
+ )
542
+ else:
543
+ color = color_list[idx]
544
+ cv2.circle(canvas, (int(kp[0]), int(kp[1])), pts_width, color, -1)
545
+
546
+ return canvas.transpose(2, 0, 1)
547
+
548
+
549
+ def create_landmarks_image(
550
+ landmarks,
551
+ original_size=(772, 772),
552
+ target_size=(772, 772),
553
+ point_size=3,
554
+ n_points="all",
555
+ dim=3,
556
+ ):
557
+ """
558
+ Creates an image of landmarks on a black background using efficient NumPy operations.
559
+
560
+ Parameters:
561
+ - landmarks (np.array): An array of shape (68, 2) containing facial landmarks.
562
+ - image_size (tuple): The size of the output image (height, width).
563
+ - point_size (int): The radius of each landmark point in pixels.
564
+
565
+ Returns:
566
+ - img (np.array): An image array with landmarks plotted.
567
+ """
568
+ if n_points == "all":
569
+ indexes = range(len(landmarks))
570
+ elif n_points == "fixed":
571
+ indexes = ALL_FIXED_POINTS
572
+ elif n_points == "stable":
573
+ indexes = [33, 36, 39, 42, 45]
574
+
575
+ landmarks = landmarks[indexes]
576
+
577
+ img = np.zeros(target_size, dtype=np.uint8)
578
+
579
+ landmarks = scale_landmarks(landmarks, original_size, target_size)
580
+
581
+ # Ensure the landmarks are in bounds and integer
582
+ landmarks = np.clip(
583
+ landmarks, [0, 0], [target_size[1] - 1, target_size[0] - 1]
584
+ ).astype(int)
585
+
586
+ # Get x and y coordinates from landmarks
587
+ x, y = landmarks[:, 0], landmarks[:, 1]
588
+
589
+ # Define a grid offset based on point_size around each landmark
590
+ offset = np.arange(-point_size // 2, point_size // 2 + 1)
591
+ grid_x, grid_y = np.meshgrid(offset, offset, indexing="ij")
592
+
593
+ # Calculate the full set of x and y coordinates for the points
594
+ full_x = x[:, np.newaxis, np.newaxis] + grid_x[np.newaxis, :, :]
595
+ full_y = y[:, np.newaxis, np.newaxis] + grid_y[np.newaxis, :, :]
596
+
597
+ # Clip the coordinates to stay within image boundaries
598
+ full_x = np.clip(full_x, 0, target_size[1] - 1)
599
+ full_y = np.clip(full_y, 0, target_size[0] - 1)
600
+
601
+ # Flatten the arrays to use them as indices
602
+ full_x = full_x.ravel()
603
+ full_y = full_y.ravel()
604
+
605
+ # Set the points in the image
606
+ img[full_y, full_x] = 255
607
+
608
+ return np.stack([img] * dim, axis=0)
609
+
610
+
611
+ def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
612
+ len_file = audio.shape[-1]
613
+
614
+ if max_len_sec or max_len_raw:
615
+ max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
616
+ if len_file < int(max_len):
617
+ # dummy = np.zeros((1, int(max_len_sec * sr) - len_file))
618
+ # extened_wav = np.concatenate((audio_data, dummy[0]))
619
+ extened_wav = torch.nn.functional.pad(
620
+ audio, (0, int(max_len) - len_file), "constant"
621
+ )
622
+ else:
623
+ extened_wav = audio[:, : int(max_len)]
624
+ else:
625
+ extened_wav = audio
626
+
627
+ return extened_wav
628
+
629
+
630
+ def ssim_to_bin(ssim_score):
631
+ # Normalize the SSIM score to a 0-100 scale
632
+ normalized_diff_ssim = (1 - ((ssim_score + 1) / 2)) * 100
633
+ # Assign to one of the 100 bins
634
+ bin_index = float(min(np.floor(normalized_diff_ssim), 99))
635
+ return bin_index
inference_functions.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+ import numpy as np
4
+ from einops import rearrange, repeat
5
+ import math
6
+
7
+
8
+ def get_unique_embedder_keys_from_conditioner(conditioner):
9
+ return list(set([x.input_key for x in conditioner.embedders]))
10
+
11
+
12
+ def get_batch(keys, value_dict, N, T, device):
13
+ batch = {}
14
+ batch_uc = {}
15
+
16
+ for key in keys:
17
+ if key == "fps_id":
18
+ batch[key] = (
19
+ torch.tensor([value_dict["fps_id"]])
20
+ .to(device)
21
+ .repeat(int(math.prod(N)))
22
+ )
23
+ elif key == "motion_bucket_id":
24
+ batch[key] = (
25
+ torch.tensor([value_dict["motion_bucket_id"]])
26
+ .to(device)
27
+ .repeat(int(math.prod(N)))
28
+ )
29
+ elif key == "cond_aug":
30
+ batch[key] = repeat(
31
+ torch.tensor([value_dict["cond_aug"]]).to(device),
32
+ "1 -> b",
33
+ b=math.prod(N),
34
+ )
35
+ elif key == "cond_frames":
36
+ batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
37
+ elif key == "cond_frames_without_noise":
38
+ batch[key] = repeat(
39
+ value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
40
+ )
41
+ else:
42
+ batch[key] = value_dict[key]
43
+
44
+ if T is not None:
45
+ batch["num_video_frames"] = T
46
+
47
+ for key in batch.keys():
48
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
49
+ batch_uc[key] = torch.clone(batch[key])
50
+ return batch, batch_uc
51
+
52
+
53
+ def merge_overlapping_segments(segments: torch.Tensor, overlap: int) -> torch.Tensor:
54
+ """
55
+ Merges overlapping segments by averaging overlapping frames.
56
+ Segments have shape (b, t, ...), where 'b' is the number of segments,
57
+ 't' is frames per segment, and '...' are other dimensions.
58
+
59
+ Args:
60
+ segments: Tensor of shape (b, t, ...)
61
+ overlap: Integer, number of frames that overlap between consecutive segments
62
+
63
+ Returns:
64
+ Tensor of the merged video
65
+ """
66
+ # Get the shape details
67
+ b, t, *other_dims = segments.shape
68
+ num_frames = (b - 1) * (
69
+ t - overlap
70
+ ) + t # Calculate the total number of frames in the merged video
71
+
72
+ # Initialize the output tensor and a count tensor to keep track of contributions for averaging
73
+ output_shape = [num_frames] + other_dims
74
+ output = torch.zeros(output_shape, dtype=segments.dtype, device=segments.device)
75
+ count = torch.zeros(output_shape, dtype=torch.float32, device=segments.device)
76
+
77
+ current_index = 0
78
+ for i in range(b):
79
+ end_index = current_index + t
80
+ # Add the segment to the output tensor
81
+ output[current_index:end_index] += rearrange(segments[i], "... -> ...")
82
+ # Increment the count tensor for each frame that's added
83
+ count[current_index:end_index] += 1
84
+ # Update the starting index for the next segment
85
+ current_index += t - overlap
86
+
87
+ # Avoid division by zero
88
+ count[count == 0] = 1
89
+ # Average the frames where there's overlap
90
+ output /= count
91
+
92
+ return output
93
+
94
+
95
+ def get_batch_overlap(
96
+ keys: List[str],
97
+ value_dict: Dict[str, Any],
98
+ N: Tuple[int, ...],
99
+ T: Optional[int],
100
+ device: str,
101
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
102
+ """
103
+ Create a batch dictionary with overlapping frames for model input.
104
+
105
+ Args:
106
+ keys: List of keys to include in the batch
107
+ value_dict: Dictionary containing values for each key
108
+ N: Batch dimensions
109
+ T: Number of frames (optional)
110
+ device: Device to place tensors on
111
+
112
+ Returns:
113
+ Tuple of (batch dictionary, unconditional batch dictionary)
114
+ """
115
+ batch = {}
116
+ batch_uc = {}
117
+
118
+ for key in keys:
119
+ if key == "fps_id":
120
+ batch[key] = (
121
+ torch.tensor([value_dict["fps_id"]])
122
+ .to(device)
123
+ .repeat(int(math.prod(N)))
124
+ )
125
+ elif key == "motion_bucket_id":
126
+ batch[key] = (
127
+ torch.tensor([value_dict["motion_bucket_id"]])
128
+ .to(device)
129
+ .repeat(int(math.prod(N)))
130
+ )
131
+ elif key == "cond_aug":
132
+ batch[key] = repeat(
133
+ torch.tensor([value_dict["cond_aug"]]).to(device),
134
+ "1 -> b",
135
+ b=math.prod(N),
136
+ )
137
+ elif key == "cond_frames":
138
+ batch[key] = repeat(value_dict["cond_frames"], "b ... -> (b t) ...", t=N[0])
139
+ elif key == "cond_frames_without_noise":
140
+ batch[key] = repeat(
141
+ value_dict["cond_frames_without_noise"], "b ... -> (b t) ...", t=N[0]
142
+ )
143
+ else:
144
+ batch[key] = value_dict[key]
145
+
146
+ if T is not None:
147
+ batch["num_video_frames"] = T
148
+
149
+ for key in batch.keys():
150
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
151
+ batch_uc[key] = torch.clone(batch[key])
152
+ return batch, batch_uc
153
+
154
+
155
+ @torch.inference_mode()
156
+ def sample_keyframes(
157
+ model_keyframes: Any,
158
+ audio_list: torch.Tensor,
159
+ gt_list: torch.Tensor,
160
+ masks_list: torch.Tensor,
161
+ condition: torch.Tensor,
162
+ num_frames: int,
163
+ fps_id: int,
164
+ cond_aug: float,
165
+ device: str,
166
+ embbedings: Optional[torch.Tensor],
167
+ force_uc_zero_embeddings: List[str],
168
+ n_batch_keyframes: int,
169
+ added_frames: int,
170
+ strength: float,
171
+ scale: Optional[Union[float, List[float]]],
172
+ gt_as_cond: bool = False,
173
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
174
+ """
175
+ Sample keyframes using the keyframe generation model.
176
+
177
+ Args:
178
+ model_keyframes: The keyframe generation model
179
+ audio_list: List of audio embeddings
180
+ gt_list: List of ground truth frames
181
+ masks_list: List of masks
182
+ condition: Conditioning tensor
183
+ num_frames: Number of frames to generate
184
+ fps_id: FPS ID
185
+ cond_aug: Conditioning augmentation factor
186
+ device: Device to use for computation
187
+ embbedings: Optional embeddings
188
+ force_uc_zero_embeddings: List of embeddings to force to zero in unconditional case
189
+ n_batch_keyframes: Batch size for keyframe generation
190
+ added_frames: Number of additional frames
191
+ strength: Strength parameter for sampling
192
+ scale: Scale parameter for guidance
193
+ gt_as_cond: Whether to use ground truth as conditioning
194
+
195
+ Returns:
196
+ Tuple of (latent samples, decoded samples)
197
+ """
198
+ if scale is not None:
199
+ model_keyframes.sampler.guider.set_scale(scale)
200
+ # samples_list = []
201
+ samples_z_list = []
202
+ # samples_x_list = []
203
+
204
+ for i in range(audio_list.shape[0]):
205
+ H, W = condition.shape[-2:]
206
+ assert condition.shape[1] == 3
207
+ F = 8
208
+ C = 4
209
+ shape = (num_frames, C, H // F, W // F)
210
+
211
+ audio_cond = audio_list[i].unsqueeze(0)
212
+
213
+ value_dict: Dict[str, Any] = {}
214
+ value_dict["fps_id"] = fps_id
215
+ value_dict["cond_aug"] = cond_aug
216
+ value_dict["cond_frames_without_noise"] = condition
217
+ if embbedings is not None:
218
+ value_dict["cond_frames"] = embbedings + cond_aug * torch.randn_like(
219
+ embbedings
220
+ )
221
+ else:
222
+ value_dict["cond_frames"] = condition + cond_aug * torch.randn_like(
223
+ condition
224
+ )
225
+ gt = rearrange(gt_list[i].unsqueeze(0), "b t c h w -> b c t h w").to(device)
226
+
227
+ if gt_as_cond:
228
+ value_dict["cond_frames"] = gt[:, :, 0]
229
+
230
+ value_dict["cond_aug"] = cond_aug
231
+ value_dict["audio_emb"] = audio_cond
232
+
233
+ value_dict["gt"] = gt
234
+ value_dict["masks"] = masks_list[i].unsqueeze(0).transpose(1, 2).to(device)
235
+
236
+ with torch.no_grad():
237
+ batch, batch_uc = get_batch(
238
+ get_unique_embedder_keys_from_conditioner(model_keyframes.conditioner),
239
+ value_dict,
240
+ [1, 1],
241
+ T=num_frames,
242
+ device=device,
243
+ )
244
+
245
+ c, uc = model_keyframes.conditioner.get_unconditional_conditioning(
246
+ batch,
247
+ batch_uc=batch_uc,
248
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
249
+ )
250
+
251
+ for k in ["crossattn"]:
252
+ if c[k].shape[1] != num_frames:
253
+ uc[k] = repeat(
254
+ uc[k],
255
+ "b ... -> b t ...",
256
+ t=num_frames,
257
+ )
258
+ uc[k] = rearrange(
259
+ uc[k],
260
+ "b t ... -> (b t) ...",
261
+ t=num_frames,
262
+ )
263
+ c[k] = repeat(
264
+ c[k],
265
+ "b ... -> b t ...",
266
+ t=num_frames,
267
+ )
268
+ c[k] = rearrange(
269
+ c[k],
270
+ "b t ... -> (b t) ...",
271
+ t=num_frames,
272
+ )
273
+
274
+ video = torch.randn(shape, device=device)
275
+
276
+ additional_model_inputs: Dict[str, torch.Tensor] = {}
277
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
278
+ n_batch_keyframes, num_frames
279
+ ).to(device)
280
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
281
+
282
+ def denoiser(
283
+ input: torch.Tensor, sigma: torch.Tensor, c: Dict[str, torch.Tensor]
284
+ ) -> torch.Tensor:
285
+ return model_keyframes.denoiser(
286
+ model_keyframes.model,
287
+ input,
288
+ sigma,
289
+ c,
290
+ **additional_model_inputs,
291
+ )
292
+
293
+ samples_z = model_keyframes.sampler(
294
+ denoiser, video, cond=c, uc=uc, strength=strength
295
+ )
296
+ samples_z_list.append(samples_z)
297
+ # samples_x_list.append(samples_x)
298
+ # samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
299
+ # samples_list.append(samples)
300
+
301
+ video = None
302
+
303
+ # samples = (
304
+ # torch.concat(samples_list)[:-added_frames]
305
+ # if added_frames > 0
306
+ # else torch.concat(samples_list)
307
+ # )
308
+ samples_z = (
309
+ torch.concat(samples_z_list)[:-added_frames]
310
+ if added_frames > 0
311
+ else torch.concat(samples_z_list)
312
+ )
313
+ # samples_x = (
314
+ # torch.concat(samples_x_list)[:-added_frames]
315
+ # if added_frames > 0
316
+ # else torch.concat(samples_x_list)
317
+ # )
318
+
319
+ return samples_z
320
+
321
+
322
+ @torch.inference_mode()
323
+ def sample_interpolation(
324
+ model: Any,
325
+ samples_z: torch.Tensor,
326
+ # samples_x: torch.Tensor,
327
+ audio_interpolation_list: List[torch.Tensor],
328
+ gt_chunks: List[torch.Tensor],
329
+ masks_chunks: List[torch.Tensor],
330
+ condition: torch.Tensor,
331
+ num_frames: int,
332
+ device: str,
333
+ overlap: int,
334
+ fps_id: int,
335
+ cond_aug: float,
336
+ force_uc_zero_embeddings: List[str],
337
+ n_batch: int,
338
+ chunk_size: Optional[int],
339
+ strength: float,
340
+ scale: Optional[float] = None,
341
+ cut_audio: bool = False,
342
+ to_remove: List[bool] = [],
343
+ ) -> np.ndarray:
344
+ """
345
+ Sample interpolation frames between keyframes.
346
+
347
+ Args:
348
+ model: The interpolation model
349
+ samples_z: Latent samples from keyframe generation
350
+ samples_x: Decoded samples from keyframe generation
351
+ audio_interpolation_list: List of audio embeddings for interpolation
352
+ gt_chunks: Ground truth video chunks
353
+ masks_chunks: Mask chunks for conditional generation
354
+ condition: Visual conditioning
355
+ num_frames: Number of frames to generate
356
+ device: Device to run inference on
357
+ overlap: Number of frames to overlap between segments
358
+ fps_id: FPS ID for conditioning
359
+ motion_bucket_id: Motion bucket ID for conditioning
360
+ cond_aug: Conditioning augmentation strength
361
+ force_uc_zero_embeddings: Keys to zero out in unconditional embeddings
362
+ n_batch: Batch size for generation
363
+ chunk_size: Size of chunks for processing (to manage memory)
364
+ strength: Strength of the conditioning
365
+ scale: Optional scale for classifier-free guidance
366
+ cut_audio: Whether to cut audio embeddings
367
+ to_remove: List of flags indicating which frames to remove
368
+
369
+ Returns:
370
+ Generated video frames as numpy array
371
+ """
372
+ if scale is not None:
373
+ model.sampler.guider.set_scale(scale)
374
+
375
+ # Creating condition for interpolation model. We need to create a list of inputs, each input is [first, last]
376
+ # The first and last are the first and last frames of the interpolation
377
+ # interpolation_cond_list = []
378
+ interpolation_cond_list_emb = []
379
+
380
+ # samples_x = [sample for i, sample in zip(to_remove, samples_x) if not i]
381
+ samples_z = [sample for i, sample in zip(to_remove, samples_z) if not i]
382
+
383
+ for i in range(0, len(samples_z) - 1, overlap if overlap > 0 else 2):
384
+ # interpolation_cond_list.append(
385
+ # torch.stack([samples_x[i], samples_x[i + 1]], dim=1)
386
+ # )
387
+ interpolation_cond_list_emb.append(
388
+ torch.stack([samples_z[i], samples_z[i + 1]], dim=1)
389
+ )
390
+
391
+ # condition = torch.stack(interpolation_cond_list).to(device)
392
+ audio_cond = torch.stack(audio_interpolation_list).to(device)
393
+ embbedings = torch.stack(interpolation_cond_list_emb).to(device)
394
+
395
+ gt_chunks = torch.stack(gt_chunks).to(device)
396
+ masks_chunks = torch.stack(masks_chunks).to(device)
397
+
398
+ H, W = 512, 512
399
+ F = 8
400
+ C = 4
401
+ shape = (num_frames * audio_cond.shape[0], C, H // F, W // F)
402
+
403
+ value_dict: Dict[str, Any] = {}
404
+ value_dict["fps_id"] = fps_id
405
+ value_dict["cond_aug"] = cond_aug
406
+ # value_dict["cond_frames_without_noise"] = condition
407
+
408
+ value_dict["cond_frames"] = embbedings
409
+ value_dict["cond_aug"] = cond_aug
410
+ if cut_audio:
411
+ value_dict["audio_emb"] = audio_cond[:, :, :, :768]
412
+ else:
413
+ value_dict["audio_emb"] = audio_cond
414
+
415
+ value_dict["gt"] = rearrange(gt_chunks, "b t c h w -> b c t h w").to(device)
416
+ value_dict["masks"] = masks_chunks.transpose(1, 2).to(device)
417
+
418
+ with torch.no_grad():
419
+ with torch.autocast(device):
420
+ batch, batch_uc = get_batch_overlap(
421
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
422
+ value_dict,
423
+ [1, num_frames],
424
+ T=num_frames,
425
+ device=device,
426
+ )
427
+
428
+ c, uc = model.conditioner.get_unconditional_conditioning(
429
+ batch,
430
+ batch_uc=batch_uc,
431
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
432
+ )
433
+
434
+ for k in ["crossattn"]:
435
+ if c[k].shape[1] != num_frames:
436
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
437
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
438
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
439
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
440
+
441
+ video = torch.randn(shape, device=device)
442
+
443
+ additional_model_inputs: Dict[str, torch.Tensor] = {}
444
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
445
+ n_batch, num_frames
446
+ ).to(device)
447
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
448
+
449
+ # Debug information
450
+ print(
451
+ f"Shapes - Embeddings: {embbedings.shape}, "
452
+ f"Audio: {audio_cond.shape}, Video: {shape}, Additional inputs: {additional_model_inputs}"
453
+ )
454
+
455
+ if chunk_size is not None:
456
+ chunk_size = chunk_size * num_frames
457
+
458
+ def denoiser(
459
+ input: torch.Tensor, sigma: torch.Tensor, c: Dict[str, torch.Tensor]
460
+ ) -> torch.Tensor:
461
+ return model.denoiser(
462
+ model.model,
463
+ input,
464
+ sigma,
465
+ c,
466
+ num_overlap_frames=overlap,
467
+ num_frames=num_frames,
468
+ n_skips=n_batch,
469
+ chunk_size=chunk_size,
470
+ **additional_model_inputs,
471
+ )
472
+
473
+ samples_z = model.sampler(denoiser, video, cond=c, uc=uc, strength=strength)
474
+ samples_z = rearrange(samples_z, "(b t) c h w -> b t c h w", t=num_frames)
475
+ samples_z[:, 0] = embbedings[:, :, 0]
476
+ samples_z[:, -1] = embbedings[:, :, 1]
477
+ samples_z = rearrange(samples_z, "b t c h w -> (b t) c h w")
478
+
479
+ samples_x = model.decode_first_stage(samples_z)
480
+
481
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
482
+
483
+ # Free up memory
484
+ video = None
485
+
486
+ samples = rearrange(samples, "(b t) c h w -> b t c h w", t=num_frames)
487
+ samples = merge_overlapping_segments(samples, overlap)
488
+
489
+ vid = (
490
+ (rearrange(samples, "t c h w -> t c h w") * 255).cpu().numpy().astype(np.uint8)
491
+ )
492
+
493
+ return vid
landmarks_extractor.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage import io
2
+ import face_alignment
3
+
4
+
5
+ class LandmarksExtractor:
6
+ def __init__(self, device="cuda", landmarks_type="2D", flip=False):
7
+ self.fa = face_alignment.FaceAlignment(
8
+ face_alignment.LandmarksType.TWO_D
9
+ if landmarks_type == "2D"
10
+ else face_alignment.LandmarksType.THREE_D,
11
+ flip_input=flip,
12
+ device=device,
13
+ face_detector="sfd",
14
+ )
15
+
16
+ self.landmarks = []
17
+
18
+ def cuda(self):
19
+ return self
20
+
21
+ def extract_landmarks(self, image):
22
+ # image: either a path to an image or a numpy array (H, W, C) or tensor batch (B, C, H, W)
23
+ if isinstance(image, str):
24
+ image = io.imread(image)
25
+
26
+ # Ensure image is on CPU
27
+ if hasattr(image, "device"):
28
+ image = image.cpu()
29
+
30
+ if len(image.shape) == 3:
31
+ preds = self.fa.get_landmarks(image)
32
+ else:
33
+ preds = self.fa.get_landmarks_from_batch(image)
34
+
35
+ return preds
sgm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import AutoencodingEngine, DiffusionEngine
2
+ from .util import get_configs_path, instantiate_from_config
3
+
4
+ __version__ = "0.1.0"
sgm/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (418 Bytes). View file
 
sgm/__pycache__/lr_scheduler.cpython-311.pyc ADDED
Binary file (6.6 kB). View file
 
sgm/__pycache__/util.cpython-311.pyc ADDED
Binary file (21.5 kB). View file
 
sgm/callbacks/__pycache__/video_logger.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
sgm/callbacks/custom_ddp.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # from pytorch_lightning.overrides import LightningDistributedModule
2
+ from pytorch_lightning.strategies import DDPStrategy
3
+
4
+
5
+ class CustomDDPPlugin(DDPStrategy):
6
+ def configure_ddp(self):
7
+ # self.pre_configure_ddp()
8
+ self._model = self._setup_model((self.model))
9
+ self._register_ddp_hooks()
10
+ self._model._set_static_graph() # THIS IS THE MAGIC LINE
sgm/callbacks/image_logger.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning.callbacks import Callback
2
+ from pytorch_lightning.loggers import WandbLogger
3
+ import numpy as np
4
+ from pytorch_lightning.utilities import rank_zero_only
5
+ from typing import Union
6
+ import pytorch_lightning as pl
7
+ import os
8
+ from matplotlib import pyplot as plt
9
+ from sgm.util import exists, isheatmap
10
+ import torchvision
11
+ from PIL import Image
12
+ import torch
13
+ import wandb
14
+ from einops import rearrange
15
+
16
+
17
+ class ImageLogger(Callback):
18
+ def __init__(
19
+ self,
20
+ batch_frequency,
21
+ max_images,
22
+ clamp=True,
23
+ increase_log_steps=True,
24
+ rescale=True,
25
+ disabled=False,
26
+ log_on_batch_idx=False,
27
+ log_first_step=False,
28
+ log_images_kwargs=None,
29
+ log_before_first_step=False,
30
+ enable_autocast=True,
31
+ ):
32
+ super().__init__()
33
+ self.enable_autocast = enable_autocast
34
+ self.rescale = rescale
35
+ self.batch_freq = batch_frequency
36
+ self.max_images = max_images
37
+ self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
38
+ if not increase_log_steps:
39
+ self.log_steps = [self.batch_freq]
40
+ self.clamp = clamp
41
+ self.disabled = disabled
42
+ self.log_on_batch_idx = log_on_batch_idx
43
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
44
+ self.log_first_step = log_first_step
45
+ self.log_before_first_step = log_before_first_step
46
+
47
+ @rank_zero_only
48
+ def log_local(
49
+ self,
50
+ save_dir,
51
+ split,
52
+ images,
53
+ global_step,
54
+ current_epoch,
55
+ batch_idx,
56
+ pl_module: Union[None, pl.LightningModule] = None,
57
+ ):
58
+ root = os.path.join(save_dir, "images", split)
59
+ for k in images:
60
+ if isheatmap(images[k]):
61
+ fig, ax = plt.subplots()
62
+ ax = ax.matshow(images[k].cpu().numpy(), cmap="hot", interpolation="lanczos")
63
+ plt.colorbar(ax)
64
+ plt.axis("off")
65
+
66
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
67
+ os.makedirs(root, exist_ok=True)
68
+ path = os.path.join(root, filename)
69
+ plt.savefig(path)
70
+ plt.close()
71
+ # TODO: support wandb
72
+ else:
73
+ grid = torchvision.utils.make_grid(images[k].squeeze(2), nrow=4)
74
+ if self.rescale:
75
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
76
+ # print(grid.shape, grid.dtype, grid.min(), grid.max(), k)
77
+ grid = rearrange(grid.squeeze(1), "c h w -> h w c")
78
+ # grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
79
+ grid = grid.numpy()
80
+ grid = (grid * 255).astype(np.uint8)
81
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
82
+ path = os.path.join(root, filename)
83
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
84
+ img = Image.fromarray(grid)
85
+ img.save(path)
86
+ if exists(pl_module):
87
+ assert isinstance(
88
+ pl_module.logger, WandbLogger
89
+ ), "logger_log_image only supports WandbLogger currently"
90
+ pl_module.logger.log_image(
91
+ key=f"{split}/{k}",
92
+ images=[
93
+ img,
94
+ ],
95
+ step=pl_module.global_step,
96
+ )
97
+
98
+ @rank_zero_only
99
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
100
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
101
+ if (
102
+ self.check_frequency(check_idx)
103
+ and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
104
+ and callable(pl_module.log_images)
105
+ and
106
+ # batch_idx > 5 and
107
+ self.max_images > 0
108
+ ):
109
+ logger = type(pl_module.logger)
110
+ is_train = pl_module.training
111
+ if is_train:
112
+ pl_module.eval()
113
+
114
+ gpu_autocast_kwargs = {
115
+ "enabled": self.enable_autocast, # torch.is_autocast_enabled(),
116
+ "dtype": torch.get_autocast_gpu_dtype(),
117
+ "cache_enabled": torch.is_autocast_cache_enabled(),
118
+ }
119
+ with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
120
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
121
+
122
+ for k in images:
123
+ N = min(images[k].shape[0], self.max_images)
124
+ if not isheatmap(images[k]):
125
+ images[k] = images[k][:N]
126
+ if isinstance(images[k], torch.Tensor):
127
+ images[k] = images[k].detach().float().cpu()
128
+ if self.clamp and not isheatmap(images[k]):
129
+ images[k] = torch.clamp(images[k], -1.0, 1.0)
130
+
131
+ self.log_local(
132
+ pl_module.logger.save_dir,
133
+ split,
134
+ images,
135
+ pl_module.global_step,
136
+ pl_module.current_epoch,
137
+ batch_idx,
138
+ pl_module=pl_module if isinstance(pl_module.logger, WandbLogger) else None,
139
+ )
140
+
141
+ if is_train:
142
+ pl_module.train()
143
+
144
+ def check_frequency(self, check_idx):
145
+ if check_idx:
146
+ check_idx -= 1
147
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
148
+ check_idx > 0 or self.log_first_step
149
+ ):
150
+ try:
151
+ self.log_steps.pop(0)
152
+ except IndexError as e:
153
+ print(e)
154
+ pass
155
+ return True
156
+ return False
157
+
158
+ @rank_zero_only
159
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
160
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
161
+ self.log_img(pl_module, batch, batch_idx, split="train")
162
+
163
+ @rank_zero_only
164
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
165
+ if self.log_before_first_step and pl_module.global_step == 0:
166
+ print(f"{self.__class__.__name__}: logging before training")
167
+ self.log_img(pl_module, batch, batch_idx, split="train")
168
+
169
+ @rank_zero_only
170
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs):
171
+ if not self.disabled and pl_module.global_step > 0:
172
+ self.log_img(pl_module, batch, batch_idx, split="val")
173
+ if hasattr(pl_module, "calibrate_grad_norm"):
174
+ if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
175
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
176
+
177
+
178
+ @rank_zero_only
179
+ def init_wandb(save_dir, opt, config, group_name, name_str):
180
+ print(f"setting WANDB_DIR to {save_dir}")
181
+ os.makedirs(save_dir, exist_ok=True)
182
+
183
+ os.environ["WANDB_DIR"] = save_dir
184
+ if opt.debug:
185
+ wandb.init(project=opt.projectname, mode="offline", group=group_name)
186
+ else:
187
+ wandb.init(
188
+ project=opt.projectname,
189
+ config=config,
190
+ settings=wandb.Settings(code_dir="./sgm"),
191
+ group=group_name,
192
+ name=name_str,
193
+ )
sgm/callbacks/setup_callback.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning.callbacks import Callback
2
+ import pytorch_lightning as pl
3
+ import os
4
+ from omegaconf import OmegaConf
5
+ from pytorch_lightning.utilities import rank_zero_only
6
+
7
+ MULTINODE_HACKS = True
8
+
9
+
10
+ class SetupCallback(Callback):
11
+ def __init__(
12
+ self,
13
+ resume,
14
+ now,
15
+ logdir,
16
+ ckptdir,
17
+ cfgdir,
18
+ config,
19
+ lightning_config,
20
+ debug,
21
+ ckpt_name=None,
22
+ ):
23
+ super().__init__()
24
+ self.resume = resume
25
+ self.now = now
26
+ self.logdir = logdir
27
+ self.ckptdir = ckptdir
28
+ self.cfgdir = cfgdir
29
+ self.config = config
30
+ self.lightning_config = lightning_config
31
+ self.debug = debug
32
+ self.ckpt_name = ckpt_name
33
+
34
+ @rank_zero_only
35
+ def on_exception(self, trainer: pl.Trainer, pl_module, exception):
36
+ print("Exception occurred: {}".format(exception))
37
+ if not self.debug and trainer.global_rank == 0:
38
+ print("Summoning checkpoint.")
39
+ if self.ckpt_name is None:
40
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
41
+ else:
42
+ ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)
43
+ trainer.save_checkpoint(ckpt_path)
44
+
45
+ @rank_zero_only
46
+ def on_fit_start(self, trainer, pl_module):
47
+ if trainer.global_rank == 0:
48
+ # Create logdirs and save configs
49
+ os.makedirs(self.logdir, exist_ok=True)
50
+ os.makedirs(self.ckptdir, exist_ok=True)
51
+ os.makedirs(self.cfgdir, exist_ok=True)
52
+
53
+ if "callbacks" in self.lightning_config:
54
+ if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]:
55
+ os.makedirs(
56
+ os.path.join(self.ckptdir, "trainstep_checkpoints"),
57
+ exist_ok=True,
58
+ )
59
+ print("Project config")
60
+ print(OmegaConf.to_yaml(self.config))
61
+ if MULTINODE_HACKS:
62
+ import time
63
+
64
+ time.sleep(5)
65
+ OmegaConf.save(
66
+ self.config,
67
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
68
+ )
69
+
70
+ print("Lightning config")
71
+ print(OmegaConf.to_yaml(self.lightning_config))
72
+ OmegaConf.save(
73
+ OmegaConf.create({"lightning": self.lightning_config}),
74
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
75
+ )
76
+
77
+ else:
78
+ # ModelCheckpoint callback created log directory --- remove it
79
+ if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
80
+ dst, name = os.path.split(self.logdir)
81
+ dst = os.path.join(dst, "child_runs", name)
82
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
83
+ try:
84
+ os.rename(self.logdir, dst)
85
+ except FileNotFoundError:
86
+ pass
sgm/callbacks/video_logger.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning.callbacks import Callback
2
+ from pytorch_lightning.loggers import WandbLogger
3
+ import numpy as np
4
+ from pytorch_lightning.utilities import rank_zero_only
5
+ from typing import Union
6
+ import pytorch_lightning as pl
7
+ import os
8
+ from sgm.util import exists, suppress_output, default
9
+ import torchvision
10
+ from PIL import Image
11
+ import torch
12
+ import wandb
13
+ import moviepy.editor as mpy
14
+ from einops import rearrange
15
+ import torchaudio
16
+ # import tempfile
17
+ # import cv2
18
+ # import scipy.io.wavfile as wav
19
+ # import ffmpeg
20
+
21
+
22
+ @suppress_output
23
+ def save_audio_video(
24
+ video, audio=None, frame_rate=25, sample_rate=16000, save_path="temp.mp4", keep_intermediate=False
25
+ ):
26
+ """Save audio and video to a single file.
27
+ video: (t, c, h, w)
28
+ audio: (channels t)
29
+ """
30
+
31
+ # temp_filename = next(tempfile._get_candidate_names())
32
+ # if save_path:
33
+ # save_path = save_path
34
+ # else:
35
+ # save_path = "/tmp/" + next(tempfile._get_candidate_names()) + ".mp4"
36
+ save_path = str(save_path)
37
+ try:
38
+ torchvision.io.write_video(
39
+ "temp_video.mp4", rearrange(video.detach().cpu(), "t c h w -> t h w c").to(torch.uint8), frame_rate
40
+ )
41
+ video_clip = mpy.VideoFileClip("temp_video.mp4")
42
+ if audio is not None:
43
+ torchaudio.save("temp_audio.wav", audio.detach().cpu(), sample_rate)
44
+ audio_clip = mpy.AudioFileClip("temp_audio.wav")
45
+ video_clip = video_clip.set_audio(audio_clip)
46
+ video_clip.write_videofile(save_path, fps=frame_rate, codec="libx264", audio_codec="aac", verbose=False)
47
+ if not keep_intermediate:
48
+ os.remove("temp_video.mp4")
49
+ if audio is not None:
50
+ os.remove("temp_audio.wav")
51
+ return 1
52
+ except Exception as e:
53
+ print(e)
54
+ print("Saving video to file failed")
55
+ return 0
56
+
57
+
58
+ # def write_video_opencv(video, video_rate, video_path):
59
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
60
+ # out = cv2.VideoWriter(video_path, fourcc, video_rate, (video.shape[2], video.shape[3]), 0)
61
+ # for frame in list(video):
62
+ # frame = np.squeeze(frame)
63
+ # out.write(np.squeeze(frame))
64
+ # out.release()
65
+
66
+
67
+ # # Code mostly inherited from bulletin
68
+ # def save_av_sample(video, video_rate, audio=None, audio_rate=16_000, path=None):
69
+ # # Save video sample in train dir for debugging
70
+ # # video_save = 0.5 * video.detach().cpu().numpy() + 0.5
71
+ # video_save = rearrange(video, "t c h w -> t h w c").detach().cpu().numpy()
72
+ # temp_filename = next(tempfile._get_candidate_names())
73
+ # if path:
74
+ # video_path = path
75
+ # else:
76
+ # video_path = "/tmp/" + next(tempfile._get_candidate_names()) + ".mp4"
77
+ # write_video_opencv((video_save).astype(np.uint8), video_rate, "/tmp/" + temp_filename + ".mp4")
78
+ # audio_save = audio.detach().squeeze().cpu().numpy()
79
+ # wav.write("/tmp/" + temp_filename + ".wav", audio_rate, audio_save)
80
+ # try:
81
+ # in1 = ffmpeg.input("/tmp/" + temp_filename + ".mp4")
82
+ # in2 = ffmpeg.input("/tmp/" + temp_filename + ".wav")
83
+ # out = ffmpeg.output(in1["v"], in2["a"], video_path, loglevel="panic").overwrite_output()
84
+ # out.run(capture_stdout=True, capture_stderr=True)
85
+ # except ffmpeg.Error as e:
86
+ # print("stdout:", e.stdout.decode("utf8"))
87
+ # print("stderr:", e.stderr.decode("utf8"))
88
+ # raise e
89
+ # return video_path
90
+
91
+
92
+ class VideoLogger(Callback):
93
+ def __init__(
94
+ self,
95
+ batch_frequency,
96
+ max_videos,
97
+ clamp=True,
98
+ increase_log_steps=True,
99
+ rescale=True,
100
+ disabled=False,
101
+ log_on_batch_idx=False,
102
+ log_first_step=False,
103
+ log_videos_kwargs=None,
104
+ log_before_first_step=False,
105
+ enable_autocast=True,
106
+ batch_frequency_val=None,
107
+ ):
108
+ super().__init__()
109
+ self.enable_autocast = enable_autocast
110
+ self.rescale = rescale
111
+ self.batch_freq = batch_frequency
112
+ self.max_videos = max_videos
113
+ self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
114
+ if not increase_log_steps:
115
+ self.log_steps = [self.batch_freq]
116
+ self.batch_freq_val = default(batch_frequency_val, self.batch_freq)
117
+ self.log_steps_val = [2**n for n in range(int(np.log2(self.batch_freq_val)) + 1)]
118
+ if not increase_log_steps:
119
+ self.log_steps_val = [self.batch_freq_val]
120
+ self.clamp = clamp
121
+ self.disabled = disabled
122
+ self.log_on_batch_idx = log_on_batch_idx
123
+ self.log_videos_kwargs = log_videos_kwargs if log_videos_kwargs else {}
124
+ self.log_first_step = log_first_step
125
+ self.log_before_first_step = log_before_first_step
126
+
127
+ @rank_zero_only
128
+ def log_local(
129
+ self,
130
+ save_dir,
131
+ split,
132
+ log_elements,
133
+ raw_audio,
134
+ global_step,
135
+ current_epoch,
136
+ batch_idx,
137
+ pl_module: Union[None, pl.LightningModule] = None,
138
+ ):
139
+ root = os.path.join(save_dir, "videos", split)
140
+ for k in log_elements:
141
+ element = log_elements[k]
142
+ if len(element.shape) == 4:
143
+ grid = torchvision.utils.make_grid(element, nrow=4)
144
+ if self.rescale:
145
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
146
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
147
+ grid = grid.numpy()
148
+ grid = (grid * 255).astype(np.uint8)
149
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
150
+ path = os.path.join(root, filename)
151
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
152
+ img = Image.fromarray(grid)
153
+ img.save(path)
154
+ if exists(pl_module):
155
+ assert isinstance(
156
+ pl_module.logger, WandbLogger
157
+ ), "logger_log_image only supports WandbLogger currently"
158
+ pl_module.logger.log_image(
159
+ key=f"{split}/{k}",
160
+ images=[
161
+ img,
162
+ ],
163
+ step=pl_module.global_step,
164
+ )
165
+ elif len(element.shape) == 5:
166
+ video = element
167
+ if self.rescale:
168
+ video = (video + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
169
+ video = video * 255.0
170
+ video = video.permute(0, 2, 1, 3, 4).cpu().detach().to(torch.uint8) # b,t,c,h,w
171
+ for i in range(video.shape[0]):
172
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}_{}.mp4".format(k, global_step, current_epoch, batch_idx, i)
173
+ path = os.path.join(root, filename)
174
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
175
+ log_audio = raw_audio[i] if raw_audio is not None else None
176
+ success = save_audio_video(
177
+ video[i],
178
+ audio=log_audio.unsqueeze(0) if log_audio is not None else None,
179
+ frame_rate=25,
180
+ sample_rate=16000,
181
+ save_path=path,
182
+ keep_intermediate=False,
183
+ )
184
+
185
+ # video_path = save_av_sample(video[i], 25, audio=raw_audio, audio_rate=16000, path=None)
186
+ if exists(pl_module):
187
+ assert isinstance(
188
+ pl_module.logger, WandbLogger
189
+ ), "logger_log_image only supports WandbLogger currently"
190
+ pl_module.logger.experiment.log(
191
+ {
192
+ f"{split}/{k}": wandb.Video(
193
+ path if success else video,
194
+ # caption=f"diffused videos w {n_frames} frames (condition left, generated right)",
195
+ fps=25,
196
+ format="mp4",
197
+ )
198
+ },
199
+ )
200
+
201
+ @rank_zero_only
202
+ def log_video(self, pl_module, batch, batch_idx, split="train"):
203
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
204
+ # print(f"check_idx: {check_idx}", f"split: {split}")
205
+ if (
206
+ self.check_frequency(check_idx, split=split)
207
+ and hasattr(pl_module, "log_videos") # batch_idx % self.batch_freq == 0
208
+ and callable(pl_module.log_videos)
209
+ and
210
+ # batch_idx > 5 and
211
+ self.max_videos > 0
212
+ ):
213
+ logger = type(pl_module.logger)
214
+ is_train = pl_module.training
215
+ if is_train:
216
+ pl_module.eval()
217
+
218
+ gpu_autocast_kwargs = {
219
+ "enabled": self.enable_autocast, # torch.is_autocast_enabled(),
220
+ "dtype": torch.get_autocast_gpu_dtype(),
221
+ "cache_enabled": torch.is_autocast_cache_enabled(),
222
+ }
223
+ with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
224
+ videos = pl_module.log_videos(batch, split=split, **self.log_videos_kwargs)
225
+
226
+ for k in videos:
227
+ N = min(videos[k].shape[0], self.max_videos)
228
+ videos[k] = videos[k][:N]
229
+ if isinstance(videos[k], torch.Tensor):
230
+ videos[k] = videos[k].detach().float().cpu()
231
+ if self.clamp:
232
+ videos[k] = torch.clamp(videos[k], -1.0, 1.0)
233
+
234
+ raw_audio = batch.get("raw_audio", None)
235
+
236
+ self.log_local(
237
+ pl_module.logger.save_dir,
238
+ split,
239
+ videos,
240
+ raw_audio,
241
+ pl_module.global_step,
242
+ pl_module.current_epoch,
243
+ batch_idx,
244
+ pl_module=pl_module if isinstance(pl_module.logger, WandbLogger) else None,
245
+ )
246
+
247
+ if is_train:
248
+ pl_module.train()
249
+
250
+ def check_frequency(self, check_idx, split="train"):
251
+ if split == "val":
252
+ if check_idx:
253
+ check_idx -= 1
254
+ if ((check_idx % self.batch_freq_val) == 0 or (check_idx in self.log_steps_val)) and (
255
+ check_idx > 0 or self.log_first_step
256
+ ):
257
+ try:
258
+ self.log_steps_val.pop(0)
259
+ except IndexError as e:
260
+ print(e)
261
+ pass
262
+ return True
263
+ return False
264
+ if check_idx:
265
+ check_idx -= 1
266
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
267
+ check_idx > 0 or self.log_first_step
268
+ ):
269
+ try:
270
+ self.log_steps.pop(0)
271
+ except IndexError as e:
272
+ print(e)
273
+ pass
274
+ return True
275
+ return False
276
+
277
+ @rank_zero_only
278
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
279
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
280
+ self.log_video(pl_module, batch, batch_idx, split="train")
281
+
282
+ @rank_zero_only
283
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
284
+ if self.log_before_first_step and pl_module.global_step == 0:
285
+ print(f"{self.__class__.__name__}: logging before training")
286
+ self.log_video(pl_module, batch, batch_idx, split="train")
287
+
288
+ @rank_zero_only
289
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs):
290
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
291
+ self.log_video(pl_module, batch, batch_idx, split="val")
292
+ if hasattr(pl_module, "calibrate_grad_norm"):
293
+ if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
294
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
sgm/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # from .dataset import StableDataModuleFromConfig
sgm/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
sgm/data/__pycache__/data_utils.cpython-311.pyc ADDED
Binary file (28.3 kB). View file
 
sgm/data/__pycache__/mask.cpython-311.pyc ADDED
Binary file (17.9 kB). View file
 
sgm/data/__pycache__/video_datamodule_latent.cpython-311.pyc ADDED
Binary file (7.4 kB). View file
 
sgm/data/__pycache__/video_dataset_latent.cpython-311.pyc ADDED
Binary file (34.4 kB). View file
 
sgm/data/data_utils.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw
4
+ import cv2
5
+ from functools import partial
6
+ import math
7
+
8
+
9
+ def get_size(img):
10
+ if isinstance(img, (np.ndarray, torch.Tensor)):
11
+ return img.shape[1::-1]
12
+ else:
13
+ return img.size
14
+
15
+
16
+ def imresample(img, sz):
17
+ im_data = torch.nn.functional.interpolate(img, size=sz, mode="area")
18
+ return im_data
19
+
20
+
21
+ def crop_resize(img, box, image_size):
22
+ if isinstance(img, np.ndarray):
23
+ img = img[box[1] : box[3], box[0] : box[2]]
24
+ out = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_AREA).copy()
25
+ elif isinstance(img, torch.Tensor):
26
+ img = img[box[1] : box[3], box[0] : box[2]]
27
+ out = (
28
+ imresample(img.permute(2, 0, 1).unsqueeze(0).float(), (image_size, image_size))
29
+ .byte()
30
+ .squeeze(0)
31
+ .permute(1, 2, 0)
32
+ )
33
+ else:
34
+ out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR)
35
+ return out
36
+
37
+
38
+ def fixed_image_standardization(image_tensor):
39
+ processed_tensor = (image_tensor - 127.5) / 128.0
40
+ return processed_tensor
41
+
42
+
43
+ def extract_face(img, landmarks, image_size=160, margin=0, postprocess=False):
44
+ """Extract face + margin from images given facial landmarks.
45
+
46
+ Arguments:
47
+ img {PIL.Image/torch.Tensor/np.ndarray} -- Input image(s) with shape (B, H, W, C)
48
+ landmarks {numpy.ndarray} -- Facial landmarks with shape (B, 68, 2)
49
+ image_size {int} -- Output image size in pixels. The image will be square.
50
+ margin {int} -- Margin to add to bounding box, in terms of pixels in the final image.
51
+ postprocess {bool} -- Whether to apply standardization
52
+
53
+ Returns:
54
+ torch.tensor -- tensor representing the extracted faces with shape (B, H, W, C)
55
+ """
56
+ # Calculate bounding boxes from landmarks for all faces in batch
57
+ x_min = np.min(landmarks, axis=1)[:, 0] # Shape: (B,)
58
+ y_min = np.min(landmarks, axis=1)[:, 1] # Shape: (B,)
59
+ x_max = np.max(landmarks, axis=1)[:, 0] # Shape: (B,)
60
+ y_max = np.max(landmarks, axis=1)[:, 1] # Shape: (B,)
61
+
62
+ # Calculate margin for top only
63
+ box_height = y_max - y_min
64
+ top_margin = margin * box_height / (image_size - margin)
65
+
66
+ # Create boxes for all faces
67
+ boxes = np.stack(
68
+ [
69
+ x_min,
70
+ np.maximum(y_min - top_margin, 0), # Only add margin to top
71
+ x_max,
72
+ y_max,
73
+ ],
74
+ axis=1,
75
+ ).astype(int) # Shape: (B, 4)
76
+
77
+ # Process each face in the batch
78
+ faces = []
79
+ for i in range(len(boxes)):
80
+ face = crop_resize(img[i], boxes[i], image_size)
81
+ faces.append(face)
82
+
83
+ faces = torch.stack(faces, dim=0)
84
+ faces = faces.float()
85
+
86
+ if postprocess:
87
+ faces = fixed_image_standardization(faces)
88
+
89
+ return faces
90
+
91
+
92
+ def crop_mouth_region(images, landmarks, crop_size=96):
93
+ """
94
+ Takes a fixed-size square crop centered on the mouth region.
95
+
96
+ Parameters:
97
+ - images: tensor/array of shape (num_frames, height, width, channels) or (height, width, channels)
98
+ - landmarks: numpy array of shape (num_frames, 68, 2) or (68, 2)
99
+ - crop_size: size of the square crop (both height and width)
100
+ - padding: percentage of padding around the mouth region (0.0 to 1.0)
101
+
102
+ Returns:
103
+ - List of fixed-size crops or single crop if input is single image
104
+ """
105
+ # Handle single image case
106
+ single_image = False
107
+ if len(images.shape) == 3:
108
+ images = images[None]
109
+ landmarks = landmarks[None]
110
+ single_image = True
111
+
112
+ num_frames = len(images)
113
+ crops = []
114
+
115
+ # Mouth landmarks indices (48-67 for mouth region)
116
+ mouth_indices = range(48, 68)
117
+
118
+ for i in range(num_frames):
119
+ # Get mouth landmarks for current frame
120
+ mouth_landmarks = landmarks[i][mouth_indices]
121
+
122
+ # Find center of mouth
123
+ center_x = int(np.mean(mouth_landmarks[:, 0]))
124
+ center_y = int(np.mean(mouth_landmarks[:, 1]))
125
+
126
+ # Calculate crop boundaries
127
+ half_size = crop_size // 2
128
+ left = max(0, center_x - half_size)
129
+ right = min(images.shape[2], center_x + half_size)
130
+ top = max(0, center_y - half_size)
131
+ bottom = min(images.shape[1], center_y + half_size)
132
+
133
+ # Adjust if crop would go out of bounds
134
+ if left == 0:
135
+ right = crop_size
136
+ if right == images.shape[2]:
137
+ left = images.shape[2] - crop_size
138
+ if top == 0:
139
+ bottom = crop_size
140
+ if bottom == images.shape[1]:
141
+ top = images.shape[1] - crop_size
142
+
143
+ # Take the crop
144
+ crop = images[i, top:bottom, left:right]
145
+ crops.append(crop)
146
+
147
+ return crops[0] if single_image else crops
148
+
149
+
150
+ def create_masks_from_landmarks_box(landmark_list, img_shape, nose_index=28, dtype="uint8", box_expand=0.0):
151
+ height, width = img_shape[:2]
152
+ num_frames = landmark_list.shape[0]
153
+
154
+ # Initialize the masks array
155
+ masks = np.zeros((num_frames, height, width), dtype=dtype)
156
+
157
+ if 0 <= box_expand < 1:
158
+ box_expand = int(box_expand * width)
159
+
160
+ for i in range(num_frames):
161
+ # Get the landmarks for the current frame
162
+ landmarks = landmark_list[i]
163
+
164
+ # Get the y-coordinate of the nose landmark
165
+ nose_point_h = landmarks[nose_index, 1]
166
+ cut_h = nose_point_h
167
+
168
+ # Find the leftmost and rightmost landmarks
169
+ far_left_index = np.argmin(landmarks[:, 0])
170
+ far_right_index = np.argmax(landmarks[:, 0])
171
+
172
+ # Define the points for the mask contour
173
+ left_up_point = np.array([landmarks[far_left_index][0], cut_h - box_expand], dtype=np.int32)
174
+ left_down_point = np.array([landmarks[far_left_index][0], height], dtype=np.int32)
175
+ right_up_point = np.array([landmarks[far_right_index][0], cut_h - box_expand], dtype=np.int32)
176
+ right_down_point = np.array([landmarks[far_right_index][0], height], dtype=np.int32)
177
+
178
+ # Define the contour
179
+ contour = np.array([[left_up_point, left_down_point, right_down_point, right_up_point]])
180
+
181
+ # Draw the contour on the mask
182
+ cv2.drawContours(masks[i], [contour], -1, color=(1), thickness=cv2.FILLED)
183
+
184
+ return torch.from_numpy(masks)
185
+
186
+
187
+ def create_masks_from_landmarks_full_size(
188
+ landmarks_batch, image_height, image_width, start_index=48, end_index=68, offset=0, nose_index=33
189
+ ):
190
+ """
191
+ Efficiently creates a batch of masks using vectorized operations where each mask has ones from the highest
192
+ landmark in the specified range (adjusted by an offset) to the bottom of the image, and zeros otherwise.
193
+
194
+ Parameters:
195
+ - landmarks_batch (np.array): An array of shape (B, 68, 2) containing facial landmarks for multiple samples.
196
+ - image_height (int): The height of the image for which masks are created.
197
+ - image_width (int): The width of the image for which masks are created.
198
+ - start_index (int): The starting index of the range to check (inclusive).
199
+ - end_index (int): The ending index of the range to check (inclusive).
200
+ - offset (int): An offset to add or subtract from the y-coordinate of the highest landmark.
201
+
202
+ Returns:
203
+ - np.array: An array of masks of shape (B, image_height, image_width) for each batch.
204
+ """
205
+ # Extract the y-coordinates for the specified range across all batches
206
+ y_coords = landmarks_batch[:, nose_index : nose_index + 1, 1]
207
+
208
+ # Find the index of the minimum y-coordinate in the specified range for each batch
209
+ min_y_indices = np.argmin(y_coords, axis=1)
210
+
211
+ # Gather the highest landmarks' y-coordinates using the indices found
212
+ highest_y_coords = y_coords[np.arange(len(y_coords)), min_y_indices]
213
+
214
+ if abs(offset) < 1 and abs(offset) > 0:
215
+ offset = int(offset * image_height)
216
+
217
+ # Apply the offset to the highest y-coordinate
218
+ adjusted_y_coords = highest_y_coords + offset
219
+
220
+ # Clip the coordinates to stay within image boundaries
221
+ adjusted_y_coords = np.clip(adjusted_y_coords, 0, image_height - 1)
222
+
223
+ # Use broadcasting to create a mask without loops
224
+ # Create a range of indices from 0 to image_height - 1
225
+ all_indices = np.arange(image_height)
226
+
227
+ # Compare each index in 'all_indices' to each 'adjusted_y_coord' in the batch
228
+ # 'all_indices' has shape (image_height,), we reshape to (1, image_height) to broadcast against (B, 1)
229
+ mask_2d = (all_indices >= adjusted_y_coords[:, None]).astype(int)
230
+
231
+ # Extend the 2D mask to a full 3D mask of size (B, image_height, image_width)
232
+ full_mask = np.tile(mask_2d[:, :, np.newaxis], (1, 1, image_width))
233
+
234
+ return torch.from_numpy(full_mask)
235
+
236
+
237
+ def expand_polygon(polygon, expand_size):
238
+ """
239
+ Expands the polygon outward by a specified number of pixels.
240
+
241
+ Parameters:
242
+ - polygon (list of tuples): The polygon points as (x, y).
243
+ - expand_size (int): The number of pixels to expand the polygon outward.
244
+
245
+ Returns:
246
+ - expanded_polygon (list of tuples): The expanded polygon points as (x, y).
247
+ """
248
+ if expand_size == 0:
249
+ return polygon
250
+
251
+ # Calculate centroid of the polygon
252
+ centroid_x = sum([point[0] for point in polygon]) / len(polygon)
253
+ centroid_y = sum([point[1] for point in polygon]) / len(polygon)
254
+
255
+ # Expand each point outward from the centroid
256
+ expanded_polygon = []
257
+ for x, y in polygon:
258
+ vector_x = x - centroid_x
259
+ vector_y = y - centroid_y
260
+ length = np.sqrt(vector_x**2 + vector_y**2)
261
+ if length == 0:
262
+ expanded_polygon.append((x, y))
263
+ else:
264
+ new_x = x + expand_size * (vector_x / length)
265
+ new_y = y + expand_size * (vector_y / length)
266
+ expanded_polygon.append((int(new_x), int(new_y)))
267
+
268
+ return expanded_polygon
269
+
270
+
271
+ def create_masks_from_landmarks_mouth(landmark_list, img_shape, nose_index=33, dtype="uint8", box_expand=0.0):
272
+ height, width = img_shape[:2]
273
+ num_frames = landmark_list.shape[0]
274
+
275
+ # Initialize the masks array
276
+ masks = np.zeros((num_frames, height, width), dtype=dtype)
277
+
278
+ if 0 <= box_expand < 1:
279
+ box_expand = int(box_expand * width)
280
+
281
+ for i in range(num_frames):
282
+ # Get the landmarks for the current frame
283
+ landmarks = landmark_list[i]
284
+
285
+ # Get the y-coordinate of the nose landmark
286
+ nose_point_h = landmarks[nose_index, 1]
287
+ cut_h = nose_point_h
288
+
289
+ # Find the leftmost and rightmost landmarks
290
+ far_left_index = np.argmin(landmarks[:, 0])
291
+ far_right_index = np.argmax(landmarks[:, 0])
292
+
293
+ # Find lowest landmark y-coordinate
294
+ lowest_y = np.max(landmarks[:, 1])
295
+ # Add box_expand to the lowest point
296
+ lowest_y = min(height, lowest_y + box_expand)
297
+
298
+ # Define the points for the mask contour
299
+ left_up_point = np.array([landmarks[far_left_index][0], cut_h - box_expand], dtype=np.int32)
300
+ left_down_point = np.array([landmarks[far_left_index][0], lowest_y], dtype=np.int32)
301
+ right_up_point = np.array([landmarks[far_right_index][0], cut_h - box_expand], dtype=np.int32)
302
+ right_down_point = np.array([landmarks[far_right_index][0], lowest_y], dtype=np.int32)
303
+
304
+ # Define the contour
305
+ contour = np.array([[left_up_point, left_down_point, right_down_point, right_up_point]])
306
+
307
+ # Draw the contour on the mask
308
+ cv2.drawContours(masks[i], [contour], -1, color=(1), thickness=cv2.FILLED)
309
+
310
+ return torch.from_numpy(masks)
311
+
312
+
313
+ def create_face_mask_from_landmarks(landmarks_batch, image_height, image_width, mask_expand=0):
314
+ """
315
+ Creates a batch of masks where each mask covers the face region using landmarks.
316
+
317
+ Parameters:
318
+ - landmarks_batch (np.array): An array of shape (B, 68, 2) containing facial landmarks for multiple samples.
319
+ - image_height (int): The height of the image for which masks are created.
320
+ - image_width (int): The width of the image for which masks are created.
321
+ - mask_expand (int): The number of pixels to expand the mask outward.
322
+
323
+ Returns:
324
+ - np.array: An array of masks of shape (B, image_height, image_width) for each batch.
325
+ """
326
+ # Initialize an array to hold all masks
327
+ masks = np.zeros((landmarks_batch.shape[0], image_height, image_width), dtype=np.uint8)
328
+
329
+ if abs(mask_expand) < 1 and abs(mask_expand) > 0:
330
+ mask_expand = int(mask_expand * image_height)
331
+
332
+ for i, landmarks in enumerate(landmarks_batch):
333
+ # Create a blank image for each mask
334
+ mask = Image.new("L", (image_width, image_height), 0)
335
+ draw = ImageDraw.Draw(mask)
336
+
337
+ # Extract relevant landmarks for the face
338
+ jawline_landmarks = landmarks[2:15] # Jawline
339
+ # upper_face_landmarks = landmarks[17:27] # Eyebrows and top of nose bridge
340
+
341
+ # Combine landmarks to form a polygon around the face
342
+ # face_polygon = np.concatenate((jawline_landmarks, upper_face_landmarks[::-1]), axis=0)
343
+ face_polygon = jawline_landmarks
344
+
345
+ # Convert landmarks to a list of tuples
346
+ face_polygon = [(int(x), int(y)) for x, y in face_polygon]
347
+
348
+ # Expand the polygon if necessary
349
+ expanded_polygon = expand_polygon(face_polygon, mask_expand)
350
+
351
+ # Draw the polygon and fill it
352
+ draw.polygon(expanded_polygon, outline=1, fill=1)
353
+
354
+ # Convert mask to numpy array and add it to the batch of masks
355
+ masks[i] = np.array(mask)
356
+
357
+ return torch.from_numpy(masks)
358
+
359
+
360
+ ALL_FIXED_POINTS = (
361
+ [i for i in range(0, 4)] + [i for i in range(13, 17)] + [i for i in range(27, 36)] + [36, 39, 42, 45]
362
+ )
363
+
364
+
365
+ def gaussian_kernel(sigma, width, height):
366
+ """Create a 2D Gaussian kernel."""
367
+ x = torch.arange(0, width, 1) - width // 2
368
+ y = torch.arange(0, height, 1) - height // 2
369
+ x = x.float()
370
+ y = y.float()
371
+ x2 = x**2
372
+ y2 = y[:, None] ** 2
373
+ g = torch.exp(-(x2 + y2) / (2 * sigma**2))
374
+ return g / g.sum()
375
+
376
+
377
+ def generate_hm(landmarks, height, width, n_points="all", sigma=3):
378
+ if n_points == "all":
379
+ Nlandmarks = range(len(landmarks))
380
+ elif n_points == "fixed":
381
+ Nlandmarks = ALL_FIXED_POINTS
382
+ elif n_points == "stable":
383
+ Nlandmarks = [33, 36, 39, 42, 45]
384
+
385
+ kernel = gaussian_kernel(sigma, width, height)
386
+ hm = torch.zeros((height, width))
387
+ for I in Nlandmarks:
388
+ x0, y0 = landmarks[I]
389
+ x0, y0 = int(x0), int(y0)
390
+ left, right = max(0, x0 - width // 2), min(width, x0 + width // 2)
391
+ top, bottom = max(0, y0 - height // 2), min(height, y0 + height // 2)
392
+ hm[top:bottom, left:right] += kernel[
393
+ max(0, -y0 + height // 2) : min(height, height - y0 + height // 2),
394
+ max(0, -x0 + width // 2) : min(width, width - x0 + width // 2),
395
+ ]
396
+ # Normalize the heatmap to have values between 0 and 1
397
+ max_val = hm.max()
398
+ if max_val > 0:
399
+ hm /= max_val
400
+ return hm
401
+
402
+
403
+ def get_heatmap(landmarks, image_size, or_im_size, n_points="stable", sigma=4):
404
+ stack = []
405
+ seq_length = landmarks.shape[0]
406
+ if or_im_size[0] != image_size[0] or or_im_size[1] != image_size[1]:
407
+ landmarks = scale_landmarks(landmarks, or_im_size, image_size)
408
+ gen_single_heatmap = partial(
409
+ generate_hm,
410
+ height=image_size[0],
411
+ width=image_size[1],
412
+ n_points=n_points,
413
+ sigma=sigma,
414
+ )
415
+ for i in range(seq_length):
416
+ stack.append(gen_single_heatmap(landmarks[i]))
417
+
418
+ return torch.stack(stack, axis=0).unsqueeze(0) # (1, seq_length, height, width)
419
+
420
+
421
+ def scale_landmarks(landmarks, original_size, target_size):
422
+ """
423
+ Scale landmarks from original size to target size.
424
+
425
+ Parameters:
426
+ - landmarks (np.array): An array of shape (N, 2) containing facial landmarks.
427
+ - original_size (tuple): The size (height, width) for which the landmarks are currently scaled.
428
+ - target_size (tuple): The size (height, width) to which landmarks should be scaled.
429
+
430
+ Returns:
431
+ - scaled_landmarks (np.array): Scaled landmarks.
432
+ """
433
+ scale_y = target_size[0] / original_size[0]
434
+ scale_x = target_size[1] / original_size[1]
435
+ scaled_landmarks = landmarks * np.array([scale_x, scale_y])
436
+ return scaled_landmarks.astype(int)
437
+
438
+
439
+ def draw_kps_image(
440
+ image_shape, original_size, landmarks, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255)], rgb=True, pts_width=4
441
+ ):
442
+ stick_width = pts_width
443
+ limb_seq = np.array([[0, 2], [1, 2]])
444
+ kps = landmarks[[36, 45, 33], :]
445
+ kps = scale_landmarks(kps, original_size, image_shape)
446
+ if not rgb: # Grayscale image
447
+ canvas = np.zeros((image_shape[0], image_shape[1], 1))
448
+ color_mode = "grayscale"
449
+ else: # Color image
450
+ canvas = np.zeros((image_shape[0], image_shape[1], 3))
451
+ color_mode = "color"
452
+
453
+ polygon_cache = {}
454
+
455
+ for index in limb_seq:
456
+ color = color_list[index[0]]
457
+ if color_mode == "grayscale":
458
+ color = (int(0.299 * color[2] + 0.587 * color[1] + 0.114 * color[0]),) # Convert to grayscale intensity
459
+
460
+ x = kps[index][:, 0]
461
+ y = kps[index][:, 1]
462
+ length = np.sqrt((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2)
463
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
464
+
465
+ cache_key = (color, int(np.mean(x)), int(np.mean(y)), int(length / 2), int(angle))
466
+ if cache_key not in polygon_cache:
467
+ polygon_cache[cache_key] = cv2.ellipse2Poly(
468
+ (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stick_width), int(angle), 0, 360, 1
469
+ )
470
+
471
+ polygon = polygon_cache[cache_key]
472
+ cv2.fillConvexPoly(canvas, polygon, [int(c * 0.6) for c in color])
473
+
474
+ for idx, kp in enumerate(kps):
475
+ if color_mode == "grayscale":
476
+ color = (int(0.299 * color_list[idx][2] + 0.587 * color_list[idx][1] + 0.114 * color_list[idx][0]),)
477
+ else:
478
+ color = color_list[idx]
479
+ cv2.circle(canvas, (int(kp[0]), int(kp[1])), pts_width, color, -1)
480
+
481
+ return canvas.transpose(2, 0, 1)
482
+
483
+
484
+ def create_landmarks_image(
485
+ landmarks, original_size=(772, 772), target_size=(772, 772), point_size=3, n_points="all", dim=3
486
+ ):
487
+ """
488
+ Creates an image of landmarks on a black background using efficient NumPy operations.
489
+
490
+ Parameters:
491
+ - landmarks (np.array): An array of shape (68, 2) containing facial landmarks.
492
+ - image_size (tuple): The size of the output image (height, width).
493
+ - point_size (int): The radius of each landmark point in pixels.
494
+
495
+ Returns:
496
+ - img (np.array): An image array with landmarks plotted.
497
+ """
498
+ if n_points == "all":
499
+ indexes = range(len(landmarks))
500
+ elif n_points == "fixed":
501
+ indexes = ALL_FIXED_POINTS
502
+ elif n_points == "stable":
503
+ indexes = [33, 36, 39, 42, 45]
504
+
505
+ landmarks = landmarks[indexes]
506
+
507
+ img = np.zeros(target_size, dtype=np.uint8)
508
+
509
+ landmarks = scale_landmarks(landmarks, original_size, target_size)
510
+
511
+ # Ensure the landmarks are in bounds and integer
512
+ landmarks = np.clip(landmarks, [0, 0], [target_size[1] - 1, target_size[0] - 1]).astype(int)
513
+
514
+ # Get x and y coordinates from landmarks
515
+ x, y = landmarks[:, 0], landmarks[:, 1]
516
+
517
+ # Define a grid offset based on point_size around each landmark
518
+ offset = np.arange(-point_size // 2, point_size // 2 + 1)
519
+ grid_x, grid_y = np.meshgrid(offset, offset, indexing="ij")
520
+
521
+ # Calculate the full set of x and y coordinates for the points
522
+ full_x = x[:, np.newaxis, np.newaxis] + grid_x[np.newaxis, :, :]
523
+ full_y = y[:, np.newaxis, np.newaxis] + grid_y[np.newaxis, :, :]
524
+
525
+ # Clip the coordinates to stay within image boundaries
526
+ full_x = np.clip(full_x, 0, target_size[1] - 1)
527
+ full_y = np.clip(full_y, 0, target_size[0] - 1)
528
+
529
+ # Flatten the arrays to use them as indices
530
+ full_x = full_x.ravel()
531
+ full_y = full_y.ravel()
532
+
533
+ # Set the points in the image
534
+ img[full_y, full_x] = 255
535
+
536
+ return np.stack([img] * dim, axis=0)
537
+
538
+
539
+ def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
540
+ len_file = audio.shape[-1]
541
+
542
+ if max_len_sec or max_len_raw:
543
+ max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
544
+ if len_file < int(max_len):
545
+ # dummy = np.zeros((1, int(max_len_sec * sr) - len_file))
546
+ # extened_wav = np.concatenate((audio_data, dummy[0]))
547
+ extened_wav = torch.nn.functional.pad(audio, (0, int(max_len) - len_file), "constant")
548
+ else:
549
+ extened_wav = audio[:, : int(max_len)]
550
+ else:
551
+ extened_wav = audio
552
+
553
+ return extened_wav
554
+
555
+
556
+ def ssim_to_bin(ssim_score):
557
+ # Normalize the SSIM score to a 0-100 scale
558
+ normalized_diff_ssim = (1 - ((ssim_score + 1) / 2)) * 100
559
+ # Assign to one of the 100 bins
560
+ bin_index = float(min(np.floor(normalized_diff_ssim), 99))
561
+ return bin_index
sgm/data/dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torchdata.datapipes.iter
4
+ import webdataset as wds
5
+ from omegaconf import DictConfig
6
+ from pytorch_lightning import LightningDataModule
7
+
8
+ try:
9
+ from sdata import create_dataset, create_dummy_dataset, create_loader
10
+ except ImportError as e:
11
+ print("#" * 100)
12
+ print("Datasets not yet available")
13
+ print("to enable, we need to add stable-datasets as a submodule")
14
+ print("please use ``git submodule update --init --recursive``")
15
+ print("and do ``pip install -e stable-datasets/`` from the root of this repo")
16
+ print("#" * 100)
17
+ exit(1)
18
+
19
+
20
+ class StableDataModuleFromConfig(LightningDataModule):
21
+ def __init__(
22
+ self,
23
+ train: DictConfig,
24
+ validation: Optional[DictConfig] = None,
25
+ test: Optional[DictConfig] = None,
26
+ skip_val_loader: bool = False,
27
+ dummy: bool = False,
28
+ ):
29
+ super().__init__()
30
+ self.train_config = train
31
+ assert (
32
+ "datapipeline" in self.train_config and "loader" in self.train_config
33
+ ), "train config requires the fields `datapipeline` and `loader`"
34
+
35
+ self.val_config = validation
36
+ if not skip_val_loader:
37
+ if self.val_config is not None:
38
+ assert (
39
+ "datapipeline" in self.val_config and "loader" in self.val_config
40
+ ), "validation config requires the fields `datapipeline` and `loader`"
41
+ else:
42
+ print(
43
+ "Warning: No Validation datapipeline defined, using that one from training"
44
+ )
45
+ self.val_config = train
46
+
47
+ self.test_config = test
48
+ if self.test_config is not None:
49
+ assert (
50
+ "datapipeline" in self.test_config and "loader" in self.test_config
51
+ ), "test config requires the fields `datapipeline` and `loader`"
52
+
53
+ self.dummy = dummy
54
+ if self.dummy:
55
+ print("#" * 100)
56
+ print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
57
+ print("#" * 100)
58
+
59
+ def setup(self, stage: str) -> None:
60
+ print("Preparing datasets")
61
+ if self.dummy:
62
+ data_fn = create_dummy_dataset
63
+ else:
64
+ data_fn = create_dataset
65
+
66
+ self.train_datapipeline = data_fn(**self.train_config.datapipeline)
67
+ if self.val_config:
68
+ self.val_datapipeline = data_fn(**self.val_config.datapipeline)
69
+ if self.test_config:
70
+ self.test_datapipeline = data_fn(**self.test_config.datapipeline)
71
+
72
+ def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
73
+ loader = create_loader(self.train_datapipeline, **self.train_config.loader)
74
+ return loader
75
+
76
+ def val_dataloader(self) -> wds.DataPipeline:
77
+ return create_loader(self.val_datapipeline, **self.val_config.loader)
78
+
79
+ def test_dataloader(self) -> wds.DataPipeline:
80
+ return create_loader(self.test_datapipeline, **self.test_config.loader)
sgm/data/mask.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+
3
+ """
4
+ Functions taken from https://github.com/DanBigioi/DiffusionVideoEditing
5
+
6
+
7
+ """
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+
13
+ " Countour from 2:15 not good for head poses "
14
+
15
+
16
+ def face_mask(img_shape, landmark_list, dtype="uint8"):
17
+ height, width = img_shape[:2]
18
+ mask = np.ones((height, width, 1), dtype=dtype)
19
+ cv2.drawContours(
20
+ mask, np.int32([landmark_list[2:15]]), -1, color=(0), thickness=cv2.FILLED
21
+ )
22
+
23
+ return mask
24
+
25
+
26
+ def face_mask_jaw_box(img_shape, landmark_list, dtype="uint8", kernel_size=10):
27
+ nose = 33
28
+ jaw = 8
29
+
30
+ height, width = img_shape[:2]
31
+ mask = np.ones((height, width, 1), dtype=dtype)
32
+ combined_landmarks = np.concatenate((landmark_list[2:15], [landmark_list[33]]))
33
+
34
+ # Draw the combined contour on the mask
35
+ cv2.drawContours(
36
+ mask, [np.int32(combined_landmarks)], -1, color=(0), thickness=cv2.FILLED
37
+ )
38
+
39
+ inverted_mask = 1 - mask
40
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
41
+ mask = cv2.dilate(inverted_mask, kernel, iterations=1)
42
+ mask = np.expand_dims(
43
+ mask, axis=-1
44
+ ) # Add a singleton dimension to match the number of channels
45
+ mask = 1 - mask
46
+
47
+ cut_h = landmark_list[nose][1]
48
+
49
+ far_left = int(np.argmin(landmark_list[:, 0]))
50
+ far_right = int(np.argmax(landmark_list[:, 0]))
51
+ left_up_point = np.int32([landmark_list[far_left][0], cut_h]) # 2
52
+ right_up_point = np.int32([landmark_list[far_right][0], cut_h]) # 15
53
+ height_landmarks = min(landmark_list[jaw, 1] + 20, height)
54
+ left_down_point = np.int32([landmark_list[far_left][0], height_landmarks])
55
+ right_down_point = np.int32([landmark_list[far_right][0], height_landmarks])
56
+
57
+ # print(cut_h, cut_h + 10, height_landmarks)
58
+
59
+ mask_box = [left_up_point, left_down_point, right_down_point, right_up_point]
60
+
61
+ return mask, mask_box
62
+
63
+
64
+ " Stretch the tight face mask - Countour from 2:15 but dilate, not good for extreme head poses "
65
+
66
+
67
+ def face_mask_stretch(img_shape, landmark_list, dtype="uint8", kernel_size=10):
68
+ height, width = img_shape[:2]
69
+ mask = np.ones((height, width, 1), dtype=dtype)
70
+ combined_landmarks = np.concatenate((landmark_list[2:15], [landmark_list[33]]))
71
+
72
+ # Draw the combined contour on the mask
73
+ cv2.drawContours(
74
+ mask, [np.int32(combined_landmarks)], -1, color=(0), thickness=cv2.FILLED
75
+ )
76
+
77
+ # cv2.drawContours(mask, np.int32([landmark_list[2:15]]), -1, color=(0), thickness=cv2.FILLED)
78
+ inverted_mask = 1 - mask
79
+
80
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
81
+ mask = cv2.dilate(inverted_mask, kernel, iterations=1)
82
+ mask = np.expand_dims(
83
+ mask, axis=-1
84
+ ) # Add a singleton dimension to match the number of channels
85
+ mask = 1 - mask
86
+
87
+ return mask
88
+
89
+
90
+ " Small box around mouth - Use far left, far right points for extreme head poses, cut between nose and upper mouth point"
91
+
92
+
93
+ def face_mask_box_pose(img_shape, landmark_list, dtype="uint8"):
94
+ """
95
+ When the head pose is different than frontal then the normal cropping with landmarks does not work correctly.
96
+ Crop using as height the middle nose point
97
+ Take the left/right corners using the far_left and far_right landmarks
98
+ TODO: Maybe it is better to add some more pixels to have a bigger mask, especially on large head poses
99
+ """
100
+
101
+ height, width = img_shape[:2]
102
+
103
+ nose = 33
104
+ upper_lip = 51
105
+ jaw = 8
106
+
107
+ nose_point_h = landmark_list[nose, 1]
108
+ upper_lip_point = landmark_list[upper_lip, 1]
109
+ cut_h = (upper_lip_point - nose_point_h) / 2 + nose_point_h
110
+
111
+ # cut_h = landmark_list[nose][1]
112
+
113
+ mask = np.ones((height, width, 1), dtype=dtype)
114
+
115
+ far_left = int(np.argmin(landmark_list[:, 0]))
116
+ far_right = int(np.argmax(landmark_list[:, 0]))
117
+
118
+ left_up_point = np.int32([landmark_list[far_left][0], cut_h]) # 2
119
+ right_up_point = np.int32([landmark_list[far_right][0], cut_h]) # 15
120
+
121
+ height_landmarks = min(landmark_list[jaw, 1] + 20, height)
122
+ left_down_point = np.int32([landmark_list[far_left][0], height_landmarks])
123
+ right_down_point = np.int32([landmark_list[far_right][0], height_landmarks])
124
+
125
+ cv2.drawContours(
126
+ mask,
127
+ np.int32(
128
+ [
129
+ [
130
+ left_up_point,
131
+ left_down_point,
132
+ right_up_point,
133
+ right_down_point,
134
+ left_up_point,
135
+ right_up_point,
136
+ left_down_point,
137
+ right_down_point,
138
+ ]
139
+ ]
140
+ ),
141
+ -1,
142
+ color=(0),
143
+ thickness=cv2.FILLED,
144
+ )
145
+
146
+ return mask
147
+
148
+
149
+ " Small box around mouth - Use far left, far right points for extreme head poses, cut from nose"
150
+
151
+
152
+ def face_mask_box_pose_nose(
153
+ img_shape,
154
+ landmark_list,
155
+ dtype="uint8",
156
+ get_box=False,
157
+ pixels_above_nose=None,
158
+ pixels_under_jaw=None,
159
+ ):
160
+ height, width = img_shape[:2]
161
+
162
+ nose = 33
163
+ jaw = 8
164
+
165
+ cut_h = landmark_list[nose][1]
166
+ if pixels_above_nose is not None:
167
+ # this is only for inference to take a bigger mask and blend it back to the original frame
168
+ cut_h = cut_h - pixels_above_nose
169
+
170
+ mask = np.ones((height, width, 1), dtype=dtype)
171
+
172
+ far_left = int(np.argmin(landmark_list[:, 0]))
173
+ far_right = int(np.argmax(landmark_list[:, 0]))
174
+
175
+ left_up_point = np.int32([landmark_list[far_left][0], cut_h]) # 2
176
+ right_up_point = np.int32([landmark_list[far_right][0], cut_h]) # 15
177
+
178
+ height_landmarks = min(landmark_list[jaw, 1] + 20, height)
179
+ if pixels_under_jaw is not None:
180
+ height_landmarks = min(landmark_list[jaw, 1] + pixels_under_jaw, height)
181
+ left_down_point = np.int32([landmark_list[far_left][0], height_landmarks])
182
+ right_down_point = np.int32([landmark_list[far_right][0], height_landmarks])
183
+
184
+ cv2.drawContours(
185
+ mask,
186
+ np.int32(
187
+ [
188
+ [
189
+ left_up_point,
190
+ left_down_point,
191
+ right_up_point,
192
+ right_down_point,
193
+ left_up_point,
194
+ right_up_point,
195
+ left_down_point,
196
+ right_down_point,
197
+ ]
198
+ ]
199
+ ),
200
+ -1,
201
+ color=(0),
202
+ thickness=cv2.FILLED,
203
+ )
204
+
205
+ if get_box:
206
+ mask_box = [left_up_point, left_down_point, right_down_point, right_up_point]
207
+ return mask, mask_box
208
+ else:
209
+ return mask
210
+
211
+
212
+ def face_mask_box_pose_big(
213
+ img_shape, landmark_list, dtype="uint8", cut_h=None, far_left=None, far_right=None
214
+ ):
215
+ height, width = img_shape[:2]
216
+ mask = np.ones((height, width, 1), dtype=dtype)
217
+ nose = 33
218
+ nose_point_h = landmark_list[nose, 1]
219
+ if cut_h is None:
220
+ cut_h = nose_point_h
221
+
222
+ if far_right is None and far_left is None:
223
+ far_left = int(np.argmin(landmark_list[:, 0]))
224
+ far_right = int(np.argmax(landmark_list[:, 0]))
225
+
226
+ left_up_point = np.int32([landmark_list[far_left][0], cut_h])
227
+ left_down_point = np.int32([landmark_list[far_left][0], height])
228
+
229
+ right_up_point = np.int32([landmark_list[far_right][0], cut_h])
230
+ right_down_point = np.int32([landmark_list[far_right][0], height])
231
+ else:
232
+ left_up_point = np.int32([far_left, cut_h])
233
+ left_down_point = np.int32([far_left, height])
234
+
235
+ right_up_point = np.int32([far_right, cut_h])
236
+ right_down_point = np.int32([far_right, height])
237
+
238
+ cv2.drawContours(
239
+ mask,
240
+ np.int32(
241
+ [
242
+ [
243
+ left_up_point,
244
+ left_down_point,
245
+ right_up_point,
246
+ right_down_point,
247
+ left_up_point,
248
+ right_up_point,
249
+ left_down_point,
250
+ right_down_point,
251
+ ]
252
+ ]
253
+ ),
254
+ -1,
255
+ color=(0),
256
+ thickness=cv2.FILLED,
257
+ )
258
+
259
+ return mask
260
+
261
+
262
+ def face_mask_box_pose_big_cover_nose(img_shape, landmark_list, dtype="uint8"):
263
+ height, width = img_shape[:2]
264
+
265
+ middle_nose_point = 29
266
+
267
+ cut_h = landmark_list[middle_nose_point, 1]
268
+
269
+ mask = np.ones((height, width, 1), dtype=dtype)
270
+
271
+ far_left = int(np.argmin(landmark_list[:, 0]))
272
+ far_right = int(np.argmax(landmark_list[:, 0]))
273
+
274
+ left_up_point = np.int32([landmark_list[far_left][0], cut_h])
275
+ left_down_point = np.int32([landmark_list[far_left][0], height])
276
+
277
+ right_up_point = np.int32([landmark_list[far_right][0], cut_h])
278
+ right_down_point = np.int32([landmark_list[far_right][0], height])
279
+
280
+ cv2.drawContours(
281
+ mask,
282
+ np.int32(
283
+ [
284
+ [
285
+ left_up_point,
286
+ left_down_point,
287
+ right_up_point,
288
+ right_down_point,
289
+ left_up_point,
290
+ right_up_point,
291
+ left_down_point,
292
+ right_down_point,
293
+ ]
294
+ ]
295
+ ),
296
+ -1,
297
+ color=(0),
298
+ thickness=cv2.FILLED,
299
+ )
300
+
301
+ return mask
302
+
303
+
304
+ def face_mask_square(img_shape, landmark_list, dtype="uint8"):
305
+ height, width = img_shape[:2]
306
+
307
+ mask = np.ones((height, width, 1), dtype=dtype)
308
+
309
+ far_left = np.min(landmark_list[:, 0])
310
+ far_right = np.max(landmark_list[:, 1])
311
+ print("far_left {}, far_right {}".format(far_left, far_right))
312
+
313
+ left_p = 2
314
+ right_p = 14
315
+
316
+ print(
317
+ "left_p {}, right_p {}".format(
318
+ landmark_list[left_p][0], landmark_list[right_p][0]
319
+ )
320
+ )
321
+
322
+ cv2.drawContours(
323
+ mask,
324
+ np.int32(
325
+ [
326
+ [
327
+ landmark_list[left_p],
328
+ [landmark_list[left_p][0], height],
329
+ landmark_list[right_p],
330
+ [landmark_list[right_p][0], height],
331
+ landmark_list[left_p],
332
+ landmark_list[right_p],
333
+ [landmark_list[left_p][0], height],
334
+ [landmark_list[right_p][0], height],
335
+ ]
336
+ ]
337
+ ),
338
+ -1,
339
+ color=(0),
340
+ thickness=cv2.FILLED,
341
+ )
342
+
343
+ return mask
344
+
345
+
346
+ " Used for half face "
347
+
348
+
349
+ def bbox2mask(img_shape, bbox, dtype="uint8"):
350
+ """Generate mask in ndarray from bbox.
351
+
352
+ The returned mask has the shape of (h, w, 1). '1' indicates the
353
+ hole and '0' indicates the valid regions.
354
+
355
+ We prefer to use `uint8` as the data type of masks, which may be different
356
+ from other codes in the community.
357
+
358
+ Args:
359
+ img_shape (tuple[int]): The size of the image.
360
+ bbox (tuple[int]): Configuration tuple, (top, left, height, width)
361
+ dtype (str): Indicate the data type of returned masks. Default: 'uint8'
362
+
363
+ Return:
364
+ numpy.ndarray: Mask in the shape of (h, w, 1).
365
+ """
366
+
367
+ height, width = img_shape[:2]
368
+
369
+ mask = np.ones((height, width, 1), dtype=dtype)
370
+ mask[bbox[0] : bbox[0] + bbox[2], bbox[1] : bbox[1] + bbox[3], :] = 0.0
371
+
372
+ return mask
373
+
374
+
375
+ def face_mask_cheeks(img_shape, landmark_list, dtype="uint8"):
376
+ height, width = img_shape[:2]
377
+ mask = np.ones((height, width, 1), dtype=dtype)
378
+
379
+ middle_nose_point = 29
380
+ nose = 33
381
+ cut_h = int(landmark_list[middle_nose_point, 1])
382
+
383
+ far_left = int(np.argmin(landmark_list[:, 0]))
384
+ far_right = int(np.argmax(landmark_list[:, 0]))
385
+
386
+ left_up_point = np.int32([landmark_list[far_left][0], cut_h])
387
+ left_down_point = np.int32([landmark_list[far_left][0], height])
388
+
389
+ right_up_point = np.int32([landmark_list[far_right][0], cut_h])
390
+ right_down_point = np.int32([landmark_list[far_right][0], height])
391
+
392
+ cv2.drawContours(
393
+ mask,
394
+ np.int32(
395
+ [
396
+ [
397
+ left_up_point,
398
+ left_down_point,
399
+ right_up_point,
400
+ right_down_point,
401
+ left_up_point,
402
+ right_up_point,
403
+ left_down_point,
404
+ right_down_point,
405
+ ]
406
+ ]
407
+ ),
408
+ -1,
409
+ color=(0),
410
+ thickness=cv2.FILLED,
411
+ )
412
+
413
+ # Calculate the bounding box coordinates for the nose
414
+ nose_jaw_dist = (
415
+ abs(landmark_list[2][0] - landmark_list[middle_nose_point][0]) * 0.10
416
+ ) # 1, 15
417
+ # nose_right_dist = (landmark_list[middle_nose_point][0] - landmark_list[1][0]) * 0.10
418
+ # nose_left_dist = (landmark_list[15][0] - landmark_list[middle_nose_point][0]) * 0.10
419
+ #
420
+
421
+ nose_min_x = int(landmark_list[31][0] - nose_jaw_dist)
422
+ nose_max_x = int(landmark_list[35][0] + nose_jaw_dist)
423
+ # nose_min_x = int(landmark_list[31][0] - nose_right_dist)
424
+ # nose_max_x = int(landmark_list[35][0] + nose_left_dist)
425
+ nose_min_y = cut_h
426
+ nose_max_y = int(landmark_list[nose, 1])
427
+
428
+ # Clear the nose area from the mask using a rectangle
429
+ mask_nose = np.ones((height, width, 1), dtype=dtype)
430
+ cv2.rectangle(
431
+ mask_nose,
432
+ (nose_min_x, nose_min_y),
433
+ (nose_max_x, nose_max_y),
434
+ color=(0),
435
+ thickness=cv2.FILLED,
436
+ )
437
+
438
+ mask_nose = 1 - mask_nose
439
+ mask = mask + mask_nose
440
+
441
+ return mask
442
+
443
+
444
+ def face_mask_cheeks_batch(
445
+ img_shape, landmark_list, dtype="uint8", box_expand=0.0, show_nose=True
446
+ ):
447
+ height, width = img_shape[:2]
448
+
449
+ # Handle both single and multiple landmarks
450
+ if len(landmark_list.shape) == 2:
451
+ landmark_list = landmark_list[None, ...] # Add batch dimension
452
+ num_frames = landmark_list.shape[0]
453
+
454
+ # Initialize masks for all frames
455
+ masks = np.ones((num_frames, height, width), dtype=dtype)
456
+
457
+ for i in range(num_frames):
458
+ landmarks = landmark_list[i]
459
+ middle_nose_point = 29
460
+ nose = 33
461
+ cut_h = int(landmarks[middle_nose_point, 1])
462
+
463
+ # Add height expansion
464
+ if box_expand > 0:
465
+ cut_h = max(0, cut_h - int(box_expand * height))
466
+
467
+ far_left = int(np.argmin(landmarks[:, 0]))
468
+ far_right = int(np.argmax(landmarks[:, 0]))
469
+
470
+ left_up_point = np.int32([landmarks[far_left][0], cut_h])
471
+ left_down_point = np.int32([landmarks[far_left][0], height])
472
+
473
+ right_up_point = np.int32([landmarks[far_right][0], cut_h])
474
+ right_down_point = np.int32([landmarks[far_right][0], height])
475
+
476
+ cv2.drawContours(
477
+ masks[i],
478
+ np.int32(
479
+ [
480
+ [
481
+ left_up_point,
482
+ left_down_point,
483
+ right_up_point,
484
+ right_down_point,
485
+ left_up_point,
486
+ right_up_point,
487
+ left_down_point,
488
+ right_down_point,
489
+ ]
490
+ ]
491
+ ),
492
+ -1,
493
+ color=(0),
494
+ thickness=cv2.FILLED,
495
+ )
496
+
497
+ if show_nose:
498
+ # Calculate the bounding box coordinates for the nose
499
+ nose_jaw_dist = (
500
+ abs(landmarks[2][0] - landmarks[middle_nose_point][0]) * 0.10
501
+ ) # 1, 15
502
+
503
+ nose_min_x = int(landmarks[31][0] - nose_jaw_dist)
504
+ nose_max_x = int(landmarks[35][0] + nose_jaw_dist)
505
+ nose_min_y = cut_h
506
+ nose_max_y = int(landmarks[nose, 1])
507
+
508
+ # Clear the nose area from the mask using a rectangle
509
+ mask_nose = np.ones((height, width), dtype=dtype)
510
+ cv2.rectangle(
511
+ mask_nose,
512
+ (nose_min_x, nose_min_y),
513
+ (nose_max_x, nose_max_y),
514
+ color=(0),
515
+ thickness=cv2.FILLED,
516
+ )
517
+
518
+ mask_nose = 1 - mask_nose
519
+ masks[i] = masks[i] + mask_nose
520
+
521
+ # If input was single frame, return single mask
522
+ if landmark_list.shape[0] == 1:
523
+ return masks[0]
524
+
525
+ return 1 - torch.from_numpy(masks)
sgm/data/video_datamodule_latent.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pytorch_lightning import LightningDataModule
4
+ from torch.utils.data import DataLoader
5
+ from omegaconf import DictConfig
6
+
7
+ import sys
8
+ import pyrootutils
9
+
10
+ root = pyrootutils.setup_root(__file__, pythonpath=True)
11
+ sys.path.append(root)
12
+ from sgm.data.video_dataset_latent import VideoDataset
13
+
14
+
15
+ class VideoDataModule(LightningDataModule):
16
+ """
17
+ A DataModule implements 5 key methods:
18
+
19
+ def prepare_data(self):
20
+ # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
21
+ # download data, pre-process, split, save to disk, etc...
22
+ def setup(self, stage):
23
+ # things to do on every process in DDP
24
+ # load data, set variables, etc...
25
+ def train_dataloader(self):
26
+ # return train dataloader
27
+ def val_dataloader(self):
28
+ # return validation dataloader
29
+ def test_dataloader(self):
30
+ # return test dataloader
31
+ def teardown(self):
32
+ # called on every process in DDP
33
+ # clean up after fit or test
34
+
35
+ This allows you to share a full dataset without explaining how to download,
36
+ split, transform and process the data.
37
+
38
+ Read the docs:
39
+ https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ train: DictConfig,
45
+ validation: Optional[DictConfig] = None,
46
+ test: Optional[DictConfig] = None,
47
+ skip_val_loader: bool = False,
48
+ ):
49
+ super().__init__()
50
+
51
+ # this line allows to access init params with 'self.hparams' attribute
52
+ # also ensures init params will be stored in ckpt
53
+ self.train_config = train
54
+ assert "datapipeline" in self.train_config and "loader" in self.train_config, (
55
+ "train config requires the fields `datapipeline` and `loader`"
56
+ )
57
+
58
+ self.val_config = validation
59
+ if not skip_val_loader:
60
+ if self.val_config is not None:
61
+ assert (
62
+ "datapipeline" in self.val_config and "loader" in self.val_config
63
+ ), "validation config requires the fields `datapipeline` and `loader`"
64
+ else:
65
+ print(
66
+ "Warning: No Validation datapipeline defined, using that one from training"
67
+ )
68
+ self.val_config = train
69
+
70
+ self.test_config = test
71
+ if self.test_config is not None:
72
+ assert (
73
+ "datapipeline" in self.test_config and "loader" in self.test_config
74
+ ), "test config requires the fields `datapipeline` and `loader`"
75
+
76
+ def setup(self, stage: Optional[str] = None):
77
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
78
+
79
+ This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
80
+ careful not to execute things like random split twice!
81
+ """
82
+ print("Preparing datasets")
83
+
84
+ self.train_datapipeline = VideoDataset(**self.train_config.datapipeline)
85
+ if self.val_config:
86
+ self.val_datapipeline = VideoDataset(**self.val_config.datapipeline)
87
+ if self.test_config:
88
+ self.test_datapipeline = VideoDataset(**self.test_config.datapipeline)
89
+
90
+ def train_dataloader(self):
91
+ return DataLoader(self.train_datapipeline, **self.train_config.loader)
92
+
93
+ def val_dataloader(self):
94
+ if self.val_datapipeline:
95
+ return DataLoader(self.val_datapipeline, **self.val_config.loader)
96
+ else:
97
+ return None
98
+
99
+ def test_dataloader(self):
100
+ if self.test_datapipeline:
101
+ return DataLoader(self.test_datapipeline, **self.test_config.loader)
102
+ else:
103
+ return None
104
+
105
+ def teardown(self, stage: Optional[str] = None):
106
+ """Clean up after fit or test."""
107
+ pass
108
+
109
+ def state_dict(self):
110
+ """Extra things to save to checkpoint."""
111
+ return {}
112
+
113
+ def load_state_dict(self, state_dict: Dict[str, Any]):
114
+ """Things to do when loading checkpoint."""
115
+ pass
116
+
117
+
118
+ if __name__ == "__main__":
119
+ import hydra
120
+ import omegaconf
121
+ import pyrootutils
122
+ import cv2
123
+
124
+ root = pyrootutils.setup_root(__file__, pythonpath=True)
125
+ cfg = omegaconf.OmegaConf.load(
126
+ root / "configs" / "datamodule" / "image_datamodule.yaml"
127
+ )
128
+ # cfg.data_dir = str(root / "data")
129
+ data = hydra.utils.instantiate(cfg)
130
+ data.prepare_data()
131
+ data.setup()
132
+ print(data.data_train.__getitem__(0)[0].shape)
133
+ batch = next(iter(data.train_dataloader()))
134
+ identity, target = batch
135
+ image_identity = (identity[0].permute(1, 2, 0).numpy() + 1) / 2 * 255
136
+ image_other = (target[0].permute(1, 2, 0).numpy() + 1) / 2 * 255
137
+ cv2.imwrite("image_identity.png", image_identity[:, :, ::-1])
138
+ cv2.imwrite("image_other.png", image_other[:, :, ::-1])
sgm/data/video_dataset_latent.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from functools import partial
4
+ from torch.utils.data import Dataset, WeightedRandomSampler
5
+ import torch.nn.functional as F
6
+ import torch
7
+ import math
8
+ import decord
9
+ from einops import rearrange
10
+ from more_itertools import sliding_window
11
+ from omegaconf import ListConfig
12
+ import torchaudio
13
+ import soundfile as sf
14
+ from torchvision.transforms import RandomHorizontalFlip
15
+ from audiomentations import Compose, AddGaussianNoise, PitchShift
16
+ from safetensors.torch import load_file
17
+ from tqdm import tqdm
18
+ import cv2
19
+ from sgm.data.data_utils import (
20
+ create_masks_from_landmarks_full_size,
21
+ create_face_mask_from_landmarks,
22
+ create_masks_from_landmarks_box,
23
+ create_masks_from_landmarks_mouth,
24
+ )
25
+ from sgm.data.mask import face_mask_cheeks_batch
26
+
27
+ torchaudio.set_audio_backend("sox_io")
28
+ decord.bridge.set_bridge("torch")
29
+
30
+
31
+ def exists(x):
32
+ return x is not None
33
+
34
+
35
+ def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
36
+ len_file = audio.shape[-1]
37
+
38
+ if max_len_sec or max_len_raw:
39
+ max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
40
+ if len_file < int(max_len):
41
+ extened_wav = torch.nn.functional.pad(
42
+ audio, (0, int(max_len) - len_file), "constant"
43
+ )
44
+ else:
45
+ extened_wav = audio[:, : int(max_len)]
46
+ else:
47
+ extened_wav = audio
48
+
49
+ return extened_wav
50
+
51
+
52
+ # Similar to regular video dataset but trades flexibility for speed
53
+ class VideoDataset(Dataset):
54
+ def __init__(
55
+ self,
56
+ filelist,
57
+ resize_size=None,
58
+ audio_folder="Audio",
59
+ video_folder="CroppedVideos",
60
+ emotions_folder="emotions",
61
+ landmarks_folder=None,
62
+ audio_emb_folder=None,
63
+ video_extension=".avi",
64
+ audio_extension=".wav",
65
+ audio_rate=16000,
66
+ latent_folder=None,
67
+ audio_in_video=False,
68
+ fps=25,
69
+ num_frames=5,
70
+ need_cond=True,
71
+ step=1,
72
+ mode="prediction",
73
+ scale_audio=False,
74
+ augment=False,
75
+ augment_audio=False,
76
+ use_latent=False,
77
+ latent_type="stable",
78
+ latent_scale=1, # For backwards compatibility
79
+ from_audio_embedding=False,
80
+ load_all_possible_indexes=False,
81
+ audio_emb_type="wavlm",
82
+ cond_noise=[-3.0, 0.5],
83
+ motion_id=255.0,
84
+ data_mean=None,
85
+ data_std=None,
86
+ use_latent_condition=False,
87
+ skip_frames=0,
88
+ get_separate_id=False,
89
+ virtual_increase=1,
90
+ filter_by_length=False,
91
+ select_randomly=False,
92
+ balance_datasets=True,
93
+ use_emotions=False,
94
+ get_original_frames=False,
95
+ add_extra_audio_emb=False,
96
+ expand_box=0.0,
97
+ nose_index=28,
98
+ what_mask="full",
99
+ get_masks=False,
100
+ ):
101
+ self.audio_folder = audio_folder
102
+ self.from_audio_embedding = from_audio_embedding
103
+ self.audio_emb_type = audio_emb_type
104
+ self.cond_noise = cond_noise
105
+ self.latent_condition = use_latent_condition
106
+ precomputed_latent = latent_type
107
+ self.audio_emb_folder = (
108
+ audio_emb_folder if audio_emb_folder is not None else audio_folder
109
+ )
110
+ self.skip_frames = skip_frames
111
+ self.get_separate_id = get_separate_id
112
+ self.fps = fps
113
+ self.virtual_increase = virtual_increase
114
+ self.select_randomly = select_randomly
115
+ self.use_emotions = use_emotions
116
+ self.emotions_folder = emotions_folder
117
+ self.get_original_frames = get_original_frames
118
+ self.add_extra_audio_emb = add_extra_audio_emb
119
+ self.expand_box = expand_box
120
+ self.nose_index = nose_index
121
+ self.landmarks_folder = landmarks_folder
122
+ self.what_mask = what_mask
123
+ self.get_masks = get_masks
124
+
125
+ assert not (exists(data_mean) ^ exists(data_std)), (
126
+ "Both data_mean and data_std should be provided"
127
+ )
128
+
129
+ if data_mean is not None:
130
+ data_mean = rearrange(torch.as_tensor(data_mean), "c -> c () () ()")
131
+ data_std = rearrange(torch.as_tensor(data_std), "c -> c () () ()")
132
+ self.data_mean = data_mean
133
+ self.data_std = data_std
134
+ self.motion_id = motion_id
135
+ self.latent_folder = (
136
+ latent_folder if latent_folder is not None else video_folder
137
+ )
138
+ self.audio_in_video = audio_in_video
139
+
140
+ self.filelist = []
141
+ self.audio_filelist = []
142
+ self.landmark_filelist = [] if get_masks else None
143
+ with open(filelist, "r") as files:
144
+ for f in files.readlines():
145
+ f = f.rstrip()
146
+
147
+ audio_path = f.replace(video_folder, audio_folder).replace(
148
+ video_extension, audio_extension
149
+ )
150
+
151
+ self.filelist += [f]
152
+ self.audio_filelist += [audio_path]
153
+ if self.get_masks:
154
+ landmark_path = f.replace(video_folder, landmarks_folder).replace(
155
+ video_extension, ".npy"
156
+ )
157
+ self.landmark_filelist += [landmark_path]
158
+
159
+ self.resize_size = resize_size
160
+ if use_latent and not precomputed_latent:
161
+ self.resize_size *= 4 if latent_type in ["stable", "ldm"] else 8
162
+ self.scale_audio = scale_audio
163
+ self.step = step
164
+ self.use_latent = use_latent
165
+ self.precomputed_latent = precomputed_latent
166
+ self.latent_type = latent_type
167
+ self.latent_scale = latent_scale
168
+ self.video_ext = video_extension
169
+ self.video_folder = video_folder
170
+
171
+ self.augment = augment
172
+ self.maybe_augment = RandomHorizontalFlip(p=0.5) if augment else lambda x: x
173
+ self.maybe_augment_audio = (
174
+ Compose(
175
+ [
176
+ AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.002, p=0.25),
177
+ # TimeStretch(min_rate=0.8, max_rate=1.25, p=0.3),
178
+ PitchShift(min_semitones=-1, max_semitones=1, p=0.25),
179
+ # Shift(min_fraction=-0.5, max_fraction=0.5, p=0.333),
180
+ ]
181
+ )
182
+ if augment_audio
183
+ else lambda x, sample_rate: x
184
+ )
185
+ self.maybe_augment_audio = partial(
186
+ self.maybe_augment_audio, sample_rate=audio_rate
187
+ )
188
+
189
+ self.mode = mode
190
+ if mode == "interpolation":
191
+ need_cond = False # Interpolation does not need condition as first and last frame becomes the condition
192
+ self.need_cond = need_cond # If need cond will extract one more frame than the number of frames
193
+ if get_separate_id:
194
+ self.need_cond = True
195
+ # It is used for the conditional model when the condition is not on the temporal dimension
196
+ num_frames = num_frames if not self.need_cond else num_frames + 1
197
+
198
+ vr = decord.VideoReader(self.filelist[0])
199
+ self.video_rate = math.ceil(vr.get_avg_fps())
200
+ print(f"Video rate: {self.video_rate}")
201
+ self.audio_rate = audio_rate
202
+ a2v_ratio = fps / float(self.audio_rate)
203
+ self.samples_per_frame = math.ceil(1 / a2v_ratio)
204
+
205
+ if get_separate_id:
206
+ assert mode == "prediction", (
207
+ "Separate identity frame is only supported for prediction mode"
208
+ )
209
+ # No need for extra frame if we are getting a separate identity frame
210
+ self.need_cond = True
211
+ num_frames -= 1
212
+ self.num_frames = num_frames
213
+ self.load_all_possible_indexes = load_all_possible_indexes
214
+ if load_all_possible_indexes:
215
+ self._indexes = self._get_indexes(
216
+ self.filelist, self.audio_filelist, self.landmark_filelist
217
+ )
218
+ else:
219
+ if filter_by_length:
220
+ self._indexes = self.filter_by_length(
221
+ self.filelist, self.audio_filelist, self.landmark_filelist
222
+ )
223
+ else:
224
+ if self.get_masks:
225
+ self._indexes = list(
226
+ zip(self.filelist, self.audio_filelist, self.landmark_filelist)
227
+ )
228
+ else:
229
+ self._indexes = list(
230
+ zip(
231
+ self.filelist,
232
+ self.audio_filelist,
233
+ [None] * len(self.filelist),
234
+ )
235
+ )
236
+
237
+ self.balance_datasets = balance_datasets
238
+ if self.balance_datasets:
239
+ self.weights = self._calculate_weights()
240
+ self.sampler = WeightedRandomSampler(
241
+ self.weights, num_samples=len(self._indexes), replacement=True
242
+ )
243
+
244
+ def __len__(self):
245
+ return len(self._indexes) * self.virtual_increase
246
+
247
+ def _load_landmarks(self, filename, original_size, target_size, indexes):
248
+ landmarks = np.load(filename, allow_pickle=True)[indexes, :]
249
+ if self.what_mask == "full":
250
+ mask = create_masks_from_landmarks_full_size(
251
+ landmarks,
252
+ original_size[0],
253
+ original_size[1],
254
+ offset=self.expand_box,
255
+ nose_index=self.nose_index,
256
+ )
257
+ elif self.what_mask == "box":
258
+ mask = create_masks_from_landmarks_box(
259
+ landmarks,
260
+ (original_size[0], original_size[1]),
261
+ box_expand=self.expand_box,
262
+ nose_index=self.nose_index,
263
+ )
264
+ elif self.what_mask == "heart":
265
+ mask = face_mask_cheeks_batch(
266
+ original_size, landmarks, box_expand=0.0, show_nose=True
267
+ )
268
+ elif self.what_mask == "mouth":
269
+ mask = create_masks_from_landmarks_mouth(
270
+ landmarks,
271
+ (original_size[0], original_size[1]),
272
+ box_expand=0.01,
273
+ nose_index=self.nose_index,
274
+ )
275
+ else:
276
+ mask = create_face_mask_from_landmarks(
277
+ landmarks, original_size[0], original_size[1], mask_expand=0.05
278
+ )
279
+ # Interpolate the mask to the target size
280
+ mask = F.interpolate(
281
+ mask.unsqueeze(1).float(), size=target_size, mode="nearest"
282
+ )
283
+
284
+ return mask, landmarks
285
+
286
+ def get_emotions(self, video_file, video_indexes):
287
+ emotions_path = video_file.replace(
288
+ self.video_folder, self.emotions_folder
289
+ ).replace(self.video_ext, ".pt")
290
+ emotions = torch.load(emotions_path)
291
+ return (
292
+ emotions["valence"][video_indexes],
293
+ emotions["arousal"][video_indexes],
294
+ emotions["labels"][video_indexes],
295
+ )
296
+
297
+ def get_frame_indices(self, total_video_frames, select_randomly=False, start_idx=0):
298
+ if select_randomly:
299
+ # Randomly select self.num_frames indices from the available range
300
+ available_indices = list(range(start_idx, total_video_frames))
301
+ if len(available_indices) < self.num_frames:
302
+ raise ValueError(
303
+ "Not enough frames in the video to sample with given parameters."
304
+ )
305
+ indexes = random.sample(available_indices, self.num_frames)
306
+ return sorted(indexes) # Sort to maintain temporal order
307
+ else:
308
+ # Calculate the maximum possible start index
309
+ max_start_idx = total_video_frames - (
310
+ (self.num_frames - 1) * (self.skip_frames + 1) + 1
311
+ )
312
+
313
+ # Generate a random start index
314
+ if max_start_idx > 0:
315
+ start_idx = np.random.randint(start_idx, max_start_idx)
316
+ else:
317
+ raise ValueError(
318
+ "Not enough frames in the video to sample with given parameters."
319
+ )
320
+
321
+ # Generate the indices
322
+ indexes = [
323
+ start_idx + i * (self.skip_frames + 1) for i in range(self.num_frames)
324
+ ]
325
+
326
+ return indexes
327
+
328
+ def _load_audio(self, filename, max_len_sec, start=None, indexes=None):
329
+ audio, sr = sf.read(
330
+ filename,
331
+ start=math.ceil(start * self.audio_rate),
332
+ frames=math.ceil(self.audio_rate * max_len_sec),
333
+ always_2d=True,
334
+ ) # e.g (16000, 1)
335
+ audio = audio.T # (1, 16000)
336
+ assert sr == self.audio_rate, (
337
+ f"Audio rate is {sr} but should be {self.audio_rate}"
338
+ )
339
+ audio = audio.mean(0, keepdims=True)
340
+ audio = self.maybe_augment_audio(audio)
341
+ audio = torch.from_numpy(audio).float()
342
+ # audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=self.audio_rate)
343
+ audio = trim_pad_audio(audio, self.audio_rate, max_len_sec=max_len_sec)
344
+ return audio[0]
345
+
346
+ def ensure_shape(self, tensors):
347
+ target_length = self.samples_per_frame
348
+ processed_tensors = []
349
+ for tensor in tensors:
350
+ current_length = tensor.shape[1]
351
+ diff = current_length - target_length
352
+ assert abs(diff) <= 5, (
353
+ f"Expected shape {target_length}, but got {current_length}"
354
+ )
355
+ if diff < 0:
356
+ # Calculate how much padding is needed
357
+ padding_needed = target_length - current_length
358
+ # Pad the tensor
359
+ padded_tensor = F.pad(tensor, (0, padding_needed))
360
+ processed_tensors.append(padded_tensor)
361
+ elif diff > 0:
362
+ # Trim the tensor
363
+ trimmed_tensor = tensor[:, :target_length]
364
+ processed_tensors.append(trimmed_tensor)
365
+ else:
366
+ # If it's already the correct size
367
+ processed_tensors.append(tensor)
368
+ return torch.cat(processed_tensors)
369
+
370
+ def normalize_latents(self, latents):
371
+ if self.data_mean is not None:
372
+ # Normalize latents to 0 mean and 0.5 std
373
+ latents = ((latents - self.data_mean) / self.data_std) * 0.5
374
+ return latents
375
+
376
+ def convert_indexes(self, indexes_25fps, fps_from=25, fps_to=60):
377
+ ratio = fps_to / fps_from
378
+ indexes_60fps = [int(index * ratio) for index in indexes_25fps]
379
+ return indexes_60fps
380
+
381
+ def _get_frames_and_audio(self, idx):
382
+ if self.load_all_possible_indexes:
383
+ indexes, video_file, audio_file, land_file = self._indexes[idx]
384
+ if self.audio_in_video:
385
+ vr = decord.AVReader(video_file, sample_rate=self.audio_rate)
386
+ else:
387
+ vr = decord.VideoReader(video_file)
388
+ len_video = len(vr)
389
+ if "AA_processed" in video_file or "1000actors_nsv" in video_file:
390
+ len_video *= 25 / 60
391
+ len_video = int(len_video)
392
+ else:
393
+ video_file, audio_file, land_file = self._indexes[idx]
394
+ if self.audio_in_video:
395
+ vr = decord.AVReader(video_file, sample_rate=self.audio_rate)
396
+ else:
397
+ vr = decord.VideoReader(video_file)
398
+ len_video = len(vr)
399
+ if "AA_processed" in video_file or "1000actors_nsv" in video_file:
400
+ len_video *= 25 / 60
401
+ len_video = int(len_video)
402
+
403
+ indexes = self.get_frame_indices(
404
+ len_video,
405
+ select_randomly=self.select_randomly,
406
+ start_idx=120 if "1000actors_nsv" in video_file else 0,
407
+ )
408
+
409
+ if self.get_separate_id:
410
+ id_idx = np.random.randint(0, len_video)
411
+ indexes.insert(0, id_idx)
412
+
413
+ if "AA_processed" in video_file or "1000actors_nsv" in video_file:
414
+ video_indexes = self.convert_indexes(indexes, fps_from=25, fps_to=60)
415
+ audio_file = audio_file.replace("_output_output", "")
416
+ if self.audio_emb_type == "wav2vec2" and "AA_processed" in video_file:
417
+ audio_path_extra = ".safetensors"
418
+ else:
419
+ audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors"
420
+
421
+ video_path_extra = f"_{self.latent_type}_512_latent.safetensors"
422
+ audio_path_extra_extra = (
423
+ ".pt" if "AA_processed" in video_file else "_beats_emb.pt"
424
+ )
425
+
426
+ else:
427
+ video_indexes = indexes
428
+ audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors"
429
+ video_path_extra = f"_{self.latent_type}_512_latent.safetensors"
430
+ audio_path_extra_extra = "_beats_emb.pt"
431
+
432
+ emotions = None
433
+ if self.use_emotions:
434
+ emotions = self.get_emotions(video_file, video_indexes)
435
+ if self.get_separate_id:
436
+ emotions = (emotions[0][1:], emotions[1][1:], emotions[2][1:])
437
+
438
+ raw_audio = None
439
+ if self.audio_in_video:
440
+ raw_audio, frames_video = vr.get_batch(video_indexes)
441
+ raw_audio = rearrange(self.ensure_shape(raw_audio), "f s -> (f s)")
442
+
443
+ if self.use_latent and self.precomputed_latent:
444
+ latent_file = video_file.replace(self.video_ext, video_path_extra).replace(
445
+ self.video_folder, self.latent_folder
446
+ )
447
+ frames = load_file(latent_file)["latents"][video_indexes, :, :, :]
448
+
449
+ if frames.shape[-1] != 64:
450
+ print(f"Frames shape: {frames.shape}, video file: {video_file}")
451
+
452
+ frames = rearrange(frames, "t c h w -> c t h w") * self.latent_scale
453
+ frames = self.normalize_latents(frames)
454
+ else:
455
+ if self.audio_in_video:
456
+ frames = frames_video.permute(3, 0, 1, 2).float()
457
+ else:
458
+ frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float()
459
+
460
+ if raw_audio is None:
461
+ # Audio is not in video
462
+ raw_audio = self._load_audio(
463
+ audio_file,
464
+ max_len_sec=frames.shape[1] / self.fps,
465
+ start=indexes[0] / self.fps,
466
+ # indexes=indexes,
467
+ )
468
+ if not self.from_audio_embedding:
469
+ audio = raw_audio
470
+ audio_frames = rearrange(audio, "(f s) -> f s", s=self.samples_per_frame)
471
+ else:
472
+ audio = load_file(
473
+ audio_file.replace(self.audio_folder, self.audio_emb_folder).split(".")[
474
+ 0
475
+ ]
476
+ + audio_path_extra
477
+ )["audio"]
478
+ audio_frames = audio[indexes, :]
479
+ if self.add_extra_audio_emb:
480
+ audio_extra = torch.load(
481
+ audio_file.replace(self.audio_folder, self.audio_emb_folder).split(
482
+ "."
483
+ )[0]
484
+ + audio_path_extra_extra
485
+ )
486
+ audio_extra = audio_extra[indexes, :]
487
+ audio_frames = torch.cat([audio_frames, audio_extra], dim=-1)
488
+
489
+ audio_frames = (
490
+ audio_frames[1:] if self.need_cond else audio_frames
491
+ ) # Remove audio of first frame
492
+
493
+ if self.get_original_frames:
494
+ original_frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float()
495
+ original_frames = self.scale_and_crop((original_frames / 255.0) * 2 - 1)
496
+ original_frames = (
497
+ original_frames[:, 1:] if self.need_cond else original_frames
498
+ )
499
+ else:
500
+ original_frames = None
501
+
502
+ if not self.use_latent or (self.use_latent and not self.precomputed_latent):
503
+ frames = self.scale_and_crop((frames / 255.0) * 2 - 1)
504
+
505
+ target = frames[:, 1:] if self.need_cond else frames
506
+ if self.mode == "prediction":
507
+ if self.use_latent:
508
+ if self.audio_in_video:
509
+ clean_cond = (
510
+ frames_video[0].unsqueeze(0).permute(3, 0, 1, 2).float()
511
+ )
512
+ else:
513
+ clean_cond = (
514
+ vr[video_indexes[0]].unsqueeze(0).permute(3, 0, 1, 2).float()
515
+ )
516
+ original_size = clean_cond.shape[-2:]
517
+ clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1).squeeze(
518
+ 0
519
+ )
520
+ if self.latent_condition:
521
+ noisy_cond = frames[:, 0]
522
+ else:
523
+ noisy_cond = clean_cond
524
+ else:
525
+ clean_cond = frames[:, 0]
526
+ noisy_cond = clean_cond
527
+ elif self.mode == "interpolation":
528
+ if self.use_latent:
529
+ if self.audio_in_video:
530
+ clean_cond = frames_video[[0, -1]].permute(3, 0, 1, 2).float()
531
+ else:
532
+ clean_cond = (
533
+ vr.get_batch([video_indexes[0], video_indexes[-1]])
534
+ .permute(3, 0, 1, 2)
535
+ .float()
536
+ )
537
+ original_size = clean_cond.shape[-2:]
538
+ clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1)
539
+ if self.latent_condition:
540
+ noisy_cond = torch.stack([target[:, 0], target[:, -1]], dim=1)
541
+ else:
542
+ noisy_cond = clean_cond
543
+ else:
544
+ clean_cond = torch.stack([target[:, 0], target[:, -1]], dim=1)
545
+ noisy_cond = clean_cond
546
+
547
+ # Add noise to conditional frame
548
+ if self.cond_noise and isinstance(self.cond_noise, ListConfig):
549
+ cond_noise = (
550
+ self.cond_noise[0] + self.cond_noise[1] * torch.randn((1,))
551
+ ).exp()
552
+ noisy_cond = noisy_cond + cond_noise * torch.randn_like(noisy_cond)
553
+ else:
554
+ noisy_cond = noisy_cond + self.cond_noise * torch.randn_like(noisy_cond)
555
+ cond_noise = self.cond_noise
556
+
557
+ if self.get_masks:
558
+ target_size = (
559
+ (self.resize_size, self.resize_size)
560
+ if not self.use_latent
561
+ else (self.resize_size // 8, self.resize_size // 8)
562
+ )
563
+ masks, landmarks = self._load_landmarks(
564
+ land_file, original_size, target_size, video_indexes
565
+ )
566
+
567
+ landmarks = None
568
+ masks = (
569
+ masks.permute(1, 0, 2, 3)[:, 1:]
570
+ if self.need_cond
571
+ else masks.permute(1, 0, 2, 3)
572
+ )
573
+ else:
574
+ masks = None
575
+ landmarks = None
576
+
577
+ return (
578
+ original_frames,
579
+ clean_cond,
580
+ noisy_cond,
581
+ target,
582
+ audio_frames,
583
+ raw_audio,
584
+ cond_noise,
585
+ emotions,
586
+ masks,
587
+ landmarks,
588
+ )
589
+
590
+ def filter_by_length(self, video_filelist, audio_filelist):
591
+ def with_opencv(filename):
592
+ video = cv2.VideoCapture(filename)
593
+ frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
594
+
595
+ return int(frame_count)
596
+
597
+ filtered_video = []
598
+ filtered_audio = []
599
+ min_length = (self.num_frames - 1) * (self.skip_frames + 1) + 1
600
+ for vid_file, audio_file in tqdm(
601
+ zip(video_filelist, audio_filelist),
602
+ total=len(video_filelist),
603
+ desc="Filtering",
604
+ ):
605
+ # vr = decord.VideoReader(vid_file)
606
+
607
+ len_video = with_opencv(vid_file)
608
+ # Short videos
609
+ if len_video < min_length:
610
+ continue
611
+ filtered_video.append(vid_file)
612
+ filtered_audio.append(audio_file)
613
+ print(f"New number of files: {len(filtered_video)}")
614
+ return filtered_video, filtered_audio
615
+
616
+ def _get_indexes(self, video_filelist, audio_filelist):
617
+ indexes = []
618
+ self.og_shape = None
619
+ for vid_file, audio_file in zip(video_filelist, audio_filelist):
620
+ vr = decord.VideoReader(vid_file)
621
+ if self.og_shape is None:
622
+ self.og_shape = vr[0].shape[-2]
623
+ len_video = len(vr)
624
+ # Short videos
625
+ if len_video < self.num_frames:
626
+ continue
627
+ else:
628
+ possible_indexes = list(
629
+ sliding_window(range(len_video), self.num_frames)
630
+ )[:: self.step]
631
+ possible_indexes = list(
632
+ map(lambda x: (x, vid_file, audio_file), possible_indexes)
633
+ )
634
+ indexes.extend(possible_indexes)
635
+ print("Indexes", len(indexes), "\n")
636
+ return indexes
637
+
638
+ def scale_and_crop(self, video):
639
+ h, w = video.shape[-2], video.shape[-1]
640
+ # scale shorter side to resolution
641
+
642
+ if self.resize_size is not None:
643
+ scale = self.resize_size / min(h, w)
644
+ if h < w:
645
+ target_size = (self.resize_size, math.ceil(w * scale))
646
+ else:
647
+ target_size = (math.ceil(h * scale), self.resize_size)
648
+ video = F.interpolate(
649
+ video,
650
+ size=target_size,
651
+ mode="bilinear",
652
+ align_corners=False,
653
+ antialias=True,
654
+ )
655
+
656
+ # center crop
657
+ h, w = video.shape[-2], video.shape[-1]
658
+ w_start = (w - self.resize_size) // 2
659
+ h_start = (h - self.resize_size) // 2
660
+ video = video[
661
+ :,
662
+ :,
663
+ h_start : h_start + self.resize_size,
664
+ w_start : w_start + self.resize_size,
665
+ ]
666
+ return self.maybe_augment(video)
667
+
668
+ def _calculate_weights(self):
669
+ aa_processed_count = sum(
670
+ 1
671
+ for item in self._indexes
672
+ if "AA_processed" in (item[1] if len(item) == 3 else item[0])
673
+ )
674
+ nsv_processed_count = sum(
675
+ 1
676
+ for item in self._indexes
677
+ if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0])
678
+ )
679
+ other_count = len(self._indexes) - aa_processed_count - nsv_processed_count
680
+
681
+ aa_processed_weight = 1 / aa_processed_count if aa_processed_count > 0 else 0
682
+ nsv_processed_weight = 1 / nsv_processed_count if nsv_processed_count > 0 else 0
683
+ other_weight = 1 / other_count if other_count > 0 else 0
684
+
685
+ print(
686
+ f"AA processed count: {aa_processed_count}, NSV processed count: {nsv_processed_count}, other count: {other_count}"
687
+ )
688
+ print(f"AA processed weight: {aa_processed_weight}")
689
+ print(f"NSV processed weight: {nsv_processed_weight}")
690
+ print(f"Other weight: {other_weight}")
691
+
692
+ weights = [
693
+ aa_processed_weight
694
+ if "AA_processed" in (item[1] if len(item) == 3 else item[0])
695
+ else nsv_processed_weight
696
+ if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0])
697
+ else other_weight
698
+ for item in self._indexes
699
+ ]
700
+ return weights
701
+
702
+ def __getitem__(self, idx):
703
+ if self.balance_datasets:
704
+ idx = self.sampler.__iter__().__next__()
705
+
706
+ try:
707
+ (
708
+ original_frames,
709
+ clean_cond,
710
+ noisy_cond,
711
+ target,
712
+ audio,
713
+ raw_audio,
714
+ cond_noise,
715
+ emotions,
716
+ masks,
717
+ landmarks,
718
+ ) = self._get_frames_and_audio(idx % len(self._indexes))
719
+ except Exception as e:
720
+ print(f"Error with index {idx}: {e}")
721
+ return self.__getitem__(np.random.randint(0, len(self)))
722
+ out_data = {}
723
+
724
+ if original_frames is not None:
725
+ out_data["original_frames"] = original_frames
726
+
727
+ if audio is not None:
728
+ out_data["audio_emb"] = audio
729
+ out_data["raw_audio"] = raw_audio
730
+
731
+ if self.use_emotions:
732
+ out_data["valence"] = emotions[0]
733
+ out_data["arousal"] = emotions[1]
734
+ out_data["emo_labels"] = emotions[2]
735
+ if self.use_latent:
736
+ input_key = "latents"
737
+ else:
738
+ input_key = "frames"
739
+ out_data[input_key] = target
740
+ if noisy_cond is not None:
741
+ out_data["cond_frames"] = noisy_cond
742
+ out_data["cond_frames_without_noise"] = clean_cond
743
+ if cond_noise is not None:
744
+ out_data["cond_aug"] = cond_noise
745
+
746
+ if masks is not None:
747
+ out_data["masks"] = masks
748
+ out_data["gt"] = target
749
+ if landmarks is not None:
750
+ out_data["landmarks"] = landmarks
751
+
752
+ out_data["motion_bucket_id"] = torch.tensor([self.motion_id])
753
+ out_data["fps_id"] = torch.tensor([self.fps - 1])
754
+ out_data["num_video_frames"] = self.num_frames
755
+ out_data["image_only_indicator"] = torch.zeros(self.num_frames)
756
+ return out_data
757
+
758
+
759
+ if __name__ == "__main__":
760
+ import torchvision.transforms as transforms
761
+ import cv2
762
+
763
+ transform = transforms.Compose(transforms=[transforms.Resize((256, 256))])
764
+ dataset = VideoDataset(
765
+ "/vol/paramonos2/projects/antoni/datasets/mahnob/filelist_videos_val.txt",
766
+ transform=transform,
767
+ num_frames=25,
768
+ )
769
+ print(len(dataset))
770
+ idx = np.random.randint(0, len(dataset))
771
+
772
+ for i in range(10):
773
+ print(dataset[i][0].shape, dataset[i][1].shape)
774
+
775
+ image_identity = (dataset[idx][0].permute(1, 2, 0).numpy() + 1) / 2 * 255
776
+ image_other = (dataset[idx][1][:, -1].permute(1, 2, 0).numpy() + 1) / 2 * 255
777
+ cv2.imwrite("image_identity.png", image_identity[:, :, ::-1])
778
+ for i in range(25):
779
+ image = (dataset[idx][1][:, i].permute(1, 2, 0).numpy() + 1) / 2 * 255
780
+ cv2.imwrite(f"tmp_vid_dataset/image_{i}.png", image[:, :, ::-1])
sgm/inference/api.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from dataclasses import asdict, dataclass
3
+ from enum import Enum
4
+ from typing import Optional
5
+
6
+ from omegaconf import OmegaConf
7
+
8
+ from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
9
+ do_sample)
10
+ from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
11
+ DPMPP2SAncestralSampler,
12
+ EulerAncestralSampler,
13
+ EulerEDMSampler,
14
+ HeunEDMSampler,
15
+ LinearMultistepSampler)
16
+ from sgm.util import load_model_from_config
17
+
18
+
19
+ class ModelArchitecture(str, Enum):
20
+ SD_2_1 = "stable-diffusion-v2-1"
21
+ SD_2_1_768 = "stable-diffusion-v2-1-768"
22
+ SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
23
+ SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
24
+ SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
25
+ SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
26
+
27
+
28
+ class Sampler(str, Enum):
29
+ EULER_EDM = "EulerEDMSampler"
30
+ HEUN_EDM = "HeunEDMSampler"
31
+ EULER_ANCESTRAL = "EulerAncestralSampler"
32
+ DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
33
+ DPMPP2M = "DPMPP2MSampler"
34
+ LINEAR_MULTISTEP = "LinearMultistepSampler"
35
+
36
+
37
+ class Discretization(str, Enum):
38
+ LEGACY_DDPM = "LegacyDDPMDiscretization"
39
+ EDM = "EDMDiscretization"
40
+
41
+
42
+ class Guider(str, Enum):
43
+ VANILLA = "VanillaCFG"
44
+ IDENTITY = "IdentityGuider"
45
+
46
+
47
+ class Thresholder(str, Enum):
48
+ NONE = "None"
49
+
50
+
51
+ @dataclass
52
+ class SamplingParams:
53
+ width: int = 1024
54
+ height: int = 1024
55
+ steps: int = 50
56
+ sampler: Sampler = Sampler.DPMPP2M
57
+ discretization: Discretization = Discretization.LEGACY_DDPM
58
+ guider: Guider = Guider.VANILLA
59
+ thresholder: Thresholder = Thresholder.NONE
60
+ scale: float = 6.0
61
+ aesthetic_score: float = 5.0
62
+ negative_aesthetic_score: float = 5.0
63
+ img2img_strength: float = 1.0
64
+ orig_width: int = 1024
65
+ orig_height: int = 1024
66
+ crop_coords_top: int = 0
67
+ crop_coords_left: int = 0
68
+ sigma_min: float = 0.0292
69
+ sigma_max: float = 14.6146
70
+ rho: float = 3.0
71
+ s_churn: float = 0.0
72
+ s_tmin: float = 0.0
73
+ s_tmax: float = 999.0
74
+ s_noise: float = 1.0
75
+ eta: float = 1.0
76
+ order: int = 4
77
+
78
+
79
+ @dataclass
80
+ class SamplingSpec:
81
+ width: int
82
+ height: int
83
+ channels: int
84
+ factor: int
85
+ is_legacy: bool
86
+ config: str
87
+ ckpt: str
88
+ is_guided: bool
89
+
90
+
91
+ model_specs = {
92
+ ModelArchitecture.SD_2_1: SamplingSpec(
93
+ height=512,
94
+ width=512,
95
+ channels=4,
96
+ factor=8,
97
+ is_legacy=True,
98
+ config="sd_2_1.yaml",
99
+ ckpt="v2-1_512-ema-pruned.safetensors",
100
+ is_guided=True,
101
+ ),
102
+ ModelArchitecture.SD_2_1_768: SamplingSpec(
103
+ height=768,
104
+ width=768,
105
+ channels=4,
106
+ factor=8,
107
+ is_legacy=True,
108
+ config="sd_2_1_768.yaml",
109
+ ckpt="v2-1_768-ema-pruned.safetensors",
110
+ is_guided=True,
111
+ ),
112
+ ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
113
+ height=1024,
114
+ width=1024,
115
+ channels=4,
116
+ factor=8,
117
+ is_legacy=False,
118
+ config="sd_xl_base.yaml",
119
+ ckpt="sd_xl_base_0.9.safetensors",
120
+ is_guided=True,
121
+ ),
122
+ ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
123
+ height=1024,
124
+ width=1024,
125
+ channels=4,
126
+ factor=8,
127
+ is_legacy=True,
128
+ config="sd_xl_refiner.yaml",
129
+ ckpt="sd_xl_refiner_0.9.safetensors",
130
+ is_guided=True,
131
+ ),
132
+ ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
133
+ height=1024,
134
+ width=1024,
135
+ channels=4,
136
+ factor=8,
137
+ is_legacy=False,
138
+ config="sd_xl_base.yaml",
139
+ ckpt="sd_xl_base_1.0.safetensors",
140
+ is_guided=True,
141
+ ),
142
+ ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
143
+ height=1024,
144
+ width=1024,
145
+ channels=4,
146
+ factor=8,
147
+ is_legacy=True,
148
+ config="sd_xl_refiner.yaml",
149
+ ckpt="sd_xl_refiner_1.0.safetensors",
150
+ is_guided=True,
151
+ ),
152
+ }
153
+
154
+
155
+ class SamplingPipeline:
156
+ def __init__(
157
+ self,
158
+ model_id: ModelArchitecture,
159
+ model_path="checkpoints",
160
+ config_path="configs/inference",
161
+ device="cuda",
162
+ use_fp16=True,
163
+ ) -> None:
164
+ if model_id not in model_specs:
165
+ raise ValueError(f"Model {model_id} not supported")
166
+ self.model_id = model_id
167
+ self.specs = model_specs[self.model_id]
168
+ self.config = str(pathlib.Path(config_path, self.specs.config))
169
+ self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
170
+ self.device = device
171
+ self.model = self._load_model(device=device, use_fp16=use_fp16)
172
+
173
+ def _load_model(self, device="cuda", use_fp16=True):
174
+ config = OmegaConf.load(self.config)
175
+ model = load_model_from_config(config, self.ckpt)
176
+ if model is None:
177
+ raise ValueError(f"Model {self.model_id} could not be loaded")
178
+ model.to(device)
179
+ if use_fp16:
180
+ model.conditioner.half()
181
+ model.model.half()
182
+ return model
183
+
184
+ def text_to_image(
185
+ self,
186
+ params: SamplingParams,
187
+ prompt: str,
188
+ negative_prompt: str = "",
189
+ samples: int = 1,
190
+ return_latents: bool = False,
191
+ ):
192
+ sampler = get_sampler_config(params)
193
+ value_dict = asdict(params)
194
+ value_dict["prompt"] = prompt
195
+ value_dict["negative_prompt"] = negative_prompt
196
+ value_dict["target_width"] = params.width
197
+ value_dict["target_height"] = params.height
198
+ return do_sample(
199
+ self.model,
200
+ sampler,
201
+ value_dict,
202
+ samples,
203
+ params.height,
204
+ params.width,
205
+ self.specs.channels,
206
+ self.specs.factor,
207
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
208
+ return_latents=return_latents,
209
+ filter=None,
210
+ )
211
+
212
+ def image_to_image(
213
+ self,
214
+ params: SamplingParams,
215
+ image,
216
+ prompt: str,
217
+ negative_prompt: str = "",
218
+ samples: int = 1,
219
+ return_latents: bool = False,
220
+ ):
221
+ sampler = get_sampler_config(params)
222
+
223
+ if params.img2img_strength < 1.0:
224
+ sampler.discretization = Img2ImgDiscretizationWrapper(
225
+ sampler.discretization,
226
+ strength=params.img2img_strength,
227
+ )
228
+ height, width = image.shape[2], image.shape[3]
229
+ value_dict = asdict(params)
230
+ value_dict["prompt"] = prompt
231
+ value_dict["negative_prompt"] = negative_prompt
232
+ value_dict["target_width"] = width
233
+ value_dict["target_height"] = height
234
+ return do_img2img(
235
+ image,
236
+ self.model,
237
+ sampler,
238
+ value_dict,
239
+ samples,
240
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
241
+ return_latents=return_latents,
242
+ filter=None,
243
+ )
244
+
245
+ def refiner(
246
+ self,
247
+ params: SamplingParams,
248
+ image,
249
+ prompt: str,
250
+ negative_prompt: Optional[str] = None,
251
+ samples: int = 1,
252
+ return_latents: bool = False,
253
+ ):
254
+ sampler = get_sampler_config(params)
255
+ value_dict = {
256
+ "orig_width": image.shape[3] * 8,
257
+ "orig_height": image.shape[2] * 8,
258
+ "target_width": image.shape[3] * 8,
259
+ "target_height": image.shape[2] * 8,
260
+ "prompt": prompt,
261
+ "negative_prompt": negative_prompt,
262
+ "crop_coords_top": 0,
263
+ "crop_coords_left": 0,
264
+ "aesthetic_score": 6.0,
265
+ "negative_aesthetic_score": 2.5,
266
+ }
267
+
268
+ return do_img2img(
269
+ image,
270
+ self.model,
271
+ sampler,
272
+ value_dict,
273
+ samples,
274
+ skip_encode=True,
275
+ return_latents=return_latents,
276
+ filter=None,
277
+ )
278
+
279
+
280
+ def get_guider_config(params: SamplingParams):
281
+ if params.guider == Guider.IDENTITY:
282
+ guider_config = {
283
+ "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
284
+ }
285
+ elif params.guider == Guider.VANILLA:
286
+ scale = params.scale
287
+
288
+ thresholder = params.thresholder
289
+
290
+ if thresholder == Thresholder.NONE:
291
+ dyn_thresh_config = {
292
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
293
+ }
294
+ else:
295
+ raise NotImplementedError
296
+
297
+ guider_config = {
298
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
299
+ "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
300
+ }
301
+ else:
302
+ raise NotImplementedError
303
+ return guider_config
304
+
305
+
306
+ def get_discretization_config(params: SamplingParams):
307
+ if params.discretization == Discretization.LEGACY_DDPM:
308
+ discretization_config = {
309
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
310
+ }
311
+ elif params.discretization == Discretization.EDM:
312
+ discretization_config = {
313
+ "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
314
+ "params": {
315
+ "sigma_min": params.sigma_min,
316
+ "sigma_max": params.sigma_max,
317
+ "rho": params.rho,
318
+ },
319
+ }
320
+ else:
321
+ raise ValueError(f"unknown discretization {params.discretization}")
322
+ return discretization_config
323
+
324
+
325
+ def get_sampler_config(params: SamplingParams):
326
+ discretization_config = get_discretization_config(params)
327
+ guider_config = get_guider_config(params)
328
+ sampler = None
329
+ if params.sampler == Sampler.EULER_EDM:
330
+ return EulerEDMSampler(
331
+ num_steps=params.steps,
332
+ discretization_config=discretization_config,
333
+ guider_config=guider_config,
334
+ s_churn=params.s_churn,
335
+ s_tmin=params.s_tmin,
336
+ s_tmax=params.s_tmax,
337
+ s_noise=params.s_noise,
338
+ verbose=True,
339
+ )
340
+ if params.sampler == Sampler.HEUN_EDM:
341
+ return HeunEDMSampler(
342
+ num_steps=params.steps,
343
+ discretization_config=discretization_config,
344
+ guider_config=guider_config,
345
+ s_churn=params.s_churn,
346
+ s_tmin=params.s_tmin,
347
+ s_tmax=params.s_tmax,
348
+ s_noise=params.s_noise,
349
+ verbose=True,
350
+ )
351
+ if params.sampler == Sampler.EULER_ANCESTRAL:
352
+ return EulerAncestralSampler(
353
+ num_steps=params.steps,
354
+ discretization_config=discretization_config,
355
+ guider_config=guider_config,
356
+ eta=params.eta,
357
+ s_noise=params.s_noise,
358
+ verbose=True,
359
+ )
360
+ if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
361
+ return DPMPP2SAncestralSampler(
362
+ num_steps=params.steps,
363
+ discretization_config=discretization_config,
364
+ guider_config=guider_config,
365
+ eta=params.eta,
366
+ s_noise=params.s_noise,
367
+ verbose=True,
368
+ )
369
+ if params.sampler == Sampler.DPMPP2M:
370
+ return DPMPP2MSampler(
371
+ num_steps=params.steps,
372
+ discretization_config=discretization_config,
373
+ guider_config=guider_config,
374
+ verbose=True,
375
+ )
376
+ if params.sampler == Sampler.LINEAR_MULTISTEP:
377
+ return LinearMultistepSampler(
378
+ num_steps=params.steps,
379
+ discretization_config=discretization_config,
380
+ guider_config=guider_config,
381
+ order=params.order,
382
+ verbose=True,
383
+ )
384
+
385
+ raise ValueError(f"unknown sampler {params.sampler}!")
sgm/inference/helpers.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from einops import rearrange
8
+ from imwatermark import WatermarkEncoder
9
+ from omegaconf import ListConfig
10
+ from PIL import Image
11
+ from torch import autocast
12
+
13
+ from sgm.util import append_dims
14
+
15
+
16
+ class WatermarkEmbedder:
17
+ def __init__(self, watermark):
18
+ self.watermark = watermark
19
+ self.num_bits = len(WATERMARK_BITS)
20
+ self.encoder = WatermarkEncoder()
21
+ self.encoder.set_watermark("bits", self.watermark)
22
+
23
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
24
+ """
25
+ Adds a predefined watermark to the input image
26
+
27
+ Args:
28
+ image: ([N,] B, RGB, H, W) in range [0, 1]
29
+
30
+ Returns:
31
+ same as input but watermarked
32
+ """
33
+ squeeze = len(image.shape) == 4
34
+ if squeeze:
35
+ image = image[None, ...]
36
+ n = image.shape[0]
37
+ image_np = rearrange(
38
+ (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
39
+ ).numpy()[:, :, :, ::-1]
40
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
41
+ # watermarking libary expects input as cv2 BGR format
42
+ for k in range(image_np.shape[0]):
43
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
44
+ image = torch.from_numpy(
45
+ rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
46
+ ).to(image.device)
47
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
48
+ if squeeze:
49
+ image = image[0]
50
+ return image
51
+
52
+
53
+ # A fixed 48-bit message that was choosen at random
54
+ # WATERMARK_MESSAGE = 0xB3EC907BB19E
55
+ WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
56
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
57
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
58
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
59
+
60
+
61
+ def get_unique_embedder_keys_from_conditioner(conditioner):
62
+ return list({x.input_key for x in conditioner.embedders})
63
+
64
+
65
+ def perform_save_locally(save_path, samples):
66
+ os.makedirs(os.path.join(save_path), exist_ok=True)
67
+ base_count = len(os.listdir(os.path.join(save_path)))
68
+ samples = embed_watermark(samples)
69
+ for sample in samples:
70
+ sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
71
+ Image.fromarray(sample.astype(np.uint8)).save(
72
+ os.path.join(save_path, f"{base_count:09}.png")
73
+ )
74
+ base_count += 1
75
+
76
+
77
+ class Img2ImgDiscretizationWrapper:
78
+ """
79
+ wraps a discretizer, and prunes the sigmas
80
+ params:
81
+ strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
82
+ """
83
+
84
+ def __init__(self, discretization, strength: float = 1.0):
85
+ self.discretization = discretization
86
+ self.strength = strength
87
+ assert 0.0 <= self.strength <= 1.0
88
+
89
+ def __call__(self, *args, **kwargs):
90
+ # sigmas start large first, and decrease then
91
+ sigmas = self.discretization(*args, **kwargs)
92
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
93
+ sigmas = torch.flip(sigmas, (0,))
94
+ sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
95
+ print("prune index:", max(int(self.strength * len(sigmas)), 1))
96
+ sigmas = torch.flip(sigmas, (0,))
97
+ print(f"sigmas after pruning: ", sigmas)
98
+ return sigmas
99
+
100
+
101
+ def do_sample(
102
+ model,
103
+ sampler,
104
+ value_dict,
105
+ num_samples,
106
+ H,
107
+ W,
108
+ C,
109
+ F,
110
+ force_uc_zero_embeddings: Optional[List] = None,
111
+ batch2model_input: Optional[List] = None,
112
+ return_latents=False,
113
+ filter=None,
114
+ device="cuda",
115
+ ):
116
+ if force_uc_zero_embeddings is None:
117
+ force_uc_zero_embeddings = []
118
+ if batch2model_input is None:
119
+ batch2model_input = []
120
+
121
+ with torch.no_grad():
122
+ with autocast(device) as precision_scope:
123
+ with model.ema_scope():
124
+ num_samples = [num_samples]
125
+ batch, batch_uc = get_batch(
126
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
127
+ value_dict,
128
+ num_samples,
129
+ )
130
+ for key in batch:
131
+ if isinstance(batch[key], torch.Tensor):
132
+ print(key, batch[key].shape)
133
+ elif isinstance(batch[key], list):
134
+ print(key, [len(l) for l in batch[key]])
135
+ else:
136
+ print(key, batch[key])
137
+ c, uc = model.conditioner.get_unconditional_conditioning(
138
+ batch,
139
+ batch_uc=batch_uc,
140
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
141
+ )
142
+
143
+ for k in c:
144
+ if not k == "crossattn":
145
+ c[k], uc[k] = map(
146
+ lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
147
+ )
148
+
149
+ additional_model_inputs = {}
150
+ for k in batch2model_input:
151
+ additional_model_inputs[k] = batch[k]
152
+
153
+ shape = (math.prod(num_samples), C, H // F, W // F)
154
+ randn = torch.randn(shape).to(device)
155
+
156
+ def denoiser(input, sigma, c):
157
+ return model.denoiser(
158
+ model.model, input, sigma, c, **additional_model_inputs
159
+ )
160
+
161
+ samples_z = sampler(denoiser, randn, cond=c, uc=uc)
162
+ samples_x = model.decode_first_stage(samples_z)
163
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
164
+
165
+ if filter is not None:
166
+ samples = filter(samples)
167
+
168
+ if return_latents:
169
+ return samples, samples_z
170
+ return samples
171
+
172
+
173
+ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
174
+ # Hardcoded demo setups; might undergo some changes in the future
175
+
176
+ batch = {}
177
+ batch_uc = {}
178
+
179
+ for key in keys:
180
+ if key == "txt":
181
+ batch["txt"] = (
182
+ np.repeat([value_dict["prompt"]], repeats=math.prod(N))
183
+ .reshape(N)
184
+ .tolist()
185
+ )
186
+ batch_uc["txt"] = (
187
+ np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
188
+ .reshape(N)
189
+ .tolist()
190
+ )
191
+ elif key == "original_size_as_tuple":
192
+ batch["original_size_as_tuple"] = (
193
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
194
+ .to(device)
195
+ .repeat(*N, 1)
196
+ )
197
+ elif key == "crop_coords_top_left":
198
+ batch["crop_coords_top_left"] = (
199
+ torch.tensor(
200
+ [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
201
+ )
202
+ .to(device)
203
+ .repeat(*N, 1)
204
+ )
205
+ elif key == "aesthetic_score":
206
+ batch["aesthetic_score"] = (
207
+ torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
208
+ )
209
+ batch_uc["aesthetic_score"] = (
210
+ torch.tensor([value_dict["negative_aesthetic_score"]])
211
+ .to(device)
212
+ .repeat(*N, 1)
213
+ )
214
+
215
+ elif key == "target_size_as_tuple":
216
+ batch["target_size_as_tuple"] = (
217
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]])
218
+ .to(device)
219
+ .repeat(*N, 1)
220
+ )
221
+ else:
222
+ batch[key] = value_dict[key]
223
+
224
+ for key in batch.keys():
225
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
226
+ batch_uc[key] = torch.clone(batch[key])
227
+ return batch, batch_uc
228
+
229
+
230
+ def get_input_image_tensor(image: Image.Image, device="cuda"):
231
+ w, h = image.size
232
+ print(f"loaded input image of size ({w}, {h})")
233
+ width, height = map(
234
+ lambda x: x - x % 64, (w, h)
235
+ ) # resize to integer multiple of 64
236
+ image = image.resize((width, height))
237
+ image_array = np.array(image.convert("RGB"))
238
+ image_array = image_array[None].transpose(0, 3, 1, 2)
239
+ image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
240
+ return image_tensor.to(device)
241
+
242
+
243
+ def do_img2img(
244
+ img,
245
+ model,
246
+ sampler,
247
+ value_dict,
248
+ num_samples,
249
+ force_uc_zero_embeddings=[],
250
+ additional_kwargs={},
251
+ offset_noise_level: float = 0.0,
252
+ return_latents=False,
253
+ skip_encode=False,
254
+ filter=None,
255
+ device="cuda",
256
+ ):
257
+ with torch.no_grad():
258
+ with autocast(device) as precision_scope:
259
+ with model.ema_scope():
260
+ batch, batch_uc = get_batch(
261
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
262
+ value_dict,
263
+ [num_samples],
264
+ )
265
+ c, uc = model.conditioner.get_unconditional_conditioning(
266
+ batch,
267
+ batch_uc=batch_uc,
268
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
269
+ )
270
+
271
+ for k in c:
272
+ c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
273
+
274
+ for k in additional_kwargs:
275
+ c[k] = uc[k] = additional_kwargs[k]
276
+ if skip_encode:
277
+ z = img
278
+ else:
279
+ z = model.encode_first_stage(img)
280
+ noise = torch.randn_like(z)
281
+ sigmas = sampler.discretization(sampler.num_steps)
282
+ sigma = sigmas[0].to(z.device)
283
+
284
+ if offset_noise_level > 0.0:
285
+ noise = noise + offset_noise_level * append_dims(
286
+ torch.randn(z.shape[0], device=z.device), z.ndim
287
+ )
288
+ noised_z = z + noise * append_dims(sigma, z.ndim)
289
+ noised_z = noised_z / torch.sqrt(
290
+ 1.0 + sigmas[0] ** 2.0
291
+ ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
292
+
293
+ def denoiser(x, sigma, c):
294
+ return model.denoiser(model.model, x, sigma, c)
295
+
296
+ samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
297
+ samples_x = model.decode_first_stage(samples_z)
298
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
299
+
300
+ if filter is not None:
301
+ samples = filter(samples)
302
+
303
+ if return_latents:
304
+ return samples, samples_z
305
+ return samples
sgm/lr_scheduler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ warm_up_steps,
12
+ lr_min,
13
+ lr_max,
14
+ lr_start,
15
+ max_decay_steps,
16
+ verbosity_interval=0,
17
+ ):
18
+ self.lr_warm_up_steps = warm_up_steps
19
+ self.lr_start = lr_start
20
+ self.lr_min = lr_min
21
+ self.lr_max = lr_max
22
+ self.lr_max_decay_steps = max_decay_steps
23
+ self.last_lr = 0.0
24
+ self.verbosity_interval = verbosity_interval
25
+
26
+ def schedule(self, n, **kwargs):
27
+ if self.verbosity_interval > 0:
28
+ if n % self.verbosity_interval == 0:
29
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
+ if n < self.lr_warm_up_steps:
31
+ lr = (
32
+ self.lr_max - self.lr_start
33
+ ) / self.lr_warm_up_steps * n + self.lr_start
34
+ self.last_lr = lr
35
+ return lr
36
+ else:
37
+ t = (n - self.lr_warm_up_steps) / (
38
+ self.lr_max_decay_steps - self.lr_warm_up_steps
39
+ )
40
+ t = min(t, 1.0)
41
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
+ 1 + np.cos(t * np.pi)
43
+ )
44
+ self.last_lr = lr
45
+ return lr
46
+
47
+ def __call__(self, n, **kwargs):
48
+ return self.schedule(n, **kwargs)
49
+
50
+
51
+ class LambdaWarmUpCosineScheduler2:
52
+ """
53
+ supports repeated iterations, configurable via lists
54
+ note: use with a base_lr of 1.0.
55
+ """
56
+
57
+ def __init__(
58
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
+ ):
60
+ assert (
61
+ len(warm_up_steps)
62
+ == len(f_min)
63
+ == len(f_max)
64
+ == len(f_start)
65
+ == len(cycle_lengths)
66
+ )
67
+ self.lr_warm_up_steps = warm_up_steps
68
+ self.f_start = f_start
69
+ self.f_min = f_min
70
+ self.f_max = f_max
71
+ self.cycle_lengths = cycle_lengths
72
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
+ self.last_f = 0.0
74
+ self.verbosity_interval = verbosity_interval
75
+
76
+ def find_in_interval(self, n):
77
+ interval = 0
78
+ for cl in self.cum_cycles[1:]:
79
+ if n <= cl:
80
+ return interval
81
+ interval += 1
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0:
88
+ print(
89
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
+ f"current cycle {cycle}"
91
+ )
92
+ if n < self.lr_warm_up_steps[cycle]:
93
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
+ cycle
95
+ ] * n + self.f_start[cycle]
96
+ self.last_f = f
97
+ return f
98
+ else:
99
+ t = (n - self.lr_warm_up_steps[cycle]) / (
100
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
+ )
102
+ t = min(t, 1.0)
103
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
+ 1 + np.cos(t * np.pi)
105
+ )
106
+ self.last_f = f
107
+ return f
108
+
109
+ def __call__(self, n, **kwargs):
110
+ return self.schedule(n, **kwargs)
111
+
112
+
113
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
+ def schedule(self, n, **kwargs):
115
+ cycle = self.find_in_interval(n)
116
+ n = n - self.cum_cycles[cycle]
117
+ if self.verbosity_interval > 0:
118
+ if n % self.verbosity_interval == 0:
119
+ print(
120
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
+ f"current cycle {cycle}"
122
+ )
123
+
124
+ if n < self.lr_warm_up_steps[cycle]:
125
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
+ cycle
127
+ ] * n + self.f_start[cycle]
128
+ self.last_f = f
129
+ return f
130
+ else:
131
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
+ self.cycle_lengths[cycle] - n
133
+ ) / (self.cycle_lengths[cycle])
134
+ self.last_f = f
135
+ return f
sgm/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .autoencoder import AutoencodingEngine
2
+ from .diffusion import DiffusionEngine
sgm/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (335 Bytes). View file
 
sgm/models/__pycache__/autoencoder.cpython-311.pyc ADDED
Binary file (35.8 kB). View file
 
sgm/models/__pycache__/diffusion.cpython-311.pyc ADDED
Binary file (37.1 kB). View file
 
sgm/models/autoencoder.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import re
4
+ from abc import abstractmethod
5
+ from contextlib import contextmanager
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
+
8
+ import pytorch_lightning as pl
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from packaging import version
13
+
14
+ from ..modules.autoencoding.regularizers import AbstractRegularizer
15
+ from ..modules.ema import LitEma
16
+ from ..util import (default, get_nested_attribute, get_obj_from_str,
17
+ instantiate_from_config)
18
+
19
+ logpy = logging.getLogger(__name__)
20
+
21
+
22
+ class AbstractAutoencoder(pl.LightningModule):
23
+ """
24
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
25
+ unCLIP models, etc. Hence, it is fairly general, and specific features
26
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ ema_decay: Union[None, float] = None,
32
+ monitor: Union[None, str] = None,
33
+ input_key: str = "jpg",
34
+ ):
35
+ super().__init__()
36
+
37
+ self.input_key = input_key
38
+ self.use_ema = ema_decay is not None
39
+ if monitor is not None:
40
+ self.monitor = monitor
41
+
42
+ if self.use_ema:
43
+ self.model_ema = LitEma(self, decay=ema_decay)
44
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
45
+
46
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
+ self.automatic_optimization = False
48
+
49
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
50
+ if ckpt is None:
51
+ return
52
+ if isinstance(ckpt, str):
53
+ ckpt = {
54
+ "target": "sgm.modules.checkpoint.CheckpointEngine",
55
+ "params": {"ckpt_path": ckpt},
56
+ }
57
+ engine = instantiate_from_config(ckpt)
58
+ engine(self)
59
+
60
+ @abstractmethod
61
+ def get_input(self, batch) -> Any:
62
+ raise NotImplementedError()
63
+
64
+ def on_train_batch_end(self, *args, **kwargs):
65
+ # for EMA computation
66
+ if self.use_ema:
67
+ self.model_ema(self)
68
+
69
+ @contextmanager
70
+ def ema_scope(self, context=None):
71
+ if self.use_ema:
72
+ self.model_ema.store(self.parameters())
73
+ self.model_ema.copy_to(self)
74
+ if context is not None:
75
+ logpy.info(f"{context}: Switched to EMA weights")
76
+ try:
77
+ yield None
78
+ finally:
79
+ if self.use_ema:
80
+ self.model_ema.restore(self.parameters())
81
+ if context is not None:
82
+ logpy.info(f"{context}: Restored training weights")
83
+
84
+ @abstractmethod
85
+ def encode(self, *args, **kwargs) -> torch.Tensor:
86
+ raise NotImplementedError("encode()-method of abstract base class called")
87
+
88
+ @abstractmethod
89
+ def decode(self, *args, **kwargs) -> torch.Tensor:
90
+ raise NotImplementedError("decode()-method of abstract base class called")
91
+
92
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
93
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
94
+ return get_obj_from_str(cfg["target"])(
95
+ params, lr=lr, **cfg.get("params", dict())
96
+ )
97
+
98
+ def configure_optimizers(self) -> Any:
99
+ raise NotImplementedError()
100
+
101
+
102
+ class AutoencodingEngine(AbstractAutoencoder):
103
+ """
104
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
105
+ (we also restore them explicitly as special cases for legacy reasons).
106
+ Regularizations such as KL or VQ are moved to the regularizer class.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ *args,
112
+ encoder_config: Dict,
113
+ decoder_config: Dict,
114
+ loss_config: Dict,
115
+ regularizer_config: Dict,
116
+ optimizer_config: Union[Dict, None] = None,
117
+ lr_g_factor: float = 1.0,
118
+ trainable_ae_params: Optional[List[List[str]]] = None,
119
+ ae_optimizer_args: Optional[List[dict]] = None,
120
+ trainable_disc_params: Optional[List[List[str]]] = None,
121
+ disc_optimizer_args: Optional[List[dict]] = None,
122
+ disc_start_iter: int = 0,
123
+ diff_boost_factor: float = 3.0,
124
+ ckpt_engine: Union[None, str, dict] = None,
125
+ ckpt_path: Optional[str] = None,
126
+ additional_decode_keys: Optional[List[str]] = None,
127
+ **kwargs,
128
+ ):
129
+ super().__init__(*args, **kwargs)
130
+ self.automatic_optimization = False # pytorch lightning
131
+
132
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
133
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
134
+ self.loss: torch.nn.Module = instantiate_from_config(loss_config)
135
+ self.regularization: AbstractRegularizer = instantiate_from_config(
136
+ regularizer_config
137
+ )
138
+ self.optimizer_config = default(
139
+ optimizer_config, {"target": "torch.optim.Adam"}
140
+ )
141
+ self.diff_boost_factor = diff_boost_factor
142
+ self.disc_start_iter = disc_start_iter
143
+ self.lr_g_factor = lr_g_factor
144
+ self.trainable_ae_params = trainable_ae_params
145
+ if self.trainable_ae_params is not None:
146
+ self.ae_optimizer_args = default(
147
+ ae_optimizer_args,
148
+ [{} for _ in range(len(self.trainable_ae_params))],
149
+ )
150
+ assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
151
+ else:
152
+ self.ae_optimizer_args = [{}] # makes type consitent
153
+
154
+ self.trainable_disc_params = trainable_disc_params
155
+ if self.trainable_disc_params is not None:
156
+ self.disc_optimizer_args = default(
157
+ disc_optimizer_args,
158
+ [{} for _ in range(len(self.trainable_disc_params))],
159
+ )
160
+ assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
161
+ else:
162
+ self.disc_optimizer_args = [{}] # makes type consitent
163
+
164
+ if ckpt_path is not None:
165
+ assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
166
+ logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
167
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
168
+ self.additional_decode_keys = set(default(additional_decode_keys, []))
169
+
170
+ def get_input(self, batch: Dict) -> torch.Tensor:
171
+ # assuming unified data format, dataloader returns a dict.
172
+ # image tensors should be scaled to -1 ... 1 and in channels-first
173
+ # format (e.g., bchw instead if bhwc)
174
+ return batch[self.input_key]
175
+
176
+ def get_autoencoder_params(self) -> list:
177
+ params = []
178
+ if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
179
+ params += list(self.loss.get_trainable_autoencoder_parameters())
180
+ if hasattr(self.regularization, "get_trainable_parameters"):
181
+ params += list(self.regularization.get_trainable_parameters())
182
+ params = params + list(self.encoder.parameters())
183
+ params = params + list(self.decoder.parameters())
184
+ return params
185
+
186
+ def get_discriminator_params(self) -> list:
187
+ if hasattr(self.loss, "get_trainable_parameters"):
188
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
189
+ else:
190
+ params = []
191
+ return params
192
+
193
+ def get_last_layer(self):
194
+ return self.decoder.get_last_layer()
195
+
196
+ def encode(
197
+ self,
198
+ x: torch.Tensor,
199
+ return_reg_log: bool = False,
200
+ unregularized: bool = False,
201
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
202
+ z = self.encoder(x)
203
+ if unregularized:
204
+ return z, dict()
205
+ z, reg_log = self.regularization(z)
206
+ if return_reg_log:
207
+ return z, reg_log
208
+ return z
209
+
210
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
211
+ x = self.decoder(z, **kwargs)
212
+ return x
213
+
214
+ def forward(
215
+ self, x: torch.Tensor, **additional_decode_kwargs
216
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
217
+ z, reg_log = self.encode(x, return_reg_log=True)
218
+ dec = self.decode(z, **additional_decode_kwargs)
219
+ return z, dec, reg_log
220
+
221
+ def inner_training_step(
222
+ self, batch: dict, batch_idx: int, optimizer_idx: int = 0
223
+ ) -> torch.Tensor:
224
+ x = self.get_input(batch)
225
+ additional_decode_kwargs = {
226
+ key: batch[key] for key in self.additional_decode_keys.intersection(batch)
227
+ }
228
+ z, xrec, regularization_log = self(x, **additional_decode_kwargs)
229
+ if hasattr(self.loss, "forward_keys"):
230
+ extra_info = {
231
+ "z": z,
232
+ "optimizer_idx": optimizer_idx,
233
+ "global_step": self.global_step,
234
+ "last_layer": self.get_last_layer(),
235
+ "split": "train",
236
+ "regularization_log": regularization_log,
237
+ "autoencoder": self,
238
+ }
239
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
240
+ else:
241
+ extra_info = dict()
242
+
243
+ if optimizer_idx == 0:
244
+ # autoencode
245
+ out_loss = self.loss(x, xrec, **extra_info)
246
+ if isinstance(out_loss, tuple):
247
+ aeloss, log_dict_ae = out_loss
248
+ else:
249
+ # simple loss function
250
+ aeloss = out_loss
251
+ log_dict_ae = {"train/loss/rec": aeloss.detach()}
252
+
253
+ self.log_dict(
254
+ log_dict_ae,
255
+ prog_bar=False,
256
+ logger=True,
257
+ on_step=True,
258
+ on_epoch=True,
259
+ sync_dist=False,
260
+ )
261
+ self.log(
262
+ "loss",
263
+ aeloss.mean().detach(),
264
+ prog_bar=True,
265
+ logger=False,
266
+ on_epoch=False,
267
+ on_step=True,
268
+ )
269
+ return aeloss
270
+ elif optimizer_idx == 1:
271
+ # discriminator
272
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
273
+ # -> discriminator always needs to return a tuple
274
+ self.log_dict(
275
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
276
+ )
277
+ return discloss
278
+ else:
279
+ raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
280
+
281
+ def training_step(self, batch: dict, batch_idx: int):
282
+ opts = self.optimizers()
283
+ if not isinstance(opts, list):
284
+ # Non-adversarial case
285
+ opts = [opts]
286
+ optimizer_idx = batch_idx % len(opts)
287
+ if self.global_step < self.disc_start_iter:
288
+ optimizer_idx = 0
289
+ opt = opts[optimizer_idx]
290
+ opt.zero_grad()
291
+ with opt.toggle_model():
292
+ loss = self.inner_training_step(
293
+ batch, batch_idx, optimizer_idx=optimizer_idx
294
+ )
295
+ self.manual_backward(loss)
296
+ opt.step()
297
+
298
+ def validation_step(self, batch: dict, batch_idx: int) -> Dict:
299
+ log_dict = self._validation_step(batch, batch_idx)
300
+ with self.ema_scope():
301
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
302
+ log_dict.update(log_dict_ema)
303
+ return log_dict
304
+
305
+ def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
306
+ x = self.get_input(batch)
307
+
308
+ z, xrec, regularization_log = self(x)
309
+ if hasattr(self.loss, "forward_keys"):
310
+ extra_info = {
311
+ "z": z,
312
+ "optimizer_idx": 0,
313
+ "global_step": self.global_step,
314
+ "last_layer": self.get_last_layer(),
315
+ "split": "val" + postfix,
316
+ "regularization_log": regularization_log,
317
+ "autoencoder": self,
318
+ }
319
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
320
+ else:
321
+ extra_info = dict()
322
+ out_loss = self.loss(x, xrec, **extra_info)
323
+ if isinstance(out_loss, tuple):
324
+ aeloss, log_dict_ae = out_loss
325
+ else:
326
+ # simple loss function
327
+ aeloss = out_loss
328
+ log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
329
+ full_log_dict = log_dict_ae
330
+
331
+ if "optimizer_idx" in extra_info:
332
+ extra_info["optimizer_idx"] = 1
333
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
334
+ full_log_dict.update(log_dict_disc)
335
+ self.log(
336
+ f"val{postfix}/loss/rec",
337
+ log_dict_ae[f"val{postfix}/loss/rec"],
338
+ sync_dist=True,
339
+ )
340
+ self.log_dict(full_log_dict, sync_dist=True)
341
+ return full_log_dict
342
+
343
+ def get_param_groups(
344
+ self, parameter_names: List[List[str]], optimizer_args: List[dict]
345
+ ) -> Tuple[List[Dict[str, Any]], int]:
346
+ groups = []
347
+ num_params = 0
348
+ for names, args in zip(parameter_names, optimizer_args):
349
+ params = []
350
+ for pattern_ in names:
351
+ pattern_params = []
352
+ pattern = re.compile(pattern_)
353
+ for p_name, param in self.named_parameters():
354
+ if re.match(pattern, p_name):
355
+ pattern_params.append(param)
356
+ num_params += param.numel()
357
+ if len(pattern_params) == 0:
358
+ logpy.warn(f"Did not find parameters for pattern {pattern_}")
359
+ params.extend(pattern_params)
360
+ groups.append({"params": params, **args})
361
+ return groups, num_params
362
+
363
+ def configure_optimizers(self) -> List[torch.optim.Optimizer]:
364
+ if self.trainable_ae_params is None:
365
+ ae_params = self.get_autoencoder_params()
366
+ else:
367
+ ae_params, num_ae_params = self.get_param_groups(
368
+ self.trainable_ae_params, self.ae_optimizer_args
369
+ )
370
+ logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
371
+ if self.trainable_disc_params is None:
372
+ disc_params = self.get_discriminator_params()
373
+ else:
374
+ disc_params, num_disc_params = self.get_param_groups(
375
+ self.trainable_disc_params, self.disc_optimizer_args
376
+ )
377
+ logpy.info(
378
+ f"Number of trainable discriminator parameters: {num_disc_params:,}"
379
+ )
380
+ opt_ae = self.instantiate_optimizer_from_config(
381
+ ae_params,
382
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
383
+ self.optimizer_config,
384
+ )
385
+ opts = [opt_ae]
386
+ if len(disc_params) > 0:
387
+ opt_disc = self.instantiate_optimizer_from_config(
388
+ disc_params, self.learning_rate, self.optimizer_config
389
+ )
390
+ opts.append(opt_disc)
391
+
392
+ return opts
393
+
394
+ @torch.no_grad()
395
+ def log_images(
396
+ self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
397
+ ) -> dict:
398
+ log = dict()
399
+ additional_decode_kwargs = {}
400
+ x = self.get_input(batch)
401
+ additional_decode_kwargs.update(
402
+ {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
403
+ )
404
+
405
+ _, xrec, _ = self(x, **additional_decode_kwargs)
406
+ log["inputs"] = x
407
+ log["reconstructions"] = xrec
408
+ diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
409
+ diff.clamp_(0, 1.0)
410
+ log["diff"] = 2.0 * diff - 1.0
411
+ # diff_boost shows location of small errors, by boosting their
412
+ # brightness.
413
+ log["diff_boost"] = (
414
+ 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
415
+ )
416
+ if hasattr(self.loss, "log_images"):
417
+ log.update(self.loss.log_images(x, xrec))
418
+ with self.ema_scope():
419
+ _, xrec_ema, _ = self(x, **additional_decode_kwargs)
420
+ log["reconstructions_ema"] = xrec_ema
421
+ diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
422
+ diff_ema.clamp_(0, 1.0)
423
+ log["diff_ema"] = 2.0 * diff_ema - 1.0
424
+ log["diff_boost_ema"] = (
425
+ 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
426
+ )
427
+ if additional_log_kwargs:
428
+ additional_decode_kwargs.update(additional_log_kwargs)
429
+ _, xrec_add, _ = self(x, **additional_decode_kwargs)
430
+ log_str = "reconstructions-" + "-".join(
431
+ [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
432
+ )
433
+ log[log_str] = xrec_add
434
+ return log
435
+
436
+
437
+ class AutoencodingEngineLegacy(AutoencodingEngine):
438
+ def __init__(self, embed_dim: int, **kwargs):
439
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
440
+ ddconfig = kwargs.pop("ddconfig")
441
+ ckpt_path = kwargs.pop("ckpt_path", None)
442
+ ckpt_engine = kwargs.pop("ckpt_engine", None)
443
+ super().__init__(
444
+ encoder_config={
445
+ "target": "sgm.modules.diffusionmodules.model.Encoder",
446
+ "params": ddconfig,
447
+ },
448
+ decoder_config={
449
+ "target": "sgm.modules.diffusionmodules.model.Decoder",
450
+ "params": ddconfig,
451
+ },
452
+ **kwargs,
453
+ )
454
+ self.quant_conv = torch.nn.Conv2d(
455
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
456
+ (1 + ddconfig["double_z"]) * embed_dim,
457
+ 1,
458
+ )
459
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
460
+ self.embed_dim = embed_dim
461
+
462
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
463
+
464
+ def get_autoencoder_params(self) -> list:
465
+ params = super().get_autoencoder_params()
466
+ return params
467
+
468
+ def encode(
469
+ self, x: torch.Tensor, return_reg_log: bool = False
470
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
471
+ if self.max_batch_size is None:
472
+ z = self.encoder(x)
473
+ z = self.quant_conv(z)
474
+ else:
475
+ N = x.shape[0]
476
+ bs = self.max_batch_size
477
+ n_batches = int(math.ceil(N / bs))
478
+ z = list()
479
+ for i_batch in range(n_batches):
480
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
481
+ z_batch = self.quant_conv(z_batch)
482
+ z.append(z_batch)
483
+ z = torch.cat(z, 0)
484
+
485
+ z, reg_log = self.regularization(z)
486
+ if return_reg_log:
487
+ return z, reg_log
488
+ return z
489
+
490
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
491
+ if self.max_batch_size is None:
492
+ dec = self.post_quant_conv(z)
493
+ dec = self.decoder(dec, **decoder_kwargs)
494
+ else:
495
+ N = z.shape[0]
496
+ bs = self.max_batch_size
497
+ n_batches = int(math.ceil(N / bs))
498
+ dec = list()
499
+ for i_batch in range(n_batches):
500
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
501
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
502
+ dec.append(dec_batch)
503
+ dec = torch.cat(dec, 0)
504
+
505
+ return dec
506
+
507
+
508
+ class AutoencoderKL(AutoencodingEngineLegacy):
509
+ def __init__(self, **kwargs):
510
+ if "lossconfig" in kwargs:
511
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
512
+ super().__init__(
513
+ regularizer_config={
514
+ "target": (
515
+ "sgm.modules.autoencoding.regularizers"
516
+ ".DiagonalGaussianRegularizer"
517
+ )
518
+ },
519
+ **kwargs,
520
+ )
521
+
522
+
523
+ class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
524
+ def __init__(
525
+ self,
526
+ embed_dim: int,
527
+ n_embed: int,
528
+ sane_index_shape: bool = False,
529
+ **kwargs,
530
+ ):
531
+ if "lossconfig" in kwargs:
532
+ logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
533
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
534
+ super().__init__(
535
+ regularizer_config={
536
+ "target": (
537
+ "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
538
+ ),
539
+ "params": {
540
+ "n_e": n_embed,
541
+ "e_dim": embed_dim,
542
+ "sane_index_shape": sane_index_shape,
543
+ },
544
+ },
545
+ **kwargs,
546
+ )
547
+
548
+
549
+ class IdentityFirstStage(AbstractAutoencoder):
550
+ def __init__(self, *args, **kwargs):
551
+ super().__init__(*args, **kwargs)
552
+
553
+ def get_input(self, x: Any) -> Any:
554
+ return x
555
+
556
+ def encode(self, x: Any, *args, **kwargs) -> Any:
557
+ return x
558
+
559
+ def decode(self, x: Any, *args, **kwargs) -> Any:
560
+ return x
561
+
562
+
563
+ class AEIntegerWrapper(nn.Module):
564
+ def __init__(
565
+ self,
566
+ model: nn.Module,
567
+ shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
568
+ regularization_key: str = "regularization",
569
+ encoder_kwargs: Optional[Dict[str, Any]] = None,
570
+ ):
571
+ super().__init__()
572
+ self.model = model
573
+ assert hasattr(model, "encode") and hasattr(
574
+ model, "decode"
575
+ ), "Need AE interface"
576
+ self.regularization = get_nested_attribute(model, regularization_key)
577
+ self.shape = shape
578
+ self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
579
+
580
+ def encode(self, x) -> torch.Tensor:
581
+ assert (
582
+ not self.training
583
+ ), f"{self.__class__.__name__} only supports inference currently"
584
+ _, log = self.model.encode(x, **self.encoder_kwargs)
585
+ assert isinstance(log, dict)
586
+ inds = log["min_encoding_indices"]
587
+ return rearrange(inds, "b ... -> b (...)")
588
+
589
+ def decode(
590
+ self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
591
+ ) -> torch.Tensor:
592
+ # expect inds shape (b, s) with s = h*w
593
+ shape = default(shape, self.shape) # Optional[(h, w)]
594
+ if shape is not None:
595
+ assert len(shape) == 2, f"Unhandeled shape {shape}"
596
+ inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
597
+ h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
598
+ h = rearrange(h, "b h w c -> b c h w")
599
+ return self.model.decode(h)
600
+
601
+
602
+ class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
603
+ def __init__(self, **kwargs):
604
+ if "lossconfig" in kwargs:
605
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
606
+ super().__init__(
607
+ regularizer_config={
608
+ "target": (
609
+ "sgm.modules.autoencoding.regularizers"
610
+ ".DiagonalGaussianRegularizer"
611
+ ),
612
+ "params": {"sample": False},
613
+ },
614
+ **kwargs,
615
+ )
sgm/models/diffusion.py ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+ import re
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from omegaconf import ListConfig, OmegaConf
9
+ from safetensors.torch import load_file as load_safetensors
10
+ from torch.optim.lr_scheduler import LambdaLR
11
+ from einops import rearrange
12
+ from diffusers.models.attention_processor import IPAdapterAttnProcessor2_0
13
+
14
+ from ..modules import UNCONDITIONAL_CONFIG
15
+ from ..modules.autoencoding.temporal_ae import VideoDecoder
16
+ from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
17
+ from ..modules.ema import LitEma
18
+ from ..util import (
19
+ default,
20
+ disabled_train,
21
+ get_obj_from_str,
22
+ instantiate_from_config,
23
+ log_txt_as_img,
24
+ )
25
+
26
+
27
+ class DiffusionEngine(pl.LightningModule):
28
+ def __init__(
29
+ self,
30
+ network_config,
31
+ denoiser_config,
32
+ first_stage_config,
33
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
34
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
35
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
36
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
37
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
38
+ network_wrapper: Union[None, str, Dict, ListConfig, OmegaConf] = None,
39
+ ckpt_path: Union[None, str] = None,
40
+ remove_keys_from_weights: Union[None, List, Tuple] = None,
41
+ pattern_to_remove: Union[None, str] = None,
42
+ remove_keys_from_unet_weights: Union[None, List, Tuple] = None,
43
+ use_ema: bool = False,
44
+ ema_decay_rate: float = 0.9999,
45
+ scale_factor: float = 1.0,
46
+ disable_first_stage_autocast=False,
47
+ input_key: str = "jpg",
48
+ log_keys: Union[List, None] = None,
49
+ no_log_keys: Union[List, None] = None,
50
+ no_cond_log: bool = False,
51
+ compile_model: bool = False,
52
+ en_and_decode_n_samples_a_time: Optional[int] = None,
53
+ only_train_ipadapter: Optional[bool] = False,
54
+ to_unfreeze: Optional[List[str]] = [],
55
+ to_freeze: Optional[List[str]] = [],
56
+ separate_unet_ckpt: Optional[str] = None,
57
+ use_thunder: Optional[bool] = False,
58
+ is_dubbing: Optional[bool] = False,
59
+ bad_model_path: Optional[str] = None,
60
+ bad_model_config: Optional[Dict] = None,
61
+ ):
62
+ super().__init__()
63
+
64
+ # self.automatic_optimization = False
65
+ self.log_keys = log_keys
66
+ self.no_log_keys = no_log_keys
67
+ self.input_key = input_key
68
+ self.is_dubbing = is_dubbing
69
+ self.optimizer_config = default(
70
+ optimizer_config, {"target": "torch.optim.AdamW"}
71
+ )
72
+ self.model = self.initialize_network(
73
+ network_config, network_wrapper, compile_model=compile_model
74
+ )
75
+
76
+ self.denoiser = instantiate_from_config(denoiser_config)
77
+
78
+ self.sampler = (
79
+ instantiate_from_config(sampler_config)
80
+ if sampler_config is not None
81
+ else None
82
+ )
83
+ self.is_guided = True
84
+ if (
85
+ self.sampler
86
+ and "IdentityGuider" in sampler_config["params"]["guider_config"]["target"]
87
+ ):
88
+ self.is_guided = False
89
+ if self.sampler is not None:
90
+ config_guider = sampler_config["params"]["guider_config"]
91
+ sampler_config["params"]["guider_config"] = None
92
+ self.sampler_no_guidance = instantiate_from_config(sampler_config)
93
+ sampler_config["params"]["guider_config"] = config_guider
94
+ self.conditioner = instantiate_from_config(
95
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
96
+ )
97
+ self.scheduler_config = scheduler_config
98
+ self._init_first_stage(first_stage_config)
99
+
100
+ self.loss_fn = (
101
+ instantiate_from_config(loss_fn_config)
102
+ if loss_fn_config is not None
103
+ else None
104
+ )
105
+
106
+ self.use_ema = use_ema
107
+ if self.use_ema:
108
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
109
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
110
+
111
+ self.scale_factor = scale_factor
112
+ self.disable_first_stage_autocast = disable_first_stage_autocast
113
+ self.no_cond_log = no_cond_log
114
+
115
+ if ckpt_path is not None:
116
+ self.init_from_ckpt(
117
+ ckpt_path,
118
+ remove_keys_from_weights=remove_keys_from_weights,
119
+ pattern_to_remove=pattern_to_remove,
120
+ )
121
+ if separate_unet_ckpt is not None:
122
+ sd = torch.load(separate_unet_ckpt)["state_dict"]
123
+ if remove_keys_from_unet_weights is not None:
124
+ for k in list(sd.keys()):
125
+ for remove_key in remove_keys_from_unet_weights:
126
+ if remove_key in k:
127
+ del sd[k]
128
+ self.model.diffusion_model.load_state_dict(sd, strict=False)
129
+
130
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
131
+ print(
132
+ "Using",
133
+ self.en_and_decode_n_samples_a_time,
134
+ "samples at a time for encoding and decoding",
135
+ )
136
+
137
+ if to_freeze:
138
+ for name, p in self.model.diffusion_model.named_parameters():
139
+ for layer in to_freeze:
140
+ if layer[0] == "!":
141
+ if layer[1:] not in name:
142
+ # print("Freezing", name)
143
+ p.requires_grad = False
144
+ else:
145
+ if layer in name:
146
+ # print("Freezing", name)
147
+ p.requires_grad = False
148
+ # if "time_" in name:
149
+ # print("Freezing", name)
150
+ # p.requires_grad = False
151
+
152
+ if only_train_ipadapter:
153
+ # Freeze the model
154
+ for p in self.model.parameters():
155
+ p.requires_grad = False
156
+ # Unfreeze the adapter projection layer
157
+ for p in self.model.diffusion_model.encoder_hid_proj.parameters():
158
+ p.requires_grad = True
159
+ # Unfreeze the cross-attention layer
160
+ for att_layer in self.model.diffusion_model.attn_processors.values():
161
+ if isinstance(att_layer, IPAdapterAttnProcessor2_0):
162
+ for p in att_layer.parameters():
163
+ p.requires_grad = True
164
+
165
+ # for name, p in self.named_parameters():
166
+ # if p.requires_grad:
167
+ # print(name)
168
+
169
+ if to_unfreeze:
170
+ for name in to_unfreeze:
171
+ for p in getattr(self.model.diffusion_model, name).parameters():
172
+ p.requires_grad = True
173
+
174
+ if use_thunder:
175
+ import thunder
176
+
177
+ self.model.diffusion_model = thunder.jit(self.model.diffusion_model)
178
+
179
+ if "Karras" in denoiser_config.target:
180
+ assert bad_model_path is not None, (
181
+ "bad_model_path must be provided for KarrasGuidanceDenoiser"
182
+ )
183
+ karras_config = default(bad_model_config, network_config)
184
+ bad_model = self.initialize_network(
185
+ karras_config, network_wrapper, compile_model=compile_model
186
+ )
187
+ state_dict = self.load_bad_model_weights(bad_model_path)
188
+ bad_model.load_state_dict(state_dict)
189
+ self.denoiser.set_bad_network(bad_model)
190
+
191
+ def load_bad_model_weights(self, path: str) -> None:
192
+ print(f"Restoring bad model from {path}")
193
+ state_dict = torch.load(path, map_location="cpu")
194
+ new_dict = {}
195
+ for k, v in state_dict["module"].items():
196
+ if "learned_mask" in k:
197
+ new_dict[k.replace("_forward_module.", "").replace("model.", "")] = v
198
+ if "diffusion_model" in k:
199
+ new_dict["diffusion_model" + k.split("diffusion_model")[1]] = v
200
+ return new_dict
201
+
202
+ def initialize_network(self, network_config, network_wrapper, compile_model=False):
203
+ model = instantiate_from_config(network_config)
204
+ if isinstance(network_wrapper, str) or network_wrapper is None:
205
+ model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
206
+ model, compile_model=compile_model
207
+ )
208
+ else:
209
+ target = network_wrapper["target"]
210
+ params = network_wrapper.get("params", dict())
211
+ model = get_obj_from_str(target)(
212
+ model, compile_model=compile_model, **params
213
+ )
214
+ return model
215
+
216
+ def init_from_ckpt(
217
+ self,
218
+ path: str,
219
+ remove_keys_from_weights: Optional[Union[List, Tuple]] = None,
220
+ pattern_to_remove: str = None,
221
+ ) -> None:
222
+ print(f"Restoring from {path}")
223
+ if path.endswith("ckpt"):
224
+ sd = torch.load(path, map_location="cpu")["state_dict"]
225
+ elif path.endswith("pt"):
226
+ sd = torch.load(path, map_location="cpu")["module"]
227
+ # Remove leading _forward_module from keys
228
+ sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
229
+ elif path.endswith("bin"):
230
+ sd = torch.load(path, map_location="cpu")
231
+ # Remove leading _forward_module from keys
232
+ sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
233
+ elif path.endswith("safetensors"):
234
+ sd = load_safetensors(path)
235
+ else:
236
+ raise NotImplementedError
237
+
238
+ print(f"Loaded state dict from {path} with {len(sd)} keys")
239
+
240
+ # if remove_keys_from_weights is not None:
241
+ # for k in list(sd.keys()):
242
+ # for remove_key in remove_keys_from_weights:
243
+ # if remove_key in k:
244
+ # del sd[k]
245
+ if pattern_to_remove is not None or remove_keys_from_weights is not None:
246
+ sd = self.remove_mismatched_keys(
247
+ sd, pattern_to_remove, remove_keys_from_weights
248
+ )
249
+
250
+ missing, unexpected = self.load_state_dict(sd, strict=False)
251
+ print(
252
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
253
+ )
254
+ if len(missing) > 0:
255
+ print(f"Missing Keys: {missing}")
256
+ if len(unexpected) > 0:
257
+ print(f"Unexpected Keys: {unexpected}")
258
+
259
+ def remove_mismatched_keys(self, state_dict, pattern=None, additional_keys=None):
260
+ """Remove keys from the state dictionary based on a pattern and a list of additional specific keys."""
261
+ # Find keys that match the pattern
262
+ if pattern is not None:
263
+ mismatched_keys = [key for key in state_dict if re.search(pattern, key)]
264
+ else:
265
+ mismatched_keys = []
266
+
267
+ print(f"Removing {len(mismatched_keys)} keys based on pattern {pattern}")
268
+ print(mismatched_keys)
269
+
270
+ # Add specific keys to be removed
271
+ if additional_keys:
272
+ mismatched_keys.extend(
273
+ [key for key in additional_keys if key in state_dict]
274
+ )
275
+
276
+ # Remove all identified keys
277
+ for key in mismatched_keys:
278
+ if key in state_dict:
279
+ del state_dict[key]
280
+ return state_dict
281
+
282
+ def _init_first_stage(self, config):
283
+ model = instantiate_from_config(config).eval()
284
+ model.train = disabled_train
285
+ for param in model.parameters():
286
+ param.requires_grad = False
287
+ self.first_stage_model = model
288
+ if self.input_key == "latents":
289
+ # Remove encoder to save memory
290
+ self.first_stage_model.encoder = None
291
+ torch.cuda.empty_cache()
292
+
293
+ def get_input(self, batch):
294
+ # assuming unified data format, dataloader returns a dict.
295
+ # image tensors should be scaled to -1 ... 1 and in bchw format
296
+ return batch[self.input_key]
297
+
298
+ @torch.no_grad()
299
+ def decode_first_stage(self, z):
300
+ is_video = False
301
+ if len(z.shape) == 5:
302
+ is_video = True
303
+ T = z.shape[2]
304
+ z = rearrange(z, "b c t h w -> (b t) c h w")
305
+
306
+ z = 1.0 / self.scale_factor * z
307
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
308
+
309
+ n_rounds = math.ceil(z.shape[0] / n_samples)
310
+ all_out = []
311
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
312
+ for n in range(n_rounds):
313
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
314
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
315
+ else:
316
+ kwargs = {}
317
+ out = self.first_stage_model.decode(
318
+ z[n * n_samples : (n + 1) * n_samples], **kwargs
319
+ )
320
+ all_out.append(out)
321
+ out = torch.cat(all_out, dim=0)
322
+ if is_video:
323
+ out = rearrange(out, "(b t) c h w -> b c t h w", t=T)
324
+ torch.cuda.empty_cache()
325
+ return out
326
+
327
+ @torch.no_grad()
328
+ def encode_first_stage(self, x):
329
+ is_video = False
330
+ if len(x.shape) == 5:
331
+ is_video = True
332
+ T = x.shape[2]
333
+ x = rearrange(x, "b c t h w -> (b t) c h w")
334
+ n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
335
+ n_rounds = math.ceil(x.shape[0] / n_samples)
336
+ all_out = []
337
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
338
+ for n in range(n_rounds):
339
+ out = self.first_stage_model.encode(
340
+ x[n * n_samples : (n + 1) * n_samples]
341
+ )
342
+ all_out.append(out)
343
+ z = torch.cat(all_out, dim=0)
344
+ z = self.scale_factor * z
345
+ if is_video:
346
+ z = rearrange(z, "(b t) c h w -> b c t h w", t=T)
347
+ return z
348
+
349
+ def forward(self, x, batch):
350
+ loss_dict = self.loss_fn(
351
+ self.model,
352
+ self.denoiser,
353
+ self.conditioner,
354
+ x,
355
+ batch,
356
+ self.first_stage_model,
357
+ )
358
+ # loss_mean = loss.mean()
359
+ for k in loss_dict:
360
+ loss_dict[k] = loss_dict[k].mean()
361
+ # loss_dict = {"loss": loss_mean}
362
+ return loss_dict["loss"], loss_dict
363
+
364
+ def shared_step(self, batch: Dict) -> Any:
365
+ x = self.get_input(batch)
366
+ if self.input_key != "latents":
367
+ x = self.encode_first_stage(x)
368
+ batch["global_step"] = self.global_step
369
+ loss, loss_dict = self(x, batch)
370
+ return loss, loss_dict
371
+
372
+ def training_step(self, batch, batch_idx):
373
+ loss, loss_dict = self.shared_step(batch)
374
+ # debugging_message = "Training step"
375
+ # print(f"RANK - {self.trainer.global_rank}: {debugging_message}")
376
+
377
+ self.log_dict(
378
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
379
+ )
380
+
381
+ self.log(
382
+ "global_step",
383
+ self.global_step,
384
+ prog_bar=True,
385
+ logger=True,
386
+ on_step=True,
387
+ on_epoch=False,
388
+ )
389
+
390
+ # debugging_message = "Training step - log"
391
+ # print(f"RANK - {self.trainer.global_rank}: {debugging_message}")
392
+
393
+ if self.scheduler_config is not None:
394
+ lr = self.optimizers().param_groups[0]["lr"]
395
+ self.log(
396
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
397
+ )
398
+
399
+ # # to prevent other processes from moving forward until all processes are in sync
400
+ # self.trainer.strategy.barrier()
401
+
402
+ return loss
403
+
404
+ # def validation_step(self, batch, batch_idx):
405
+ # # loss, loss_dict = self.shared_step(batch)
406
+ # # self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False)
407
+ # self.log(
408
+ # "global_step",
409
+ # self.global_step,
410
+ # prog_bar=True,
411
+ # logger=True,
412
+ # on_step=True,
413
+ # on_epoch=False,
414
+ # )
415
+ # return 0
416
+
417
+ # def on_train_epoch_start(self, *args, **kwargs):
418
+ # print(f"RANK - {self.trainer.global_rank}: on_train_epoch_start")
419
+
420
+ def on_train_start(self, *args, **kwargs):
421
+ # os.environ["CUDA_VISIBLE_DEVICES"] = str(self.trainer.global_rank)
422
+ # torch.cuda.set_device(self.trainer.global_rank)
423
+ # torch.cuda.empty_cache()
424
+ if self.sampler is None or self.loss_fn is None:
425
+ raise ValueError("Sampler and loss function need to be set for training.")
426
+
427
+ # def on_before_batch_transfer(self, batch, dataloader_idx):
428
+ # print(f"RANK - {self.trainer.global_rank}: on_before_batch_transfer - {dataloader_idx}")
429
+ # return batch
430
+
431
+ # def on_after_batch_transfer(self, batch, dataloader_idx):
432
+ # print(f"RANK - {self.trainer.global_rank}: on_after_batch_transfer - {dataloader_idx}")
433
+ # return batch
434
+
435
+ def on_train_batch_end(self, *args, **kwargs):
436
+ # print(f"RANK - {self.trainer.global_rank}: on_train_batch_end")
437
+ if self.use_ema:
438
+ self.model_ema(self.model)
439
+
440
+ @contextmanager
441
+ def ema_scope(self, context=None):
442
+ if self.use_ema:
443
+ self.model_ema.store(self.model.parameters())
444
+ self.model_ema.copy_to(self.model)
445
+ if context is not None:
446
+ print(f"{context}: Switched to EMA weights")
447
+ try:
448
+ yield None
449
+ finally:
450
+ if self.use_ema:
451
+ self.model_ema.restore(self.model.parameters())
452
+ if context is not None:
453
+ print(f"{context}: Restored training weights")
454
+
455
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
456
+ return get_obj_from_str(cfg["target"])(
457
+ params, lr=lr, **cfg.get("params", dict())
458
+ )
459
+
460
+ def configure_optimizers(self):
461
+ lr = self.learning_rate
462
+ params = list(self.model.parameters())
463
+ for embedder in self.conditioner.embedders:
464
+ if embedder.is_trainable:
465
+ params = params + list(embedder.parameters())
466
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
467
+ if self.scheduler_config is not None:
468
+ scheduler = instantiate_from_config(self.scheduler_config)
469
+ print("Setting up LambdaLR scheduler...")
470
+ scheduler = [
471
+ {
472
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
473
+ "interval": "step",
474
+ "frequency": 1,
475
+ }
476
+ ]
477
+ return [opt], scheduler
478
+ return opt
479
+
480
+ @torch.no_grad()
481
+ def sample(
482
+ self,
483
+ cond: Dict,
484
+ uc: Union[Dict, None] = None,
485
+ batch_size: int = 16,
486
+ shape: Union[None, Tuple, List] = None,
487
+ **kwargs,
488
+ ):
489
+ randn = torch.randn(batch_size, *shape).to(self.device)
490
+
491
+ denoiser = lambda input, sigma, c: self.denoiser(
492
+ self.model, input, sigma, c, **kwargs
493
+ )
494
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
495
+
496
+ return samples
497
+
498
+ @torch.no_grad()
499
+ def sample_no_guider(
500
+ self,
501
+ cond: Dict,
502
+ uc: Union[Dict, None] = None,
503
+ batch_size: int = 16,
504
+ shape: Union[None, Tuple, List] = None,
505
+ **kwargs,
506
+ ):
507
+ randn = torch.randn(batch_size, *shape).to(self.device)
508
+
509
+ denoiser = lambda input, sigma, c: self.denoiser(
510
+ self.model, input, sigma, c, **kwargs
511
+ )
512
+ samples = self.sampler_no_guidance(denoiser, randn, cond, uc=uc)
513
+
514
+ return samples
515
+
516
+ @torch.no_grad()
517
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
518
+ """
519
+ Defines heuristics to log different conditionings.
520
+ These can be lists of strings (text-to-image), tensors, ints, ...
521
+ """
522
+ image_h, image_w = batch[self.input_key].shape[-2:]
523
+ log = dict()
524
+
525
+ for embedder in self.conditioner.embedders:
526
+ if (
527
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
528
+ ) and not self.no_cond_log:
529
+ if embedder.input_key in self.no_log_keys:
530
+ continue
531
+ x = batch[embedder.input_key][:n]
532
+ if isinstance(x, torch.Tensor):
533
+ if x.dim() == 1:
534
+ # class-conditional, convert integer to string
535
+ x = [str(x[i].item()) for i in range(x.shape[0])]
536
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
537
+ elif x.dim() == 2:
538
+ # size and crop cond and the like
539
+ x = [
540
+ "x".join([str(xx) for xx in x[i].tolist()])
541
+ for i in range(x.shape[0])
542
+ ]
543
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
544
+ elif x.dim() == 4: # already an image
545
+ xc = x
546
+ elif x.dim() == 5:
547
+ xc = torch.cat([x[:, :, i] for i in range(x.shape[2])], dim=-1)
548
+ else:
549
+ print(x.shape, embedder.input_key)
550
+ raise NotImplementedError()
551
+ elif isinstance(x, (List, ListConfig)):
552
+ if isinstance(x[0], str):
553
+ # strings
554
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
555
+ else:
556
+ raise NotImplementedError()
557
+ else:
558
+ raise NotImplementedError()
559
+ log[embedder.input_key] = xc
560
+ return log
561
+
562
+ @torch.no_grad()
563
+ def log_images(
564
+ self,
565
+ batch: Dict,
566
+ N: int = 8,
567
+ sample: bool = True,
568
+ ucg_keys: List[str] = None,
569
+ **kwargs,
570
+ ) -> Dict:
571
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
572
+ if ucg_keys:
573
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
574
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
575
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
576
+ )
577
+ else:
578
+ ucg_keys = conditioner_input_keys
579
+ log = dict()
580
+
581
+ x = self.get_input(batch)
582
+
583
+ c, uc = self.conditioner.get_unconditional_conditioning(
584
+ batch,
585
+ force_uc_zero_embeddings=ucg_keys
586
+ if len(self.conditioner.embedders) > 0
587
+ else [],
588
+ )
589
+
590
+ sampling_kwargs = {}
591
+
592
+ N = min(x.shape[0], N)
593
+ x = x.to(self.device)[:N]
594
+ if self.input_key != "latents":
595
+ log["inputs"] = x
596
+ z = self.encode_first_stage(x)
597
+ else:
598
+ z = x
599
+ log["reconstructions"] = self.decode_first_stage(z)
600
+ log.update(self.log_conditionings(batch, N))
601
+
602
+ for k in c:
603
+ if isinstance(c[k], torch.Tensor):
604
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
605
+
606
+ if sample:
607
+ with self.ema_scope("Plotting"):
608
+ samples = self.sample(
609
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
610
+ )
611
+ samples = self.decode_first_stage(samples)
612
+
613
+ log["samples"] = samples
614
+
615
+ with self.ema_scope("Plotting"):
616
+ samples = self.sample_no_guider(
617
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
618
+ )
619
+ samples = self.decode_first_stage(samples)
620
+
621
+ log["samples_no_guidance"] = samples
622
+ return log
623
+
624
+ @torch.no_grad()
625
+ def log_videos(
626
+ self,
627
+ batch: Dict,
628
+ N: int = 8,
629
+ sample: bool = True,
630
+ ucg_keys: List[str] = None,
631
+ **kwargs,
632
+ ) -> Dict:
633
+ # conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
634
+ # if ucg_keys:
635
+ # assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
636
+ # "Each defined ucg key for sampling must be in the provided conditioner input keys,"
637
+ # f"but we have {ucg_keys} vs. {conditioner_input_keys}"
638
+ # )
639
+ # else:
640
+ # ucg_keys = conditioner_input_keys
641
+ log = dict()
642
+ batch_uc = {}
643
+
644
+ x = self.get_input(batch)
645
+ num_frames = x.shape[2] # assuming bcthw format
646
+
647
+ for key in batch.keys():
648
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
649
+ batch_uc[key] = torch.clone(batch[key])
650
+
651
+ c, uc = self.conditioner.get_unconditional_conditioning(
652
+ batch,
653
+ batch_uc=batch_uc,
654
+ force_uc_zero_embeddings=ucg_keys
655
+ if ucg_keys is not None
656
+ else [
657
+ "cond_frames",
658
+ "cond_frames_without_noise",
659
+ ],
660
+ )
661
+
662
+ # for k in ["crossattn", "concat"]:
663
+ # uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
664
+ # uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
665
+ # c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
666
+ # c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
667
+
668
+ sampling_kwargs = {}
669
+
670
+ N = min(x.shape[0], N)
671
+ x = x.to(self.device)[:N]
672
+
673
+ if self.input_key != "latents":
674
+ log["inputs"] = x
675
+ z = self.encode_first_stage(x)
676
+ else:
677
+ z = x
678
+ log["reconstructions"] = self.decode_first_stage(z)
679
+ log.update(self.log_conditionings(batch, N))
680
+
681
+ if c.get("masks", None) is not None:
682
+ # Create a mask reconstruction
683
+ masks = 1 - c["masks"]
684
+ t = masks.shape[2]
685
+ masks = rearrange(masks, "b c t h w -> (b t) c h w")
686
+ target_size = (
687
+ log["reconstructions"].shape[-2],
688
+ log["reconstructions"].shape[-1],
689
+ )
690
+ masks = torch.nn.functional.interpolate(
691
+ masks, size=target_size, mode="nearest"
692
+ )
693
+ masks = rearrange(masks, "(b t) c h w -> b c t h w", t=t)
694
+ log["mask_reconstructions"] = log["reconstructions"] * masks
695
+
696
+ for k in c:
697
+ if isinstance(c[k], torch.Tensor):
698
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
699
+ elif isinstance(c[k], list):
700
+ for i in range(len(c[k])):
701
+ c[k][i], uc[k][i] = map(
702
+ lambda y: y[k][i][:N].to(self.device), (c, uc)
703
+ )
704
+
705
+ if sample:
706
+ n = 2 if self.is_guided else 1
707
+ # if num_frames == 1:
708
+ # sampling_kwargs["image_only_indicator"] = torch.ones(n, num_frames).to(self.device)
709
+ # else:
710
+ sampling_kwargs["image_only_indicator"] = torch.zeros(n, num_frames).to(
711
+ self.device
712
+ )
713
+ sampling_kwargs["num_video_frames"] = batch["num_video_frames"]
714
+
715
+ with self.ema_scope("Plotting"):
716
+ samples = self.sample(
717
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
718
+ )
719
+ samples = self.decode_first_stage(samples)
720
+ if self.is_dubbing:
721
+ samples[:, :, :, : samples.shape[-2] // 2] = log["reconstructions"][
722
+ :, :, :, : samples.shape[-2] // 2
723
+ ]
724
+ log["samples"] = samples
725
+
726
+ # Without guidance
727
+ # if num_frames == 1:
728
+ # sampling_kwargs["image_only_indicator"] = torch.ones(1, num_frames).to(self.device)
729
+ # else:
730
+ sampling_kwargs["image_only_indicator"] = torch.zeros(1, num_frames).to(
731
+ self.device
732
+ )
733
+ sampling_kwargs["num_video_frames"] = batch["num_video_frames"]
734
+
735
+ with self.ema_scope("Plotting"):
736
+ samples = self.sample_no_guider(
737
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
738
+ )
739
+ samples = self.decode_first_stage(samples)
740
+ if self.is_dubbing:
741
+ samples[:, :, :, : samples.shape[-2] // 2] = log["reconstructions"][
742
+ :, :, :, : samples.shape[-2] // 2
743
+ ]
744
+ log["samples_no_guidance"] = samples
745
+
746
+ torch.cuda.empty_cache()
747
+ return log
sgm/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .encoders.modules import GeneralConditioner
2
+
3
+ UNCONDITIONAL_CONFIG = {
4
+ "target": "sgm.modules.GeneralConditioner",
5
+ "params": {"emb_models": []},
6
+ }
sgm/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (400 Bytes). View file
 
sgm/modules/__pycache__/attention.cpython-311.pyc ADDED
Binary file (39.1 kB). View file
 
sgm/modules/__pycache__/ema.cpython-311.pyc ADDED
Binary file (5.87 kB). View file
 
sgm/modules/__pycache__/video_attention.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
sgm/modules/attention.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from inspect import isfunction
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from packaging import version
10
+ from torch import nn
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+ logpy = logging.getLogger(__name__)
14
+
15
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
16
+ SDP_IS_AVAILABLE = True
17
+ from torch.backends.cuda import SDPBackend, sdp_kernel
18
+
19
+ BACKEND_MAP = {
20
+ SDPBackend.MATH: {
21
+ "enable_math": True,
22
+ "enable_flash": False,
23
+ "enable_mem_efficient": False,
24
+ },
25
+ SDPBackend.FLASH_ATTENTION: {
26
+ "enable_math": False,
27
+ "enable_flash": True,
28
+ "enable_mem_efficient": False,
29
+ },
30
+ SDPBackend.EFFICIENT_ATTENTION: {
31
+ "enable_math": False,
32
+ "enable_flash": False,
33
+ "enable_mem_efficient": True,
34
+ },
35
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
36
+ }
37
+ else:
38
+ from contextlib import nullcontext
39
+
40
+ SDP_IS_AVAILABLE = False
41
+ sdp_kernel = nullcontext
42
+ BACKEND_MAP = {}
43
+ logpy.warn(
44
+ f"No SDP backend available, likely because you are running in pytorch "
45
+ f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
46
+ f"You might want to consider upgrading."
47
+ )
48
+
49
+ try:
50
+ import xformers
51
+ import xformers.ops
52
+
53
+ XFORMERS_IS_AVAILABLE = True
54
+ except:
55
+ XFORMERS_IS_AVAILABLE = False
56
+ logpy.warn("no module 'xformers'. Processing without...")
57
+
58
+ # from .diffusionmodules.util import mixed_checkpoint as checkpoint
59
+
60
+
61
+ def exists(val):
62
+ return val is not None
63
+
64
+
65
+ def uniq(arr):
66
+ return {el: True for el in arr}.keys()
67
+
68
+
69
+ def default(val, d):
70
+ if exists(val):
71
+ return val
72
+ return d() if isfunction(d) else d
73
+
74
+
75
+ def max_neg_value(t):
76
+ return -torch.finfo(t.dtype).max
77
+
78
+
79
+ def init_(tensor):
80
+ dim = tensor.shape[-1]
81
+ std = 1 / math.sqrt(dim)
82
+ tensor.uniform_(-std, std)
83
+ return tensor
84
+
85
+
86
+ # feedforward
87
+ class GEGLU(nn.Module):
88
+ def __init__(self, dim_in, dim_out):
89
+ super().__init__()
90
+ self.proj = nn.Linear(dim_in, dim_out * 2)
91
+
92
+ def forward(self, x):
93
+ x, gate = self.proj(x).chunk(2, dim=-1)
94
+ return x * F.gelu(gate)
95
+
96
+
97
+ class FeedForward(nn.Module):
98
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
99
+ super().__init__()
100
+ inner_dim = int(dim * mult)
101
+ dim_out = default(dim_out, dim)
102
+ project_in = (
103
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
104
+ if not glu
105
+ else GEGLU(dim, inner_dim)
106
+ )
107
+
108
+ self.net = nn.Sequential(
109
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.net(x)
114
+
115
+
116
+ def zero_module(module):
117
+ """
118
+ Zero out the parameters of a module and return it.
119
+ """
120
+ for p in module.parameters():
121
+ p.detach().zero_()
122
+ return module
123
+
124
+
125
+ def Normalize(in_channels):
126
+ return torch.nn.GroupNorm(
127
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
128
+ )
129
+
130
+
131
+ class LinearAttention(nn.Module):
132
+ def __init__(self, dim, heads=4, dim_head=32):
133
+ super().__init__()
134
+ self.heads = heads
135
+ hidden_dim = dim_head * heads
136
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
137
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
138
+
139
+ def forward(self, x):
140
+ b, c, h, w = x.shape
141
+ qkv = self.to_qkv(x)
142
+ q, k, v = rearrange(
143
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
144
+ )
145
+ k = k.softmax(dim=-1)
146
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
147
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
148
+ out = rearrange(
149
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
150
+ )
151
+ return self.to_out(out)
152
+
153
+
154
+ class SelfAttention(nn.Module):
155
+ ATTENTION_MODES = ("xformers", "torch", "math")
156
+
157
+ def __init__(
158
+ self,
159
+ dim: int,
160
+ num_heads: int = 8,
161
+ qkv_bias: bool = False,
162
+ qk_scale: Optional[float] = None,
163
+ attn_drop: float = 0.0,
164
+ proj_drop: float = 0.0,
165
+ attn_mode: str = "xformers",
166
+ ):
167
+ super().__init__()
168
+ self.num_heads = num_heads
169
+ head_dim = dim // num_heads
170
+ self.scale = qk_scale or head_dim**-0.5
171
+
172
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
173
+ self.attn_drop = nn.Dropout(attn_drop)
174
+ self.proj = nn.Linear(dim, dim)
175
+ self.proj_drop = nn.Dropout(proj_drop)
176
+ assert attn_mode in self.ATTENTION_MODES
177
+ self.attn_mode = attn_mode
178
+
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ B, L, C = x.shape
181
+
182
+ qkv = self.qkv(x)
183
+ if self.attn_mode == "torch":
184
+ qkv = rearrange(
185
+ qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
186
+ ).float()
187
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
188
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
189
+ x = rearrange(x, "B H L D -> B L (H D)")
190
+ elif self.attn_mode == "xformers":
191
+ qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
192
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
193
+ x = xformers.ops.memory_efficient_attention(q, k, v)
194
+ x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
195
+ elif self.attn_mode == "math":
196
+ qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
197
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
198
+ attn = (q @ k.transpose(-2, -1)) * self.scale
199
+ attn = attn.softmax(dim=-1)
200
+ attn = self.attn_drop(attn)
201
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
202
+ else:
203
+ raise NotImplemented
204
+
205
+ x = self.proj(x)
206
+ x = self.proj_drop(x)
207
+ return x
208
+
209
+
210
+ class SpatialSelfAttention(nn.Module):
211
+ def __init__(self, in_channels):
212
+ super().__init__()
213
+ self.in_channels = in_channels
214
+
215
+ self.norm = Normalize(in_channels)
216
+ self.q = torch.nn.Conv2d(
217
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
218
+ )
219
+ self.k = torch.nn.Conv2d(
220
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
221
+ )
222
+ self.v = torch.nn.Conv2d(
223
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
224
+ )
225
+ self.proj_out = torch.nn.Conv2d(
226
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
227
+ )
228
+
229
+ def forward(self, x):
230
+ h_ = x
231
+ h_ = self.norm(h_)
232
+ q = self.q(h_)
233
+ k = self.k(h_)
234
+ v = self.v(h_)
235
+
236
+ # compute attention
237
+ b, c, h, w = q.shape
238
+ q = rearrange(q, "b c h w -> b (h w) c")
239
+ k = rearrange(k, "b c h w -> b c (h w)")
240
+ w_ = torch.einsum("bij,bjk->bik", q, k)
241
+
242
+ w_ = w_ * (int(c) ** (-0.5))
243
+ w_ = torch.nn.functional.softmax(w_, dim=2)
244
+
245
+ # attend to values
246
+ v = rearrange(v, "b c h w -> b c (h w)")
247
+ w_ = rearrange(w_, "b i j -> b j i")
248
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
249
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
250
+ h_ = self.proj_out(h_)
251
+
252
+ return x + h_
253
+
254
+
255
+ class CrossAttention(nn.Module):
256
+ def __init__(
257
+ self,
258
+ query_dim,
259
+ context_dim=None,
260
+ heads=8,
261
+ dim_head=64,
262
+ dropout=0.0,
263
+ backend=None,
264
+ **kwargs,
265
+ ):
266
+ super().__init__()
267
+ inner_dim = dim_head * heads
268
+ context_dim = default(context_dim, query_dim)
269
+
270
+ self.scale = dim_head**-0.5
271
+ self.heads = heads
272
+
273
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
274
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
275
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
276
+
277
+ self.to_out = nn.Sequential(
278
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
279
+ )
280
+ self.backend = backend
281
+
282
+ def forward(
283
+ self,
284
+ x,
285
+ context=None,
286
+ mask=None,
287
+ additional_tokens=None,
288
+ n_times_crossframe_attn_in_self=0,
289
+ skip_attention=None,
290
+ **kwargs,
291
+ ):
292
+ h = self.heads
293
+
294
+ if additional_tokens is not None:
295
+ # get the number of masked tokens at the beginning of the output sequence
296
+ n_tokens_to_mask = additional_tokens.shape[1]
297
+ # add additional token
298
+ x = torch.cat([additional_tokens, x], dim=1)
299
+
300
+ # Ensure skip_attention is a BΓ—1 boolean tensor
301
+ if skip_attention is None:
302
+ skip_attention = torch.zeros_like(x[:, :1], dtype=torch.bool)
303
+
304
+ assert isinstance(skip_attention, torch.Tensor)
305
+ assert skip_attention.shape[1] == 1 and skip_attention.dtype == torch.bool
306
+
307
+ # Split the batch into skip and non-skip parts
308
+ skip_indices = skip_attention.squeeze(1)
309
+ non_skip_indices = ~skip_indices
310
+
311
+ # Process skip attention samples
312
+ if skip_indices.any():
313
+ x_skip = x[skip_indices]
314
+ out_skip = self.to_v(x_skip)
315
+ out_skip = rearrange(out_skip, "b n (h d) -> b n (h d)", h=h)
316
+
317
+ # If all samples are skipped, return early
318
+ if not non_skip_indices.any():
319
+ if additional_tokens is not None:
320
+ out_skip = out_skip[:, n_tokens_to_mask:]
321
+ return self.to_out(out_skip)
322
+
323
+ # Process non-skip samples with attention
324
+ x_non_skip = x[non_skip_indices]
325
+ q = self.to_q(x_non_skip)
326
+ context = default(context, x_non_skip)
327
+ k = self.to_k(context)
328
+ v = self.to_v(context)
329
+
330
+ if n_times_crossframe_attn_in_self:
331
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
332
+ assert x_non_skip.shape[0] % n_times_crossframe_attn_in_self == 0
333
+ k = repeat(
334
+ k[::n_times_crossframe_attn_in_self],
335
+ "b ... -> (b n) ...",
336
+ n=n_times_crossframe_attn_in_self,
337
+ )
338
+ v = repeat(
339
+ v[::n_times_crossframe_attn_in_self],
340
+ "b ... -> (b n) ...",
341
+ n=n_times_crossframe_attn_in_self,
342
+ )
343
+
344
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
345
+
346
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
347
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
348
+
349
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
350
+
351
+ # Combine skip and non-skip results
352
+ combined_out = torch.zeros(
353
+ (x.shape[0], out.shape[1], out.shape[2]), dtype=out.dtype, device=out.device
354
+ )
355
+ combined_out[non_skip_indices] = out
356
+ if skip_indices.any():
357
+ combined_out[skip_indices] = out_skip
358
+
359
+ if additional_tokens is not None:
360
+ combined_out = combined_out[:, n_tokens_to_mask:]
361
+ return self.to_out(combined_out)
362
+
363
+
364
+ class MemoryEfficientCrossAttention(nn.Module):
365
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
366
+ def __init__(
367
+ self,
368
+ query_dim,
369
+ context_dim=None,
370
+ heads=8,
371
+ dim_head=64,
372
+ dropout=0.0,
373
+ use_reference=False,
374
+ extra_linear=False,
375
+ **kwargs,
376
+ ):
377
+ super().__init__()
378
+ logpy.debug(
379
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
380
+ f"context_dim is {context_dim} and using {heads} heads with a "
381
+ f"dimension of {dim_head}."
382
+ )
383
+ inner_dim = dim_head * heads
384
+ self.is_context = context_dim is not None
385
+ context_dim = default(context_dim, query_dim)
386
+
387
+ self.heads = heads
388
+ self.dim_head = dim_head
389
+ self.use_reference = use_reference and self.is_context
390
+
391
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
392
+ if not self.use_reference:
393
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
394
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
395
+ else:
396
+ if extra_linear:
397
+ self.to_k = nn.Linear(inner_dim, inner_dim, bias=False)
398
+ self.to_v = nn.Linear(inner_dim, inner_dim, bias=False)
399
+ self.extra_linear = extra_linear
400
+
401
+ self.to_out = nn.Sequential(
402
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
403
+ )
404
+ self.attention_op: Optional[Any] = None
405
+
406
+ def forward(
407
+ self,
408
+ x,
409
+ context=None,
410
+ mask=None,
411
+ additional_tokens=None,
412
+ n_times_crossframe_attn_in_self=0,
413
+ skip_attention=None,
414
+ ):
415
+ if additional_tokens is not None:
416
+ # get the number of masked tokens at the beginning of the output sequence
417
+ n_tokens_to_mask = additional_tokens.shape[1]
418
+ # add additional token
419
+ x = torch.cat([additional_tokens, x], dim=1)
420
+
421
+ # Ensure skip_attention is a BΓ—1 boolean tensor
422
+ if skip_attention is None:
423
+ skip_attention = torch.zeros(x.shape[0], 1, dtype=torch.bool)
424
+ # print(x.shape)
425
+ # print(skip_attention)
426
+ # print(skip_attention.shape)
427
+ # print(any(skip_attention))
428
+ assert isinstance(skip_attention, torch.Tensor)
429
+ assert skip_attention.shape[1] == 1 and skip_attention.dtype == torch.bool
430
+
431
+ # Split the batch into skip and non-skip parts
432
+ skip_indices = skip_attention.squeeze(1)
433
+ non_skip_indices = ~skip_indices
434
+
435
+ # Process skip attention samples
436
+ if skip_indices.any():
437
+ x_skip = x[skip_indices]
438
+ out_skip = self.to_v(x_skip)
439
+ out_skip = (
440
+ out_skip.unsqueeze(0)
441
+ .reshape(-1, self.heads, out_skip.shape[1], self.dim_head)
442
+ .permute(0, 2, 1, 3)
443
+ .reshape(-1, out_skip.shape[1], self.heads * self.dim_head)
444
+ )
445
+ # If all samples are skipped, return early
446
+ if not non_skip_indices.any():
447
+ if additional_tokens is not None:
448
+ out_skip = out_skip[:, n_tokens_to_mask:]
449
+ return self.to_out(out_skip)
450
+
451
+ x_non_skip = x[non_skip_indices]
452
+ q = self.to_q(x_non_skip)
453
+ if not self.use_reference:
454
+ context = default(context, x_non_skip)
455
+ k = self.to_k(context)
456
+ v = self.to_v(context)
457
+ else:
458
+ # Reference has already correct shape
459
+ assert context is not None
460
+ if self.extra_linear:
461
+ k = self.to_k(context)
462
+ v = self.to_v(context)
463
+ k, v = context, context
464
+
465
+ if n_times_crossframe_attn_in_self:
466
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
467
+ assert x_non_skip.shape[0] % n_times_crossframe_attn_in_self == 0
468
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
469
+ k = repeat(
470
+ k[::n_times_crossframe_attn_in_self],
471
+ "b ... -> (b n) ...",
472
+ n=n_times_crossframe_attn_in_self,
473
+ )
474
+ v = repeat(
475
+ v[::n_times_crossframe_attn_in_self],
476
+ "b ... -> (b n) ...",
477
+ n=n_times_crossframe_attn_in_self,
478
+ )
479
+
480
+ b, _, _ = q.shape
481
+ q, k, v = map(
482
+ lambda t: t.unsqueeze(3)
483
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
484
+ .permute(0, 2, 1, 3)
485
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
486
+ .contiguous(),
487
+ (q, k, v),
488
+ )
489
+ if q.dtype != k.dtype:
490
+ k = k.to(q.dtype)
491
+ v = v.to(q.dtype)
492
+
493
+ # actually compute the attention, what we cannot get enough of
494
+ if version.parse(xformers.__version__) >= version.parse("0.0.21"):
495
+ # NOTE: workaround for
496
+ # https://github.com/facebookresearch/xformers/issues/845
497
+ max_bs = 32768
498
+ N = q.shape[0]
499
+ n_batches = math.ceil(N / max_bs)
500
+ out = list()
501
+ for i_batch in range(n_batches):
502
+ batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
503
+ out.append(
504
+ xformers.ops.memory_efficient_attention(
505
+ q[batch],
506
+ k[batch],
507
+ v[batch],
508
+ attn_bias=None,
509
+ op=self.attention_op,
510
+ )
511
+ )
512
+ out = torch.cat(out, 0)
513
+ else:
514
+ out = xformers.ops.memory_efficient_attention(
515
+ q, k, v, attn_bias=None, op=self.attention_op
516
+ )
517
+
518
+ # TODO: Use this directly in the attention operation, as a bias
519
+ if exists(mask):
520
+ raise NotImplementedError
521
+ out = (
522
+ out.unsqueeze(0)
523
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
524
+ .permute(0, 2, 1, 3)
525
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
526
+ )
527
+ # Combine skip and non-skip results
528
+ combined_out = torch.zeros(
529
+ (x.shape[0], out.shape[1], out.shape[2]), dtype=out.dtype, device=out.device
530
+ )
531
+ combined_out[non_skip_indices] = out
532
+ if skip_indices.any():
533
+ combined_out[skip_indices] = out_skip
534
+ else:
535
+ combined_out = out
536
+
537
+ if additional_tokens is not None:
538
+ # remove additional token
539
+ combined_out = combined_out[:, n_tokens_to_mask:]
540
+ return self.to_out(combined_out)
541
+
542
+
543
+ class BasicTransformerBlock(nn.Module):
544
+ ATTENTION_MODES = {
545
+ "softmax": CrossAttention, # vanilla attention
546
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
547
+ }
548
+
549
+ def __init__(
550
+ self,
551
+ dim,
552
+ n_heads,
553
+ d_head,
554
+ dropout=0.0,
555
+ context_dim=None,
556
+ gated_ff=True,
557
+ checkpoint=True,
558
+ disable_self_attn=False,
559
+ attn_mode="softmax",
560
+ sdp_backend=None,
561
+ reference_to=None,
562
+ ):
563
+ super().__init__()
564
+ assert attn_mode in self.ATTENTION_MODES
565
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
566
+ logpy.warn(
567
+ f"Attention mode '{attn_mode}' is not available. Falling "
568
+ f"back to native attention. This is not a problem in "
569
+ f"Pytorch >= 2.0. FYI, you are running with PyTorch "
570
+ f"version {torch.__version__}."
571
+ )
572
+ attn_mode = "softmax"
573
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
574
+ logpy.warn(
575
+ "We do not support vanilla attention anymore, as it is too "
576
+ "expensive. Sorry."
577
+ )
578
+ if not XFORMERS_IS_AVAILABLE:
579
+ assert False, (
580
+ "Please install xformers via e.g. 'pip install xformers==0.0.16'"
581
+ )
582
+ else:
583
+ logpy.info("Falling back to xformers efficient attention.")
584
+ attn_mode = "softmax-xformers"
585
+ attn_cls = self.ATTENTION_MODES[attn_mode]
586
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
587
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
588
+ else:
589
+ assert sdp_backend is None
590
+ self.disable_self_attn = disable_self_attn
591
+ extra_linear = (reference_to is not None) and ("extra" in reference_to)
592
+ if extra_linear:
593
+ reference_to = reference_to.replace("_extra", "")
594
+ assert reference_to in [None, "self", "cross"]
595
+ self.reference_to = reference_to
596
+ self.attn1 = attn_cls(
597
+ query_dim=dim,
598
+ heads=n_heads,
599
+ dim_head=d_head,
600
+ dropout=dropout,
601
+ context_dim=context_dim
602
+ if (self.disable_self_attn or reference_to == "self")
603
+ else None,
604
+ backend=sdp_backend,
605
+ use_reference=reference_to == "self",
606
+ extra_linear=extra_linear,
607
+ ) # is a self-attention if not self.disable_self_attn
608
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
609
+ self.attn2 = attn_cls(
610
+ query_dim=dim,
611
+ context_dim=context_dim,
612
+ heads=n_heads,
613
+ dim_head=d_head,
614
+ dropout=dropout,
615
+ backend=sdp_backend,
616
+ use_reference=reference_to == "cross",
617
+ extra_linear=extra_linear,
618
+ ) # is self-attn if context is none
619
+ self.norm1 = nn.LayerNorm(dim)
620
+ self.norm2 = nn.LayerNorm(dim)
621
+ self.norm3 = nn.LayerNorm(dim)
622
+ self.checkpoint = checkpoint
623
+ if self.checkpoint:
624
+ logpy.debug(f"{self.__class__.__name__} is using checkpointing")
625
+
626
+ def forward(
627
+ self,
628
+ x,
629
+ context=None,
630
+ reference_context=None,
631
+ additional_tokens=None,
632
+ n_times_crossframe_attn_in_self=0,
633
+ skip_attention=None,
634
+ ):
635
+ kwargs = {"x": x}
636
+
637
+ if context is not None:
638
+ kwargs.update({"context": context})
639
+
640
+ if reference_context is not None:
641
+ kwargs.update({"reference_context": reference_context})
642
+
643
+ if additional_tokens is not None:
644
+ kwargs.update({"additional_tokens": additional_tokens})
645
+
646
+ if n_times_crossframe_attn_in_self:
647
+ kwargs.update(
648
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
649
+ )
650
+
651
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
652
+ if self.checkpoint:
653
+ # inputs = {"x": x, "context": context}
654
+ return checkpoint(
655
+ self._forward, x, context, reference_context, None, 0, skip_attention
656
+ )
657
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
658
+ else:
659
+ return self._forward(**kwargs)
660
+
661
+ def _forward(
662
+ self,
663
+ x,
664
+ context=None,
665
+ reference_context=None,
666
+ additional_tokens=None,
667
+ n_times_crossframe_attn_in_self=0,
668
+ skip_attention=None,
669
+ ):
670
+ self_context = reference_context if self.reference_to == "self" else context
671
+ # print(self.reference_to)
672
+ # print("context: ", context.shape if context is not None else None)
673
+ # print("reference_context: ", reference_context.shape if reference_context is not None else None)
674
+ # print("x: ", x.shape)
675
+
676
+ x = (
677
+ self.attn1(
678
+ self.norm1(x),
679
+ context=self_context
680
+ if (self.disable_self_attn or self.reference_to == "self")
681
+ else None,
682
+ additional_tokens=additional_tokens,
683
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
684
+ if not self.disable_self_attn
685
+ else 0,
686
+ skip_attention=skip_attention,
687
+ )
688
+ + x
689
+ )
690
+ cross_context = reference_context if self.reference_to == "cross" else context
691
+ x = (
692
+ self.attn2(
693
+ self.norm2(x),
694
+ context=cross_context,
695
+ additional_tokens=additional_tokens,
696
+ )
697
+ + x
698
+ )
699
+ x = self.ff(self.norm3(x)) + x
700
+ return x
701
+
702
+
703
+ class BasicTransformerSingleLayerBlock(nn.Module):
704
+ ATTENTION_MODES = {
705
+ "softmax": CrossAttention, # vanilla attention
706
+ "softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
707
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
708
+ }
709
+
710
+ def __init__(
711
+ self,
712
+ dim,
713
+ n_heads,
714
+ d_head,
715
+ dropout=0.0,
716
+ context_dim=None,
717
+ gated_ff=True,
718
+ checkpoint=True,
719
+ attn_mode="softmax",
720
+ ):
721
+ super().__init__()
722
+ assert attn_mode in self.ATTENTION_MODES
723
+ attn_cls = self.ATTENTION_MODES[attn_mode]
724
+ self.attn1 = attn_cls(
725
+ query_dim=dim,
726
+ heads=n_heads,
727
+ dim_head=d_head,
728
+ dropout=dropout,
729
+ context_dim=context_dim,
730
+ )
731
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
732
+ self.norm1 = nn.LayerNorm(dim)
733
+ self.norm2 = nn.LayerNorm(dim)
734
+ self.checkpoint = checkpoint
735
+
736
+ def forward(self, x, context=None):
737
+ # inputs = {"x": x, "context": context}
738
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
739
+ return checkpoint(self._forward, x, context)
740
+
741
+ def _forward(self, x, context=None):
742
+ x = self.attn1(self.norm1(x), context=context) + x
743
+ x = self.ff(self.norm2(x)) + x
744
+ return x
745
+
746
+
747
+ class SpatialTransformer(nn.Module):
748
+ """
749
+ Transformer block for image-like data.
750
+ First, project the input (aka embedding)
751
+ and reshape to b, t, d.
752
+ Then apply standard transformer action.
753
+ Finally, reshape to image
754
+ NEW: use_linear for more efficiency instead of the 1x1 convs
755
+ """
756
+
757
+ def __init__(
758
+ self,
759
+ in_channels,
760
+ n_heads,
761
+ d_head,
762
+ depth=1,
763
+ dropout=0.0,
764
+ context_dim=None,
765
+ disable_self_attn=False,
766
+ use_linear=False,
767
+ attn_type="softmax",
768
+ use_checkpoint=True,
769
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
770
+ sdp_backend=None,
771
+ reference_to=None,
772
+ ):
773
+ super().__init__()
774
+ logpy.debug(
775
+ f"constructing {self.__class__.__name__} of depth {depth} w/ "
776
+ f"{in_channels} channels and {n_heads} heads."
777
+ )
778
+
779
+ if exists(context_dim) and not isinstance(context_dim, list):
780
+ context_dim = [context_dim]
781
+ if exists(context_dim) and isinstance(context_dim, list):
782
+ if depth != len(context_dim):
783
+ logpy.warn(
784
+ f"{self.__class__.__name__}: Found context dims "
785
+ f"{context_dim} of depth {len(context_dim)}, which does not "
786
+ f"match the specified 'depth' of {depth}. Setting context_dim "
787
+ f"to {depth * [context_dim[0]]} now."
788
+ )
789
+ # depth does not match context dims.
790
+ assert all(map(lambda x: x == context_dim[0], context_dim)), (
791
+ "need homogenous context_dim to match depth automatically"
792
+ )
793
+ context_dim = depth * [context_dim[0]]
794
+ elif context_dim is None:
795
+ context_dim = [None] * depth
796
+ self.in_channels = in_channels
797
+ inner_dim = n_heads * d_head
798
+ self.norm = Normalize(in_channels)
799
+ if not use_linear:
800
+ self.proj_in = nn.Conv2d(
801
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
802
+ )
803
+ else:
804
+ self.proj_in = nn.Linear(in_channels, inner_dim)
805
+
806
+ self.transformer_blocks = nn.ModuleList(
807
+ [
808
+ BasicTransformerBlock(
809
+ inner_dim,
810
+ n_heads,
811
+ d_head,
812
+ dropout=dropout,
813
+ context_dim=context_dim[d],
814
+ disable_self_attn=disable_self_attn,
815
+ attn_mode=attn_type,
816
+ checkpoint=use_checkpoint,
817
+ sdp_backend=sdp_backend,
818
+ reference_to=reference_to,
819
+ )
820
+ for d in range(depth)
821
+ ]
822
+ )
823
+ if not use_linear:
824
+ self.proj_out = zero_module(
825
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
826
+ )
827
+ else:
828
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
829
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
830
+ self.use_linear = use_linear
831
+
832
+ def forward(self, x, context=None, skip_attention=None):
833
+ # note: if no context is given, cross-attention defaults to self-attention
834
+ if not isinstance(context, list):
835
+ context = [context]
836
+ b, c, h, w = x.shape
837
+ x_in = x
838
+ x = self.norm(x)
839
+ if not self.use_linear:
840
+ x = self.proj_in(x)
841
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
842
+ if self.use_linear:
843
+ x = self.proj_in(x)
844
+ for i, block in enumerate(self.transformer_blocks):
845
+ if i > 0 and len(context) == 1:
846
+ i = 0 # use same context for each block
847
+ x = block(x, context=context[i], skip_attention=skip_attention)
848
+ if self.use_linear:
849
+ x = self.proj_out(x)
850
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
851
+ if not self.use_linear:
852
+ x = self.proj_out(x)
853
+ return x + x_in
854
+
855
+
856
+ class SimpleTransformer(nn.Module):
857
+ def __init__(
858
+ self,
859
+ dim: int,
860
+ depth: int,
861
+ heads: int,
862
+ dim_head: int,
863
+ context_dim: Optional[int] = None,
864
+ dropout: float = 0.0,
865
+ checkpoint: bool = True,
866
+ ):
867
+ super().__init__()
868
+ self.layers = nn.ModuleList([])
869
+ for _ in range(depth):
870
+ self.layers.append(
871
+ BasicTransformerBlock(
872
+ dim,
873
+ heads,
874
+ dim_head,
875
+ dropout=dropout,
876
+ context_dim=context_dim,
877
+ attn_mode="softmax-xformers",
878
+ checkpoint=checkpoint,
879
+ )
880
+ )
881
+
882
+ def forward(
883
+ self,
884
+ x: torch.Tensor,
885
+ context: Optional[torch.Tensor] = None,
886
+ ) -> torch.Tensor:
887
+ for layer in self.layers:
888
+ x = layer(x, context)
889
+ return x