Spaces:
Runtime error
Runtime error
Antoni Bigata
commited on
Commit
Β·
b5ce381
1
Parent(s):
17d618b
first commit
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- WavLM.py +854 -0
- WavLM_modules.py +765 -0
- __pycache__/WavLM.cpython-311.pyc +0 -0
- __pycache__/WavLM_modules.cpython-311.pyc +0 -0
- __pycache__/data_utils.cpython-311.pyc +0 -0
- __pycache__/dino_game.cpython-311.pyc +0 -0
- __pycache__/inference_functions.cpython-311.pyc +0 -0
- __pycache__/landmarks_extractor.cpython-311.pyc +0 -0
- __pycache__/utils.cpython-311.pyc +0 -0
- __pycache__/vae_wrapper.cpython-311.pyc +0 -0
- __pycache__/wordle_game.cpython-311.pyc +0 -0
- app.py +978 -0
- data_utils.py +635 -0
- inference_functions.py +493 -0
- landmarks_extractor.py +35 -0
- sgm/__init__.py +4 -0
- sgm/__pycache__/__init__.cpython-311.pyc +0 -0
- sgm/__pycache__/lr_scheduler.cpython-311.pyc +0 -0
- sgm/__pycache__/util.cpython-311.pyc +0 -0
- sgm/callbacks/__pycache__/video_logger.cpython-311.pyc +0 -0
- sgm/callbacks/custom_ddp.py +10 -0
- sgm/callbacks/image_logger.py +193 -0
- sgm/callbacks/setup_callback.py +86 -0
- sgm/callbacks/video_logger.py +294 -0
- sgm/data/__init__.py +1 -0
- sgm/data/__pycache__/__init__.cpython-311.pyc +0 -0
- sgm/data/__pycache__/data_utils.cpython-311.pyc +0 -0
- sgm/data/__pycache__/mask.cpython-311.pyc +0 -0
- sgm/data/__pycache__/video_datamodule_latent.cpython-311.pyc +0 -0
- sgm/data/__pycache__/video_dataset_latent.cpython-311.pyc +0 -0
- sgm/data/data_utils.py +561 -0
- sgm/data/dataset.py +80 -0
- sgm/data/mask.py +525 -0
- sgm/data/video_datamodule_latent.py +138 -0
- sgm/data/video_dataset_latent.py +780 -0
- sgm/inference/api.py +385 -0
- sgm/inference/helpers.py +305 -0
- sgm/lr_scheduler.py +135 -0
- sgm/models/__init__.py +2 -0
- sgm/models/__pycache__/__init__.cpython-311.pyc +0 -0
- sgm/models/__pycache__/autoencoder.cpython-311.pyc +0 -0
- sgm/models/__pycache__/diffusion.cpython-311.pyc +0 -0
- sgm/models/autoencoder.py +615 -0
- sgm/models/diffusion.py +747 -0
- sgm/modules/__init__.py +6 -0
- sgm/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- sgm/modules/__pycache__/attention.cpython-311.pyc +0 -0
- sgm/modules/__pycache__/ema.cpython-311.pyc +0 -0
- sgm/modules/__pycache__/video_attention.cpython-311.pyc +0 -0
- sgm/modules/attention.py +889 -0
WavLM.py
ADDED
@@ -0,0 +1,854 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
4 |
+
# Copyright (c) 2021 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import logging
|
12 |
+
from typing import List, Optional, Tuple
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch.nn import LayerNorm
|
20 |
+
from einops import rearrange
|
21 |
+
import requests
|
22 |
+
from clint.textui import progress
|
23 |
+
import os
|
24 |
+
from WavLM_modules import (
|
25 |
+
Fp32GroupNorm,
|
26 |
+
Fp32LayerNorm,
|
27 |
+
GradMultiply,
|
28 |
+
MultiheadAttention,
|
29 |
+
SamePad,
|
30 |
+
init_bert_params,
|
31 |
+
get_activation_fn,
|
32 |
+
TransposeLast,
|
33 |
+
GLU_Linear,
|
34 |
+
)
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
class WavLM_wrapper(nn.Module):
|
40 |
+
def __init__(
|
41 |
+
self, model_size="Base+", feed_as_frames=True, merge_type="cat", model_path=None
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
assert model_size in ["Base+", "Large"]
|
45 |
+
if model_path is None:
|
46 |
+
model_path = os.path.join(
|
47 |
+
os.path.dirname(__file__), f"WavLM-{model_size}.pt"
|
48 |
+
)
|
49 |
+
if not os.path.exists(model_path):
|
50 |
+
self.download_model(model_path, model_size)
|
51 |
+
checkpoint = torch.load(model_path)
|
52 |
+
cfg = WavLMConfig(checkpoint["cfg"])
|
53 |
+
self.cfg = cfg
|
54 |
+
self.model = WavLM(cfg)
|
55 |
+
self.model.load_state_dict(checkpoint["model"])
|
56 |
+
self.model.eval()
|
57 |
+
for param in self.model.parameters():
|
58 |
+
param.requires_grad = False
|
59 |
+
self.code_size = 768 * 2 if merge_type == "cat" else 768
|
60 |
+
self.merge_type = merge_type
|
61 |
+
self.feed_as_frames = feed_as_frames
|
62 |
+
|
63 |
+
def download_model(self, out_path, size: str = "Base+"):
|
64 |
+
print("Downloading model...")
|
65 |
+
if size == "Base+":
|
66 |
+
url = "https://valle.blob.core.windows.net/share/wavlm/WavLM-Base+.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D"
|
67 |
+
else:
|
68 |
+
url = "https://valle.blob.core.windows.net/share/wavlm/WavLM-Large.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D"
|
69 |
+
r = requests.get(url, allow_redirects=True, stream=True)
|
70 |
+
with open(out_path, "wb") as f:
|
71 |
+
total_length = int(r.headers.get("content-length"))
|
72 |
+
for chunk in progress.bar(
|
73 |
+
r.iter_content(chunk_size=1024), expected_size=(total_length / 1024) + 1
|
74 |
+
):
|
75 |
+
if chunk:
|
76 |
+
f.write(chunk)
|
77 |
+
f.flush()
|
78 |
+
print("Model downloaded to %s" % out_path)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
x: (batch, n_frames, audio_features)
|
84 |
+
"""
|
85 |
+
T = x.shape[1]
|
86 |
+
|
87 |
+
if self.feed_as_frames:
|
88 |
+
x = rearrange(x, "b f d -> (b f) d")
|
89 |
+
else:
|
90 |
+
x = rearrange(x, "b ... -> b (...)")
|
91 |
+
|
92 |
+
if self.cfg.normalize:
|
93 |
+
x = torch.nn.functional.layer_norm(x, x.shape)
|
94 |
+
|
95 |
+
x = self.model.extract_features(x)[0] # B, new_features, C
|
96 |
+
if self.feed_as_frames:
|
97 |
+
x = rearrange(x, "(b f) d c -> b f d c", f=T)
|
98 |
+
else:
|
99 |
+
x = torch.nn.functional.interpolate(
|
100 |
+
x.permute(0, 2, 1), T * 2, mode="nearest"
|
101 |
+
)
|
102 |
+
x = rearrange(x, "b c (f d) -> b f d c", d=2)
|
103 |
+
|
104 |
+
if self.merge_type == "cat":
|
105 |
+
if x.dim() == 3:
|
106 |
+
return rearrange(x, "b d c -> b (d c)")
|
107 |
+
return rearrange(x, "b f d c -> b f (d c)")
|
108 |
+
elif self.merge_type == "sum":
|
109 |
+
return x.sum(dim=-2)
|
110 |
+
elif self.merge_type == "mean":
|
111 |
+
return x.mean(dim=-2)
|
112 |
+
elif self.merge_type == "None":
|
113 |
+
return x
|
114 |
+
else:
|
115 |
+
raise NotImplementedError
|
116 |
+
|
117 |
+
|
118 |
+
def compute_mask_indices(
|
119 |
+
shape: Tuple[int, int],
|
120 |
+
padding_mask: Optional[torch.Tensor],
|
121 |
+
mask_prob: float,
|
122 |
+
mask_length: int,
|
123 |
+
mask_type: str = "static",
|
124 |
+
mask_other: float = 0.0,
|
125 |
+
min_masks: int = 0,
|
126 |
+
no_overlap: bool = False,
|
127 |
+
min_space: int = 0,
|
128 |
+
) -> np.ndarray:
|
129 |
+
"""
|
130 |
+
Computes random mask spans for a given shape
|
131 |
+
Args:
|
132 |
+
shape: the the shape for which to compute masks.
|
133 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
134 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
135 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
136 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
137 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
138 |
+
mask_type: how to compute mask lengths
|
139 |
+
static = fixed size
|
140 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
141 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
142 |
+
poisson = sample from possion distribution with lambda = mask length
|
143 |
+
min_masks: minimum number of masked spans
|
144 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
145 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
146 |
+
"""
|
147 |
+
|
148 |
+
bsz, all_sz = shape
|
149 |
+
mask = np.full((bsz, all_sz), False)
|
150 |
+
|
151 |
+
all_num_mask = int(
|
152 |
+
# add a random number for probabilistic rounding
|
153 |
+
mask_prob * all_sz / float(mask_length) + np.random.rand()
|
154 |
+
)
|
155 |
+
|
156 |
+
all_num_mask = max(min_masks, all_num_mask)
|
157 |
+
|
158 |
+
mask_idcs = []
|
159 |
+
for i in range(bsz):
|
160 |
+
if padding_mask is not None:
|
161 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
162 |
+
num_mask = int(
|
163 |
+
# add a random number for probabilistic rounding
|
164 |
+
mask_prob * sz / float(mask_length) + np.random.rand()
|
165 |
+
)
|
166 |
+
num_mask = max(min_masks, num_mask)
|
167 |
+
else:
|
168 |
+
sz = all_sz
|
169 |
+
num_mask = all_num_mask
|
170 |
+
|
171 |
+
if mask_type == "static":
|
172 |
+
lengths = np.full(num_mask, mask_length)
|
173 |
+
elif mask_type == "uniform":
|
174 |
+
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
175 |
+
elif mask_type == "normal":
|
176 |
+
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
177 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
178 |
+
elif mask_type == "poisson":
|
179 |
+
lengths = np.random.poisson(mask_length, size=num_mask)
|
180 |
+
lengths = [int(round(x)) for x in lengths]
|
181 |
+
else:
|
182 |
+
raise Exception("unknown mask selection " + mask_type)
|
183 |
+
|
184 |
+
if sum(lengths) == 0:
|
185 |
+
lengths[0] = min(mask_length, sz - 1)
|
186 |
+
|
187 |
+
if no_overlap:
|
188 |
+
mask_idc = []
|
189 |
+
|
190 |
+
def arrange(s, e, length, keep_length):
|
191 |
+
span_start = np.random.randint(s, e - length)
|
192 |
+
mask_idc.extend(span_start + i for i in range(length))
|
193 |
+
|
194 |
+
new_parts = []
|
195 |
+
if span_start - s - min_space >= keep_length:
|
196 |
+
new_parts.append((s, span_start - min_space + 1))
|
197 |
+
if e - span_start - keep_length - min_space > keep_length:
|
198 |
+
new_parts.append((span_start + length + min_space, e))
|
199 |
+
return new_parts
|
200 |
+
|
201 |
+
parts = [(0, sz)]
|
202 |
+
min_length = min(lengths)
|
203 |
+
for length in sorted(lengths, reverse=True):
|
204 |
+
lens = np.fromiter(
|
205 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
206 |
+
np.int,
|
207 |
+
)
|
208 |
+
l_sum = np.sum(lens)
|
209 |
+
if l_sum == 0:
|
210 |
+
break
|
211 |
+
probs = lens / np.sum(lens)
|
212 |
+
c = np.random.choice(len(parts), p=probs)
|
213 |
+
s, e = parts.pop(c)
|
214 |
+
parts.extend(arrange(s, e, length, min_length))
|
215 |
+
mask_idc = np.asarray(mask_idc)
|
216 |
+
else:
|
217 |
+
min_len = min(lengths)
|
218 |
+
if sz - min_len <= num_mask:
|
219 |
+
min_len = sz - num_mask - 1
|
220 |
+
|
221 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
222 |
+
|
223 |
+
mask_idc = np.asarray(
|
224 |
+
[
|
225 |
+
mask_idc[j] + offset
|
226 |
+
for j in range(len(mask_idc))
|
227 |
+
for offset in range(lengths[j])
|
228 |
+
]
|
229 |
+
)
|
230 |
+
|
231 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
232 |
+
|
233 |
+
min_len = min([len(m) for m in mask_idcs])
|
234 |
+
for i, mask_idc in enumerate(mask_idcs):
|
235 |
+
if len(mask_idc) > min_len:
|
236 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
237 |
+
mask[i, mask_idc] = True
|
238 |
+
|
239 |
+
return mask
|
240 |
+
|
241 |
+
|
242 |
+
class WavLMConfig:
|
243 |
+
def __init__(self, cfg=None):
|
244 |
+
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
|
245 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
246 |
+
|
247 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
248 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
249 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
250 |
+
self.activation_fn: str = "gelu" # activation function to use
|
251 |
+
|
252 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
253 |
+
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
|
254 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
255 |
+
self.feature_grad_mult: float = (
|
256 |
+
1.0 # multiply feature extractor var grads by this
|
257 |
+
)
|
258 |
+
|
259 |
+
self.normalize: bool = (
|
260 |
+
False # normalize input to have 0 mean and unit variance during training
|
261 |
+
)
|
262 |
+
|
263 |
+
# dropouts
|
264 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
265 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
266 |
+
self.activation_dropout: float = (
|
267 |
+
0.0 # dropout probability after activation in FFN
|
268 |
+
)
|
269 |
+
self.encoder_layerdrop: float = (
|
270 |
+
0.0 # probability of dropping a tarnsformer layer
|
271 |
+
)
|
272 |
+
self.dropout_input: float = (
|
273 |
+
0.0 # dropout to apply to the input (after feat extr)
|
274 |
+
)
|
275 |
+
self.dropout_features: float = (
|
276 |
+
0.0 # dropout to apply to the features (after feat extr)
|
277 |
+
)
|
278 |
+
|
279 |
+
# masking
|
280 |
+
self.mask_length: int = 10 # mask length
|
281 |
+
self.mask_prob: float = 0.65 # probability of replacing a token with mask
|
282 |
+
self.mask_selection: str = "static" # how to choose mask length
|
283 |
+
self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
|
284 |
+
self.no_mask_overlap: bool = False # whether to allow masks to overlap
|
285 |
+
self.mask_min_space: int = (
|
286 |
+
1 # min space between spans (if no overlap is enabled)
|
287 |
+
)
|
288 |
+
|
289 |
+
# channel masking
|
290 |
+
self.mask_channel_length: int = 10 # length of the mask for features (channels)
|
291 |
+
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
|
292 |
+
self.mask_channel_selection: str = (
|
293 |
+
"static" # how to choose mask length for channel masking
|
294 |
+
)
|
295 |
+
self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
|
296 |
+
self.no_mask_channel_overlap: bool = (
|
297 |
+
False # whether to allow channel masks to overlap
|
298 |
+
)
|
299 |
+
self.mask_channel_min_space: int = (
|
300 |
+
1 # min space between spans (if no overlap is enabled)
|
301 |
+
)
|
302 |
+
|
303 |
+
# positional embeddings
|
304 |
+
self.conv_pos: int = (
|
305 |
+
128 # number of filters for convolutional positional embeddings
|
306 |
+
)
|
307 |
+
self.conv_pos_groups: int = (
|
308 |
+
16 # number of groups for convolutional positional embedding
|
309 |
+
)
|
310 |
+
|
311 |
+
# relative position embedding
|
312 |
+
self.relative_position_embedding: bool = (
|
313 |
+
False # apply relative position embedding
|
314 |
+
)
|
315 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
316 |
+
self.max_distance: int = (
|
317 |
+
1280 # maximum distance for relative position embedding
|
318 |
+
)
|
319 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
320 |
+
|
321 |
+
if cfg is not None:
|
322 |
+
self.update(cfg)
|
323 |
+
|
324 |
+
def update(self, cfg: dict):
|
325 |
+
self.__dict__.update(cfg)
|
326 |
+
|
327 |
+
|
328 |
+
class WavLM(nn.Module):
|
329 |
+
def __init__(
|
330 |
+
self,
|
331 |
+
cfg: WavLMConfig,
|
332 |
+
) -> None:
|
333 |
+
super().__init__()
|
334 |
+
logger.info(f"WavLM Config: {cfg.__dict__}")
|
335 |
+
|
336 |
+
self.cfg = cfg
|
337 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
338 |
+
self.embed = feature_enc_layers[-1][0]
|
339 |
+
|
340 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
341 |
+
conv_layers=feature_enc_layers,
|
342 |
+
dropout=0.0,
|
343 |
+
mode=cfg.extractor_mode,
|
344 |
+
conv_bias=cfg.conv_bias,
|
345 |
+
)
|
346 |
+
|
347 |
+
self.post_extract_proj = (
|
348 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
349 |
+
if self.embed != cfg.encoder_embed_dim
|
350 |
+
else None
|
351 |
+
)
|
352 |
+
|
353 |
+
self.mask_prob = cfg.mask_prob
|
354 |
+
self.mask_selection = cfg.mask_selection
|
355 |
+
self.mask_other = cfg.mask_other
|
356 |
+
self.mask_length = cfg.mask_length
|
357 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
358 |
+
self.mask_min_space = cfg.mask_min_space
|
359 |
+
|
360 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
361 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
362 |
+
self.mask_channel_other = cfg.mask_channel_other
|
363 |
+
self.mask_channel_length = cfg.mask_channel_length
|
364 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
365 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
366 |
+
|
367 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
368 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
369 |
+
|
370 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
371 |
+
|
372 |
+
self.mask_emb = nn.Parameter(
|
373 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
374 |
+
)
|
375 |
+
|
376 |
+
self.encoder = TransformerEncoder(cfg)
|
377 |
+
self.layer_norm = LayerNorm(self.embed)
|
378 |
+
|
379 |
+
def apply_mask(self, x, padding_mask):
|
380 |
+
B, T, C = x.shape
|
381 |
+
if self.mask_prob > 0:
|
382 |
+
mask_indices = compute_mask_indices(
|
383 |
+
(B, T),
|
384 |
+
padding_mask,
|
385 |
+
self.mask_prob,
|
386 |
+
self.mask_length,
|
387 |
+
self.mask_selection,
|
388 |
+
self.mask_other,
|
389 |
+
min_masks=2,
|
390 |
+
no_overlap=self.no_mask_overlap,
|
391 |
+
min_space=self.mask_min_space,
|
392 |
+
)
|
393 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
394 |
+
x[mask_indices] = self.mask_emb
|
395 |
+
else:
|
396 |
+
mask_indices = None
|
397 |
+
|
398 |
+
if self.mask_channel_prob > 0:
|
399 |
+
mask_channel_indices = compute_mask_indices(
|
400 |
+
(B, C),
|
401 |
+
None,
|
402 |
+
self.mask_channel_prob,
|
403 |
+
self.mask_channel_length,
|
404 |
+
self.mask_channel_selection,
|
405 |
+
self.mask_channel_other,
|
406 |
+
no_overlap=self.no_mask_channel_overlap,
|
407 |
+
min_space=self.mask_channel_min_space,
|
408 |
+
)
|
409 |
+
mask_channel_indices = (
|
410 |
+
torch.from_numpy(mask_channel_indices)
|
411 |
+
.to(x.device)
|
412 |
+
.unsqueeze(1)
|
413 |
+
.expand(-1, T, -1)
|
414 |
+
)
|
415 |
+
x[mask_channel_indices] = 0
|
416 |
+
|
417 |
+
return x, mask_indices
|
418 |
+
|
419 |
+
def forward_padding_mask(
|
420 |
+
self,
|
421 |
+
features: torch.Tensor,
|
422 |
+
padding_mask: torch.Tensor,
|
423 |
+
) -> torch.Tensor:
|
424 |
+
extra = padding_mask.size(1) % features.size(1)
|
425 |
+
if extra > 0:
|
426 |
+
padding_mask = padding_mask[:, :-extra]
|
427 |
+
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
428 |
+
padding_mask = padding_mask.all(-1)
|
429 |
+
return padding_mask
|
430 |
+
|
431 |
+
def extract_features(
|
432 |
+
self,
|
433 |
+
source: torch.Tensor,
|
434 |
+
padding_mask: Optional[torch.Tensor] = None,
|
435 |
+
mask: bool = False,
|
436 |
+
ret_conv: bool = False,
|
437 |
+
output_layer: Optional[int] = None,
|
438 |
+
ret_layer_results: bool = False,
|
439 |
+
):
|
440 |
+
if self.feature_grad_mult > 0:
|
441 |
+
features = self.feature_extractor(source)
|
442 |
+
if self.feature_grad_mult != 1.0:
|
443 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
444 |
+
else:
|
445 |
+
with torch.no_grad():
|
446 |
+
features = self.feature_extractor(source)
|
447 |
+
|
448 |
+
features = features.transpose(1, 2)
|
449 |
+
features = self.layer_norm(features)
|
450 |
+
|
451 |
+
if padding_mask is not None:
|
452 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
453 |
+
|
454 |
+
if self.post_extract_proj is not None:
|
455 |
+
features = self.post_extract_proj(features)
|
456 |
+
|
457 |
+
features = self.dropout_input(features)
|
458 |
+
|
459 |
+
if mask:
|
460 |
+
x, mask_indices = self.apply_mask(features, padding_mask)
|
461 |
+
else:
|
462 |
+
x = features
|
463 |
+
|
464 |
+
# feature: (B, T, D), float
|
465 |
+
# target: (B, T), long
|
466 |
+
# x: (B, T, D), float
|
467 |
+
# padding_mask: (B, T), bool
|
468 |
+
# mask_indices: (B, T), bool
|
469 |
+
x, layer_results = self.encoder(
|
470 |
+
x,
|
471 |
+
padding_mask=padding_mask,
|
472 |
+
layer=None if output_layer is None else output_layer - 1,
|
473 |
+
)
|
474 |
+
|
475 |
+
res = {
|
476 |
+
"x": x,
|
477 |
+
"padding_mask": padding_mask,
|
478 |
+
"features": features,
|
479 |
+
"layer_results": layer_results,
|
480 |
+
}
|
481 |
+
|
482 |
+
feature = res["features"] if ret_conv else res["x"]
|
483 |
+
if ret_layer_results:
|
484 |
+
feature = (feature, res["layer_results"])
|
485 |
+
return feature, res["padding_mask"]
|
486 |
+
|
487 |
+
|
488 |
+
class ConvFeatureExtractionModel(nn.Module):
|
489 |
+
def __init__(
|
490 |
+
self,
|
491 |
+
conv_layers: List[Tuple[int, int, int]],
|
492 |
+
dropout: float = 0.0,
|
493 |
+
mode: str = "default",
|
494 |
+
conv_bias: bool = False,
|
495 |
+
conv_type: str = "default",
|
496 |
+
):
|
497 |
+
super().__init__()
|
498 |
+
|
499 |
+
assert mode in {"default", "layer_norm"}
|
500 |
+
|
501 |
+
def block(
|
502 |
+
n_in,
|
503 |
+
n_out,
|
504 |
+
k,
|
505 |
+
stride,
|
506 |
+
is_layer_norm=False,
|
507 |
+
is_group_norm=False,
|
508 |
+
conv_bias=False,
|
509 |
+
):
|
510 |
+
def make_conv():
|
511 |
+
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
512 |
+
nn.init.kaiming_normal_(conv.weight)
|
513 |
+
return conv
|
514 |
+
|
515 |
+
assert not (is_layer_norm and is_group_norm), (
|
516 |
+
"layer norm and group norm are exclusive"
|
517 |
+
)
|
518 |
+
|
519 |
+
if is_layer_norm:
|
520 |
+
return nn.Sequential(
|
521 |
+
make_conv(),
|
522 |
+
nn.Dropout(p=dropout),
|
523 |
+
nn.Sequential(
|
524 |
+
TransposeLast(),
|
525 |
+
Fp32LayerNorm(dim, elementwise_affine=True),
|
526 |
+
TransposeLast(),
|
527 |
+
),
|
528 |
+
nn.GELU(),
|
529 |
+
)
|
530 |
+
elif is_group_norm:
|
531 |
+
return nn.Sequential(
|
532 |
+
make_conv(),
|
533 |
+
nn.Dropout(p=dropout),
|
534 |
+
Fp32GroupNorm(dim, dim, affine=True),
|
535 |
+
nn.GELU(),
|
536 |
+
)
|
537 |
+
else:
|
538 |
+
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
539 |
+
|
540 |
+
self.conv_type = conv_type
|
541 |
+
if self.conv_type == "default":
|
542 |
+
in_d = 1
|
543 |
+
self.conv_layers = nn.ModuleList()
|
544 |
+
for i, cl in enumerate(conv_layers):
|
545 |
+
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
546 |
+
(dim, k, stride) = cl
|
547 |
+
|
548 |
+
self.conv_layers.append(
|
549 |
+
block(
|
550 |
+
in_d,
|
551 |
+
dim,
|
552 |
+
k,
|
553 |
+
stride,
|
554 |
+
is_layer_norm=mode == "layer_norm",
|
555 |
+
is_group_norm=mode == "default" and i == 0,
|
556 |
+
conv_bias=conv_bias,
|
557 |
+
)
|
558 |
+
)
|
559 |
+
in_d = dim
|
560 |
+
elif self.conv_type == "conv2d":
|
561 |
+
in_d = 1
|
562 |
+
self.conv_layers = nn.ModuleList()
|
563 |
+
for i, cl in enumerate(conv_layers):
|
564 |
+
assert len(cl) == 3
|
565 |
+
(dim, k, stride) = cl
|
566 |
+
|
567 |
+
self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride))
|
568 |
+
self.conv_layers.append(torch.nn.ReLU())
|
569 |
+
in_d = dim
|
570 |
+
elif self.conv_type == "custom":
|
571 |
+
in_d = 1
|
572 |
+
idim = 80
|
573 |
+
self.conv_layers = nn.ModuleList()
|
574 |
+
for i, cl in enumerate(conv_layers):
|
575 |
+
assert len(cl) == 3
|
576 |
+
(dim, k, stride) = cl
|
577 |
+
self.conv_layers.append(
|
578 |
+
torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
|
579 |
+
)
|
580 |
+
self.conv_layers.append(torch.nn.LayerNorm([dim, idim]))
|
581 |
+
self.conv_layers.append(torch.nn.ReLU())
|
582 |
+
in_d = dim
|
583 |
+
if (i + 1) % 2 == 0:
|
584 |
+
self.conv_layers.append(
|
585 |
+
torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
586 |
+
)
|
587 |
+
idim = int(math.ceil(idim / 2))
|
588 |
+
else:
|
589 |
+
pass
|
590 |
+
|
591 |
+
def forward(self, x, mask=None):
|
592 |
+
# BxT -> BxCxT
|
593 |
+
x = x.unsqueeze(1)
|
594 |
+
if self.conv_type == "custom":
|
595 |
+
for conv in self.conv_layers:
|
596 |
+
if isinstance(conv, nn.LayerNorm):
|
597 |
+
x = x.transpose(1, 2)
|
598 |
+
x = conv(x).transpose(1, 2)
|
599 |
+
else:
|
600 |
+
x = conv(x)
|
601 |
+
x = x.transpose(2, 3).contiguous()
|
602 |
+
x = x.view(x.size(0), -1, x.size(-1))
|
603 |
+
else:
|
604 |
+
for conv in self.conv_layers:
|
605 |
+
x = conv(x)
|
606 |
+
if self.conv_type == "conv2d":
|
607 |
+
b, c, t, f = x.size()
|
608 |
+
x = x.transpose(2, 3).contiguous().view(b, c * f, t)
|
609 |
+
return x
|
610 |
+
|
611 |
+
|
612 |
+
class TransformerEncoder(nn.Module):
|
613 |
+
def __init__(self, args):
|
614 |
+
super().__init__()
|
615 |
+
|
616 |
+
self.dropout = args.dropout
|
617 |
+
self.embedding_dim = args.encoder_embed_dim
|
618 |
+
|
619 |
+
self.pos_conv = nn.Conv1d(
|
620 |
+
self.embedding_dim,
|
621 |
+
self.embedding_dim,
|
622 |
+
kernel_size=args.conv_pos,
|
623 |
+
padding=args.conv_pos // 2,
|
624 |
+
groups=args.conv_pos_groups,
|
625 |
+
)
|
626 |
+
dropout = 0
|
627 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
628 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
629 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
630 |
+
|
631 |
+
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
632 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
633 |
+
|
634 |
+
if hasattr(args, "relative_position_embedding"):
|
635 |
+
self.relative_position_embedding = args.relative_position_embedding
|
636 |
+
self.num_buckets = args.num_buckets
|
637 |
+
self.max_distance = args.max_distance
|
638 |
+
else:
|
639 |
+
self.relative_position_embedding = False
|
640 |
+
self.num_buckets = 0
|
641 |
+
self.max_distance = 0
|
642 |
+
|
643 |
+
self.layers = nn.ModuleList(
|
644 |
+
[
|
645 |
+
TransformerSentenceEncoderLayer(
|
646 |
+
embedding_dim=self.embedding_dim,
|
647 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
648 |
+
num_attention_heads=args.encoder_attention_heads,
|
649 |
+
dropout=self.dropout,
|
650 |
+
attention_dropout=args.attention_dropout,
|
651 |
+
activation_dropout=args.activation_dropout,
|
652 |
+
activation_fn=args.activation_fn,
|
653 |
+
layer_norm_first=args.layer_norm_first,
|
654 |
+
has_relative_attention_bias=(
|
655 |
+
self.relative_position_embedding and i == 0
|
656 |
+
),
|
657 |
+
num_buckets=self.num_buckets,
|
658 |
+
max_distance=self.max_distance,
|
659 |
+
gru_rel_pos=args.gru_rel_pos,
|
660 |
+
)
|
661 |
+
for i in range(args.encoder_layers)
|
662 |
+
]
|
663 |
+
)
|
664 |
+
|
665 |
+
self.layer_norm_first = args.layer_norm_first
|
666 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
667 |
+
self.layerdrop = args.encoder_layerdrop
|
668 |
+
|
669 |
+
self.apply(init_bert_params)
|
670 |
+
|
671 |
+
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
|
672 |
+
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
|
673 |
+
|
674 |
+
if self.layer_norm_first and layer is None:
|
675 |
+
x = self.layer_norm(x)
|
676 |
+
|
677 |
+
return x, layer_results
|
678 |
+
|
679 |
+
def extract_features(
|
680 |
+
self, x, padding_mask=None, streaming_mask=None, tgt_layer=None
|
681 |
+
):
|
682 |
+
if padding_mask is not None:
|
683 |
+
x[padding_mask] = 0
|
684 |
+
|
685 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
686 |
+
x_conv = x_conv.transpose(1, 2)
|
687 |
+
x += x_conv
|
688 |
+
|
689 |
+
if not self.layer_norm_first:
|
690 |
+
x = self.layer_norm(x)
|
691 |
+
|
692 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
693 |
+
|
694 |
+
# B x T x C -> T x B x C
|
695 |
+
x = x.transpose(0, 1)
|
696 |
+
|
697 |
+
layer_results = []
|
698 |
+
z = None
|
699 |
+
if tgt_layer is not None:
|
700 |
+
layer_results.append((x, z))
|
701 |
+
r = None
|
702 |
+
pos_bias = None
|
703 |
+
for i, layer in enumerate(self.layers):
|
704 |
+
dropout_probability = np.random.random()
|
705 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
706 |
+
x, z, pos_bias = layer(
|
707 |
+
x,
|
708 |
+
self_attn_padding_mask=padding_mask,
|
709 |
+
need_weights=False,
|
710 |
+
self_attn_mask=streaming_mask,
|
711 |
+
pos_bias=pos_bias,
|
712 |
+
)
|
713 |
+
if tgt_layer is not None:
|
714 |
+
layer_results.append((x, z))
|
715 |
+
if i == tgt_layer:
|
716 |
+
r = x
|
717 |
+
break
|
718 |
+
|
719 |
+
if r is not None:
|
720 |
+
x = r
|
721 |
+
|
722 |
+
# T x B x C -> B x T x C
|
723 |
+
x = x.transpose(0, 1)
|
724 |
+
|
725 |
+
return x, layer_results
|
726 |
+
|
727 |
+
|
728 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
729 |
+
"""
|
730 |
+
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
731 |
+
models.
|
732 |
+
"""
|
733 |
+
|
734 |
+
def __init__(
|
735 |
+
self,
|
736 |
+
embedding_dim: float = 768,
|
737 |
+
ffn_embedding_dim: float = 3072,
|
738 |
+
num_attention_heads: float = 8,
|
739 |
+
dropout: float = 0.1,
|
740 |
+
attention_dropout: float = 0.1,
|
741 |
+
activation_dropout: float = 0.1,
|
742 |
+
activation_fn: str = "relu",
|
743 |
+
layer_norm_first: bool = False,
|
744 |
+
has_relative_attention_bias: bool = False,
|
745 |
+
num_buckets: int = 0,
|
746 |
+
max_distance: int = 0,
|
747 |
+
rescale_init: bool = False,
|
748 |
+
gru_rel_pos: bool = False,
|
749 |
+
) -> None:
|
750 |
+
super().__init__()
|
751 |
+
# Initialize parameters
|
752 |
+
self.embedding_dim = embedding_dim
|
753 |
+
self.dropout = dropout
|
754 |
+
self.activation_dropout = activation_dropout
|
755 |
+
|
756 |
+
# Initialize blocks
|
757 |
+
self.activation_name = activation_fn
|
758 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
759 |
+
self.self_attn = MultiheadAttention(
|
760 |
+
self.embedding_dim,
|
761 |
+
num_attention_heads,
|
762 |
+
dropout=attention_dropout,
|
763 |
+
self_attention=True,
|
764 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
765 |
+
num_buckets=num_buckets,
|
766 |
+
max_distance=max_distance,
|
767 |
+
rescale_init=rescale_init,
|
768 |
+
gru_rel_pos=gru_rel_pos,
|
769 |
+
)
|
770 |
+
|
771 |
+
self.dropout1 = nn.Dropout(dropout)
|
772 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
773 |
+
self.dropout3 = nn.Dropout(dropout)
|
774 |
+
|
775 |
+
self.layer_norm_first = layer_norm_first
|
776 |
+
|
777 |
+
# layer norm associated with the self attention layer
|
778 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
779 |
+
|
780 |
+
if self.activation_name == "glu":
|
781 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
782 |
+
else:
|
783 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
784 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
785 |
+
|
786 |
+
# layer norm associated with the position wise feed-forward NN
|
787 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
788 |
+
|
789 |
+
def forward(
|
790 |
+
self,
|
791 |
+
x: torch.Tensor,
|
792 |
+
self_attn_mask: torch.Tensor = None,
|
793 |
+
self_attn_padding_mask: torch.Tensor = None,
|
794 |
+
need_weights: bool = False,
|
795 |
+
pos_bias=None,
|
796 |
+
):
|
797 |
+
"""
|
798 |
+
LayerNorm is applied either before or after the self-attention/ffn
|
799 |
+
modules similar to the original Transformer imlementation.
|
800 |
+
"""
|
801 |
+
residual = x
|
802 |
+
|
803 |
+
if self.layer_norm_first:
|
804 |
+
x = self.self_attn_layer_norm(x)
|
805 |
+
x, attn, pos_bias = self.self_attn(
|
806 |
+
query=x,
|
807 |
+
key=x,
|
808 |
+
value=x,
|
809 |
+
key_padding_mask=self_attn_padding_mask,
|
810 |
+
need_weights=False,
|
811 |
+
attn_mask=self_attn_mask,
|
812 |
+
position_bias=pos_bias,
|
813 |
+
)
|
814 |
+
x = self.dropout1(x)
|
815 |
+
x = residual + x
|
816 |
+
|
817 |
+
residual = x
|
818 |
+
x = self.final_layer_norm(x)
|
819 |
+
if self.activation_name == "glu":
|
820 |
+
x = self.fc1(x)
|
821 |
+
else:
|
822 |
+
x = self.activation_fn(self.fc1(x))
|
823 |
+
x = self.dropout2(x)
|
824 |
+
x = self.fc2(x)
|
825 |
+
x = self.dropout3(x)
|
826 |
+
x = residual + x
|
827 |
+
else:
|
828 |
+
x, attn, pos_bias = self.self_attn(
|
829 |
+
query=x,
|
830 |
+
key=x,
|
831 |
+
value=x,
|
832 |
+
key_padding_mask=self_attn_padding_mask,
|
833 |
+
need_weights=need_weights,
|
834 |
+
attn_mask=self_attn_mask,
|
835 |
+
position_bias=pos_bias,
|
836 |
+
)
|
837 |
+
|
838 |
+
x = self.dropout1(x)
|
839 |
+
x = residual + x
|
840 |
+
|
841 |
+
x = self.self_attn_layer_norm(x)
|
842 |
+
|
843 |
+
residual = x
|
844 |
+
if self.activation_name == "glu":
|
845 |
+
x = self.fc1(x)
|
846 |
+
else:
|
847 |
+
x = self.activation_fn(self.fc1(x))
|
848 |
+
x = self.dropout2(x)
|
849 |
+
x = self.fc2(x)
|
850 |
+
x = self.dropout3(x)
|
851 |
+
x = residual + x
|
852 |
+
x = self.final_layer_norm(x)
|
853 |
+
|
854 |
+
return x, attn, pos_bias
|
WavLM_modules.py
ADDED
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
4 |
+
# Copyright (c) 2021 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import warnings
|
12 |
+
from typing import Dict, Optional, Tuple
|
13 |
+
import torch
|
14 |
+
from torch import Tensor, nn
|
15 |
+
from torch.nn import Parameter
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
class TransposeLast(nn.Module):
|
20 |
+
def __init__(self, deconstruct_idx=None):
|
21 |
+
super().__init__()
|
22 |
+
self.deconstruct_idx = deconstruct_idx
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
if self.deconstruct_idx is not None:
|
26 |
+
x = x[self.deconstruct_idx]
|
27 |
+
return x.transpose(-2, -1)
|
28 |
+
|
29 |
+
|
30 |
+
class Fp32LayerNorm(nn.LayerNorm):
|
31 |
+
def __init__(self, *args, **kwargs):
|
32 |
+
super().__init__(*args, **kwargs)
|
33 |
+
|
34 |
+
def forward(self, input):
|
35 |
+
output = F.layer_norm(
|
36 |
+
input.float(),
|
37 |
+
self.normalized_shape,
|
38 |
+
self.weight.float() if self.weight is not None else None,
|
39 |
+
self.bias.float() if self.bias is not None else None,
|
40 |
+
self.eps,
|
41 |
+
)
|
42 |
+
return output.type_as(input)
|
43 |
+
|
44 |
+
|
45 |
+
class Fp32GroupNorm(nn.GroupNorm):
|
46 |
+
def __init__(self, *args, **kwargs):
|
47 |
+
super().__init__(*args, **kwargs)
|
48 |
+
|
49 |
+
def forward(self, input):
|
50 |
+
output = F.group_norm(
|
51 |
+
input.float(),
|
52 |
+
self.num_groups,
|
53 |
+
self.weight.float() if self.weight is not None else None,
|
54 |
+
self.bias.float() if self.bias is not None else None,
|
55 |
+
self.eps,
|
56 |
+
)
|
57 |
+
return output.type_as(input)
|
58 |
+
|
59 |
+
|
60 |
+
class GradMultiply(torch.autograd.Function):
|
61 |
+
@staticmethod
|
62 |
+
def forward(ctx, x, scale):
|
63 |
+
ctx.scale = scale
|
64 |
+
res = x.new(x)
|
65 |
+
return res
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def backward(ctx, grad):
|
69 |
+
return grad * ctx.scale, None
|
70 |
+
|
71 |
+
|
72 |
+
class SamePad(nn.Module):
|
73 |
+
def __init__(self, kernel_size, causal=False):
|
74 |
+
super().__init__()
|
75 |
+
if causal:
|
76 |
+
self.remove = kernel_size - 1
|
77 |
+
else:
|
78 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
if self.remove > 0:
|
82 |
+
x = x[:, :, : -self.remove]
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class Swish(nn.Module):
|
87 |
+
"""Swish function"""
|
88 |
+
|
89 |
+
def __init__(self):
|
90 |
+
"""Construct an MultiHeadedAttention object."""
|
91 |
+
super(Swish, self).__init__()
|
92 |
+
self.act = torch.nn.Sigmoid()
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
return x * self.act(x)
|
96 |
+
|
97 |
+
|
98 |
+
class GLU_Linear(nn.Module):
|
99 |
+
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
100 |
+
super(GLU_Linear, self).__init__()
|
101 |
+
|
102 |
+
self.glu_type = glu_type
|
103 |
+
self.output_dim = output_dim
|
104 |
+
|
105 |
+
if glu_type == "sigmoid":
|
106 |
+
self.glu_act = torch.nn.Sigmoid()
|
107 |
+
elif glu_type == "swish":
|
108 |
+
self.glu_act = Swish()
|
109 |
+
elif glu_type == "relu":
|
110 |
+
self.glu_act = torch.nn.ReLU()
|
111 |
+
elif glu_type == "gelu":
|
112 |
+
self.glu_act = torch.nn.GELU()
|
113 |
+
|
114 |
+
if bias_in_glu:
|
115 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
116 |
+
else:
|
117 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
121 |
+
x = self.linear(x)
|
122 |
+
|
123 |
+
if self.glu_type == "bilinear":
|
124 |
+
x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2]
|
125 |
+
else:
|
126 |
+
x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2])
|
127 |
+
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
def gelu_accurate(x):
|
132 |
+
if not hasattr(gelu_accurate, "_a"):
|
133 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
134 |
+
return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
135 |
+
|
136 |
+
|
137 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
138 |
+
return torch.nn.functional.gelu(x.float()).type_as(x)
|
139 |
+
|
140 |
+
|
141 |
+
def get_activation_fn(activation: str):
|
142 |
+
"""Returns the activation function corresponding to `activation`"""
|
143 |
+
|
144 |
+
if activation == "relu":
|
145 |
+
return F.relu
|
146 |
+
elif activation == "gelu":
|
147 |
+
return gelu
|
148 |
+
elif activation == "gelu_fast":
|
149 |
+
warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate")
|
150 |
+
return gelu_accurate
|
151 |
+
elif activation == "gelu_accurate":
|
152 |
+
return gelu_accurate
|
153 |
+
elif activation == "tanh":
|
154 |
+
return torch.tanh
|
155 |
+
elif activation == "linear":
|
156 |
+
return lambda x: x
|
157 |
+
elif activation == "glu":
|
158 |
+
return lambda x: x
|
159 |
+
else:
|
160 |
+
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
161 |
+
|
162 |
+
|
163 |
+
def init_bert_params(module):
|
164 |
+
"""
|
165 |
+
Initialize the weights specific to the BERT Model.
|
166 |
+
This overrides the default initializations depending on the specified arguments.
|
167 |
+
1. If normal_init_linear_weights is set then weights of linear
|
168 |
+
layer will be initialized using the normal distribution and
|
169 |
+
bais will be set to the specified value.
|
170 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
171 |
+
layer will be initialized using the normal distribution.
|
172 |
+
3. If normal_init_proj_weights is set then weights of
|
173 |
+
in_project_weight for MultiHeadAttention initialized using
|
174 |
+
the normal distribution (to be validated).
|
175 |
+
"""
|
176 |
+
|
177 |
+
def normal_(data):
|
178 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
179 |
+
# so that the RNG is consistent with and without FSDP
|
180 |
+
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
181 |
+
|
182 |
+
if isinstance(module, nn.Linear):
|
183 |
+
normal_(module.weight.data)
|
184 |
+
if module.bias is not None:
|
185 |
+
module.bias.data.zero_()
|
186 |
+
if isinstance(module, nn.Embedding):
|
187 |
+
normal_(module.weight.data)
|
188 |
+
if module.padding_idx is not None:
|
189 |
+
module.weight.data[module.padding_idx].zero_()
|
190 |
+
if isinstance(module, MultiheadAttention):
|
191 |
+
normal_(module.q_proj.weight.data)
|
192 |
+
normal_(module.k_proj.weight.data)
|
193 |
+
normal_(module.v_proj.weight.data)
|
194 |
+
|
195 |
+
|
196 |
+
def quant_noise(module, p, block_size):
|
197 |
+
"""
|
198 |
+
Wraps modules and applies quantization noise to the weights for
|
199 |
+
subsequent quantization with Iterative Product Quantization as
|
200 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
201 |
+
Args:
|
202 |
+
- module: nn.Module
|
203 |
+
- p: amount of Quantization Noise
|
204 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
205 |
+
Remarks:
|
206 |
+
- Module weights must have the right sizes wrt the block size
|
207 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
208 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
209 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
210 |
+
- We implement the simplest form of noise here as stated in the paper
|
211 |
+
which consists in randomly dropping blocks
|
212 |
+
"""
|
213 |
+
|
214 |
+
# if no quantization noise, don't register hook
|
215 |
+
if p <= 0:
|
216 |
+
return module
|
217 |
+
|
218 |
+
# supported modules
|
219 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
220 |
+
|
221 |
+
# test whether module.weight has the right sizes wrt block_size
|
222 |
+
is_conv = module.weight.ndim == 4
|
223 |
+
|
224 |
+
# 2D matrix
|
225 |
+
if not is_conv:
|
226 |
+
assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
|
227 |
+
|
228 |
+
# 4D matrix
|
229 |
+
else:
|
230 |
+
# 1x1 convolutions
|
231 |
+
if module.kernel_size == (1, 1):
|
232 |
+
assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
|
233 |
+
# regular convolutions
|
234 |
+
else:
|
235 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
236 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
237 |
+
|
238 |
+
def _forward_pre_hook(mod, input):
|
239 |
+
# no noise for evaluation
|
240 |
+
if mod.training:
|
241 |
+
if not is_conv:
|
242 |
+
# gather weight and sizes
|
243 |
+
weight = mod.weight
|
244 |
+
in_features = weight.size(1)
|
245 |
+
out_features = weight.size(0)
|
246 |
+
|
247 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
248 |
+
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|
249 |
+
mask.bernoulli_(p)
|
250 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
251 |
+
|
252 |
+
else:
|
253 |
+
# gather weight and sizes
|
254 |
+
weight = mod.weight
|
255 |
+
in_channels = mod.in_channels
|
256 |
+
out_channels = mod.out_channels
|
257 |
+
|
258 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
259 |
+
if mod.kernel_size == (1, 1):
|
260 |
+
mask = torch.zeros(
|
261 |
+
int(in_channels // block_size * out_channels),
|
262 |
+
device=weight.device,
|
263 |
+
)
|
264 |
+
mask.bernoulli_(p)
|
265 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
266 |
+
else:
|
267 |
+
mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
|
268 |
+
mask.bernoulli_(p)
|
269 |
+
mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
270 |
+
|
271 |
+
# scale weights and apply mask
|
272 |
+
mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
|
273 |
+
s = 1 / (1 - p)
|
274 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
275 |
+
|
276 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
277 |
+
return module
|
278 |
+
|
279 |
+
|
280 |
+
class MultiheadAttention(nn.Module):
|
281 |
+
"""Multi-headed attention.
|
282 |
+
See "Attention Is All You Need" for more details.
|
283 |
+
"""
|
284 |
+
|
285 |
+
def __init__(
|
286 |
+
self,
|
287 |
+
embed_dim,
|
288 |
+
num_heads,
|
289 |
+
kdim=None,
|
290 |
+
vdim=None,
|
291 |
+
dropout=0.0,
|
292 |
+
bias=True,
|
293 |
+
add_bias_kv=False,
|
294 |
+
add_zero_attn=False,
|
295 |
+
self_attention=False,
|
296 |
+
encoder_decoder_attention=False,
|
297 |
+
q_noise=0.0,
|
298 |
+
qn_block_size=8,
|
299 |
+
has_relative_attention_bias=False,
|
300 |
+
num_buckets=32,
|
301 |
+
max_distance=128,
|
302 |
+
gru_rel_pos=False,
|
303 |
+
rescale_init=False,
|
304 |
+
):
|
305 |
+
super().__init__()
|
306 |
+
self.embed_dim = embed_dim
|
307 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
308 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
309 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
310 |
+
|
311 |
+
self.num_heads = num_heads
|
312 |
+
self.dropout_module = nn.Dropout(dropout)
|
313 |
+
|
314 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
315 |
+
self.num_buckets = num_buckets
|
316 |
+
self.max_distance = max_distance
|
317 |
+
if self.has_relative_attention_bias:
|
318 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
319 |
+
|
320 |
+
self.head_dim = embed_dim // num_heads
|
321 |
+
self.q_head_dim = self.head_dim
|
322 |
+
self.k_head_dim = self.head_dim
|
323 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
324 |
+
self.scaling = self.head_dim**-0.5
|
325 |
+
|
326 |
+
self.self_attention = self_attention
|
327 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
328 |
+
|
329 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
330 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
331 |
+
)
|
332 |
+
|
333 |
+
k_bias = True
|
334 |
+
if rescale_init:
|
335 |
+
k_bias = False
|
336 |
+
|
337 |
+
k_embed_dim = embed_dim
|
338 |
+
q_embed_dim = embed_dim
|
339 |
+
|
340 |
+
self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size)
|
341 |
+
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
342 |
+
self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size)
|
343 |
+
|
344 |
+
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
345 |
+
|
346 |
+
if add_bias_kv:
|
347 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
348 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
349 |
+
else:
|
350 |
+
self.bias_k = self.bias_v = None
|
351 |
+
|
352 |
+
self.add_zero_attn = add_zero_attn
|
353 |
+
|
354 |
+
self.gru_rel_pos = gru_rel_pos
|
355 |
+
if self.gru_rel_pos:
|
356 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
357 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
358 |
+
|
359 |
+
self.reset_parameters()
|
360 |
+
|
361 |
+
def reset_parameters(self):
|
362 |
+
if self.qkv_same_dim:
|
363 |
+
# Empirically observed the convergence to be much better with
|
364 |
+
# the scaled initialization
|
365 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
366 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
367 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
368 |
+
else:
|
369 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
370 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
371 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
372 |
+
|
373 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
374 |
+
if self.out_proj.bias is not None:
|
375 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
376 |
+
if self.bias_k is not None:
|
377 |
+
nn.init.xavier_normal_(self.bias_k)
|
378 |
+
if self.bias_v is not None:
|
379 |
+
nn.init.xavier_normal_(self.bias_v)
|
380 |
+
if self.has_relative_attention_bias:
|
381 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
382 |
+
|
383 |
+
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
384 |
+
num_buckets = self.num_buckets
|
385 |
+
max_distance = self.max_distance
|
386 |
+
relative_buckets = 0
|
387 |
+
|
388 |
+
if bidirectional:
|
389 |
+
num_buckets = num_buckets // 2
|
390 |
+
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
391 |
+
relative_positions = torch.abs(relative_positions)
|
392 |
+
else:
|
393 |
+
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
394 |
+
|
395 |
+
max_exact = num_buckets // 2
|
396 |
+
is_small = relative_positions < max_exact
|
397 |
+
|
398 |
+
relative_postion_if_large = max_exact + (
|
399 |
+
torch.log(relative_positions.float() / max_exact)
|
400 |
+
/ math.log(max_distance / max_exact)
|
401 |
+
* (num_buckets - max_exact)
|
402 |
+
).to(torch.long)
|
403 |
+
relative_postion_if_large = torch.min(
|
404 |
+
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
405 |
+
)
|
406 |
+
|
407 |
+
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
408 |
+
return relative_buckets
|
409 |
+
|
410 |
+
def compute_bias(self, query_length, key_length):
|
411 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
412 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
413 |
+
relative_position = memory_position - context_position
|
414 |
+
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
|
415 |
+
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
416 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
417 |
+
values = values.permute([2, 0, 1])
|
418 |
+
return values
|
419 |
+
|
420 |
+
def forward(
|
421 |
+
self,
|
422 |
+
query,
|
423 |
+
key: Optional[Tensor],
|
424 |
+
value: Optional[Tensor],
|
425 |
+
key_padding_mask: Optional[Tensor] = None,
|
426 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
427 |
+
need_weights: bool = True,
|
428 |
+
static_kv: bool = False,
|
429 |
+
attn_mask: Optional[Tensor] = None,
|
430 |
+
before_softmax: bool = False,
|
431 |
+
need_head_weights: bool = False,
|
432 |
+
position_bias: Optional[Tensor] = None,
|
433 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
434 |
+
"""Input shape: Time x Batch x Channel
|
435 |
+
Args:
|
436 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
437 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
438 |
+
padding elements are indicated by 1s.
|
439 |
+
need_weights (bool, optional): return the attention weights,
|
440 |
+
averaged over heads (default: False).
|
441 |
+
attn_mask (ByteTensor, optional): typically used to
|
442 |
+
implement causal attention, where the mask prevents the
|
443 |
+
attention from looking forward in time (default: None).
|
444 |
+
before_softmax (bool, optional): return the raw attention
|
445 |
+
weights and values before the attention softmax.
|
446 |
+
need_head_weights (bool, optional): return the attention
|
447 |
+
weights for each head. Implies *need_weights*. Default:
|
448 |
+
return the average attention weights over all heads.
|
449 |
+
"""
|
450 |
+
if need_head_weights:
|
451 |
+
need_weights = True
|
452 |
+
|
453 |
+
is_tpu = query.device.type == "xla"
|
454 |
+
|
455 |
+
tgt_len, bsz, embed_dim = query.size()
|
456 |
+
src_len = tgt_len
|
457 |
+
assert embed_dim == self.embed_dim
|
458 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
459 |
+
if key is not None:
|
460 |
+
src_len, key_bsz, _ = key.size()
|
461 |
+
if not torch.jit.is_scripting():
|
462 |
+
assert key_bsz == bsz
|
463 |
+
assert value is not None
|
464 |
+
assert src_len, bsz == value.shape[:2]
|
465 |
+
|
466 |
+
if self.has_relative_attention_bias and position_bias is None:
|
467 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
468 |
+
position_bias = (
|
469 |
+
position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
470 |
+
)
|
471 |
+
|
472 |
+
if (
|
473 |
+
not is_tpu # don't use PyTorch version on TPUs
|
474 |
+
and incremental_state is None
|
475 |
+
and not static_kv
|
476 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
477 |
+
# treats bias in linear module as method.
|
478 |
+
and not torch.jit.is_scripting()
|
479 |
+
and self.q_head_dim == self.head_dim
|
480 |
+
):
|
481 |
+
assert key is not None and value is not None
|
482 |
+
assert attn_mask is None
|
483 |
+
|
484 |
+
attn_mask_rel_pos = None
|
485 |
+
if position_bias is not None:
|
486 |
+
attn_mask_rel_pos = position_bias
|
487 |
+
if self.gru_rel_pos:
|
488 |
+
query_layer = query.transpose(0, 1)
|
489 |
+
new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
|
490 |
+
query_layer = query_layer.view(*new_x_shape)
|
491 |
+
query_layer = query_layer.permute(0, 2, 1, 3)
|
492 |
+
_B, _H, _L, __ = query_layer.size()
|
493 |
+
|
494 |
+
gate_a, gate_b = torch.sigmoid(
|
495 |
+
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
|
496 |
+
).chunk(2, dim=-1)
|
497 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
498 |
+
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
499 |
+
|
500 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
501 |
+
k_proj_bias = self.k_proj.bias
|
502 |
+
if k_proj_bias is None:
|
503 |
+
k_proj_bias = torch.zeros_like(self.q_proj.bias)
|
504 |
+
|
505 |
+
x, attn = F.multi_head_attention_forward(
|
506 |
+
query,
|
507 |
+
key,
|
508 |
+
value,
|
509 |
+
self.embed_dim,
|
510 |
+
self.num_heads,
|
511 |
+
torch.empty([0]),
|
512 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
513 |
+
self.bias_k,
|
514 |
+
self.bias_v,
|
515 |
+
self.add_zero_attn,
|
516 |
+
self.dropout_module.p,
|
517 |
+
self.out_proj.weight,
|
518 |
+
self.out_proj.bias,
|
519 |
+
self.training,
|
520 |
+
# self.training or self.dropout_module.apply_during_inference,
|
521 |
+
key_padding_mask,
|
522 |
+
need_weights,
|
523 |
+
attn_mask_rel_pos,
|
524 |
+
use_separate_proj_weight=True,
|
525 |
+
q_proj_weight=self.q_proj.weight,
|
526 |
+
k_proj_weight=self.k_proj.weight,
|
527 |
+
v_proj_weight=self.v_proj.weight,
|
528 |
+
)
|
529 |
+
return x, attn, position_bias
|
530 |
+
|
531 |
+
if incremental_state is not None:
|
532 |
+
saved_state = self._get_input_buffer(incremental_state)
|
533 |
+
if saved_state is not None and "prev_key" in saved_state:
|
534 |
+
# previous time steps are cached - no need to recompute
|
535 |
+
# key and value if they are static
|
536 |
+
if static_kv:
|
537 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
538 |
+
key = value = None
|
539 |
+
else:
|
540 |
+
saved_state = None
|
541 |
+
|
542 |
+
if self.self_attention:
|
543 |
+
q = self.q_proj(query)
|
544 |
+
k = self.k_proj(query)
|
545 |
+
v = self.v_proj(query)
|
546 |
+
elif self.encoder_decoder_attention:
|
547 |
+
# encoder-decoder attention
|
548 |
+
q = self.q_proj(query)
|
549 |
+
if key is None:
|
550 |
+
assert value is None
|
551 |
+
k = v = None
|
552 |
+
else:
|
553 |
+
k = self.k_proj(key)
|
554 |
+
v = self.v_proj(key)
|
555 |
+
|
556 |
+
else:
|
557 |
+
assert key is not None and value is not None
|
558 |
+
q = self.q_proj(query)
|
559 |
+
k = self.k_proj(key)
|
560 |
+
v = self.v_proj(value)
|
561 |
+
q *= self.scaling
|
562 |
+
|
563 |
+
if self.bias_k is not None:
|
564 |
+
assert self.bias_v is not None
|
565 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
566 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
567 |
+
if attn_mask is not None:
|
568 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
569 |
+
if key_padding_mask is not None:
|
570 |
+
key_padding_mask = torch.cat(
|
571 |
+
[
|
572 |
+
key_padding_mask,
|
573 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
574 |
+
],
|
575 |
+
dim=1,
|
576 |
+
)
|
577 |
+
|
578 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1)
|
579 |
+
if k is not None:
|
580 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1)
|
581 |
+
if v is not None:
|
582 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
583 |
+
|
584 |
+
if saved_state is not None:
|
585 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
586 |
+
if "prev_key" in saved_state:
|
587 |
+
_prev_key = saved_state["prev_key"]
|
588 |
+
assert _prev_key is not None
|
589 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
590 |
+
if static_kv:
|
591 |
+
k = prev_key
|
592 |
+
else:
|
593 |
+
assert k is not None
|
594 |
+
k = torch.cat([prev_key, k], dim=1)
|
595 |
+
src_len = k.size(1)
|
596 |
+
if "prev_value" in saved_state:
|
597 |
+
_prev_value = saved_state["prev_value"]
|
598 |
+
assert _prev_value is not None
|
599 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
600 |
+
if static_kv:
|
601 |
+
v = prev_value
|
602 |
+
else:
|
603 |
+
assert v is not None
|
604 |
+
v = torch.cat([prev_value, v], dim=1)
|
605 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
606 |
+
if "prev_key_padding_mask" in saved_state:
|
607 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
608 |
+
assert k is not None and v is not None
|
609 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
610 |
+
key_padding_mask=key_padding_mask,
|
611 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
612 |
+
batch_size=bsz,
|
613 |
+
src_len=k.size(1),
|
614 |
+
static_kv=static_kv,
|
615 |
+
)
|
616 |
+
|
617 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
618 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
619 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
620 |
+
# In this branch incremental_state is never None
|
621 |
+
assert incremental_state is not None
|
622 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
623 |
+
assert k is not None
|
624 |
+
assert k.size(1) == src_len
|
625 |
+
|
626 |
+
# This is part of a workaround to get around fork/join parallelism
|
627 |
+
# not supporting Optional types.
|
628 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
629 |
+
key_padding_mask = None
|
630 |
+
|
631 |
+
if key_padding_mask is not None:
|
632 |
+
assert key_padding_mask.size(0) == bsz
|
633 |
+
assert key_padding_mask.size(1) == src_len
|
634 |
+
|
635 |
+
if self.add_zero_attn:
|
636 |
+
assert v is not None
|
637 |
+
src_len += 1
|
638 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
639 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
640 |
+
if attn_mask is not None:
|
641 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
642 |
+
if key_padding_mask is not None:
|
643 |
+
key_padding_mask = torch.cat(
|
644 |
+
[
|
645 |
+
key_padding_mask,
|
646 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
|
647 |
+
],
|
648 |
+
dim=1,
|
649 |
+
)
|
650 |
+
|
651 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
652 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
653 |
+
|
654 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
655 |
+
|
656 |
+
if attn_mask is not None:
|
657 |
+
attn_mask = attn_mask.unsqueeze(0)
|
658 |
+
attn_weights += attn_mask
|
659 |
+
|
660 |
+
if key_padding_mask is not None:
|
661 |
+
# don't attend to padding symbols
|
662 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
663 |
+
if not is_tpu:
|
664 |
+
attn_weights = attn_weights.masked_fill(
|
665 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
666 |
+
float("-inf"),
|
667 |
+
)
|
668 |
+
else:
|
669 |
+
attn_weights = attn_weights.transpose(0, 2)
|
670 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
671 |
+
attn_weights = attn_weights.transpose(0, 2)
|
672 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
673 |
+
|
674 |
+
if before_softmax:
|
675 |
+
return attn_weights, v, position_bias
|
676 |
+
|
677 |
+
if position_bias is not None:
|
678 |
+
if self.gru_rel_pos == 1:
|
679 |
+
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
680 |
+
_B, _H, _L, __ = query_layer.size()
|
681 |
+
gate_a, gate_b = torch.sigmoid(
|
682 |
+
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
|
683 |
+
).chunk(2, dim=-1)
|
684 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
685 |
+
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
686 |
+
|
687 |
+
position_bias = position_bias.view(attn_weights.size())
|
688 |
+
|
689 |
+
attn_weights = attn_weights + position_bias
|
690 |
+
|
691 |
+
attn_weights_float = F.softmax(attn_weights, dim=-1)
|
692 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
693 |
+
attn_probs = self.dropout_module(attn_weights)
|
694 |
+
|
695 |
+
assert v is not None
|
696 |
+
attn = torch.bmm(attn_probs, v)
|
697 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
698 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
699 |
+
attn = self.out_proj(attn)
|
700 |
+
attn_weights: Optional[Tensor] = None
|
701 |
+
if need_weights:
|
702 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
703 |
+
if not need_head_weights:
|
704 |
+
# average attention weights over heads
|
705 |
+
attn_weights = attn_weights.mean(dim=0)
|
706 |
+
|
707 |
+
return attn, attn_weights, position_bias
|
708 |
+
|
709 |
+
@staticmethod
|
710 |
+
def _append_prev_key_padding_mask(
|
711 |
+
key_padding_mask: Optional[Tensor],
|
712 |
+
prev_key_padding_mask: Optional[Tensor],
|
713 |
+
batch_size: int,
|
714 |
+
src_len: int,
|
715 |
+
static_kv: bool,
|
716 |
+
) -> Optional[Tensor]:
|
717 |
+
# saved key padding masks have shape (bsz, seq_len)
|
718 |
+
if prev_key_padding_mask is not None and static_kv:
|
719 |
+
new_key_padding_mask = prev_key_padding_mask
|
720 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
721 |
+
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
|
722 |
+
# During incremental decoding, as the padding token enters and
|
723 |
+
# leaves the frame, there will be a time when prev or current
|
724 |
+
# is None
|
725 |
+
elif prev_key_padding_mask is not None:
|
726 |
+
if src_len > prev_key_padding_mask.size(1):
|
727 |
+
filler = torch.zeros(
|
728 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
729 |
+
device=prev_key_padding_mask.device,
|
730 |
+
)
|
731 |
+
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
732 |
+
else:
|
733 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
734 |
+
elif key_padding_mask is not None:
|
735 |
+
if src_len > key_padding_mask.size(1):
|
736 |
+
filler = torch.zeros(
|
737 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
738 |
+
device=key_padding_mask.device,
|
739 |
+
)
|
740 |
+
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
741 |
+
else:
|
742 |
+
new_key_padding_mask = key_padding_mask.float()
|
743 |
+
else:
|
744 |
+
new_key_padding_mask = prev_key_padding_mask
|
745 |
+
return new_key_padding_mask
|
746 |
+
|
747 |
+
def _get_input_buffer(
|
748 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
749 |
+
) -> Dict[str, Optional[Tensor]]:
|
750 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
751 |
+
if result is not None:
|
752 |
+
return result
|
753 |
+
else:
|
754 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
755 |
+
return empty_result
|
756 |
+
|
757 |
+
def _set_input_buffer(
|
758 |
+
self,
|
759 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
760 |
+
buffer: Dict[str, Optional[Tensor]],
|
761 |
+
):
|
762 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
763 |
+
|
764 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
765 |
+
return attn_weights
|
__pycache__/WavLM.cpython-311.pyc
ADDED
Binary file (38.8 kB). View file
|
|
__pycache__/WavLM_modules.cpython-311.pyc
ADDED
Binary file (39.9 kB). View file
|
|
__pycache__/data_utils.cpython-311.pyc
ADDED
Binary file (28.5 kB). View file
|
|
__pycache__/dino_game.cpython-311.pyc
ADDED
Binary file (5.35 kB). View file
|
|
__pycache__/inference_functions.cpython-311.pyc
ADDED
Binary file (20.2 kB). View file
|
|
__pycache__/landmarks_extractor.cpython-311.pyc
ADDED
Binary file (1.93 kB). View file
|
|
__pycache__/utils.cpython-311.pyc
ADDED
Binary file (13.7 kB). View file
|
|
__pycache__/vae_wrapper.cpython-311.pyc
ADDED
Binary file (8.83 kB). View file
|
|
__pycache__/wordle_game.cpython-311.pyc
ADDED
Binary file (6.88 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,978 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import tempfile
|
4 |
+
import os
|
5 |
+
from vae_wrapper import VaeWrapper, encode_video_chunk
|
6 |
+
from landmarks_extractor import LandmarksExtractor
|
7 |
+
import decord
|
8 |
+
from utils import (
|
9 |
+
get_raw_audio,
|
10 |
+
save_audio_video,
|
11 |
+
calculate_splits,
|
12 |
+
instantiate_from_config,
|
13 |
+
create_pipeline_inputs,
|
14 |
+
)
|
15 |
+
from transformers import HubertModel
|
16 |
+
from einops import rearrange
|
17 |
+
import numpy as np
|
18 |
+
from WavLM import WavLM_wrapper
|
19 |
+
from omegaconf import OmegaConf
|
20 |
+
from inference_functions import (
|
21 |
+
sample_keyframes,
|
22 |
+
sample_interpolation,
|
23 |
+
)
|
24 |
+
from wordle_game import WordleGame
|
25 |
+
import torch.cuda.amp as amp # Import amp for mixed precision
|
26 |
+
|
27 |
+
|
28 |
+
# Set default tensor type to float16 for faster computation
|
29 |
+
if torch.cuda.is_available():
|
30 |
+
# torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
31 |
+
# Enable TF32 precision for better performance on Ampere+ GPUs
|
32 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
33 |
+
torch.backends.cudnn.allow_tf32 = True
|
34 |
+
|
35 |
+
# Cache for video and audio processing
|
36 |
+
cache = {
|
37 |
+
"video": {
|
38 |
+
"path": None,
|
39 |
+
"embedding": None,
|
40 |
+
"frames": None,
|
41 |
+
"landmarks": None,
|
42 |
+
},
|
43 |
+
"audio": {
|
44 |
+
"path": None,
|
45 |
+
"raw_audio": None,
|
46 |
+
"hubert_embedding": None,
|
47 |
+
"wavlm_embedding": None,
|
48 |
+
},
|
49 |
+
}
|
50 |
+
|
51 |
+
# Create mixed precision scaler
|
52 |
+
scaler = amp.GradScaler()
|
53 |
+
|
54 |
+
|
55 |
+
def load_model(
|
56 |
+
config: str,
|
57 |
+
device: str = "cuda",
|
58 |
+
ckpt: str = None,
|
59 |
+
):
|
60 |
+
"""
|
61 |
+
Load a model from configuration.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
config: Path to model configuration file
|
65 |
+
device: Device to load the model on
|
66 |
+
num_frames: Number of frames to process
|
67 |
+
input_key: Input key for the model
|
68 |
+
ckpt: Optional checkpoint path
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
Tuple of (model, filter, batch size)
|
72 |
+
"""
|
73 |
+
config = OmegaConf.load(config)
|
74 |
+
|
75 |
+
config["model"]["params"]["input_key"] = "latents"
|
76 |
+
|
77 |
+
if ckpt is not None:
|
78 |
+
config.model.params.ckpt_path = ckpt
|
79 |
+
|
80 |
+
with torch.device(device):
|
81 |
+
model = instantiate_from_config(config.model).to(device).eval()
|
82 |
+
# Convert model to half precision
|
83 |
+
if torch.cuda.is_available():
|
84 |
+
model = model.half()
|
85 |
+
model.first_stage_model = model.first_stage_model.float()
|
86 |
+
print("Converted model to FP16 precision")
|
87 |
+
|
88 |
+
# Compile model for faster inference
|
89 |
+
if torch.cuda.is_available():
|
90 |
+
try:
|
91 |
+
model = torch.compile(model)
|
92 |
+
print(f"Successfully compiled model with torch.compile()")
|
93 |
+
except Exception as e:
|
94 |
+
print(f"Warning: Failed to compile model: {e}")
|
95 |
+
|
96 |
+
return model
|
97 |
+
|
98 |
+
|
99 |
+
# keyframe_model = KeyframeModel(device=device)
|
100 |
+
# interpolation_model = InterpolationModel(device=device)
|
101 |
+
vae_model = VaeWrapper("video")
|
102 |
+
if torch.cuda.is_available():
|
103 |
+
vae_model = vae_model.half() # Convert to half precision
|
104 |
+
try:
|
105 |
+
vae_model = torch.compile(vae_model)
|
106 |
+
print("Successfully compiled vae_model in FP16")
|
107 |
+
except Exception as e:
|
108 |
+
print(f"Warning: Failed to compile vae_model: {e}")
|
109 |
+
|
110 |
+
hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
|
111 |
+
if torch.cuda.is_available():
|
112 |
+
hubert_model = hubert_model.half() # Convert to half precision
|
113 |
+
try:
|
114 |
+
hubert_model = torch.compile(hubert_model)
|
115 |
+
print("Successfully compiled hubert_model in FP16")
|
116 |
+
except Exception as e:
|
117 |
+
print(f"Warning: Failed to compile hubert_model: {e}")
|
118 |
+
|
119 |
+
wavlm_model = WavLM_wrapper(
|
120 |
+
model_size="Base+",
|
121 |
+
feed_as_frames=False,
|
122 |
+
merge_type="None",
|
123 |
+
model_path="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/WavLM-Base+.pt",
|
124 |
+
).cuda()
|
125 |
+
if torch.cuda.is_available():
|
126 |
+
wavlm_model = wavlm_model.half() # Convert to half precision
|
127 |
+
try:
|
128 |
+
wavlm_model = torch.compile(wavlm_model)
|
129 |
+
print("Successfully compiled wavlm_model in FP16")
|
130 |
+
except Exception as e:
|
131 |
+
print(f"Warning: Failed to compile wavlm_model: {e}")
|
132 |
+
|
133 |
+
landmarks_extractor = LandmarksExtractor()
|
134 |
+
keyframe_model = load_model(
|
135 |
+
config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/keyframe.yaml",
|
136 |
+
ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/keyframe_dub.pt",
|
137 |
+
)
|
138 |
+
interpolation_model = load_model(
|
139 |
+
config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/interpolation.yaml",
|
140 |
+
ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/interpolation_dub.pt",
|
141 |
+
)
|
142 |
+
keyframe_model.en_and_decode_n_samples_a_time = 2
|
143 |
+
interpolation_model.en_and_decode_n_samples_a_time = 2
|
144 |
+
|
145 |
+
# Default media paths
|
146 |
+
DEFAULT_VIDEO_PATH = os.path.join(
|
147 |
+
os.path.dirname(__file__), "assets", "sample_video.mp4"
|
148 |
+
)
|
149 |
+
DEFAULT_AUDIO_PATH = os.path.join(
|
150 |
+
os.path.dirname(__file__), "assets", "sample_audio.wav"
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
@torch.no_grad()
|
155 |
+
def compute_video_embedding(video_reader, min_len):
|
156 |
+
"""Compute embeddings from video"""
|
157 |
+
|
158 |
+
total_frames = min_len
|
159 |
+
|
160 |
+
encoded = []
|
161 |
+
video_frames = []
|
162 |
+
chunk_size = 16
|
163 |
+
resolution = 512
|
164 |
+
|
165 |
+
# # Create a progress bar for Gradio
|
166 |
+
progress = gr.Progress()
|
167 |
+
|
168 |
+
# Calculate total chunks for progress tracking
|
169 |
+
total_chunks = (total_frames + chunk_size - 1) // chunk_size
|
170 |
+
|
171 |
+
for i, start_idx in enumerate(range(0, total_frames, chunk_size)):
|
172 |
+
# Update progress bar
|
173 |
+
progress(i / total_chunks, desc="Processing video chunks")
|
174 |
+
|
175 |
+
end_idx = min(start_idx + chunk_size, total_frames)
|
176 |
+
video_chunk = video_reader.get_batch(range(start_idx, end_idx))
|
177 |
+
# Interpolate video chunk to the target resolution
|
178 |
+
video_chunk = rearrange(video_chunk, "f h w c -> f c h w")
|
179 |
+
video_chunk = torch.nn.functional.interpolate(
|
180 |
+
video_chunk,
|
181 |
+
size=(resolution, resolution),
|
182 |
+
mode="bilinear",
|
183 |
+
align_corners=False,
|
184 |
+
)
|
185 |
+
video_chunk = rearrange(video_chunk, "f c h w -> f h w c")
|
186 |
+
video_frames.append(video_chunk)
|
187 |
+
|
188 |
+
# Convert chunk to FP16 if using CUDA
|
189 |
+
if torch.cuda.is_available():
|
190 |
+
video_chunk = video_chunk.half()
|
191 |
+
|
192 |
+
# Always use autocast for FP16 computation
|
193 |
+
with amp.autocast(enabled=True):
|
194 |
+
encoded.append(encode_video_chunk(vae_model, video_chunk, resolution))
|
195 |
+
|
196 |
+
encoded = torch.cat(encoded, dim=0)
|
197 |
+
video_frames = torch.cat(video_frames, dim=0)
|
198 |
+
video_frames = rearrange(video_frames, "f h w c -> f c h w")
|
199 |
+
torch.cuda.empty_cache()
|
200 |
+
return encoded, video_frames
|
201 |
+
|
202 |
+
|
203 |
+
@torch.no_grad()
|
204 |
+
def compute_hubert_embedding(raw_audio):
|
205 |
+
"""Compute embeddings from audio"""
|
206 |
+
print(f"Computing audio embedding from {raw_audio.shape}")
|
207 |
+
|
208 |
+
audio = (
|
209 |
+
(raw_audio - raw_audio.mean()) / torch.sqrt(raw_audio.var() + 1e-7)
|
210 |
+
).unsqueeze(0)
|
211 |
+
chunks = 16000 * 20
|
212 |
+
|
213 |
+
# Create a progress bar for Gradio
|
214 |
+
progress = gr.Progress()
|
215 |
+
|
216 |
+
# Get audio embeddings
|
217 |
+
audio_embeddings = []
|
218 |
+
splits = list(calculate_splits(audio, chunks))
|
219 |
+
total_splits = len(splits)
|
220 |
+
|
221 |
+
for i, chunk in enumerate(splits):
|
222 |
+
# Update progress bar
|
223 |
+
progress(i / total_splits, desc="Processing audio chunks")
|
224 |
+
|
225 |
+
# Convert audio chunk to half precision
|
226 |
+
if torch.cuda.is_available():
|
227 |
+
chunk_cuda = chunk.cuda().half()
|
228 |
+
else:
|
229 |
+
chunk_cuda = chunk.cuda()
|
230 |
+
|
231 |
+
# Always use autocast for FP16 computation
|
232 |
+
with amp.autocast(enabled=True):
|
233 |
+
hidden_states = hubert_model(chunk_cuda)[0]
|
234 |
+
|
235 |
+
audio_embeddings.append(hidden_states)
|
236 |
+
audio_embeddings = torch.cat(audio_embeddings, dim=1)
|
237 |
+
|
238 |
+
# audio_embeddings = self.model.wav2vec2(rearrange(audio_frames, "f s -> () (f s)"))[0]
|
239 |
+
if audio_embeddings.shape[1] % 2 != 0:
|
240 |
+
audio_embeddings = torch.cat(
|
241 |
+
[audio_embeddings, torch.zeros_like(audio_embeddings[:, :1])], dim=1
|
242 |
+
)
|
243 |
+
audio_embeddings = rearrange(audio_embeddings, "() (f d) c -> f d c", d=2)
|
244 |
+
torch.cuda.empty_cache()
|
245 |
+
|
246 |
+
return audio_embeddings
|
247 |
+
|
248 |
+
|
249 |
+
@torch.no_grad()
|
250 |
+
def compute_wavlm_embedding(raw_audio):
|
251 |
+
"""Compute embeddings from audio"""
|
252 |
+
audio = rearrange(raw_audio, "(f s) -> f s", s=640)
|
253 |
+
|
254 |
+
if audio.shape[0] % 2 != 0:
|
255 |
+
audio = torch.cat([audio, torch.zeros(1, 640)], dim=0)
|
256 |
+
chunks = 500
|
257 |
+
|
258 |
+
# Create a progress bar for Gradio
|
259 |
+
progress = gr.Progress()
|
260 |
+
|
261 |
+
# Get audio embeddings
|
262 |
+
audio_embeddings = []
|
263 |
+
splits = list(calculate_splits(audio, chunks))
|
264 |
+
total_splits = len(splits)
|
265 |
+
|
266 |
+
for i, chunk in enumerate(splits):
|
267 |
+
# Update progress bar
|
268 |
+
progress(i / total_splits, desc="Processing audio chunks")
|
269 |
+
|
270 |
+
# Convert chunk to half precision
|
271 |
+
if torch.cuda.is_available():
|
272 |
+
chunk_cuda = chunk.unsqueeze(0).cuda().half()
|
273 |
+
else:
|
274 |
+
chunk_cuda = chunk.unsqueeze(0).cuda()
|
275 |
+
|
276 |
+
# Always use autocast for FP16 computation
|
277 |
+
with amp.autocast(enabled=True):
|
278 |
+
wavlm_hidden_states = wavlm_model(chunk_cuda).squeeze(0)
|
279 |
+
|
280 |
+
audio_embeddings.append(wavlm_hidden_states)
|
281 |
+
audio_embeddings = torch.cat(audio_embeddings, dim=0)
|
282 |
+
|
283 |
+
torch.cuda.empty_cache()
|
284 |
+
|
285 |
+
return audio_embeddings
|
286 |
+
|
287 |
+
|
288 |
+
@torch.no_grad()
|
289 |
+
def extract_video_landmarks(video_frames):
|
290 |
+
"""Extract landmarks from video frames"""
|
291 |
+
|
292 |
+
# Create a progress bar for Gradio
|
293 |
+
progress = gr.Progress()
|
294 |
+
|
295 |
+
landmarks = []
|
296 |
+
batch_size = 10
|
297 |
+
|
298 |
+
for i in range(0, len(video_frames), batch_size):
|
299 |
+
# Update progress bar
|
300 |
+
progress(i / len(video_frames), desc="Extracting facial landmarks")
|
301 |
+
|
302 |
+
batch = video_frames[i : i + batch_size].cpu().float()
|
303 |
+
batch_landmarks = landmarks_extractor.extract_landmarks(batch)
|
304 |
+
landmarks.extend(batch_landmarks)
|
305 |
+
|
306 |
+
torch.cuda.empty_cache()
|
307 |
+
|
308 |
+
# Convert landmarks to a list of numpy arrays with consistent shape
|
309 |
+
processed_landmarks = []
|
310 |
+
|
311 |
+
expected_shape = (68, 2) # Common shape for facial landmarks
|
312 |
+
|
313 |
+
# Process each landmark to ensure consistent shape
|
314 |
+
last_valid_landmark = None
|
315 |
+
for i, lm in enumerate(landmarks):
|
316 |
+
if lm is not None and isinstance(lm, np.ndarray) and lm.shape == expected_shape:
|
317 |
+
processed_landmarks.append(lm)
|
318 |
+
last_valid_landmark = lm
|
319 |
+
else:
|
320 |
+
# Print information about inconsistent landmarks
|
321 |
+
if lm is None:
|
322 |
+
print(f"Warning: Landmark at index {i} is None")
|
323 |
+
elif not isinstance(lm, np.ndarray):
|
324 |
+
print(
|
325 |
+
f"Warning: Landmark at index {i} is not a numpy array, type: {type(lm)}"
|
326 |
+
)
|
327 |
+
elif lm.shape != expected_shape:
|
328 |
+
print(
|
329 |
+
f"Warning: Landmark at index {i} has shape {lm.shape}, expected {expected_shape}"
|
330 |
+
)
|
331 |
+
|
332 |
+
# Replace invalid landmarks with the closest valid landmark if available
|
333 |
+
if last_valid_landmark is not None:
|
334 |
+
processed_landmarks.append(last_valid_landmark.copy())
|
335 |
+
else:
|
336 |
+
# If no valid landmark has been seen yet, look ahead for a valid one
|
337 |
+
found_future_valid = False
|
338 |
+
for future_lm in landmarks[i + 1 :]:
|
339 |
+
if (
|
340 |
+
future_lm is not None
|
341 |
+
and isinstance(future_lm, np.ndarray)
|
342 |
+
and future_lm.shape == expected_shape
|
343 |
+
):
|
344 |
+
processed_landmarks.append(future_lm.copy())
|
345 |
+
found_future_valid = True
|
346 |
+
break
|
347 |
+
|
348 |
+
# If no valid landmark found in the future, use zeros
|
349 |
+
if not found_future_valid:
|
350 |
+
processed_landmarks.append(np.zeros(expected_shape))
|
351 |
+
|
352 |
+
return np.array(processed_landmarks)
|
353 |
+
|
354 |
+
|
355 |
+
@torch.no_grad()
|
356 |
+
def sample(
|
357 |
+
audio_list,
|
358 |
+
gt_keyframes,
|
359 |
+
masks_keyframes,
|
360 |
+
to_remove,
|
361 |
+
test_keyframes_list,
|
362 |
+
num_frames,
|
363 |
+
device,
|
364 |
+
emb,
|
365 |
+
force_uc_zero_embeddings,
|
366 |
+
n_batch_keyframes,
|
367 |
+
n_batch,
|
368 |
+
test_interpolation_list,
|
369 |
+
audio_interpolation_list,
|
370 |
+
masks_interpolation,
|
371 |
+
gt_interpolation,
|
372 |
+
model_keyframes,
|
373 |
+
model,
|
374 |
+
):
|
375 |
+
# Create a progress bar for Gradio
|
376 |
+
progress = gr.Progress()
|
377 |
+
|
378 |
+
condition = torch.zeros(1, 3, 512, 512).to(device)
|
379 |
+
if torch.cuda.is_available():
|
380 |
+
condition = condition.half()
|
381 |
+
|
382 |
+
audio_list = rearrange(audio_list, "(b t) c d -> b t c d", t=num_frames)
|
383 |
+
gt_keyframes = rearrange(gt_keyframes, "(b t) c h w -> b t c h w", t=num_frames)
|
384 |
+
# Rearrange masks_keyframes and save locally
|
385 |
+
masks_keyframes = rearrange(
|
386 |
+
masks_keyframes, "(b t) c h w -> b t c h w", t=num_frames
|
387 |
+
)
|
388 |
+
|
389 |
+
# Convert to_remove into chunks of num_frames
|
390 |
+
to_remove_chunks = [
|
391 |
+
to_remove[i : i + num_frames] for i in range(0, len(to_remove), num_frames)
|
392 |
+
]
|
393 |
+
test_keyframes_list = [
|
394 |
+
test_keyframes_list[i : i + num_frames]
|
395 |
+
for i in range(0, len(test_keyframes_list), num_frames)
|
396 |
+
]
|
397 |
+
|
398 |
+
audio_cond = audio_list
|
399 |
+
if emb is not None:
|
400 |
+
embbedings = emb.unsqueeze(0).to(device)
|
401 |
+
if torch.cuda.is_available():
|
402 |
+
embbedings = embbedings.half()
|
403 |
+
else:
|
404 |
+
embbedings = None
|
405 |
+
|
406 |
+
# One batch of keframes is approximately 7 seconds
|
407 |
+
chunk_size = 2
|
408 |
+
complete_video = []
|
409 |
+
start_idx = 0
|
410 |
+
last_frame_z = None
|
411 |
+
last_frame_x = None
|
412 |
+
last_keyframe_idx = None
|
413 |
+
last_to_remove = None
|
414 |
+
|
415 |
+
total_chunks = (len(audio_cond) + chunk_size - 1) // chunk_size
|
416 |
+
|
417 |
+
for chunk_idx, chunk_start in enumerate(range(0, len(audio_cond), chunk_size)):
|
418 |
+
# Update progress bar
|
419 |
+
progress(chunk_idx / total_chunks, desc="Generating video")
|
420 |
+
|
421 |
+
# Clear GPU cache between chunks
|
422 |
+
torch.cuda.empty_cache()
|
423 |
+
|
424 |
+
chunk_end = min(chunk_start + chunk_size, len(audio_cond))
|
425 |
+
|
426 |
+
chunk_audio_cond = audio_cond[chunk_start:chunk_end].cuda()
|
427 |
+
if torch.cuda.is_available():
|
428 |
+
chunk_audio_cond = chunk_audio_cond.half()
|
429 |
+
|
430 |
+
chunk_gt_keyframes = gt_keyframes[chunk_start:chunk_end].cuda()
|
431 |
+
chunk_masks = masks_keyframes[chunk_start:chunk_end].cuda()
|
432 |
+
|
433 |
+
if torch.cuda.is_available():
|
434 |
+
chunk_gt_keyframes = chunk_gt_keyframes.half()
|
435 |
+
chunk_masks = chunk_masks.half()
|
436 |
+
|
437 |
+
test_keyframes_list_unwrapped = [
|
438 |
+
elem
|
439 |
+
for sublist in test_keyframes_list[chunk_start:chunk_end]
|
440 |
+
for elem in sublist
|
441 |
+
]
|
442 |
+
to_remove_chunks_unwrapped = [
|
443 |
+
elem
|
444 |
+
for sublist in to_remove_chunks[chunk_start:chunk_end]
|
445 |
+
for elem in sublist
|
446 |
+
]
|
447 |
+
|
448 |
+
if last_keyframe_idx is not None:
|
449 |
+
test_keyframes_list_unwrapped = [
|
450 |
+
last_keyframe_idx
|
451 |
+
] + test_keyframes_list_unwrapped
|
452 |
+
to_remove_chunks_unwrapped = [last_to_remove] + to_remove_chunks_unwrapped
|
453 |
+
|
454 |
+
last_keyframe_idx = test_keyframes_list_unwrapped[-1]
|
455 |
+
last_to_remove = to_remove_chunks_unwrapped[-1]
|
456 |
+
# Find the first non-None keyframe in the chunk
|
457 |
+
first_keyframe = next(
|
458 |
+
(kf for kf in test_keyframes_list_unwrapped if kf is not None), None
|
459 |
+
)
|
460 |
+
|
461 |
+
# Find the last non-None keyframe in the chunk
|
462 |
+
last_keyframe = next(
|
463 |
+
(kf for kf in reversed(test_keyframes_list_unwrapped) if kf is not None),
|
464 |
+
None,
|
465 |
+
)
|
466 |
+
|
467 |
+
start_idx = next(
|
468 |
+
(
|
469 |
+
idx
|
470 |
+
for idx, comb in enumerate(test_interpolation_list)
|
471 |
+
if comb[0] == first_keyframe
|
472 |
+
),
|
473 |
+
None,
|
474 |
+
)
|
475 |
+
end_idx = next(
|
476 |
+
(
|
477 |
+
idx
|
478 |
+
for idx, comb in enumerate(reversed(test_interpolation_list))
|
479 |
+
if comb[1] == last_keyframe
|
480 |
+
),
|
481 |
+
None,
|
482 |
+
)
|
483 |
+
|
484 |
+
if start_idx is not None and end_idx is not None:
|
485 |
+
end_idx = (
|
486 |
+
len(test_interpolation_list) - 1 - end_idx
|
487 |
+
) # Adjust for reversed enumeration
|
488 |
+
end_idx += 1
|
489 |
+
if start_idx is None:
|
490 |
+
break
|
491 |
+
if end_idx < start_idx:
|
492 |
+
end_idx = len(audio_interpolation_list)
|
493 |
+
|
494 |
+
audio_interpolation_list_chunk = audio_interpolation_list[start_idx:end_idx]
|
495 |
+
chunk_masks_interpolation = masks_interpolation[start_idx:end_idx]
|
496 |
+
gt_interpolation_chunks = gt_interpolation[start_idx:end_idx]
|
497 |
+
|
498 |
+
if torch.cuda.is_available():
|
499 |
+
audio_interpolation_list_chunk = [
|
500 |
+
chunk.half() for chunk in audio_interpolation_list_chunk
|
501 |
+
]
|
502 |
+
chunk_masks_interpolation = [
|
503 |
+
chunk.half() for chunk in chunk_masks_interpolation
|
504 |
+
]
|
505 |
+
gt_interpolation_chunks = [
|
506 |
+
chunk.half() for chunk in gt_interpolation_chunks
|
507 |
+
]
|
508 |
+
|
509 |
+
progress(chunk_idx / total_chunks, desc="Generating keyframes")
|
510 |
+
|
511 |
+
# Always use autocast for FP16 computation
|
512 |
+
with amp.autocast(enabled=True):
|
513 |
+
samples_z = sample_keyframes(
|
514 |
+
model_keyframes,
|
515 |
+
chunk_audio_cond,
|
516 |
+
chunk_gt_keyframes,
|
517 |
+
chunk_masks,
|
518 |
+
condition.cuda(),
|
519 |
+
num_frames,
|
520 |
+
24,
|
521 |
+
0.0,
|
522 |
+
device,
|
523 |
+
embbedings.cuda() if embbedings is not None else None,
|
524 |
+
force_uc_zero_embeddings,
|
525 |
+
n_batch_keyframes,
|
526 |
+
0,
|
527 |
+
1.0,
|
528 |
+
None,
|
529 |
+
gt_as_cond=False,
|
530 |
+
)
|
531 |
+
|
532 |
+
if last_frame_x is not None:
|
533 |
+
# samples_x = torch.cat([last_frame_x.unsqueeze(0), samples_x], axis=0)
|
534 |
+
samples_z = torch.cat([last_frame_z.unsqueeze(0), samples_z], axis=0)
|
535 |
+
|
536 |
+
# last_frame_x = samples_x[-1]
|
537 |
+
last_frame_z = samples_z[-1]
|
538 |
+
|
539 |
+
progress(chunk_idx / total_chunks, desc="Interpolating frames")
|
540 |
+
|
541 |
+
# Always use autocast for FP16 computation
|
542 |
+
with amp.autocast(enabled=True):
|
543 |
+
vid = sample_interpolation(
|
544 |
+
model,
|
545 |
+
samples_z,
|
546 |
+
# samples_x,
|
547 |
+
audio_interpolation_list_chunk,
|
548 |
+
gt_interpolation_chunks,
|
549 |
+
chunk_masks_interpolation,
|
550 |
+
condition.cuda(),
|
551 |
+
num_frames,
|
552 |
+
device,
|
553 |
+
1,
|
554 |
+
24,
|
555 |
+
0.0,
|
556 |
+
force_uc_zero_embeddings,
|
557 |
+
n_batch,
|
558 |
+
chunk_size,
|
559 |
+
1.0,
|
560 |
+
None,
|
561 |
+
cut_audio=False,
|
562 |
+
to_remove=to_remove_chunks_unwrapped,
|
563 |
+
)
|
564 |
+
|
565 |
+
if chunk_start == 0:
|
566 |
+
complete_video = vid
|
567 |
+
else:
|
568 |
+
complete_video = np.concatenate([complete_video[:-1], vid], axis=0)
|
569 |
+
|
570 |
+
return complete_video
|
571 |
+
|
572 |
+
|
573 |
+
def process_video(video_input, audio_input, max_num_seconds):
|
574 |
+
"""Main processing function to generate synchronized video"""
|
575 |
+
|
576 |
+
# Display a message to the user about the processing time
|
577 |
+
gr.Info("Processing video. This may take a while...", duration=10)
|
578 |
+
gr.Info(
|
579 |
+
"If you're tired of waiting, try playing the Wordle game in the other tab!",
|
580 |
+
duration=10,
|
581 |
+
)
|
582 |
+
|
583 |
+
# Use default media if none provided
|
584 |
+
if video_input is None:
|
585 |
+
video_input = DEFAULT_VIDEO_PATH
|
586 |
+
print(f"Using default video: {DEFAULT_VIDEO_PATH}")
|
587 |
+
|
588 |
+
if audio_input is None:
|
589 |
+
audio_input = DEFAULT_AUDIO_PATH
|
590 |
+
print(f"Using default audio: {DEFAULT_AUDIO_PATH}")
|
591 |
+
|
592 |
+
try:
|
593 |
+
# Calculate hashes for cache keys
|
594 |
+
video_path_hash = video_input
|
595 |
+
audio_path_hash = audio_input
|
596 |
+
|
597 |
+
# Check if we need to recompute video embeddings
|
598 |
+
video_cache_hit = cache["video"]["path"] == video_path_hash
|
599 |
+
audio_cache_hit = cache["audio"]["path"] == audio_path_hash
|
600 |
+
|
601 |
+
if video_cache_hit and audio_cache_hit:
|
602 |
+
print("Using cached video and audio computations")
|
603 |
+
# Make copies of cached data to avoid modifying cache
|
604 |
+
video_embedding = cache["video"]["embedding"].clone()
|
605 |
+
video_frames = cache["video"]["frames"].clone()
|
606 |
+
video_landmarks = cache["video"]["landmarks"].copy()
|
607 |
+
raw_audio = cache["audio"]["raw_audio"].clone()
|
608 |
+
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
|
609 |
+
hubert_embedding = cache["audio"]["hubert_embedding"].clone()
|
610 |
+
wavlm_embedding = cache["audio"]["wavlm_embedding"].clone()
|
611 |
+
|
612 |
+
# Ensure all data is truncated to the same length if needed
|
613 |
+
min_len = min(
|
614 |
+
len(video_frames),
|
615 |
+
len(raw_audio),
|
616 |
+
len(hubert_embedding),
|
617 |
+
len(wavlm_embedding),
|
618 |
+
)
|
619 |
+
video_frames = video_frames[:min_len]
|
620 |
+
video_embedding = video_embedding[:min_len]
|
621 |
+
video_landmarks = video_landmarks[:min_len]
|
622 |
+
raw_audio = raw_audio[:min_len]
|
623 |
+
hubert_embedding = hubert_embedding[:min_len]
|
624 |
+
wavlm_embedding = wavlm_embedding[:min_len]
|
625 |
+
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
|
626 |
+
|
627 |
+
else:
|
628 |
+
# Process video if needed
|
629 |
+
if not video_cache_hit:
|
630 |
+
print("Computing video embeddings and landmarks")
|
631 |
+
video_reader = decord.VideoReader(video_input)
|
632 |
+
decord.bridge.set_bridge("torch")
|
633 |
+
|
634 |
+
if not audio_cache_hit:
|
635 |
+
# Need to process audio to determine min_len
|
636 |
+
raw_audio = get_raw_audio(audio_input, 16000)
|
637 |
+
if len(raw_audio) == 0 or len(video_reader) == 0:
|
638 |
+
raise ValueError("Empty audio or video input")
|
639 |
+
|
640 |
+
min_len = min(len(raw_audio), len(video_reader))
|
641 |
+
|
642 |
+
# Store full audio in cache
|
643 |
+
cache["audio"]["path"] = audio_path_hash
|
644 |
+
cache["audio"]["raw_audio"] = raw_audio.clone()
|
645 |
+
|
646 |
+
# Create truncated copy for processing
|
647 |
+
raw_audio = raw_audio[:min_len]
|
648 |
+
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
|
649 |
+
else:
|
650 |
+
# Use cached audio - make a copy
|
651 |
+
if cache["audio"]["raw_audio"] is None:
|
652 |
+
raise ValueError("Cached audio is None")
|
653 |
+
|
654 |
+
raw_audio = cache["audio"]["raw_audio"].clone()
|
655 |
+
if len(raw_audio) == 0 or len(video_reader) == 0:
|
656 |
+
raise ValueError("Empty cached audio or video input")
|
657 |
+
|
658 |
+
min_len = min(len(raw_audio), len(video_reader))
|
659 |
+
|
660 |
+
# Create truncated copy for processing
|
661 |
+
raw_audio = raw_audio[:min_len]
|
662 |
+
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
|
663 |
+
|
664 |
+
# Compute video embeddings and landmarks - store full version in cache
|
665 |
+
video_embedding, video_frames = compute_video_embedding(
|
666 |
+
video_reader, len(video_reader)
|
667 |
+
)
|
668 |
+
video_landmarks = extract_video_landmarks(video_frames)
|
669 |
+
|
670 |
+
# Update video cache with full versions
|
671 |
+
cache["video"]["path"] = video_path_hash
|
672 |
+
cache["video"]["embedding"] = video_embedding
|
673 |
+
cache["video"]["frames"] = video_frames
|
674 |
+
cache["video"]["landmarks"] = video_landmarks
|
675 |
+
|
676 |
+
# Create truncated copies for processing
|
677 |
+
video_embedding = video_embedding[:min_len]
|
678 |
+
video_frames = video_frames[:min_len]
|
679 |
+
video_landmarks = video_landmarks[:min_len]
|
680 |
+
|
681 |
+
else:
|
682 |
+
# Use cached video data - make copies
|
683 |
+
print("Using cached video computations")
|
684 |
+
|
685 |
+
if (
|
686 |
+
cache["video"]["embedding"] is None
|
687 |
+
or cache["video"]["frames"] is None
|
688 |
+
or cache["video"]["landmarks"] is None
|
689 |
+
):
|
690 |
+
raise ValueError("One or more video cache entries are None")
|
691 |
+
|
692 |
+
if not audio_cache_hit:
|
693 |
+
# New audio with cached video
|
694 |
+
raw_audio = get_raw_audio(audio_input, 16000)
|
695 |
+
if len(raw_audio) == 0:
|
696 |
+
raise ValueError("Empty audio input")
|
697 |
+
|
698 |
+
# Store full audio in cache
|
699 |
+
cache["audio"]["path"] = audio_path_hash
|
700 |
+
cache["audio"]["raw_audio"] = raw_audio.clone()
|
701 |
+
|
702 |
+
# Make copies of video data
|
703 |
+
video_embedding = cache["video"]["embedding"].clone()
|
704 |
+
video_frames = cache["video"]["frames"].clone()
|
705 |
+
video_landmarks = cache["video"]["landmarks"].copy()
|
706 |
+
|
707 |
+
# Determine truncation length and create truncated copies
|
708 |
+
min_len = min(len(raw_audio), len(video_frames))
|
709 |
+
raw_audio = raw_audio[:min_len]
|
710 |
+
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
|
711 |
+
video_frames = video_frames[:min_len]
|
712 |
+
video_embedding = video_embedding[:min_len]
|
713 |
+
video_landmarks = video_landmarks[:min_len]
|
714 |
+
else:
|
715 |
+
# Both video and audio are cached - should not reach here
|
716 |
+
# as it's handled in the first if statement
|
717 |
+
pass
|
718 |
+
|
719 |
+
# Process audio if needed
|
720 |
+
if not audio_cache_hit:
|
721 |
+
print("Computing audio embeddings")
|
722 |
+
|
723 |
+
# Compute audio embeddings with the truncated audio
|
724 |
+
hubert_embedding = compute_hubert_embedding(raw_audio_reshape)
|
725 |
+
wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape)
|
726 |
+
|
727 |
+
# Update audio cache with full embeddings
|
728 |
+
# Note: raw_audio was already cached above
|
729 |
+
cache["audio"]["hubert_embedding"] = hubert_embedding.clone()
|
730 |
+
cache["audio"]["wavlm_embedding"] = wavlm_embedding.clone()
|
731 |
+
else:
|
732 |
+
# Use cached audio data - make copies
|
733 |
+
if (
|
734 |
+
cache["audio"]["hubert_embedding"] is None
|
735 |
+
or cache["audio"]["wavlm_embedding"] is None
|
736 |
+
):
|
737 |
+
raise ValueError(
|
738 |
+
"One or more audio embedding cache entries are None"
|
739 |
+
)
|
740 |
+
|
741 |
+
hubert_embedding = cache["audio"]["hubert_embedding"].clone()
|
742 |
+
wavlm_embedding = cache["audio"]["wavlm_embedding"].clone()
|
743 |
+
|
744 |
+
# Make sure embeddings match the truncated video length if needed
|
745 |
+
if "min_len" in locals() and (
|
746 |
+
min_len < len(hubert_embedding) or min_len < len(wavlm_embedding)
|
747 |
+
):
|
748 |
+
hubert_embedding = hubert_embedding[:min_len]
|
749 |
+
wavlm_embedding = wavlm_embedding[:min_len]
|
750 |
+
|
751 |
+
# Apply max_num_seconds limit if specified
|
752 |
+
if max_num_seconds > 0:
|
753 |
+
# Convert seconds to frames (assuming 25 fps)
|
754 |
+
max_frames = int(max_num_seconds * 25)
|
755 |
+
|
756 |
+
# Truncate all data to max_frames
|
757 |
+
video_embedding = video_embedding[:max_frames]
|
758 |
+
video_frames = video_frames[:max_frames]
|
759 |
+
video_landmarks = video_landmarks[:max_frames]
|
760 |
+
hubert_embedding = hubert_embedding[:max_frames]
|
761 |
+
wavlm_embedding = wavlm_embedding[:max_frames]
|
762 |
+
raw_audio = raw_audio[:max_frames]
|
763 |
+
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
|
764 |
+
|
765 |
+
# Validate shapes before proceeding
|
766 |
+
assert video_embedding.shape[0] == hubert_embedding.shape[0], (
|
767 |
+
f"Video embedding length ({video_embedding.shape[0]}) doesn't match Hubert embedding length ({hubert_embedding.shape[0]})"
|
768 |
+
)
|
769 |
+
assert video_embedding.shape[0] == wavlm_embedding.shape[0], (
|
770 |
+
f"Video embedding length ({video_embedding.shape[0]}) doesn't match WavLM embedding length ({wavlm_embedding.shape[0]})"
|
771 |
+
)
|
772 |
+
assert video_embedding.shape[0] == video_landmarks.shape[0], (
|
773 |
+
f"Video embedding length ({video_embedding.shape[0]}) doesn't match landmarks length ({video_landmarks.shape[0]})"
|
774 |
+
)
|
775 |
+
|
776 |
+
print(f"Hubert embedding shape: {hubert_embedding.shape}")
|
777 |
+
print(f"WavLM embedding shape: {wavlm_embedding.shape}")
|
778 |
+
print(f"Video embedding shape: {video_embedding.shape}")
|
779 |
+
print(f"Video landmarks shape: {video_landmarks.shape}")
|
780 |
+
|
781 |
+
# Create pipeline inputs for models
|
782 |
+
(
|
783 |
+
interpolation_chunks,
|
784 |
+
keyframe_chunks,
|
785 |
+
audio_interpolation_chunks,
|
786 |
+
audio_keyframe_chunks,
|
787 |
+
emb_cond,
|
788 |
+
masks_keyframe_chunks,
|
789 |
+
masks_interpolation_chunks,
|
790 |
+
to_remove,
|
791 |
+
audio_interpolation_idx,
|
792 |
+
audio_keyframe_idx,
|
793 |
+
) = create_pipeline_inputs(
|
794 |
+
hubert_embedding,
|
795 |
+
wavlm_embedding,
|
796 |
+
14,
|
797 |
+
video_embedding,
|
798 |
+
video_landmarks,
|
799 |
+
overlap=1,
|
800 |
+
add_zero_flag=True,
|
801 |
+
mask_arms=None,
|
802 |
+
nose_index=28,
|
803 |
+
)
|
804 |
+
|
805 |
+
complete_video = sample(
|
806 |
+
audio_keyframe_chunks,
|
807 |
+
keyframe_chunks,
|
808 |
+
masks_keyframe_chunks,
|
809 |
+
to_remove,
|
810 |
+
audio_keyframe_idx,
|
811 |
+
14,
|
812 |
+
"cuda",
|
813 |
+
emb_cond,
|
814 |
+
[],
|
815 |
+
3,
|
816 |
+
3,
|
817 |
+
audio_interpolation_idx,
|
818 |
+
audio_interpolation_chunks,
|
819 |
+
masks_interpolation_chunks,
|
820 |
+
interpolation_chunks,
|
821 |
+
keyframe_model,
|
822 |
+
interpolation_model,
|
823 |
+
)
|
824 |
+
|
825 |
+
complete_audio = rearrange(
|
826 |
+
raw_audio[: complete_video.shape[0]], "f s -> () (f s)"
|
827 |
+
)
|
828 |
+
|
829 |
+
# 4. Convert frames to video and combine with audio
|
830 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
|
831 |
+
output_path = temp_video.name
|
832 |
+
|
833 |
+
print("Saving video to", output_path)
|
834 |
+
|
835 |
+
save_audio_video(complete_video, audio=complete_audio, save_path=output_path)
|
836 |
+
torch.cuda.empty_cache()
|
837 |
+
return output_path
|
838 |
+
|
839 |
+
except Exception as e:
|
840 |
+
raise e
|
841 |
+
print(f"Error processing video: {str(e)}")
|
842 |
+
return None
|
843 |
+
|
844 |
+
|
845 |
+
def get_max_duration(video_input, audio_input):
|
846 |
+
"""Get the maximum duration in seconds for the slider"""
|
847 |
+
try:
|
848 |
+
# Default to 60 seconds if files don't exist
|
849 |
+
if video_input is None or not os.path.exists(video_input):
|
850 |
+
video_input = DEFAULT_VIDEO_PATH
|
851 |
+
|
852 |
+
if audio_input is None or not os.path.exists(audio_input):
|
853 |
+
audio_input = DEFAULT_AUDIO_PATH
|
854 |
+
|
855 |
+
# Get video duration
|
856 |
+
video_reader = decord.VideoReader(video_input)
|
857 |
+
video_duration = len(video_reader) / video_reader.get_avg_fps()
|
858 |
+
|
859 |
+
# Get audio duration
|
860 |
+
raw_audio = get_raw_audio(audio_input, 16000)
|
861 |
+
audio_duration = len(raw_audio) / 25 # Assuming 25 fps
|
862 |
+
|
863 |
+
# Return the minimum of the two durations
|
864 |
+
return min(video_duration, audio_duration)
|
865 |
+
except Exception as e:
|
866 |
+
print(f"Error getting max duration: {str(e)}")
|
867 |
+
return 60 # Default to 60 seconds
|
868 |
+
|
869 |
+
|
870 |
+
def new_game_click(state):
|
871 |
+
"""Handle the 'New Game' button click."""
|
872 |
+
message = state.new_game()
|
873 |
+
feedback_history = state.get_feedback_history()
|
874 |
+
return state, feedback_history, message
|
875 |
+
|
876 |
+
|
877 |
+
def submit_guess_click(guess, state):
|
878 |
+
"""Handle the 'Submit Guess' button click."""
|
879 |
+
message = state.submit_guess(guess)
|
880 |
+
feedback_history = state.get_feedback_history()
|
881 |
+
return state, feedback_history, message
|
882 |
+
|
883 |
+
|
884 |
+
# Create Gradio interface
|
885 |
+
with gr.Blocks(title="Video Synchronization with Diffusion Models") as demo:
|
886 |
+
gr.Markdown("# Video Synchronization with Diffusion Models")
|
887 |
+
gr.Markdown(
|
888 |
+
"Upload a video and audio to create a synchronized video with the same visuals but synchronized to the new audio."
|
889 |
+
)
|
890 |
+
|
891 |
+
with gr.Tabs():
|
892 |
+
with gr.TabItem("Video Synchronization"):
|
893 |
+
with gr.Row():
|
894 |
+
with gr.Column():
|
895 |
+
video_input = gr.Video(
|
896 |
+
label="Input Video",
|
897 |
+
value=DEFAULT_VIDEO_PATH
|
898 |
+
if os.path.exists(DEFAULT_VIDEO_PATH)
|
899 |
+
else None,
|
900 |
+
width=512,
|
901 |
+
height=512,
|
902 |
+
)
|
903 |
+
audio_input = gr.Audio(
|
904 |
+
label="Input Audio",
|
905 |
+
type="filepath",
|
906 |
+
value=DEFAULT_AUDIO_PATH
|
907 |
+
if os.path.exists(DEFAULT_AUDIO_PATH)
|
908 |
+
else None,
|
909 |
+
)
|
910 |
+
|
911 |
+
max_duration = gr.State(value=60) # Default max duration
|
912 |
+
|
913 |
+
max_seconds_slider = gr.Slider(
|
914 |
+
minimum=0,
|
915 |
+
maximum=60, # Will be updated dynamically
|
916 |
+
value=0,
|
917 |
+
step=1,
|
918 |
+
label="Max Duration (seconds, 0 = full length)",
|
919 |
+
info="Limit the processing duration (0 means use full length)",
|
920 |
+
)
|
921 |
+
|
922 |
+
process_button = gr.Button("Generate Synchronized Video")
|
923 |
+
|
924 |
+
with gr.Column("Output Video"):
|
925 |
+
video_output = gr.Video(label="Output Video", width=512, height=512)
|
926 |
+
|
927 |
+
# Update slider max value when inputs change
|
928 |
+
def update_slider_max(video, audio):
|
929 |
+
max_dur = get_max_duration(video, audio)
|
930 |
+
return {"maximum": max_dur, "__type__": "update"}
|
931 |
+
|
932 |
+
video_input.change(
|
933 |
+
update_slider_max, [video_input, audio_input], [max_seconds_slider]
|
934 |
+
)
|
935 |
+
audio_input.change(
|
936 |
+
update_slider_max, [video_input, audio_input], [max_seconds_slider]
|
937 |
+
)
|
938 |
+
|
939 |
+
# Show Wordle message when processing starts and hide when complete
|
940 |
+
process_button.click(
|
941 |
+
fn=process_video,
|
942 |
+
inputs=[video_input, audio_input, max_seconds_slider],
|
943 |
+
outputs=video_output,
|
944 |
+
)
|
945 |
+
|
946 |
+
with gr.TabItem("Wordle Game"):
|
947 |
+
state = gr.State(WordleGame()) # Persist the WordleGame instance
|
948 |
+
guess_input = gr.Textbox(label="Your guess (5 letters)", max_length=5)
|
949 |
+
submit_btn = gr.Button("Submit Guess")
|
950 |
+
new_game_btn = gr.Button("New Game")
|
951 |
+
feedback_display = gr.HTML(label="Guesses")
|
952 |
+
message_display = gr.Textbox(
|
953 |
+
label="Message", interactive=False, value="Click 'New Game' to start."
|
954 |
+
)
|
955 |
+
# Connect the 'New Game' button
|
956 |
+
new_game_btn.click(
|
957 |
+
fn=new_game_click,
|
958 |
+
inputs=[state],
|
959 |
+
outputs=[state, feedback_display, message_display],
|
960 |
+
)
|
961 |
+
# Connect the 'Submit Guess' button
|
962 |
+
submit_btn.click(
|
963 |
+
fn=submit_guess_click,
|
964 |
+
inputs=[guess_input, state],
|
965 |
+
outputs=[state, feedback_display, message_display],
|
966 |
+
)
|
967 |
+
|
968 |
+
gr.Markdown("## How it works")
|
969 |
+
gr.Markdown("""
|
970 |
+
1. The system extracts embeddings and landmarks from the input video
|
971 |
+
2. Audio embeddings are computed from the input audio
|
972 |
+
3. A keyframe model generates key visual frames
|
973 |
+
4. An interpolation model creates a smooth video between keyframes
|
974 |
+
5. The final video is rendered with the new audio
|
975 |
+
""")
|
976 |
+
|
977 |
+
if __name__ == "__main__":
|
978 |
+
demo.launch()
|
data_utils.py
ADDED
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageDraw
|
4 |
+
import cv2
|
5 |
+
from functools import partial
|
6 |
+
import math
|
7 |
+
|
8 |
+
|
9 |
+
def get_size(img):
|
10 |
+
if isinstance(img, (np.ndarray, torch.Tensor)):
|
11 |
+
return img.shape[1::-1]
|
12 |
+
else:
|
13 |
+
return img.size
|
14 |
+
|
15 |
+
|
16 |
+
def imresample(img, sz):
|
17 |
+
im_data = torch.nn.functional.interpolate(img, size=sz, mode="area")
|
18 |
+
return im_data
|
19 |
+
|
20 |
+
|
21 |
+
def crop_resize(img, box, image_size):
|
22 |
+
if isinstance(img, np.ndarray):
|
23 |
+
img = img[box[1] : box[3], box[0] : box[2]]
|
24 |
+
out = cv2.resize(
|
25 |
+
img, (image_size, image_size), interpolation=cv2.INTER_AREA
|
26 |
+
).copy()
|
27 |
+
elif isinstance(img, torch.Tensor):
|
28 |
+
img = img[box[1] : box[3], box[0] : box[2]]
|
29 |
+
out = (
|
30 |
+
imresample(
|
31 |
+
img.permute(2, 0, 1).unsqueeze(0).float(), (image_size, image_size)
|
32 |
+
)
|
33 |
+
.byte()
|
34 |
+
.squeeze(0)
|
35 |
+
.permute(1, 2, 0)
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR)
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
def fixed_image_standardization(image_tensor):
|
43 |
+
processed_tensor = (image_tensor - 127.5) / 128.0
|
44 |
+
return processed_tensor
|
45 |
+
|
46 |
+
|
47 |
+
def extract_face(img, landmarks, image_size=160, margin=0, postprocess=False):
|
48 |
+
"""Extract face + margin from images given facial landmarks.
|
49 |
+
|
50 |
+
Arguments:
|
51 |
+
img {PIL.Image/torch.Tensor/np.ndarray} -- Input image(s) with shape (B, H, W, C)
|
52 |
+
landmarks {numpy.ndarray} -- Facial landmarks with shape (B, 68, 2)
|
53 |
+
image_size {int} -- Output image size in pixels. The image will be square.
|
54 |
+
margin {int} -- Margin to add to bounding box, in terms of pixels in the final image.
|
55 |
+
postprocess {bool} -- Whether to apply standardization
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
torch.tensor -- tensor representing the extracted faces with shape (B, H, W, C)
|
59 |
+
"""
|
60 |
+
# Calculate bounding boxes from landmarks for all faces in batch
|
61 |
+
x_min = np.min(landmarks, axis=1)[:, 0] # Shape: (B,)
|
62 |
+
y_min = np.min(landmarks, axis=1)[:, 1] # Shape: (B,)
|
63 |
+
x_max = np.max(landmarks, axis=1)[:, 0] # Shape: (B,)
|
64 |
+
y_max = np.max(landmarks, axis=1)[:, 1] # Shape: (B,)
|
65 |
+
|
66 |
+
# Calculate margin for top only
|
67 |
+
box_height = y_max - y_min
|
68 |
+
top_margin = margin * box_height / (image_size - margin)
|
69 |
+
|
70 |
+
# Create boxes for all faces
|
71 |
+
boxes = np.stack(
|
72 |
+
[
|
73 |
+
x_min,
|
74 |
+
np.maximum(y_min - top_margin, 0), # Only add margin to top
|
75 |
+
x_max,
|
76 |
+
y_max,
|
77 |
+
],
|
78 |
+
axis=1,
|
79 |
+
).astype(int) # Shape: (B, 4)
|
80 |
+
|
81 |
+
# Process each face in the batch
|
82 |
+
faces = []
|
83 |
+
for i in range(len(boxes)):
|
84 |
+
face = crop_resize(img[i], boxes[i], image_size)
|
85 |
+
faces.append(face)
|
86 |
+
|
87 |
+
faces = torch.stack(faces, dim=0)
|
88 |
+
faces = faces.float()
|
89 |
+
|
90 |
+
if postprocess:
|
91 |
+
faces = fixed_image_standardization(faces)
|
92 |
+
|
93 |
+
return faces
|
94 |
+
|
95 |
+
|
96 |
+
def crop_mouth_region(images, landmarks, crop_size=96):
|
97 |
+
"""
|
98 |
+
Takes a fixed-size square crop centered on the mouth region.
|
99 |
+
|
100 |
+
Parameters:
|
101 |
+
- images: tensor/array of shape (num_frames, height, width, channels) or (height, width, channels)
|
102 |
+
- landmarks: numpy array of shape (num_frames, 68, 2) or (68, 2)
|
103 |
+
- crop_size: size of the square crop (both height and width)
|
104 |
+
- padding: percentage of padding around the mouth region (0.0 to 1.0)
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
- List of fixed-size crops or single crop if input is single image
|
108 |
+
"""
|
109 |
+
# Handle single image case
|
110 |
+
single_image = False
|
111 |
+
if len(images.shape) == 3:
|
112 |
+
images = images[None]
|
113 |
+
landmarks = landmarks[None]
|
114 |
+
single_image = True
|
115 |
+
|
116 |
+
num_frames = len(images)
|
117 |
+
crops = []
|
118 |
+
|
119 |
+
# Mouth landmarks indices (48-67 for mouth region)
|
120 |
+
mouth_indices = range(48, 68)
|
121 |
+
|
122 |
+
for i in range(num_frames):
|
123 |
+
# Get mouth landmarks for current frame
|
124 |
+
mouth_landmarks = landmarks[i][mouth_indices]
|
125 |
+
|
126 |
+
# Find center of mouth
|
127 |
+
center_x = int(np.mean(mouth_landmarks[:, 0]))
|
128 |
+
center_y = int(np.mean(mouth_landmarks[:, 1]))
|
129 |
+
|
130 |
+
# Calculate crop boundaries
|
131 |
+
half_size = crop_size // 2
|
132 |
+
left = max(0, center_x - half_size)
|
133 |
+
right = min(images.shape[2], center_x + half_size)
|
134 |
+
top = max(0, center_y - half_size)
|
135 |
+
bottom = min(images.shape[1], center_y + half_size)
|
136 |
+
|
137 |
+
# Adjust if crop would go out of bounds
|
138 |
+
if left == 0:
|
139 |
+
right = crop_size
|
140 |
+
if right == images.shape[2]:
|
141 |
+
left = images.shape[2] - crop_size
|
142 |
+
if top == 0:
|
143 |
+
bottom = crop_size
|
144 |
+
if bottom == images.shape[1]:
|
145 |
+
top = images.shape[1] - crop_size
|
146 |
+
|
147 |
+
# Take the crop
|
148 |
+
crop = images[i, top:bottom, left:right]
|
149 |
+
crops.append(crop)
|
150 |
+
|
151 |
+
return crops[0] if single_image else crops
|
152 |
+
|
153 |
+
|
154 |
+
def create_masks_from_landmarks_box(
|
155 |
+
landmark_list, img_shape, nose_index=28, dtype="uint8", box_expand=0.0
|
156 |
+
):
|
157 |
+
height, width = img_shape[:2]
|
158 |
+
num_frames = landmark_list.shape[0]
|
159 |
+
|
160 |
+
# Initialize the masks array
|
161 |
+
masks = np.zeros((num_frames, height, width), dtype=dtype)
|
162 |
+
|
163 |
+
if 0 <= box_expand < 1:
|
164 |
+
box_expand = int(box_expand * width)
|
165 |
+
|
166 |
+
for i in range(num_frames):
|
167 |
+
# Get the landmarks for the current frame
|
168 |
+
landmarks = landmark_list[i]
|
169 |
+
|
170 |
+
# Get the y-coordinate of the nose landmark
|
171 |
+
nose_point_h = landmarks[nose_index, 1]
|
172 |
+
cut_h = nose_point_h
|
173 |
+
|
174 |
+
# Find the leftmost and rightmost landmarks
|
175 |
+
far_left_index = np.argmin(landmarks[:, 0])
|
176 |
+
far_right_index = np.argmax(landmarks[:, 0])
|
177 |
+
|
178 |
+
# Define the points for the mask contour
|
179 |
+
left_up_point = np.array(
|
180 |
+
[landmarks[far_left_index][0], cut_h - box_expand], dtype=np.int32
|
181 |
+
)
|
182 |
+
left_down_point = np.array(
|
183 |
+
[landmarks[far_left_index][0], height], dtype=np.int32
|
184 |
+
)
|
185 |
+
right_up_point = np.array(
|
186 |
+
[landmarks[far_right_index][0], cut_h - box_expand], dtype=np.int32
|
187 |
+
)
|
188 |
+
right_down_point = np.array(
|
189 |
+
[landmarks[far_right_index][0], height], dtype=np.int32
|
190 |
+
)
|
191 |
+
|
192 |
+
# Define the contour
|
193 |
+
contour = np.array(
|
194 |
+
[[left_up_point, left_down_point, right_down_point, right_up_point]]
|
195 |
+
)
|
196 |
+
|
197 |
+
# Draw the contour on the mask
|
198 |
+
cv2.drawContours(masks[i], [contour], -1, color=(1), thickness=cv2.FILLED)
|
199 |
+
|
200 |
+
return torch.from_numpy(masks)
|
201 |
+
|
202 |
+
|
203 |
+
def create_masks_from_landmarks_full_size(
|
204 |
+
landmarks_batch,
|
205 |
+
image_height,
|
206 |
+
image_width,
|
207 |
+
start_index=48,
|
208 |
+
end_index=68,
|
209 |
+
offset=0,
|
210 |
+
nose_index=33,
|
211 |
+
):
|
212 |
+
"""
|
213 |
+
Efficiently creates a batch of masks using vectorized operations where each mask has ones from the highest
|
214 |
+
landmark in the specified range (adjusted by an offset) to the bottom of the image, and zeros otherwise.
|
215 |
+
|
216 |
+
Parameters:
|
217 |
+
- landmarks_batch (np.array): An array of shape (B, 68, 2) containing facial landmarks for multiple samples.
|
218 |
+
- image_height (int): The height of the image for which masks are created.
|
219 |
+
- image_width (int): The width of the image for which masks are created.
|
220 |
+
- start_index (int): The starting index of the range to check (inclusive).
|
221 |
+
- end_index (int): The ending index of the range to check (inclusive).
|
222 |
+
- offset (int): An offset to add or subtract from the y-coordinate of the highest landmark.
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
- np.array: An array of masks of shape (B, image_height, image_width) for each batch.
|
226 |
+
"""
|
227 |
+
# Extract the y-coordinates for the specified range across all batches
|
228 |
+
y_coords = landmarks_batch[:, nose_index : nose_index + 1, 1]
|
229 |
+
|
230 |
+
# Find the index of the minimum y-coordinate in the specified range for each batch
|
231 |
+
min_y_indices = np.argmin(y_coords, axis=1)
|
232 |
+
|
233 |
+
# Gather the highest landmarks' y-coordinates using the indices found
|
234 |
+
highest_y_coords = y_coords[np.arange(len(y_coords)), min_y_indices]
|
235 |
+
|
236 |
+
if abs(offset) < 1 and abs(offset) > 0:
|
237 |
+
offset = int(offset * image_height)
|
238 |
+
|
239 |
+
# Apply the offset to the highest y-coordinate
|
240 |
+
adjusted_y_coords = highest_y_coords + offset
|
241 |
+
|
242 |
+
# Clip the coordinates to stay within image boundaries
|
243 |
+
adjusted_y_coords = np.clip(adjusted_y_coords, 0, image_height - 1)
|
244 |
+
|
245 |
+
# Use broadcasting to create a mask without loops
|
246 |
+
# Create a range of indices from 0 to image_height - 1
|
247 |
+
all_indices = np.arange(image_height)
|
248 |
+
|
249 |
+
# Compare each index in 'all_indices' to each 'adjusted_y_coord' in the batch
|
250 |
+
# 'all_indices' has shape (image_height,), we reshape to (1, image_height) to broadcast against (B, 1)
|
251 |
+
mask_2d = (all_indices >= adjusted_y_coords[:, None]).astype(int)
|
252 |
+
|
253 |
+
# Extend the 2D mask to a full 3D mask of size (B, image_height, image_width)
|
254 |
+
full_mask = np.tile(mask_2d[:, :, np.newaxis], (1, 1, image_width))
|
255 |
+
|
256 |
+
return torch.from_numpy(full_mask)
|
257 |
+
|
258 |
+
|
259 |
+
def expand_polygon(polygon, expand_size):
|
260 |
+
"""
|
261 |
+
Expands the polygon outward by a specified number of pixels.
|
262 |
+
|
263 |
+
Parameters:
|
264 |
+
- polygon (list of tuples): The polygon points as (x, y).
|
265 |
+
- expand_size (int): The number of pixels to expand the polygon outward.
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
- expanded_polygon (list of tuples): The expanded polygon points as (x, y).
|
269 |
+
"""
|
270 |
+
if expand_size == 0:
|
271 |
+
return polygon
|
272 |
+
|
273 |
+
# Calculate centroid of the polygon
|
274 |
+
centroid_x = sum([point[0] for point in polygon]) / len(polygon)
|
275 |
+
centroid_y = sum([point[1] for point in polygon]) / len(polygon)
|
276 |
+
|
277 |
+
# Expand each point outward from the centroid
|
278 |
+
expanded_polygon = []
|
279 |
+
for x, y in polygon:
|
280 |
+
vector_x = x - centroid_x
|
281 |
+
vector_y = y - centroid_y
|
282 |
+
length = np.sqrt(vector_x**2 + vector_y**2)
|
283 |
+
if length == 0:
|
284 |
+
expanded_polygon.append((x, y))
|
285 |
+
else:
|
286 |
+
new_x = x + expand_size * (vector_x / length)
|
287 |
+
new_y = y + expand_size * (vector_y / length)
|
288 |
+
expanded_polygon.append((int(new_x), int(new_y)))
|
289 |
+
|
290 |
+
return expanded_polygon
|
291 |
+
|
292 |
+
|
293 |
+
def create_masks_from_landmarks_mouth(
|
294 |
+
landmark_list, img_shape, nose_index=33, dtype="uint8", box_expand=0.0
|
295 |
+
):
|
296 |
+
height, width = img_shape[:2]
|
297 |
+
num_frames = landmark_list.shape[0]
|
298 |
+
|
299 |
+
# Initialize the masks array
|
300 |
+
masks = np.zeros((num_frames, height, width), dtype=dtype)
|
301 |
+
|
302 |
+
if 0 <= box_expand < 1:
|
303 |
+
box_expand = int(box_expand * width)
|
304 |
+
|
305 |
+
for i in range(num_frames):
|
306 |
+
# Get the landmarks for the current frame
|
307 |
+
landmarks = landmark_list[i]
|
308 |
+
|
309 |
+
# Get the y-coordinate of the nose landmark
|
310 |
+
nose_point_h = landmarks[nose_index, 1]
|
311 |
+
cut_h = nose_point_h
|
312 |
+
|
313 |
+
# Find the leftmost and rightmost landmarks
|
314 |
+
far_left_index = np.argmin(landmarks[:, 0])
|
315 |
+
far_right_index = np.argmax(landmarks[:, 0])
|
316 |
+
|
317 |
+
# Find lowest landmark y-coordinate
|
318 |
+
lowest_y = np.max(landmarks[:, 1])
|
319 |
+
# Add box_expand to the lowest point
|
320 |
+
lowest_y = min(height, lowest_y + box_expand)
|
321 |
+
|
322 |
+
# Define the points for the mask contour
|
323 |
+
left_up_point = np.array(
|
324 |
+
[landmarks[far_left_index][0], cut_h - box_expand], dtype=np.int32
|
325 |
+
)
|
326 |
+
left_down_point = np.array(
|
327 |
+
[landmarks[far_left_index][0], lowest_y], dtype=np.int32
|
328 |
+
)
|
329 |
+
right_up_point = np.array(
|
330 |
+
[landmarks[far_right_index][0], cut_h - box_expand], dtype=np.int32
|
331 |
+
)
|
332 |
+
right_down_point = np.array(
|
333 |
+
[landmarks[far_right_index][0], lowest_y], dtype=np.int32
|
334 |
+
)
|
335 |
+
|
336 |
+
# Define the contour
|
337 |
+
contour = np.array(
|
338 |
+
[[left_up_point, left_down_point, right_down_point, right_up_point]]
|
339 |
+
)
|
340 |
+
|
341 |
+
# Draw the contour on the mask
|
342 |
+
cv2.drawContours(masks[i], [contour], -1, color=(1), thickness=cv2.FILLED)
|
343 |
+
|
344 |
+
return torch.from_numpy(masks)
|
345 |
+
|
346 |
+
|
347 |
+
def create_face_mask_from_landmarks(
|
348 |
+
landmarks_batch, image_height, image_width, mask_expand=0
|
349 |
+
):
|
350 |
+
"""
|
351 |
+
Creates a batch of masks where each mask covers the face region using landmarks.
|
352 |
+
|
353 |
+
Parameters:
|
354 |
+
- landmarks_batch (np.array): An array of shape (B, 68, 2) containing facial landmarks for multiple samples.
|
355 |
+
- image_height (int): The height of the image for which masks are created.
|
356 |
+
- image_width (int): The width of the image for which masks are created.
|
357 |
+
- mask_expand (int): The number of pixels to expand the mask outward.
|
358 |
+
|
359 |
+
Returns:
|
360 |
+
- np.array: An array of masks of shape (B, image_height, image_width) for each batch.
|
361 |
+
"""
|
362 |
+
# Initialize an array to hold all masks
|
363 |
+
masks = np.zeros(
|
364 |
+
(landmarks_batch.shape[0], image_height, image_width), dtype=np.uint8
|
365 |
+
)
|
366 |
+
|
367 |
+
if abs(mask_expand) < 1 and abs(mask_expand) > 0:
|
368 |
+
mask_expand = int(mask_expand * image_height)
|
369 |
+
|
370 |
+
for i, landmarks in enumerate(landmarks_batch):
|
371 |
+
# Create a blank image for each mask
|
372 |
+
mask = Image.new("L", (image_width, image_height), 0)
|
373 |
+
draw = ImageDraw.Draw(mask)
|
374 |
+
|
375 |
+
# Extract relevant landmarks for the face
|
376 |
+
jawline_landmarks = landmarks[2:15] # Jawline
|
377 |
+
# upper_face_landmarks = landmarks[17:27] # Eyebrows and top of nose bridge
|
378 |
+
|
379 |
+
# Combine landmarks to form a polygon around the face
|
380 |
+
# face_polygon = np.concatenate((jawline_landmarks, upper_face_landmarks[::-1]), axis=0)
|
381 |
+
face_polygon = jawline_landmarks
|
382 |
+
|
383 |
+
# Convert landmarks to a list of tuples
|
384 |
+
face_polygon = [(int(x), int(y)) for x, y in face_polygon]
|
385 |
+
|
386 |
+
# Expand the polygon if necessary
|
387 |
+
expanded_polygon = expand_polygon(face_polygon, mask_expand)
|
388 |
+
|
389 |
+
# Draw the polygon and fill it
|
390 |
+
draw.polygon(expanded_polygon, outline=1, fill=1)
|
391 |
+
|
392 |
+
# Convert mask to numpy array and add it to the batch of masks
|
393 |
+
masks[i] = np.array(mask)
|
394 |
+
|
395 |
+
return torch.from_numpy(masks)
|
396 |
+
|
397 |
+
|
398 |
+
ALL_FIXED_POINTS = (
|
399 |
+
[i for i in range(0, 4)]
|
400 |
+
+ [i for i in range(13, 17)]
|
401 |
+
+ [i for i in range(27, 36)]
|
402 |
+
+ [36, 39, 42, 45]
|
403 |
+
)
|
404 |
+
|
405 |
+
|
406 |
+
def gaussian_kernel(sigma, width, height):
|
407 |
+
"""Create a 2D Gaussian kernel."""
|
408 |
+
x = torch.arange(0, width, 1) - width // 2
|
409 |
+
y = torch.arange(0, height, 1) - height // 2
|
410 |
+
x = x.float()
|
411 |
+
y = y.float()
|
412 |
+
x2 = x**2
|
413 |
+
y2 = y[:, None] ** 2
|
414 |
+
g = torch.exp(-(x2 + y2) / (2 * sigma**2))
|
415 |
+
return g / g.sum()
|
416 |
+
|
417 |
+
|
418 |
+
def generate_hm(landmarks, height, width, n_points="all", sigma=3):
|
419 |
+
if n_points == "all":
|
420 |
+
Nlandmarks = range(len(landmarks))
|
421 |
+
elif n_points == "fixed":
|
422 |
+
Nlandmarks = ALL_FIXED_POINTS
|
423 |
+
elif n_points == "stable":
|
424 |
+
Nlandmarks = [33, 36, 39, 42, 45]
|
425 |
+
|
426 |
+
kernel = gaussian_kernel(sigma, width, height)
|
427 |
+
hm = torch.zeros((height, width))
|
428 |
+
for I in Nlandmarks:
|
429 |
+
x0, y0 = landmarks[I]
|
430 |
+
x0, y0 = int(x0), int(y0)
|
431 |
+
left, right = max(0, x0 - width // 2), min(width, x0 + width // 2)
|
432 |
+
top, bottom = max(0, y0 - height // 2), min(height, y0 + height // 2)
|
433 |
+
hm[top:bottom, left:right] += kernel[
|
434 |
+
max(0, -y0 + height // 2) : min(height, height - y0 + height // 2),
|
435 |
+
max(0, -x0 + width // 2) : min(width, width - x0 + width // 2),
|
436 |
+
]
|
437 |
+
# Normalize the heatmap to have values between 0 and 1
|
438 |
+
max_val = hm.max()
|
439 |
+
if max_val > 0:
|
440 |
+
hm /= max_val
|
441 |
+
return hm
|
442 |
+
|
443 |
+
|
444 |
+
def get_heatmap(landmarks, image_size, or_im_size, n_points="stable", sigma=4):
|
445 |
+
stack = []
|
446 |
+
seq_length = landmarks.shape[0]
|
447 |
+
if or_im_size[0] != image_size[0] or or_im_size[1] != image_size[1]:
|
448 |
+
landmarks = scale_landmarks(landmarks, or_im_size, image_size)
|
449 |
+
gen_single_heatmap = partial(
|
450 |
+
generate_hm,
|
451 |
+
height=image_size[0],
|
452 |
+
width=image_size[1],
|
453 |
+
n_points=n_points,
|
454 |
+
sigma=sigma,
|
455 |
+
)
|
456 |
+
for i in range(seq_length):
|
457 |
+
stack.append(gen_single_heatmap(landmarks[i]))
|
458 |
+
|
459 |
+
return torch.stack(stack, axis=0).unsqueeze(0) # (1, seq_length, height, width)
|
460 |
+
|
461 |
+
|
462 |
+
def scale_landmarks(landmarks, original_size, target_size):
|
463 |
+
"""
|
464 |
+
Scale landmarks from original size to target size.
|
465 |
+
|
466 |
+
Parameters:
|
467 |
+
- landmarks (np.array): An array of shape (N, 2) containing facial landmarks.
|
468 |
+
- original_size (tuple): The size (height, width) for which the landmarks are currently scaled.
|
469 |
+
- target_size (tuple): The size (height, width) to which landmarks should be scaled.
|
470 |
+
|
471 |
+
Returns:
|
472 |
+
- scaled_landmarks (np.array): Scaled landmarks.
|
473 |
+
"""
|
474 |
+
scale_y = target_size[0] / original_size[0]
|
475 |
+
scale_x = target_size[1] / original_size[1]
|
476 |
+
scaled_landmarks = landmarks * np.array([scale_x, scale_y])
|
477 |
+
return scaled_landmarks.astype(int)
|
478 |
+
|
479 |
+
|
480 |
+
def draw_kps_image(
|
481 |
+
image_shape,
|
482 |
+
original_size,
|
483 |
+
landmarks,
|
484 |
+
color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255)],
|
485 |
+
rgb=True,
|
486 |
+
pts_width=4,
|
487 |
+
):
|
488 |
+
stick_width = pts_width
|
489 |
+
limb_seq = np.array([[0, 2], [1, 2]])
|
490 |
+
kps = landmarks[[36, 45, 33], :]
|
491 |
+
kps = scale_landmarks(kps, original_size, image_shape)
|
492 |
+
if not rgb: # Grayscale image
|
493 |
+
canvas = np.zeros((image_shape[0], image_shape[1], 1))
|
494 |
+
color_mode = "grayscale"
|
495 |
+
else: # Color image
|
496 |
+
canvas = np.zeros((image_shape[0], image_shape[1], 3))
|
497 |
+
color_mode = "color"
|
498 |
+
|
499 |
+
polygon_cache = {}
|
500 |
+
|
501 |
+
for index in limb_seq:
|
502 |
+
color = color_list[index[0]]
|
503 |
+
if color_mode == "grayscale":
|
504 |
+
color = (
|
505 |
+
int(0.299 * color[2] + 0.587 * color[1] + 0.114 * color[0]),
|
506 |
+
) # Convert to grayscale intensity
|
507 |
+
|
508 |
+
x = kps[index][:, 0]
|
509 |
+
y = kps[index][:, 1]
|
510 |
+
length = np.sqrt((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2)
|
511 |
+
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
512 |
+
|
513 |
+
cache_key = (
|
514 |
+
color,
|
515 |
+
int(np.mean(x)),
|
516 |
+
int(np.mean(y)),
|
517 |
+
int(length / 2),
|
518 |
+
int(angle),
|
519 |
+
)
|
520 |
+
if cache_key not in polygon_cache:
|
521 |
+
polygon_cache[cache_key] = cv2.ellipse2Poly(
|
522 |
+
(int(np.mean(x)), int(np.mean(y))),
|
523 |
+
(int(length / 2), stick_width),
|
524 |
+
int(angle),
|
525 |
+
0,
|
526 |
+
360,
|
527 |
+
1,
|
528 |
+
)
|
529 |
+
|
530 |
+
polygon = polygon_cache[cache_key]
|
531 |
+
cv2.fillConvexPoly(canvas, polygon, [int(c * 0.6) for c in color])
|
532 |
+
|
533 |
+
for idx, kp in enumerate(kps):
|
534 |
+
if color_mode == "grayscale":
|
535 |
+
color = (
|
536 |
+
int(
|
537 |
+
0.299 * color_list[idx][2]
|
538 |
+
+ 0.587 * color_list[idx][1]
|
539 |
+
+ 0.114 * color_list[idx][0]
|
540 |
+
),
|
541 |
+
)
|
542 |
+
else:
|
543 |
+
color = color_list[idx]
|
544 |
+
cv2.circle(canvas, (int(kp[0]), int(kp[1])), pts_width, color, -1)
|
545 |
+
|
546 |
+
return canvas.transpose(2, 0, 1)
|
547 |
+
|
548 |
+
|
549 |
+
def create_landmarks_image(
|
550 |
+
landmarks,
|
551 |
+
original_size=(772, 772),
|
552 |
+
target_size=(772, 772),
|
553 |
+
point_size=3,
|
554 |
+
n_points="all",
|
555 |
+
dim=3,
|
556 |
+
):
|
557 |
+
"""
|
558 |
+
Creates an image of landmarks on a black background using efficient NumPy operations.
|
559 |
+
|
560 |
+
Parameters:
|
561 |
+
- landmarks (np.array): An array of shape (68, 2) containing facial landmarks.
|
562 |
+
- image_size (tuple): The size of the output image (height, width).
|
563 |
+
- point_size (int): The radius of each landmark point in pixels.
|
564 |
+
|
565 |
+
Returns:
|
566 |
+
- img (np.array): An image array with landmarks plotted.
|
567 |
+
"""
|
568 |
+
if n_points == "all":
|
569 |
+
indexes = range(len(landmarks))
|
570 |
+
elif n_points == "fixed":
|
571 |
+
indexes = ALL_FIXED_POINTS
|
572 |
+
elif n_points == "stable":
|
573 |
+
indexes = [33, 36, 39, 42, 45]
|
574 |
+
|
575 |
+
landmarks = landmarks[indexes]
|
576 |
+
|
577 |
+
img = np.zeros(target_size, dtype=np.uint8)
|
578 |
+
|
579 |
+
landmarks = scale_landmarks(landmarks, original_size, target_size)
|
580 |
+
|
581 |
+
# Ensure the landmarks are in bounds and integer
|
582 |
+
landmarks = np.clip(
|
583 |
+
landmarks, [0, 0], [target_size[1] - 1, target_size[0] - 1]
|
584 |
+
).astype(int)
|
585 |
+
|
586 |
+
# Get x and y coordinates from landmarks
|
587 |
+
x, y = landmarks[:, 0], landmarks[:, 1]
|
588 |
+
|
589 |
+
# Define a grid offset based on point_size around each landmark
|
590 |
+
offset = np.arange(-point_size // 2, point_size // 2 + 1)
|
591 |
+
grid_x, grid_y = np.meshgrid(offset, offset, indexing="ij")
|
592 |
+
|
593 |
+
# Calculate the full set of x and y coordinates for the points
|
594 |
+
full_x = x[:, np.newaxis, np.newaxis] + grid_x[np.newaxis, :, :]
|
595 |
+
full_y = y[:, np.newaxis, np.newaxis] + grid_y[np.newaxis, :, :]
|
596 |
+
|
597 |
+
# Clip the coordinates to stay within image boundaries
|
598 |
+
full_x = np.clip(full_x, 0, target_size[1] - 1)
|
599 |
+
full_y = np.clip(full_y, 0, target_size[0] - 1)
|
600 |
+
|
601 |
+
# Flatten the arrays to use them as indices
|
602 |
+
full_x = full_x.ravel()
|
603 |
+
full_y = full_y.ravel()
|
604 |
+
|
605 |
+
# Set the points in the image
|
606 |
+
img[full_y, full_x] = 255
|
607 |
+
|
608 |
+
return np.stack([img] * dim, axis=0)
|
609 |
+
|
610 |
+
|
611 |
+
def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
|
612 |
+
len_file = audio.shape[-1]
|
613 |
+
|
614 |
+
if max_len_sec or max_len_raw:
|
615 |
+
max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
|
616 |
+
if len_file < int(max_len):
|
617 |
+
# dummy = np.zeros((1, int(max_len_sec * sr) - len_file))
|
618 |
+
# extened_wav = np.concatenate((audio_data, dummy[0]))
|
619 |
+
extened_wav = torch.nn.functional.pad(
|
620 |
+
audio, (0, int(max_len) - len_file), "constant"
|
621 |
+
)
|
622 |
+
else:
|
623 |
+
extened_wav = audio[:, : int(max_len)]
|
624 |
+
else:
|
625 |
+
extened_wav = audio
|
626 |
+
|
627 |
+
return extened_wav
|
628 |
+
|
629 |
+
|
630 |
+
def ssim_to_bin(ssim_score):
|
631 |
+
# Normalize the SSIM score to a 0-100 scale
|
632 |
+
normalized_diff_ssim = (1 - ((ssim_score + 1) / 2)) * 100
|
633 |
+
# Assign to one of the 100 bins
|
634 |
+
bin_index = float(min(np.floor(normalized_diff_ssim), 99))
|
635 |
+
return bin_index
|
inference_functions.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
3 |
+
import numpy as np
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
import math
|
6 |
+
|
7 |
+
|
8 |
+
def get_unique_embedder_keys_from_conditioner(conditioner):
|
9 |
+
return list(set([x.input_key for x in conditioner.embedders]))
|
10 |
+
|
11 |
+
|
12 |
+
def get_batch(keys, value_dict, N, T, device):
|
13 |
+
batch = {}
|
14 |
+
batch_uc = {}
|
15 |
+
|
16 |
+
for key in keys:
|
17 |
+
if key == "fps_id":
|
18 |
+
batch[key] = (
|
19 |
+
torch.tensor([value_dict["fps_id"]])
|
20 |
+
.to(device)
|
21 |
+
.repeat(int(math.prod(N)))
|
22 |
+
)
|
23 |
+
elif key == "motion_bucket_id":
|
24 |
+
batch[key] = (
|
25 |
+
torch.tensor([value_dict["motion_bucket_id"]])
|
26 |
+
.to(device)
|
27 |
+
.repeat(int(math.prod(N)))
|
28 |
+
)
|
29 |
+
elif key == "cond_aug":
|
30 |
+
batch[key] = repeat(
|
31 |
+
torch.tensor([value_dict["cond_aug"]]).to(device),
|
32 |
+
"1 -> b",
|
33 |
+
b=math.prod(N),
|
34 |
+
)
|
35 |
+
elif key == "cond_frames":
|
36 |
+
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
37 |
+
elif key == "cond_frames_without_noise":
|
38 |
+
batch[key] = repeat(
|
39 |
+
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
batch[key] = value_dict[key]
|
43 |
+
|
44 |
+
if T is not None:
|
45 |
+
batch["num_video_frames"] = T
|
46 |
+
|
47 |
+
for key in batch.keys():
|
48 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
49 |
+
batch_uc[key] = torch.clone(batch[key])
|
50 |
+
return batch, batch_uc
|
51 |
+
|
52 |
+
|
53 |
+
def merge_overlapping_segments(segments: torch.Tensor, overlap: int) -> torch.Tensor:
|
54 |
+
"""
|
55 |
+
Merges overlapping segments by averaging overlapping frames.
|
56 |
+
Segments have shape (b, t, ...), where 'b' is the number of segments,
|
57 |
+
't' is frames per segment, and '...' are other dimensions.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
segments: Tensor of shape (b, t, ...)
|
61 |
+
overlap: Integer, number of frames that overlap between consecutive segments
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Tensor of the merged video
|
65 |
+
"""
|
66 |
+
# Get the shape details
|
67 |
+
b, t, *other_dims = segments.shape
|
68 |
+
num_frames = (b - 1) * (
|
69 |
+
t - overlap
|
70 |
+
) + t # Calculate the total number of frames in the merged video
|
71 |
+
|
72 |
+
# Initialize the output tensor and a count tensor to keep track of contributions for averaging
|
73 |
+
output_shape = [num_frames] + other_dims
|
74 |
+
output = torch.zeros(output_shape, dtype=segments.dtype, device=segments.device)
|
75 |
+
count = torch.zeros(output_shape, dtype=torch.float32, device=segments.device)
|
76 |
+
|
77 |
+
current_index = 0
|
78 |
+
for i in range(b):
|
79 |
+
end_index = current_index + t
|
80 |
+
# Add the segment to the output tensor
|
81 |
+
output[current_index:end_index] += rearrange(segments[i], "... -> ...")
|
82 |
+
# Increment the count tensor for each frame that's added
|
83 |
+
count[current_index:end_index] += 1
|
84 |
+
# Update the starting index for the next segment
|
85 |
+
current_index += t - overlap
|
86 |
+
|
87 |
+
# Avoid division by zero
|
88 |
+
count[count == 0] = 1
|
89 |
+
# Average the frames where there's overlap
|
90 |
+
output /= count
|
91 |
+
|
92 |
+
return output
|
93 |
+
|
94 |
+
|
95 |
+
def get_batch_overlap(
|
96 |
+
keys: List[str],
|
97 |
+
value_dict: Dict[str, Any],
|
98 |
+
N: Tuple[int, ...],
|
99 |
+
T: Optional[int],
|
100 |
+
device: str,
|
101 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
102 |
+
"""
|
103 |
+
Create a batch dictionary with overlapping frames for model input.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
keys: List of keys to include in the batch
|
107 |
+
value_dict: Dictionary containing values for each key
|
108 |
+
N: Batch dimensions
|
109 |
+
T: Number of frames (optional)
|
110 |
+
device: Device to place tensors on
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Tuple of (batch dictionary, unconditional batch dictionary)
|
114 |
+
"""
|
115 |
+
batch = {}
|
116 |
+
batch_uc = {}
|
117 |
+
|
118 |
+
for key in keys:
|
119 |
+
if key == "fps_id":
|
120 |
+
batch[key] = (
|
121 |
+
torch.tensor([value_dict["fps_id"]])
|
122 |
+
.to(device)
|
123 |
+
.repeat(int(math.prod(N)))
|
124 |
+
)
|
125 |
+
elif key == "motion_bucket_id":
|
126 |
+
batch[key] = (
|
127 |
+
torch.tensor([value_dict["motion_bucket_id"]])
|
128 |
+
.to(device)
|
129 |
+
.repeat(int(math.prod(N)))
|
130 |
+
)
|
131 |
+
elif key == "cond_aug":
|
132 |
+
batch[key] = repeat(
|
133 |
+
torch.tensor([value_dict["cond_aug"]]).to(device),
|
134 |
+
"1 -> b",
|
135 |
+
b=math.prod(N),
|
136 |
+
)
|
137 |
+
elif key == "cond_frames":
|
138 |
+
batch[key] = repeat(value_dict["cond_frames"], "b ... -> (b t) ...", t=N[0])
|
139 |
+
elif key == "cond_frames_without_noise":
|
140 |
+
batch[key] = repeat(
|
141 |
+
value_dict["cond_frames_without_noise"], "b ... -> (b t) ...", t=N[0]
|
142 |
+
)
|
143 |
+
else:
|
144 |
+
batch[key] = value_dict[key]
|
145 |
+
|
146 |
+
if T is not None:
|
147 |
+
batch["num_video_frames"] = T
|
148 |
+
|
149 |
+
for key in batch.keys():
|
150 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
151 |
+
batch_uc[key] = torch.clone(batch[key])
|
152 |
+
return batch, batch_uc
|
153 |
+
|
154 |
+
|
155 |
+
@torch.inference_mode()
|
156 |
+
def sample_keyframes(
|
157 |
+
model_keyframes: Any,
|
158 |
+
audio_list: torch.Tensor,
|
159 |
+
gt_list: torch.Tensor,
|
160 |
+
masks_list: torch.Tensor,
|
161 |
+
condition: torch.Tensor,
|
162 |
+
num_frames: int,
|
163 |
+
fps_id: int,
|
164 |
+
cond_aug: float,
|
165 |
+
device: str,
|
166 |
+
embbedings: Optional[torch.Tensor],
|
167 |
+
force_uc_zero_embeddings: List[str],
|
168 |
+
n_batch_keyframes: int,
|
169 |
+
added_frames: int,
|
170 |
+
strength: float,
|
171 |
+
scale: Optional[Union[float, List[float]]],
|
172 |
+
gt_as_cond: bool = False,
|
173 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
174 |
+
"""
|
175 |
+
Sample keyframes using the keyframe generation model.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
model_keyframes: The keyframe generation model
|
179 |
+
audio_list: List of audio embeddings
|
180 |
+
gt_list: List of ground truth frames
|
181 |
+
masks_list: List of masks
|
182 |
+
condition: Conditioning tensor
|
183 |
+
num_frames: Number of frames to generate
|
184 |
+
fps_id: FPS ID
|
185 |
+
cond_aug: Conditioning augmentation factor
|
186 |
+
device: Device to use for computation
|
187 |
+
embbedings: Optional embeddings
|
188 |
+
force_uc_zero_embeddings: List of embeddings to force to zero in unconditional case
|
189 |
+
n_batch_keyframes: Batch size for keyframe generation
|
190 |
+
added_frames: Number of additional frames
|
191 |
+
strength: Strength parameter for sampling
|
192 |
+
scale: Scale parameter for guidance
|
193 |
+
gt_as_cond: Whether to use ground truth as conditioning
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Tuple of (latent samples, decoded samples)
|
197 |
+
"""
|
198 |
+
if scale is not None:
|
199 |
+
model_keyframes.sampler.guider.set_scale(scale)
|
200 |
+
# samples_list = []
|
201 |
+
samples_z_list = []
|
202 |
+
# samples_x_list = []
|
203 |
+
|
204 |
+
for i in range(audio_list.shape[0]):
|
205 |
+
H, W = condition.shape[-2:]
|
206 |
+
assert condition.shape[1] == 3
|
207 |
+
F = 8
|
208 |
+
C = 4
|
209 |
+
shape = (num_frames, C, H // F, W // F)
|
210 |
+
|
211 |
+
audio_cond = audio_list[i].unsqueeze(0)
|
212 |
+
|
213 |
+
value_dict: Dict[str, Any] = {}
|
214 |
+
value_dict["fps_id"] = fps_id
|
215 |
+
value_dict["cond_aug"] = cond_aug
|
216 |
+
value_dict["cond_frames_without_noise"] = condition
|
217 |
+
if embbedings is not None:
|
218 |
+
value_dict["cond_frames"] = embbedings + cond_aug * torch.randn_like(
|
219 |
+
embbedings
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
value_dict["cond_frames"] = condition + cond_aug * torch.randn_like(
|
223 |
+
condition
|
224 |
+
)
|
225 |
+
gt = rearrange(gt_list[i].unsqueeze(0), "b t c h w -> b c t h w").to(device)
|
226 |
+
|
227 |
+
if gt_as_cond:
|
228 |
+
value_dict["cond_frames"] = gt[:, :, 0]
|
229 |
+
|
230 |
+
value_dict["cond_aug"] = cond_aug
|
231 |
+
value_dict["audio_emb"] = audio_cond
|
232 |
+
|
233 |
+
value_dict["gt"] = gt
|
234 |
+
value_dict["masks"] = masks_list[i].unsqueeze(0).transpose(1, 2).to(device)
|
235 |
+
|
236 |
+
with torch.no_grad():
|
237 |
+
batch, batch_uc = get_batch(
|
238 |
+
get_unique_embedder_keys_from_conditioner(model_keyframes.conditioner),
|
239 |
+
value_dict,
|
240 |
+
[1, 1],
|
241 |
+
T=num_frames,
|
242 |
+
device=device,
|
243 |
+
)
|
244 |
+
|
245 |
+
c, uc = model_keyframes.conditioner.get_unconditional_conditioning(
|
246 |
+
batch,
|
247 |
+
batch_uc=batch_uc,
|
248 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
249 |
+
)
|
250 |
+
|
251 |
+
for k in ["crossattn"]:
|
252 |
+
if c[k].shape[1] != num_frames:
|
253 |
+
uc[k] = repeat(
|
254 |
+
uc[k],
|
255 |
+
"b ... -> b t ...",
|
256 |
+
t=num_frames,
|
257 |
+
)
|
258 |
+
uc[k] = rearrange(
|
259 |
+
uc[k],
|
260 |
+
"b t ... -> (b t) ...",
|
261 |
+
t=num_frames,
|
262 |
+
)
|
263 |
+
c[k] = repeat(
|
264 |
+
c[k],
|
265 |
+
"b ... -> b t ...",
|
266 |
+
t=num_frames,
|
267 |
+
)
|
268 |
+
c[k] = rearrange(
|
269 |
+
c[k],
|
270 |
+
"b t ... -> (b t) ...",
|
271 |
+
t=num_frames,
|
272 |
+
)
|
273 |
+
|
274 |
+
video = torch.randn(shape, device=device)
|
275 |
+
|
276 |
+
additional_model_inputs: Dict[str, torch.Tensor] = {}
|
277 |
+
additional_model_inputs["image_only_indicator"] = torch.zeros(
|
278 |
+
n_batch_keyframes, num_frames
|
279 |
+
).to(device)
|
280 |
+
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
|
281 |
+
|
282 |
+
def denoiser(
|
283 |
+
input: torch.Tensor, sigma: torch.Tensor, c: Dict[str, torch.Tensor]
|
284 |
+
) -> torch.Tensor:
|
285 |
+
return model_keyframes.denoiser(
|
286 |
+
model_keyframes.model,
|
287 |
+
input,
|
288 |
+
sigma,
|
289 |
+
c,
|
290 |
+
**additional_model_inputs,
|
291 |
+
)
|
292 |
+
|
293 |
+
samples_z = model_keyframes.sampler(
|
294 |
+
denoiser, video, cond=c, uc=uc, strength=strength
|
295 |
+
)
|
296 |
+
samples_z_list.append(samples_z)
|
297 |
+
# samples_x_list.append(samples_x)
|
298 |
+
# samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
299 |
+
# samples_list.append(samples)
|
300 |
+
|
301 |
+
video = None
|
302 |
+
|
303 |
+
# samples = (
|
304 |
+
# torch.concat(samples_list)[:-added_frames]
|
305 |
+
# if added_frames > 0
|
306 |
+
# else torch.concat(samples_list)
|
307 |
+
# )
|
308 |
+
samples_z = (
|
309 |
+
torch.concat(samples_z_list)[:-added_frames]
|
310 |
+
if added_frames > 0
|
311 |
+
else torch.concat(samples_z_list)
|
312 |
+
)
|
313 |
+
# samples_x = (
|
314 |
+
# torch.concat(samples_x_list)[:-added_frames]
|
315 |
+
# if added_frames > 0
|
316 |
+
# else torch.concat(samples_x_list)
|
317 |
+
# )
|
318 |
+
|
319 |
+
return samples_z
|
320 |
+
|
321 |
+
|
322 |
+
@torch.inference_mode()
|
323 |
+
def sample_interpolation(
|
324 |
+
model: Any,
|
325 |
+
samples_z: torch.Tensor,
|
326 |
+
# samples_x: torch.Tensor,
|
327 |
+
audio_interpolation_list: List[torch.Tensor],
|
328 |
+
gt_chunks: List[torch.Tensor],
|
329 |
+
masks_chunks: List[torch.Tensor],
|
330 |
+
condition: torch.Tensor,
|
331 |
+
num_frames: int,
|
332 |
+
device: str,
|
333 |
+
overlap: int,
|
334 |
+
fps_id: int,
|
335 |
+
cond_aug: float,
|
336 |
+
force_uc_zero_embeddings: List[str],
|
337 |
+
n_batch: int,
|
338 |
+
chunk_size: Optional[int],
|
339 |
+
strength: float,
|
340 |
+
scale: Optional[float] = None,
|
341 |
+
cut_audio: bool = False,
|
342 |
+
to_remove: List[bool] = [],
|
343 |
+
) -> np.ndarray:
|
344 |
+
"""
|
345 |
+
Sample interpolation frames between keyframes.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
model: The interpolation model
|
349 |
+
samples_z: Latent samples from keyframe generation
|
350 |
+
samples_x: Decoded samples from keyframe generation
|
351 |
+
audio_interpolation_list: List of audio embeddings for interpolation
|
352 |
+
gt_chunks: Ground truth video chunks
|
353 |
+
masks_chunks: Mask chunks for conditional generation
|
354 |
+
condition: Visual conditioning
|
355 |
+
num_frames: Number of frames to generate
|
356 |
+
device: Device to run inference on
|
357 |
+
overlap: Number of frames to overlap between segments
|
358 |
+
fps_id: FPS ID for conditioning
|
359 |
+
motion_bucket_id: Motion bucket ID for conditioning
|
360 |
+
cond_aug: Conditioning augmentation strength
|
361 |
+
force_uc_zero_embeddings: Keys to zero out in unconditional embeddings
|
362 |
+
n_batch: Batch size for generation
|
363 |
+
chunk_size: Size of chunks for processing (to manage memory)
|
364 |
+
strength: Strength of the conditioning
|
365 |
+
scale: Optional scale for classifier-free guidance
|
366 |
+
cut_audio: Whether to cut audio embeddings
|
367 |
+
to_remove: List of flags indicating which frames to remove
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
Generated video frames as numpy array
|
371 |
+
"""
|
372 |
+
if scale is not None:
|
373 |
+
model.sampler.guider.set_scale(scale)
|
374 |
+
|
375 |
+
# Creating condition for interpolation model. We need to create a list of inputs, each input is [first, last]
|
376 |
+
# The first and last are the first and last frames of the interpolation
|
377 |
+
# interpolation_cond_list = []
|
378 |
+
interpolation_cond_list_emb = []
|
379 |
+
|
380 |
+
# samples_x = [sample for i, sample in zip(to_remove, samples_x) if not i]
|
381 |
+
samples_z = [sample for i, sample in zip(to_remove, samples_z) if not i]
|
382 |
+
|
383 |
+
for i in range(0, len(samples_z) - 1, overlap if overlap > 0 else 2):
|
384 |
+
# interpolation_cond_list.append(
|
385 |
+
# torch.stack([samples_x[i], samples_x[i + 1]], dim=1)
|
386 |
+
# )
|
387 |
+
interpolation_cond_list_emb.append(
|
388 |
+
torch.stack([samples_z[i], samples_z[i + 1]], dim=1)
|
389 |
+
)
|
390 |
+
|
391 |
+
# condition = torch.stack(interpolation_cond_list).to(device)
|
392 |
+
audio_cond = torch.stack(audio_interpolation_list).to(device)
|
393 |
+
embbedings = torch.stack(interpolation_cond_list_emb).to(device)
|
394 |
+
|
395 |
+
gt_chunks = torch.stack(gt_chunks).to(device)
|
396 |
+
masks_chunks = torch.stack(masks_chunks).to(device)
|
397 |
+
|
398 |
+
H, W = 512, 512
|
399 |
+
F = 8
|
400 |
+
C = 4
|
401 |
+
shape = (num_frames * audio_cond.shape[0], C, H // F, W // F)
|
402 |
+
|
403 |
+
value_dict: Dict[str, Any] = {}
|
404 |
+
value_dict["fps_id"] = fps_id
|
405 |
+
value_dict["cond_aug"] = cond_aug
|
406 |
+
# value_dict["cond_frames_without_noise"] = condition
|
407 |
+
|
408 |
+
value_dict["cond_frames"] = embbedings
|
409 |
+
value_dict["cond_aug"] = cond_aug
|
410 |
+
if cut_audio:
|
411 |
+
value_dict["audio_emb"] = audio_cond[:, :, :, :768]
|
412 |
+
else:
|
413 |
+
value_dict["audio_emb"] = audio_cond
|
414 |
+
|
415 |
+
value_dict["gt"] = rearrange(gt_chunks, "b t c h w -> b c t h w").to(device)
|
416 |
+
value_dict["masks"] = masks_chunks.transpose(1, 2).to(device)
|
417 |
+
|
418 |
+
with torch.no_grad():
|
419 |
+
with torch.autocast(device):
|
420 |
+
batch, batch_uc = get_batch_overlap(
|
421 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
422 |
+
value_dict,
|
423 |
+
[1, num_frames],
|
424 |
+
T=num_frames,
|
425 |
+
device=device,
|
426 |
+
)
|
427 |
+
|
428 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
429 |
+
batch,
|
430 |
+
batch_uc=batch_uc,
|
431 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
432 |
+
)
|
433 |
+
|
434 |
+
for k in ["crossattn"]:
|
435 |
+
if c[k].shape[1] != num_frames:
|
436 |
+
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
|
437 |
+
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
|
438 |
+
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
|
439 |
+
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
|
440 |
+
|
441 |
+
video = torch.randn(shape, device=device)
|
442 |
+
|
443 |
+
additional_model_inputs: Dict[str, torch.Tensor] = {}
|
444 |
+
additional_model_inputs["image_only_indicator"] = torch.zeros(
|
445 |
+
n_batch, num_frames
|
446 |
+
).to(device)
|
447 |
+
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
|
448 |
+
|
449 |
+
# Debug information
|
450 |
+
print(
|
451 |
+
f"Shapes - Embeddings: {embbedings.shape}, "
|
452 |
+
f"Audio: {audio_cond.shape}, Video: {shape}, Additional inputs: {additional_model_inputs}"
|
453 |
+
)
|
454 |
+
|
455 |
+
if chunk_size is not None:
|
456 |
+
chunk_size = chunk_size * num_frames
|
457 |
+
|
458 |
+
def denoiser(
|
459 |
+
input: torch.Tensor, sigma: torch.Tensor, c: Dict[str, torch.Tensor]
|
460 |
+
) -> torch.Tensor:
|
461 |
+
return model.denoiser(
|
462 |
+
model.model,
|
463 |
+
input,
|
464 |
+
sigma,
|
465 |
+
c,
|
466 |
+
num_overlap_frames=overlap,
|
467 |
+
num_frames=num_frames,
|
468 |
+
n_skips=n_batch,
|
469 |
+
chunk_size=chunk_size,
|
470 |
+
**additional_model_inputs,
|
471 |
+
)
|
472 |
+
|
473 |
+
samples_z = model.sampler(denoiser, video, cond=c, uc=uc, strength=strength)
|
474 |
+
samples_z = rearrange(samples_z, "(b t) c h w -> b t c h w", t=num_frames)
|
475 |
+
samples_z[:, 0] = embbedings[:, :, 0]
|
476 |
+
samples_z[:, -1] = embbedings[:, :, 1]
|
477 |
+
samples_z = rearrange(samples_z, "b t c h w -> (b t) c h w")
|
478 |
+
|
479 |
+
samples_x = model.decode_first_stage(samples_z)
|
480 |
+
|
481 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
482 |
+
|
483 |
+
# Free up memory
|
484 |
+
video = None
|
485 |
+
|
486 |
+
samples = rearrange(samples, "(b t) c h w -> b t c h w", t=num_frames)
|
487 |
+
samples = merge_overlapping_segments(samples, overlap)
|
488 |
+
|
489 |
+
vid = (
|
490 |
+
(rearrange(samples, "t c h w -> t c h w") * 255).cpu().numpy().astype(np.uint8)
|
491 |
+
)
|
492 |
+
|
493 |
+
return vid
|
landmarks_extractor.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from skimage import io
|
2 |
+
import face_alignment
|
3 |
+
|
4 |
+
|
5 |
+
class LandmarksExtractor:
|
6 |
+
def __init__(self, device="cuda", landmarks_type="2D", flip=False):
|
7 |
+
self.fa = face_alignment.FaceAlignment(
|
8 |
+
face_alignment.LandmarksType.TWO_D
|
9 |
+
if landmarks_type == "2D"
|
10 |
+
else face_alignment.LandmarksType.THREE_D,
|
11 |
+
flip_input=flip,
|
12 |
+
device=device,
|
13 |
+
face_detector="sfd",
|
14 |
+
)
|
15 |
+
|
16 |
+
self.landmarks = []
|
17 |
+
|
18 |
+
def cuda(self):
|
19 |
+
return self
|
20 |
+
|
21 |
+
def extract_landmarks(self, image):
|
22 |
+
# image: either a path to an image or a numpy array (H, W, C) or tensor batch (B, C, H, W)
|
23 |
+
if isinstance(image, str):
|
24 |
+
image = io.imread(image)
|
25 |
+
|
26 |
+
# Ensure image is on CPU
|
27 |
+
if hasattr(image, "device"):
|
28 |
+
image = image.cpu()
|
29 |
+
|
30 |
+
if len(image.shape) == 3:
|
31 |
+
preds = self.fa.get_landmarks(image)
|
32 |
+
else:
|
33 |
+
preds = self.fa.get_landmarks_from_batch(image)
|
34 |
+
|
35 |
+
return preds
|
sgm/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import AutoencodingEngine, DiffusionEngine
|
2 |
+
from .util import get_configs_path, instantiate_from_config
|
3 |
+
|
4 |
+
__version__ = "0.1.0"
|
sgm/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (418 Bytes). View file
|
|
sgm/__pycache__/lr_scheduler.cpython-311.pyc
ADDED
Binary file (6.6 kB). View file
|
|
sgm/__pycache__/util.cpython-311.pyc
ADDED
Binary file (21.5 kB). View file
|
|
sgm/callbacks/__pycache__/video_logger.cpython-311.pyc
ADDED
Binary file (14.3 kB). View file
|
|
sgm/callbacks/custom_ddp.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from pytorch_lightning.overrides import LightningDistributedModule
|
2 |
+
from pytorch_lightning.strategies import DDPStrategy
|
3 |
+
|
4 |
+
|
5 |
+
class CustomDDPPlugin(DDPStrategy):
|
6 |
+
def configure_ddp(self):
|
7 |
+
# self.pre_configure_ddp()
|
8 |
+
self._model = self._setup_model((self.model))
|
9 |
+
self._register_ddp_hooks()
|
10 |
+
self._model._set_static_graph() # THIS IS THE MAGIC LINE
|
sgm/callbacks/image_logger.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_lightning.callbacks import Callback
|
2 |
+
from pytorch_lightning.loggers import WandbLogger
|
3 |
+
import numpy as np
|
4 |
+
from pytorch_lightning.utilities import rank_zero_only
|
5 |
+
from typing import Union
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import os
|
8 |
+
from matplotlib import pyplot as plt
|
9 |
+
from sgm.util import exists, isheatmap
|
10 |
+
import torchvision
|
11 |
+
from PIL import Image
|
12 |
+
import torch
|
13 |
+
import wandb
|
14 |
+
from einops import rearrange
|
15 |
+
|
16 |
+
|
17 |
+
class ImageLogger(Callback):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
batch_frequency,
|
21 |
+
max_images,
|
22 |
+
clamp=True,
|
23 |
+
increase_log_steps=True,
|
24 |
+
rescale=True,
|
25 |
+
disabled=False,
|
26 |
+
log_on_batch_idx=False,
|
27 |
+
log_first_step=False,
|
28 |
+
log_images_kwargs=None,
|
29 |
+
log_before_first_step=False,
|
30 |
+
enable_autocast=True,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
self.enable_autocast = enable_autocast
|
34 |
+
self.rescale = rescale
|
35 |
+
self.batch_freq = batch_frequency
|
36 |
+
self.max_images = max_images
|
37 |
+
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
38 |
+
if not increase_log_steps:
|
39 |
+
self.log_steps = [self.batch_freq]
|
40 |
+
self.clamp = clamp
|
41 |
+
self.disabled = disabled
|
42 |
+
self.log_on_batch_idx = log_on_batch_idx
|
43 |
+
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
44 |
+
self.log_first_step = log_first_step
|
45 |
+
self.log_before_first_step = log_before_first_step
|
46 |
+
|
47 |
+
@rank_zero_only
|
48 |
+
def log_local(
|
49 |
+
self,
|
50 |
+
save_dir,
|
51 |
+
split,
|
52 |
+
images,
|
53 |
+
global_step,
|
54 |
+
current_epoch,
|
55 |
+
batch_idx,
|
56 |
+
pl_module: Union[None, pl.LightningModule] = None,
|
57 |
+
):
|
58 |
+
root = os.path.join(save_dir, "images", split)
|
59 |
+
for k in images:
|
60 |
+
if isheatmap(images[k]):
|
61 |
+
fig, ax = plt.subplots()
|
62 |
+
ax = ax.matshow(images[k].cpu().numpy(), cmap="hot", interpolation="lanczos")
|
63 |
+
plt.colorbar(ax)
|
64 |
+
plt.axis("off")
|
65 |
+
|
66 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
|
67 |
+
os.makedirs(root, exist_ok=True)
|
68 |
+
path = os.path.join(root, filename)
|
69 |
+
plt.savefig(path)
|
70 |
+
plt.close()
|
71 |
+
# TODO: support wandb
|
72 |
+
else:
|
73 |
+
grid = torchvision.utils.make_grid(images[k].squeeze(2), nrow=4)
|
74 |
+
if self.rescale:
|
75 |
+
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
76 |
+
# print(grid.shape, grid.dtype, grid.min(), grid.max(), k)
|
77 |
+
grid = rearrange(grid.squeeze(1), "c h w -> h w c")
|
78 |
+
# grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
79 |
+
grid = grid.numpy()
|
80 |
+
grid = (grid * 255).astype(np.uint8)
|
81 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
|
82 |
+
path = os.path.join(root, filename)
|
83 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
84 |
+
img = Image.fromarray(grid)
|
85 |
+
img.save(path)
|
86 |
+
if exists(pl_module):
|
87 |
+
assert isinstance(
|
88 |
+
pl_module.logger, WandbLogger
|
89 |
+
), "logger_log_image only supports WandbLogger currently"
|
90 |
+
pl_module.logger.log_image(
|
91 |
+
key=f"{split}/{k}",
|
92 |
+
images=[
|
93 |
+
img,
|
94 |
+
],
|
95 |
+
step=pl_module.global_step,
|
96 |
+
)
|
97 |
+
|
98 |
+
@rank_zero_only
|
99 |
+
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
100 |
+
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
101 |
+
if (
|
102 |
+
self.check_frequency(check_idx)
|
103 |
+
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
|
104 |
+
and callable(pl_module.log_images)
|
105 |
+
and
|
106 |
+
# batch_idx > 5 and
|
107 |
+
self.max_images > 0
|
108 |
+
):
|
109 |
+
logger = type(pl_module.logger)
|
110 |
+
is_train = pl_module.training
|
111 |
+
if is_train:
|
112 |
+
pl_module.eval()
|
113 |
+
|
114 |
+
gpu_autocast_kwargs = {
|
115 |
+
"enabled": self.enable_autocast, # torch.is_autocast_enabled(),
|
116 |
+
"dtype": torch.get_autocast_gpu_dtype(),
|
117 |
+
"cache_enabled": torch.is_autocast_cache_enabled(),
|
118 |
+
}
|
119 |
+
with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
|
120 |
+
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
121 |
+
|
122 |
+
for k in images:
|
123 |
+
N = min(images[k].shape[0], self.max_images)
|
124 |
+
if not isheatmap(images[k]):
|
125 |
+
images[k] = images[k][:N]
|
126 |
+
if isinstance(images[k], torch.Tensor):
|
127 |
+
images[k] = images[k].detach().float().cpu()
|
128 |
+
if self.clamp and not isheatmap(images[k]):
|
129 |
+
images[k] = torch.clamp(images[k], -1.0, 1.0)
|
130 |
+
|
131 |
+
self.log_local(
|
132 |
+
pl_module.logger.save_dir,
|
133 |
+
split,
|
134 |
+
images,
|
135 |
+
pl_module.global_step,
|
136 |
+
pl_module.current_epoch,
|
137 |
+
batch_idx,
|
138 |
+
pl_module=pl_module if isinstance(pl_module.logger, WandbLogger) else None,
|
139 |
+
)
|
140 |
+
|
141 |
+
if is_train:
|
142 |
+
pl_module.train()
|
143 |
+
|
144 |
+
def check_frequency(self, check_idx):
|
145 |
+
if check_idx:
|
146 |
+
check_idx -= 1
|
147 |
+
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
|
148 |
+
check_idx > 0 or self.log_first_step
|
149 |
+
):
|
150 |
+
try:
|
151 |
+
self.log_steps.pop(0)
|
152 |
+
except IndexError as e:
|
153 |
+
print(e)
|
154 |
+
pass
|
155 |
+
return True
|
156 |
+
return False
|
157 |
+
|
158 |
+
@rank_zero_only
|
159 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
160 |
+
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
161 |
+
self.log_img(pl_module, batch, batch_idx, split="train")
|
162 |
+
|
163 |
+
@rank_zero_only
|
164 |
+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
165 |
+
if self.log_before_first_step and pl_module.global_step == 0:
|
166 |
+
print(f"{self.__class__.__name__}: logging before training")
|
167 |
+
self.log_img(pl_module, batch, batch_idx, split="train")
|
168 |
+
|
169 |
+
@rank_zero_only
|
170 |
+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs):
|
171 |
+
if not self.disabled and pl_module.global_step > 0:
|
172 |
+
self.log_img(pl_module, batch, batch_idx, split="val")
|
173 |
+
if hasattr(pl_module, "calibrate_grad_norm"):
|
174 |
+
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
|
175 |
+
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
176 |
+
|
177 |
+
|
178 |
+
@rank_zero_only
|
179 |
+
def init_wandb(save_dir, opt, config, group_name, name_str):
|
180 |
+
print(f"setting WANDB_DIR to {save_dir}")
|
181 |
+
os.makedirs(save_dir, exist_ok=True)
|
182 |
+
|
183 |
+
os.environ["WANDB_DIR"] = save_dir
|
184 |
+
if opt.debug:
|
185 |
+
wandb.init(project=opt.projectname, mode="offline", group=group_name)
|
186 |
+
else:
|
187 |
+
wandb.init(
|
188 |
+
project=opt.projectname,
|
189 |
+
config=config,
|
190 |
+
settings=wandb.Settings(code_dir="./sgm"),
|
191 |
+
group=group_name,
|
192 |
+
name=name_str,
|
193 |
+
)
|
sgm/callbacks/setup_callback.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_lightning.callbacks import Callback
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
import os
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from pytorch_lightning.utilities import rank_zero_only
|
6 |
+
|
7 |
+
MULTINODE_HACKS = True
|
8 |
+
|
9 |
+
|
10 |
+
class SetupCallback(Callback):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
resume,
|
14 |
+
now,
|
15 |
+
logdir,
|
16 |
+
ckptdir,
|
17 |
+
cfgdir,
|
18 |
+
config,
|
19 |
+
lightning_config,
|
20 |
+
debug,
|
21 |
+
ckpt_name=None,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.resume = resume
|
25 |
+
self.now = now
|
26 |
+
self.logdir = logdir
|
27 |
+
self.ckptdir = ckptdir
|
28 |
+
self.cfgdir = cfgdir
|
29 |
+
self.config = config
|
30 |
+
self.lightning_config = lightning_config
|
31 |
+
self.debug = debug
|
32 |
+
self.ckpt_name = ckpt_name
|
33 |
+
|
34 |
+
@rank_zero_only
|
35 |
+
def on_exception(self, trainer: pl.Trainer, pl_module, exception):
|
36 |
+
print("Exception occurred: {}".format(exception))
|
37 |
+
if not self.debug and trainer.global_rank == 0:
|
38 |
+
print("Summoning checkpoint.")
|
39 |
+
if self.ckpt_name is None:
|
40 |
+
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
41 |
+
else:
|
42 |
+
ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)
|
43 |
+
trainer.save_checkpoint(ckpt_path)
|
44 |
+
|
45 |
+
@rank_zero_only
|
46 |
+
def on_fit_start(self, trainer, pl_module):
|
47 |
+
if trainer.global_rank == 0:
|
48 |
+
# Create logdirs and save configs
|
49 |
+
os.makedirs(self.logdir, exist_ok=True)
|
50 |
+
os.makedirs(self.ckptdir, exist_ok=True)
|
51 |
+
os.makedirs(self.cfgdir, exist_ok=True)
|
52 |
+
|
53 |
+
if "callbacks" in self.lightning_config:
|
54 |
+
if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]:
|
55 |
+
os.makedirs(
|
56 |
+
os.path.join(self.ckptdir, "trainstep_checkpoints"),
|
57 |
+
exist_ok=True,
|
58 |
+
)
|
59 |
+
print("Project config")
|
60 |
+
print(OmegaConf.to_yaml(self.config))
|
61 |
+
if MULTINODE_HACKS:
|
62 |
+
import time
|
63 |
+
|
64 |
+
time.sleep(5)
|
65 |
+
OmegaConf.save(
|
66 |
+
self.config,
|
67 |
+
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
|
68 |
+
)
|
69 |
+
|
70 |
+
print("Lightning config")
|
71 |
+
print(OmegaConf.to_yaml(self.lightning_config))
|
72 |
+
OmegaConf.save(
|
73 |
+
OmegaConf.create({"lightning": self.lightning_config}),
|
74 |
+
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
|
75 |
+
)
|
76 |
+
|
77 |
+
else:
|
78 |
+
# ModelCheckpoint callback created log directory --- remove it
|
79 |
+
if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
|
80 |
+
dst, name = os.path.split(self.logdir)
|
81 |
+
dst = os.path.join(dst, "child_runs", name)
|
82 |
+
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
83 |
+
try:
|
84 |
+
os.rename(self.logdir, dst)
|
85 |
+
except FileNotFoundError:
|
86 |
+
pass
|
sgm/callbacks/video_logger.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_lightning.callbacks import Callback
|
2 |
+
from pytorch_lightning.loggers import WandbLogger
|
3 |
+
import numpy as np
|
4 |
+
from pytorch_lightning.utilities import rank_zero_only
|
5 |
+
from typing import Union
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import os
|
8 |
+
from sgm.util import exists, suppress_output, default
|
9 |
+
import torchvision
|
10 |
+
from PIL import Image
|
11 |
+
import torch
|
12 |
+
import wandb
|
13 |
+
import moviepy.editor as mpy
|
14 |
+
from einops import rearrange
|
15 |
+
import torchaudio
|
16 |
+
# import tempfile
|
17 |
+
# import cv2
|
18 |
+
# import scipy.io.wavfile as wav
|
19 |
+
# import ffmpeg
|
20 |
+
|
21 |
+
|
22 |
+
@suppress_output
|
23 |
+
def save_audio_video(
|
24 |
+
video, audio=None, frame_rate=25, sample_rate=16000, save_path="temp.mp4", keep_intermediate=False
|
25 |
+
):
|
26 |
+
"""Save audio and video to a single file.
|
27 |
+
video: (t, c, h, w)
|
28 |
+
audio: (channels t)
|
29 |
+
"""
|
30 |
+
|
31 |
+
# temp_filename = next(tempfile._get_candidate_names())
|
32 |
+
# if save_path:
|
33 |
+
# save_path = save_path
|
34 |
+
# else:
|
35 |
+
# save_path = "/tmp/" + next(tempfile._get_candidate_names()) + ".mp4"
|
36 |
+
save_path = str(save_path)
|
37 |
+
try:
|
38 |
+
torchvision.io.write_video(
|
39 |
+
"temp_video.mp4", rearrange(video.detach().cpu(), "t c h w -> t h w c").to(torch.uint8), frame_rate
|
40 |
+
)
|
41 |
+
video_clip = mpy.VideoFileClip("temp_video.mp4")
|
42 |
+
if audio is not None:
|
43 |
+
torchaudio.save("temp_audio.wav", audio.detach().cpu(), sample_rate)
|
44 |
+
audio_clip = mpy.AudioFileClip("temp_audio.wav")
|
45 |
+
video_clip = video_clip.set_audio(audio_clip)
|
46 |
+
video_clip.write_videofile(save_path, fps=frame_rate, codec="libx264", audio_codec="aac", verbose=False)
|
47 |
+
if not keep_intermediate:
|
48 |
+
os.remove("temp_video.mp4")
|
49 |
+
if audio is not None:
|
50 |
+
os.remove("temp_audio.wav")
|
51 |
+
return 1
|
52 |
+
except Exception as e:
|
53 |
+
print(e)
|
54 |
+
print("Saving video to file failed")
|
55 |
+
return 0
|
56 |
+
|
57 |
+
|
58 |
+
# def write_video_opencv(video, video_rate, video_path):
|
59 |
+
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
60 |
+
# out = cv2.VideoWriter(video_path, fourcc, video_rate, (video.shape[2], video.shape[3]), 0)
|
61 |
+
# for frame in list(video):
|
62 |
+
# frame = np.squeeze(frame)
|
63 |
+
# out.write(np.squeeze(frame))
|
64 |
+
# out.release()
|
65 |
+
|
66 |
+
|
67 |
+
# # Code mostly inherited from bulletin
|
68 |
+
# def save_av_sample(video, video_rate, audio=None, audio_rate=16_000, path=None):
|
69 |
+
# # Save video sample in train dir for debugging
|
70 |
+
# # video_save = 0.5 * video.detach().cpu().numpy() + 0.5
|
71 |
+
# video_save = rearrange(video, "t c h w -> t h w c").detach().cpu().numpy()
|
72 |
+
# temp_filename = next(tempfile._get_candidate_names())
|
73 |
+
# if path:
|
74 |
+
# video_path = path
|
75 |
+
# else:
|
76 |
+
# video_path = "/tmp/" + next(tempfile._get_candidate_names()) + ".mp4"
|
77 |
+
# write_video_opencv((video_save).astype(np.uint8), video_rate, "/tmp/" + temp_filename + ".mp4")
|
78 |
+
# audio_save = audio.detach().squeeze().cpu().numpy()
|
79 |
+
# wav.write("/tmp/" + temp_filename + ".wav", audio_rate, audio_save)
|
80 |
+
# try:
|
81 |
+
# in1 = ffmpeg.input("/tmp/" + temp_filename + ".mp4")
|
82 |
+
# in2 = ffmpeg.input("/tmp/" + temp_filename + ".wav")
|
83 |
+
# out = ffmpeg.output(in1["v"], in2["a"], video_path, loglevel="panic").overwrite_output()
|
84 |
+
# out.run(capture_stdout=True, capture_stderr=True)
|
85 |
+
# except ffmpeg.Error as e:
|
86 |
+
# print("stdout:", e.stdout.decode("utf8"))
|
87 |
+
# print("stderr:", e.stderr.decode("utf8"))
|
88 |
+
# raise e
|
89 |
+
# return video_path
|
90 |
+
|
91 |
+
|
92 |
+
class VideoLogger(Callback):
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
batch_frequency,
|
96 |
+
max_videos,
|
97 |
+
clamp=True,
|
98 |
+
increase_log_steps=True,
|
99 |
+
rescale=True,
|
100 |
+
disabled=False,
|
101 |
+
log_on_batch_idx=False,
|
102 |
+
log_first_step=False,
|
103 |
+
log_videos_kwargs=None,
|
104 |
+
log_before_first_step=False,
|
105 |
+
enable_autocast=True,
|
106 |
+
batch_frequency_val=None,
|
107 |
+
):
|
108 |
+
super().__init__()
|
109 |
+
self.enable_autocast = enable_autocast
|
110 |
+
self.rescale = rescale
|
111 |
+
self.batch_freq = batch_frequency
|
112 |
+
self.max_videos = max_videos
|
113 |
+
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
114 |
+
if not increase_log_steps:
|
115 |
+
self.log_steps = [self.batch_freq]
|
116 |
+
self.batch_freq_val = default(batch_frequency_val, self.batch_freq)
|
117 |
+
self.log_steps_val = [2**n for n in range(int(np.log2(self.batch_freq_val)) + 1)]
|
118 |
+
if not increase_log_steps:
|
119 |
+
self.log_steps_val = [self.batch_freq_val]
|
120 |
+
self.clamp = clamp
|
121 |
+
self.disabled = disabled
|
122 |
+
self.log_on_batch_idx = log_on_batch_idx
|
123 |
+
self.log_videos_kwargs = log_videos_kwargs if log_videos_kwargs else {}
|
124 |
+
self.log_first_step = log_first_step
|
125 |
+
self.log_before_first_step = log_before_first_step
|
126 |
+
|
127 |
+
@rank_zero_only
|
128 |
+
def log_local(
|
129 |
+
self,
|
130 |
+
save_dir,
|
131 |
+
split,
|
132 |
+
log_elements,
|
133 |
+
raw_audio,
|
134 |
+
global_step,
|
135 |
+
current_epoch,
|
136 |
+
batch_idx,
|
137 |
+
pl_module: Union[None, pl.LightningModule] = None,
|
138 |
+
):
|
139 |
+
root = os.path.join(save_dir, "videos", split)
|
140 |
+
for k in log_elements:
|
141 |
+
element = log_elements[k]
|
142 |
+
if len(element.shape) == 4:
|
143 |
+
grid = torchvision.utils.make_grid(element, nrow=4)
|
144 |
+
if self.rescale:
|
145 |
+
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
146 |
+
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
147 |
+
grid = grid.numpy()
|
148 |
+
grid = (grid * 255).astype(np.uint8)
|
149 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
|
150 |
+
path = os.path.join(root, filename)
|
151 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
152 |
+
img = Image.fromarray(grid)
|
153 |
+
img.save(path)
|
154 |
+
if exists(pl_module):
|
155 |
+
assert isinstance(
|
156 |
+
pl_module.logger, WandbLogger
|
157 |
+
), "logger_log_image only supports WandbLogger currently"
|
158 |
+
pl_module.logger.log_image(
|
159 |
+
key=f"{split}/{k}",
|
160 |
+
images=[
|
161 |
+
img,
|
162 |
+
],
|
163 |
+
step=pl_module.global_step,
|
164 |
+
)
|
165 |
+
elif len(element.shape) == 5:
|
166 |
+
video = element
|
167 |
+
if self.rescale:
|
168 |
+
video = (video + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
169 |
+
video = video * 255.0
|
170 |
+
video = video.permute(0, 2, 1, 3, 4).cpu().detach().to(torch.uint8) # b,t,c,h,w
|
171 |
+
for i in range(video.shape[0]):
|
172 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}_{}.mp4".format(k, global_step, current_epoch, batch_idx, i)
|
173 |
+
path = os.path.join(root, filename)
|
174 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
175 |
+
log_audio = raw_audio[i] if raw_audio is not None else None
|
176 |
+
success = save_audio_video(
|
177 |
+
video[i],
|
178 |
+
audio=log_audio.unsqueeze(0) if log_audio is not None else None,
|
179 |
+
frame_rate=25,
|
180 |
+
sample_rate=16000,
|
181 |
+
save_path=path,
|
182 |
+
keep_intermediate=False,
|
183 |
+
)
|
184 |
+
|
185 |
+
# video_path = save_av_sample(video[i], 25, audio=raw_audio, audio_rate=16000, path=None)
|
186 |
+
if exists(pl_module):
|
187 |
+
assert isinstance(
|
188 |
+
pl_module.logger, WandbLogger
|
189 |
+
), "logger_log_image only supports WandbLogger currently"
|
190 |
+
pl_module.logger.experiment.log(
|
191 |
+
{
|
192 |
+
f"{split}/{k}": wandb.Video(
|
193 |
+
path if success else video,
|
194 |
+
# caption=f"diffused videos w {n_frames} frames (condition left, generated right)",
|
195 |
+
fps=25,
|
196 |
+
format="mp4",
|
197 |
+
)
|
198 |
+
},
|
199 |
+
)
|
200 |
+
|
201 |
+
@rank_zero_only
|
202 |
+
def log_video(self, pl_module, batch, batch_idx, split="train"):
|
203 |
+
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
204 |
+
# print(f"check_idx: {check_idx}", f"split: {split}")
|
205 |
+
if (
|
206 |
+
self.check_frequency(check_idx, split=split)
|
207 |
+
and hasattr(pl_module, "log_videos") # batch_idx % self.batch_freq == 0
|
208 |
+
and callable(pl_module.log_videos)
|
209 |
+
and
|
210 |
+
# batch_idx > 5 and
|
211 |
+
self.max_videos > 0
|
212 |
+
):
|
213 |
+
logger = type(pl_module.logger)
|
214 |
+
is_train = pl_module.training
|
215 |
+
if is_train:
|
216 |
+
pl_module.eval()
|
217 |
+
|
218 |
+
gpu_autocast_kwargs = {
|
219 |
+
"enabled": self.enable_autocast, # torch.is_autocast_enabled(),
|
220 |
+
"dtype": torch.get_autocast_gpu_dtype(),
|
221 |
+
"cache_enabled": torch.is_autocast_cache_enabled(),
|
222 |
+
}
|
223 |
+
with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
|
224 |
+
videos = pl_module.log_videos(batch, split=split, **self.log_videos_kwargs)
|
225 |
+
|
226 |
+
for k in videos:
|
227 |
+
N = min(videos[k].shape[0], self.max_videos)
|
228 |
+
videos[k] = videos[k][:N]
|
229 |
+
if isinstance(videos[k], torch.Tensor):
|
230 |
+
videos[k] = videos[k].detach().float().cpu()
|
231 |
+
if self.clamp:
|
232 |
+
videos[k] = torch.clamp(videos[k], -1.0, 1.0)
|
233 |
+
|
234 |
+
raw_audio = batch.get("raw_audio", None)
|
235 |
+
|
236 |
+
self.log_local(
|
237 |
+
pl_module.logger.save_dir,
|
238 |
+
split,
|
239 |
+
videos,
|
240 |
+
raw_audio,
|
241 |
+
pl_module.global_step,
|
242 |
+
pl_module.current_epoch,
|
243 |
+
batch_idx,
|
244 |
+
pl_module=pl_module if isinstance(pl_module.logger, WandbLogger) else None,
|
245 |
+
)
|
246 |
+
|
247 |
+
if is_train:
|
248 |
+
pl_module.train()
|
249 |
+
|
250 |
+
def check_frequency(self, check_idx, split="train"):
|
251 |
+
if split == "val":
|
252 |
+
if check_idx:
|
253 |
+
check_idx -= 1
|
254 |
+
if ((check_idx % self.batch_freq_val) == 0 or (check_idx in self.log_steps_val)) and (
|
255 |
+
check_idx > 0 or self.log_first_step
|
256 |
+
):
|
257 |
+
try:
|
258 |
+
self.log_steps_val.pop(0)
|
259 |
+
except IndexError as e:
|
260 |
+
print(e)
|
261 |
+
pass
|
262 |
+
return True
|
263 |
+
return False
|
264 |
+
if check_idx:
|
265 |
+
check_idx -= 1
|
266 |
+
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
|
267 |
+
check_idx > 0 or self.log_first_step
|
268 |
+
):
|
269 |
+
try:
|
270 |
+
self.log_steps.pop(0)
|
271 |
+
except IndexError as e:
|
272 |
+
print(e)
|
273 |
+
pass
|
274 |
+
return True
|
275 |
+
return False
|
276 |
+
|
277 |
+
@rank_zero_only
|
278 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
279 |
+
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
280 |
+
self.log_video(pl_module, batch, batch_idx, split="train")
|
281 |
+
|
282 |
+
@rank_zero_only
|
283 |
+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
284 |
+
if self.log_before_first_step and pl_module.global_step == 0:
|
285 |
+
print(f"{self.__class__.__name__}: logging before training")
|
286 |
+
self.log_video(pl_module, batch, batch_idx, split="train")
|
287 |
+
|
288 |
+
@rank_zero_only
|
289 |
+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs):
|
290 |
+
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
291 |
+
self.log_video(pl_module, batch, batch_idx, split="val")
|
292 |
+
if hasattr(pl_module, "calibrate_grad_norm"):
|
293 |
+
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
|
294 |
+
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
sgm/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# from .dataset import StableDataModuleFromConfig
|
sgm/data/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (193 Bytes). View file
|
|
sgm/data/__pycache__/data_utils.cpython-311.pyc
ADDED
Binary file (28.3 kB). View file
|
|
sgm/data/__pycache__/mask.cpython-311.pyc
ADDED
Binary file (17.9 kB). View file
|
|
sgm/data/__pycache__/video_datamodule_latent.cpython-311.pyc
ADDED
Binary file (7.4 kB). View file
|
|
sgm/data/__pycache__/video_dataset_latent.cpython-311.pyc
ADDED
Binary file (34.4 kB). View file
|
|
sgm/data/data_utils.py
ADDED
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageDraw
|
4 |
+
import cv2
|
5 |
+
from functools import partial
|
6 |
+
import math
|
7 |
+
|
8 |
+
|
9 |
+
def get_size(img):
|
10 |
+
if isinstance(img, (np.ndarray, torch.Tensor)):
|
11 |
+
return img.shape[1::-1]
|
12 |
+
else:
|
13 |
+
return img.size
|
14 |
+
|
15 |
+
|
16 |
+
def imresample(img, sz):
|
17 |
+
im_data = torch.nn.functional.interpolate(img, size=sz, mode="area")
|
18 |
+
return im_data
|
19 |
+
|
20 |
+
|
21 |
+
def crop_resize(img, box, image_size):
|
22 |
+
if isinstance(img, np.ndarray):
|
23 |
+
img = img[box[1] : box[3], box[0] : box[2]]
|
24 |
+
out = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_AREA).copy()
|
25 |
+
elif isinstance(img, torch.Tensor):
|
26 |
+
img = img[box[1] : box[3], box[0] : box[2]]
|
27 |
+
out = (
|
28 |
+
imresample(img.permute(2, 0, 1).unsqueeze(0).float(), (image_size, image_size))
|
29 |
+
.byte()
|
30 |
+
.squeeze(0)
|
31 |
+
.permute(1, 2, 0)
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR)
|
35 |
+
return out
|
36 |
+
|
37 |
+
|
38 |
+
def fixed_image_standardization(image_tensor):
|
39 |
+
processed_tensor = (image_tensor - 127.5) / 128.0
|
40 |
+
return processed_tensor
|
41 |
+
|
42 |
+
|
43 |
+
def extract_face(img, landmarks, image_size=160, margin=0, postprocess=False):
|
44 |
+
"""Extract face + margin from images given facial landmarks.
|
45 |
+
|
46 |
+
Arguments:
|
47 |
+
img {PIL.Image/torch.Tensor/np.ndarray} -- Input image(s) with shape (B, H, W, C)
|
48 |
+
landmarks {numpy.ndarray} -- Facial landmarks with shape (B, 68, 2)
|
49 |
+
image_size {int} -- Output image size in pixels. The image will be square.
|
50 |
+
margin {int} -- Margin to add to bounding box, in terms of pixels in the final image.
|
51 |
+
postprocess {bool} -- Whether to apply standardization
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
torch.tensor -- tensor representing the extracted faces with shape (B, H, W, C)
|
55 |
+
"""
|
56 |
+
# Calculate bounding boxes from landmarks for all faces in batch
|
57 |
+
x_min = np.min(landmarks, axis=1)[:, 0] # Shape: (B,)
|
58 |
+
y_min = np.min(landmarks, axis=1)[:, 1] # Shape: (B,)
|
59 |
+
x_max = np.max(landmarks, axis=1)[:, 0] # Shape: (B,)
|
60 |
+
y_max = np.max(landmarks, axis=1)[:, 1] # Shape: (B,)
|
61 |
+
|
62 |
+
# Calculate margin for top only
|
63 |
+
box_height = y_max - y_min
|
64 |
+
top_margin = margin * box_height / (image_size - margin)
|
65 |
+
|
66 |
+
# Create boxes for all faces
|
67 |
+
boxes = np.stack(
|
68 |
+
[
|
69 |
+
x_min,
|
70 |
+
np.maximum(y_min - top_margin, 0), # Only add margin to top
|
71 |
+
x_max,
|
72 |
+
y_max,
|
73 |
+
],
|
74 |
+
axis=1,
|
75 |
+
).astype(int) # Shape: (B, 4)
|
76 |
+
|
77 |
+
# Process each face in the batch
|
78 |
+
faces = []
|
79 |
+
for i in range(len(boxes)):
|
80 |
+
face = crop_resize(img[i], boxes[i], image_size)
|
81 |
+
faces.append(face)
|
82 |
+
|
83 |
+
faces = torch.stack(faces, dim=0)
|
84 |
+
faces = faces.float()
|
85 |
+
|
86 |
+
if postprocess:
|
87 |
+
faces = fixed_image_standardization(faces)
|
88 |
+
|
89 |
+
return faces
|
90 |
+
|
91 |
+
|
92 |
+
def crop_mouth_region(images, landmarks, crop_size=96):
|
93 |
+
"""
|
94 |
+
Takes a fixed-size square crop centered on the mouth region.
|
95 |
+
|
96 |
+
Parameters:
|
97 |
+
- images: tensor/array of shape (num_frames, height, width, channels) or (height, width, channels)
|
98 |
+
- landmarks: numpy array of shape (num_frames, 68, 2) or (68, 2)
|
99 |
+
- crop_size: size of the square crop (both height and width)
|
100 |
+
- padding: percentage of padding around the mouth region (0.0 to 1.0)
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
- List of fixed-size crops or single crop if input is single image
|
104 |
+
"""
|
105 |
+
# Handle single image case
|
106 |
+
single_image = False
|
107 |
+
if len(images.shape) == 3:
|
108 |
+
images = images[None]
|
109 |
+
landmarks = landmarks[None]
|
110 |
+
single_image = True
|
111 |
+
|
112 |
+
num_frames = len(images)
|
113 |
+
crops = []
|
114 |
+
|
115 |
+
# Mouth landmarks indices (48-67 for mouth region)
|
116 |
+
mouth_indices = range(48, 68)
|
117 |
+
|
118 |
+
for i in range(num_frames):
|
119 |
+
# Get mouth landmarks for current frame
|
120 |
+
mouth_landmarks = landmarks[i][mouth_indices]
|
121 |
+
|
122 |
+
# Find center of mouth
|
123 |
+
center_x = int(np.mean(mouth_landmarks[:, 0]))
|
124 |
+
center_y = int(np.mean(mouth_landmarks[:, 1]))
|
125 |
+
|
126 |
+
# Calculate crop boundaries
|
127 |
+
half_size = crop_size // 2
|
128 |
+
left = max(0, center_x - half_size)
|
129 |
+
right = min(images.shape[2], center_x + half_size)
|
130 |
+
top = max(0, center_y - half_size)
|
131 |
+
bottom = min(images.shape[1], center_y + half_size)
|
132 |
+
|
133 |
+
# Adjust if crop would go out of bounds
|
134 |
+
if left == 0:
|
135 |
+
right = crop_size
|
136 |
+
if right == images.shape[2]:
|
137 |
+
left = images.shape[2] - crop_size
|
138 |
+
if top == 0:
|
139 |
+
bottom = crop_size
|
140 |
+
if bottom == images.shape[1]:
|
141 |
+
top = images.shape[1] - crop_size
|
142 |
+
|
143 |
+
# Take the crop
|
144 |
+
crop = images[i, top:bottom, left:right]
|
145 |
+
crops.append(crop)
|
146 |
+
|
147 |
+
return crops[0] if single_image else crops
|
148 |
+
|
149 |
+
|
150 |
+
def create_masks_from_landmarks_box(landmark_list, img_shape, nose_index=28, dtype="uint8", box_expand=0.0):
|
151 |
+
height, width = img_shape[:2]
|
152 |
+
num_frames = landmark_list.shape[0]
|
153 |
+
|
154 |
+
# Initialize the masks array
|
155 |
+
masks = np.zeros((num_frames, height, width), dtype=dtype)
|
156 |
+
|
157 |
+
if 0 <= box_expand < 1:
|
158 |
+
box_expand = int(box_expand * width)
|
159 |
+
|
160 |
+
for i in range(num_frames):
|
161 |
+
# Get the landmarks for the current frame
|
162 |
+
landmarks = landmark_list[i]
|
163 |
+
|
164 |
+
# Get the y-coordinate of the nose landmark
|
165 |
+
nose_point_h = landmarks[nose_index, 1]
|
166 |
+
cut_h = nose_point_h
|
167 |
+
|
168 |
+
# Find the leftmost and rightmost landmarks
|
169 |
+
far_left_index = np.argmin(landmarks[:, 0])
|
170 |
+
far_right_index = np.argmax(landmarks[:, 0])
|
171 |
+
|
172 |
+
# Define the points for the mask contour
|
173 |
+
left_up_point = np.array([landmarks[far_left_index][0], cut_h - box_expand], dtype=np.int32)
|
174 |
+
left_down_point = np.array([landmarks[far_left_index][0], height], dtype=np.int32)
|
175 |
+
right_up_point = np.array([landmarks[far_right_index][0], cut_h - box_expand], dtype=np.int32)
|
176 |
+
right_down_point = np.array([landmarks[far_right_index][0], height], dtype=np.int32)
|
177 |
+
|
178 |
+
# Define the contour
|
179 |
+
contour = np.array([[left_up_point, left_down_point, right_down_point, right_up_point]])
|
180 |
+
|
181 |
+
# Draw the contour on the mask
|
182 |
+
cv2.drawContours(masks[i], [contour], -1, color=(1), thickness=cv2.FILLED)
|
183 |
+
|
184 |
+
return torch.from_numpy(masks)
|
185 |
+
|
186 |
+
|
187 |
+
def create_masks_from_landmarks_full_size(
|
188 |
+
landmarks_batch, image_height, image_width, start_index=48, end_index=68, offset=0, nose_index=33
|
189 |
+
):
|
190 |
+
"""
|
191 |
+
Efficiently creates a batch of masks using vectorized operations where each mask has ones from the highest
|
192 |
+
landmark in the specified range (adjusted by an offset) to the bottom of the image, and zeros otherwise.
|
193 |
+
|
194 |
+
Parameters:
|
195 |
+
- landmarks_batch (np.array): An array of shape (B, 68, 2) containing facial landmarks for multiple samples.
|
196 |
+
- image_height (int): The height of the image for which masks are created.
|
197 |
+
- image_width (int): The width of the image for which masks are created.
|
198 |
+
- start_index (int): The starting index of the range to check (inclusive).
|
199 |
+
- end_index (int): The ending index of the range to check (inclusive).
|
200 |
+
- offset (int): An offset to add or subtract from the y-coordinate of the highest landmark.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
- np.array: An array of masks of shape (B, image_height, image_width) for each batch.
|
204 |
+
"""
|
205 |
+
# Extract the y-coordinates for the specified range across all batches
|
206 |
+
y_coords = landmarks_batch[:, nose_index : nose_index + 1, 1]
|
207 |
+
|
208 |
+
# Find the index of the minimum y-coordinate in the specified range for each batch
|
209 |
+
min_y_indices = np.argmin(y_coords, axis=1)
|
210 |
+
|
211 |
+
# Gather the highest landmarks' y-coordinates using the indices found
|
212 |
+
highest_y_coords = y_coords[np.arange(len(y_coords)), min_y_indices]
|
213 |
+
|
214 |
+
if abs(offset) < 1 and abs(offset) > 0:
|
215 |
+
offset = int(offset * image_height)
|
216 |
+
|
217 |
+
# Apply the offset to the highest y-coordinate
|
218 |
+
adjusted_y_coords = highest_y_coords + offset
|
219 |
+
|
220 |
+
# Clip the coordinates to stay within image boundaries
|
221 |
+
adjusted_y_coords = np.clip(adjusted_y_coords, 0, image_height - 1)
|
222 |
+
|
223 |
+
# Use broadcasting to create a mask without loops
|
224 |
+
# Create a range of indices from 0 to image_height - 1
|
225 |
+
all_indices = np.arange(image_height)
|
226 |
+
|
227 |
+
# Compare each index in 'all_indices' to each 'adjusted_y_coord' in the batch
|
228 |
+
# 'all_indices' has shape (image_height,), we reshape to (1, image_height) to broadcast against (B, 1)
|
229 |
+
mask_2d = (all_indices >= adjusted_y_coords[:, None]).astype(int)
|
230 |
+
|
231 |
+
# Extend the 2D mask to a full 3D mask of size (B, image_height, image_width)
|
232 |
+
full_mask = np.tile(mask_2d[:, :, np.newaxis], (1, 1, image_width))
|
233 |
+
|
234 |
+
return torch.from_numpy(full_mask)
|
235 |
+
|
236 |
+
|
237 |
+
def expand_polygon(polygon, expand_size):
|
238 |
+
"""
|
239 |
+
Expands the polygon outward by a specified number of pixels.
|
240 |
+
|
241 |
+
Parameters:
|
242 |
+
- polygon (list of tuples): The polygon points as (x, y).
|
243 |
+
- expand_size (int): The number of pixels to expand the polygon outward.
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
- expanded_polygon (list of tuples): The expanded polygon points as (x, y).
|
247 |
+
"""
|
248 |
+
if expand_size == 0:
|
249 |
+
return polygon
|
250 |
+
|
251 |
+
# Calculate centroid of the polygon
|
252 |
+
centroid_x = sum([point[0] for point in polygon]) / len(polygon)
|
253 |
+
centroid_y = sum([point[1] for point in polygon]) / len(polygon)
|
254 |
+
|
255 |
+
# Expand each point outward from the centroid
|
256 |
+
expanded_polygon = []
|
257 |
+
for x, y in polygon:
|
258 |
+
vector_x = x - centroid_x
|
259 |
+
vector_y = y - centroid_y
|
260 |
+
length = np.sqrt(vector_x**2 + vector_y**2)
|
261 |
+
if length == 0:
|
262 |
+
expanded_polygon.append((x, y))
|
263 |
+
else:
|
264 |
+
new_x = x + expand_size * (vector_x / length)
|
265 |
+
new_y = y + expand_size * (vector_y / length)
|
266 |
+
expanded_polygon.append((int(new_x), int(new_y)))
|
267 |
+
|
268 |
+
return expanded_polygon
|
269 |
+
|
270 |
+
|
271 |
+
def create_masks_from_landmarks_mouth(landmark_list, img_shape, nose_index=33, dtype="uint8", box_expand=0.0):
|
272 |
+
height, width = img_shape[:2]
|
273 |
+
num_frames = landmark_list.shape[0]
|
274 |
+
|
275 |
+
# Initialize the masks array
|
276 |
+
masks = np.zeros((num_frames, height, width), dtype=dtype)
|
277 |
+
|
278 |
+
if 0 <= box_expand < 1:
|
279 |
+
box_expand = int(box_expand * width)
|
280 |
+
|
281 |
+
for i in range(num_frames):
|
282 |
+
# Get the landmarks for the current frame
|
283 |
+
landmarks = landmark_list[i]
|
284 |
+
|
285 |
+
# Get the y-coordinate of the nose landmark
|
286 |
+
nose_point_h = landmarks[nose_index, 1]
|
287 |
+
cut_h = nose_point_h
|
288 |
+
|
289 |
+
# Find the leftmost and rightmost landmarks
|
290 |
+
far_left_index = np.argmin(landmarks[:, 0])
|
291 |
+
far_right_index = np.argmax(landmarks[:, 0])
|
292 |
+
|
293 |
+
# Find lowest landmark y-coordinate
|
294 |
+
lowest_y = np.max(landmarks[:, 1])
|
295 |
+
# Add box_expand to the lowest point
|
296 |
+
lowest_y = min(height, lowest_y + box_expand)
|
297 |
+
|
298 |
+
# Define the points for the mask contour
|
299 |
+
left_up_point = np.array([landmarks[far_left_index][0], cut_h - box_expand], dtype=np.int32)
|
300 |
+
left_down_point = np.array([landmarks[far_left_index][0], lowest_y], dtype=np.int32)
|
301 |
+
right_up_point = np.array([landmarks[far_right_index][0], cut_h - box_expand], dtype=np.int32)
|
302 |
+
right_down_point = np.array([landmarks[far_right_index][0], lowest_y], dtype=np.int32)
|
303 |
+
|
304 |
+
# Define the contour
|
305 |
+
contour = np.array([[left_up_point, left_down_point, right_down_point, right_up_point]])
|
306 |
+
|
307 |
+
# Draw the contour on the mask
|
308 |
+
cv2.drawContours(masks[i], [contour], -1, color=(1), thickness=cv2.FILLED)
|
309 |
+
|
310 |
+
return torch.from_numpy(masks)
|
311 |
+
|
312 |
+
|
313 |
+
def create_face_mask_from_landmarks(landmarks_batch, image_height, image_width, mask_expand=0):
|
314 |
+
"""
|
315 |
+
Creates a batch of masks where each mask covers the face region using landmarks.
|
316 |
+
|
317 |
+
Parameters:
|
318 |
+
- landmarks_batch (np.array): An array of shape (B, 68, 2) containing facial landmarks for multiple samples.
|
319 |
+
- image_height (int): The height of the image for which masks are created.
|
320 |
+
- image_width (int): The width of the image for which masks are created.
|
321 |
+
- mask_expand (int): The number of pixels to expand the mask outward.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
- np.array: An array of masks of shape (B, image_height, image_width) for each batch.
|
325 |
+
"""
|
326 |
+
# Initialize an array to hold all masks
|
327 |
+
masks = np.zeros((landmarks_batch.shape[0], image_height, image_width), dtype=np.uint8)
|
328 |
+
|
329 |
+
if abs(mask_expand) < 1 and abs(mask_expand) > 0:
|
330 |
+
mask_expand = int(mask_expand * image_height)
|
331 |
+
|
332 |
+
for i, landmarks in enumerate(landmarks_batch):
|
333 |
+
# Create a blank image for each mask
|
334 |
+
mask = Image.new("L", (image_width, image_height), 0)
|
335 |
+
draw = ImageDraw.Draw(mask)
|
336 |
+
|
337 |
+
# Extract relevant landmarks for the face
|
338 |
+
jawline_landmarks = landmarks[2:15] # Jawline
|
339 |
+
# upper_face_landmarks = landmarks[17:27] # Eyebrows and top of nose bridge
|
340 |
+
|
341 |
+
# Combine landmarks to form a polygon around the face
|
342 |
+
# face_polygon = np.concatenate((jawline_landmarks, upper_face_landmarks[::-1]), axis=0)
|
343 |
+
face_polygon = jawline_landmarks
|
344 |
+
|
345 |
+
# Convert landmarks to a list of tuples
|
346 |
+
face_polygon = [(int(x), int(y)) for x, y in face_polygon]
|
347 |
+
|
348 |
+
# Expand the polygon if necessary
|
349 |
+
expanded_polygon = expand_polygon(face_polygon, mask_expand)
|
350 |
+
|
351 |
+
# Draw the polygon and fill it
|
352 |
+
draw.polygon(expanded_polygon, outline=1, fill=1)
|
353 |
+
|
354 |
+
# Convert mask to numpy array and add it to the batch of masks
|
355 |
+
masks[i] = np.array(mask)
|
356 |
+
|
357 |
+
return torch.from_numpy(masks)
|
358 |
+
|
359 |
+
|
360 |
+
ALL_FIXED_POINTS = (
|
361 |
+
[i for i in range(0, 4)] + [i for i in range(13, 17)] + [i for i in range(27, 36)] + [36, 39, 42, 45]
|
362 |
+
)
|
363 |
+
|
364 |
+
|
365 |
+
def gaussian_kernel(sigma, width, height):
|
366 |
+
"""Create a 2D Gaussian kernel."""
|
367 |
+
x = torch.arange(0, width, 1) - width // 2
|
368 |
+
y = torch.arange(0, height, 1) - height // 2
|
369 |
+
x = x.float()
|
370 |
+
y = y.float()
|
371 |
+
x2 = x**2
|
372 |
+
y2 = y[:, None] ** 2
|
373 |
+
g = torch.exp(-(x2 + y2) / (2 * sigma**2))
|
374 |
+
return g / g.sum()
|
375 |
+
|
376 |
+
|
377 |
+
def generate_hm(landmarks, height, width, n_points="all", sigma=3):
|
378 |
+
if n_points == "all":
|
379 |
+
Nlandmarks = range(len(landmarks))
|
380 |
+
elif n_points == "fixed":
|
381 |
+
Nlandmarks = ALL_FIXED_POINTS
|
382 |
+
elif n_points == "stable":
|
383 |
+
Nlandmarks = [33, 36, 39, 42, 45]
|
384 |
+
|
385 |
+
kernel = gaussian_kernel(sigma, width, height)
|
386 |
+
hm = torch.zeros((height, width))
|
387 |
+
for I in Nlandmarks:
|
388 |
+
x0, y0 = landmarks[I]
|
389 |
+
x0, y0 = int(x0), int(y0)
|
390 |
+
left, right = max(0, x0 - width // 2), min(width, x0 + width // 2)
|
391 |
+
top, bottom = max(0, y0 - height // 2), min(height, y0 + height // 2)
|
392 |
+
hm[top:bottom, left:right] += kernel[
|
393 |
+
max(0, -y0 + height // 2) : min(height, height - y0 + height // 2),
|
394 |
+
max(0, -x0 + width // 2) : min(width, width - x0 + width // 2),
|
395 |
+
]
|
396 |
+
# Normalize the heatmap to have values between 0 and 1
|
397 |
+
max_val = hm.max()
|
398 |
+
if max_val > 0:
|
399 |
+
hm /= max_val
|
400 |
+
return hm
|
401 |
+
|
402 |
+
|
403 |
+
def get_heatmap(landmarks, image_size, or_im_size, n_points="stable", sigma=4):
|
404 |
+
stack = []
|
405 |
+
seq_length = landmarks.shape[0]
|
406 |
+
if or_im_size[0] != image_size[0] or or_im_size[1] != image_size[1]:
|
407 |
+
landmarks = scale_landmarks(landmarks, or_im_size, image_size)
|
408 |
+
gen_single_heatmap = partial(
|
409 |
+
generate_hm,
|
410 |
+
height=image_size[0],
|
411 |
+
width=image_size[1],
|
412 |
+
n_points=n_points,
|
413 |
+
sigma=sigma,
|
414 |
+
)
|
415 |
+
for i in range(seq_length):
|
416 |
+
stack.append(gen_single_heatmap(landmarks[i]))
|
417 |
+
|
418 |
+
return torch.stack(stack, axis=0).unsqueeze(0) # (1, seq_length, height, width)
|
419 |
+
|
420 |
+
|
421 |
+
def scale_landmarks(landmarks, original_size, target_size):
|
422 |
+
"""
|
423 |
+
Scale landmarks from original size to target size.
|
424 |
+
|
425 |
+
Parameters:
|
426 |
+
- landmarks (np.array): An array of shape (N, 2) containing facial landmarks.
|
427 |
+
- original_size (tuple): The size (height, width) for which the landmarks are currently scaled.
|
428 |
+
- target_size (tuple): The size (height, width) to which landmarks should be scaled.
|
429 |
+
|
430 |
+
Returns:
|
431 |
+
- scaled_landmarks (np.array): Scaled landmarks.
|
432 |
+
"""
|
433 |
+
scale_y = target_size[0] / original_size[0]
|
434 |
+
scale_x = target_size[1] / original_size[1]
|
435 |
+
scaled_landmarks = landmarks * np.array([scale_x, scale_y])
|
436 |
+
return scaled_landmarks.astype(int)
|
437 |
+
|
438 |
+
|
439 |
+
def draw_kps_image(
|
440 |
+
image_shape, original_size, landmarks, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255)], rgb=True, pts_width=4
|
441 |
+
):
|
442 |
+
stick_width = pts_width
|
443 |
+
limb_seq = np.array([[0, 2], [1, 2]])
|
444 |
+
kps = landmarks[[36, 45, 33], :]
|
445 |
+
kps = scale_landmarks(kps, original_size, image_shape)
|
446 |
+
if not rgb: # Grayscale image
|
447 |
+
canvas = np.zeros((image_shape[0], image_shape[1], 1))
|
448 |
+
color_mode = "grayscale"
|
449 |
+
else: # Color image
|
450 |
+
canvas = np.zeros((image_shape[0], image_shape[1], 3))
|
451 |
+
color_mode = "color"
|
452 |
+
|
453 |
+
polygon_cache = {}
|
454 |
+
|
455 |
+
for index in limb_seq:
|
456 |
+
color = color_list[index[0]]
|
457 |
+
if color_mode == "grayscale":
|
458 |
+
color = (int(0.299 * color[2] + 0.587 * color[1] + 0.114 * color[0]),) # Convert to grayscale intensity
|
459 |
+
|
460 |
+
x = kps[index][:, 0]
|
461 |
+
y = kps[index][:, 1]
|
462 |
+
length = np.sqrt((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2)
|
463 |
+
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
464 |
+
|
465 |
+
cache_key = (color, int(np.mean(x)), int(np.mean(y)), int(length / 2), int(angle))
|
466 |
+
if cache_key not in polygon_cache:
|
467 |
+
polygon_cache[cache_key] = cv2.ellipse2Poly(
|
468 |
+
(int(np.mean(x)), int(np.mean(y))), (int(length / 2), stick_width), int(angle), 0, 360, 1
|
469 |
+
)
|
470 |
+
|
471 |
+
polygon = polygon_cache[cache_key]
|
472 |
+
cv2.fillConvexPoly(canvas, polygon, [int(c * 0.6) for c in color])
|
473 |
+
|
474 |
+
for idx, kp in enumerate(kps):
|
475 |
+
if color_mode == "grayscale":
|
476 |
+
color = (int(0.299 * color_list[idx][2] + 0.587 * color_list[idx][1] + 0.114 * color_list[idx][0]),)
|
477 |
+
else:
|
478 |
+
color = color_list[idx]
|
479 |
+
cv2.circle(canvas, (int(kp[0]), int(kp[1])), pts_width, color, -1)
|
480 |
+
|
481 |
+
return canvas.transpose(2, 0, 1)
|
482 |
+
|
483 |
+
|
484 |
+
def create_landmarks_image(
|
485 |
+
landmarks, original_size=(772, 772), target_size=(772, 772), point_size=3, n_points="all", dim=3
|
486 |
+
):
|
487 |
+
"""
|
488 |
+
Creates an image of landmarks on a black background using efficient NumPy operations.
|
489 |
+
|
490 |
+
Parameters:
|
491 |
+
- landmarks (np.array): An array of shape (68, 2) containing facial landmarks.
|
492 |
+
- image_size (tuple): The size of the output image (height, width).
|
493 |
+
- point_size (int): The radius of each landmark point in pixels.
|
494 |
+
|
495 |
+
Returns:
|
496 |
+
- img (np.array): An image array with landmarks plotted.
|
497 |
+
"""
|
498 |
+
if n_points == "all":
|
499 |
+
indexes = range(len(landmarks))
|
500 |
+
elif n_points == "fixed":
|
501 |
+
indexes = ALL_FIXED_POINTS
|
502 |
+
elif n_points == "stable":
|
503 |
+
indexes = [33, 36, 39, 42, 45]
|
504 |
+
|
505 |
+
landmarks = landmarks[indexes]
|
506 |
+
|
507 |
+
img = np.zeros(target_size, dtype=np.uint8)
|
508 |
+
|
509 |
+
landmarks = scale_landmarks(landmarks, original_size, target_size)
|
510 |
+
|
511 |
+
# Ensure the landmarks are in bounds and integer
|
512 |
+
landmarks = np.clip(landmarks, [0, 0], [target_size[1] - 1, target_size[0] - 1]).astype(int)
|
513 |
+
|
514 |
+
# Get x and y coordinates from landmarks
|
515 |
+
x, y = landmarks[:, 0], landmarks[:, 1]
|
516 |
+
|
517 |
+
# Define a grid offset based on point_size around each landmark
|
518 |
+
offset = np.arange(-point_size // 2, point_size // 2 + 1)
|
519 |
+
grid_x, grid_y = np.meshgrid(offset, offset, indexing="ij")
|
520 |
+
|
521 |
+
# Calculate the full set of x and y coordinates for the points
|
522 |
+
full_x = x[:, np.newaxis, np.newaxis] + grid_x[np.newaxis, :, :]
|
523 |
+
full_y = y[:, np.newaxis, np.newaxis] + grid_y[np.newaxis, :, :]
|
524 |
+
|
525 |
+
# Clip the coordinates to stay within image boundaries
|
526 |
+
full_x = np.clip(full_x, 0, target_size[1] - 1)
|
527 |
+
full_y = np.clip(full_y, 0, target_size[0] - 1)
|
528 |
+
|
529 |
+
# Flatten the arrays to use them as indices
|
530 |
+
full_x = full_x.ravel()
|
531 |
+
full_y = full_y.ravel()
|
532 |
+
|
533 |
+
# Set the points in the image
|
534 |
+
img[full_y, full_x] = 255
|
535 |
+
|
536 |
+
return np.stack([img] * dim, axis=0)
|
537 |
+
|
538 |
+
|
539 |
+
def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
|
540 |
+
len_file = audio.shape[-1]
|
541 |
+
|
542 |
+
if max_len_sec or max_len_raw:
|
543 |
+
max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
|
544 |
+
if len_file < int(max_len):
|
545 |
+
# dummy = np.zeros((1, int(max_len_sec * sr) - len_file))
|
546 |
+
# extened_wav = np.concatenate((audio_data, dummy[0]))
|
547 |
+
extened_wav = torch.nn.functional.pad(audio, (0, int(max_len) - len_file), "constant")
|
548 |
+
else:
|
549 |
+
extened_wav = audio[:, : int(max_len)]
|
550 |
+
else:
|
551 |
+
extened_wav = audio
|
552 |
+
|
553 |
+
return extened_wav
|
554 |
+
|
555 |
+
|
556 |
+
def ssim_to_bin(ssim_score):
|
557 |
+
# Normalize the SSIM score to a 0-100 scale
|
558 |
+
normalized_diff_ssim = (1 - ((ssim_score + 1) / 2)) * 100
|
559 |
+
# Assign to one of the 100 bins
|
560 |
+
bin_index = float(min(np.floor(normalized_diff_ssim), 99))
|
561 |
+
return bin_index
|
sgm/data/dataset.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torchdata.datapipes.iter
|
4 |
+
import webdataset as wds
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
from pytorch_lightning import LightningDataModule
|
7 |
+
|
8 |
+
try:
|
9 |
+
from sdata import create_dataset, create_dummy_dataset, create_loader
|
10 |
+
except ImportError as e:
|
11 |
+
print("#" * 100)
|
12 |
+
print("Datasets not yet available")
|
13 |
+
print("to enable, we need to add stable-datasets as a submodule")
|
14 |
+
print("please use ``git submodule update --init --recursive``")
|
15 |
+
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
|
16 |
+
print("#" * 100)
|
17 |
+
exit(1)
|
18 |
+
|
19 |
+
|
20 |
+
class StableDataModuleFromConfig(LightningDataModule):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
train: DictConfig,
|
24 |
+
validation: Optional[DictConfig] = None,
|
25 |
+
test: Optional[DictConfig] = None,
|
26 |
+
skip_val_loader: bool = False,
|
27 |
+
dummy: bool = False,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.train_config = train
|
31 |
+
assert (
|
32 |
+
"datapipeline" in self.train_config and "loader" in self.train_config
|
33 |
+
), "train config requires the fields `datapipeline` and `loader`"
|
34 |
+
|
35 |
+
self.val_config = validation
|
36 |
+
if not skip_val_loader:
|
37 |
+
if self.val_config is not None:
|
38 |
+
assert (
|
39 |
+
"datapipeline" in self.val_config and "loader" in self.val_config
|
40 |
+
), "validation config requires the fields `datapipeline` and `loader`"
|
41 |
+
else:
|
42 |
+
print(
|
43 |
+
"Warning: No Validation datapipeline defined, using that one from training"
|
44 |
+
)
|
45 |
+
self.val_config = train
|
46 |
+
|
47 |
+
self.test_config = test
|
48 |
+
if self.test_config is not None:
|
49 |
+
assert (
|
50 |
+
"datapipeline" in self.test_config and "loader" in self.test_config
|
51 |
+
), "test config requires the fields `datapipeline` and `loader`"
|
52 |
+
|
53 |
+
self.dummy = dummy
|
54 |
+
if self.dummy:
|
55 |
+
print("#" * 100)
|
56 |
+
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
|
57 |
+
print("#" * 100)
|
58 |
+
|
59 |
+
def setup(self, stage: str) -> None:
|
60 |
+
print("Preparing datasets")
|
61 |
+
if self.dummy:
|
62 |
+
data_fn = create_dummy_dataset
|
63 |
+
else:
|
64 |
+
data_fn = create_dataset
|
65 |
+
|
66 |
+
self.train_datapipeline = data_fn(**self.train_config.datapipeline)
|
67 |
+
if self.val_config:
|
68 |
+
self.val_datapipeline = data_fn(**self.val_config.datapipeline)
|
69 |
+
if self.test_config:
|
70 |
+
self.test_datapipeline = data_fn(**self.test_config.datapipeline)
|
71 |
+
|
72 |
+
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
|
73 |
+
loader = create_loader(self.train_datapipeline, **self.train_config.loader)
|
74 |
+
return loader
|
75 |
+
|
76 |
+
def val_dataloader(self) -> wds.DataPipeline:
|
77 |
+
return create_loader(self.val_datapipeline, **self.val_config.loader)
|
78 |
+
|
79 |
+
def test_dataloader(self) -> wds.DataPipeline:
|
80 |
+
return create_loader(self.test_datapipeline, **self.test_config.loader)
|
sgm/data/mask.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
|
3 |
+
"""
|
4 |
+
Functions taken from https://github.com/DanBigioi/DiffusionVideoEditing
|
5 |
+
|
6 |
+
|
7 |
+
"""
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
" Countour from 2:15 not good for head poses "
|
14 |
+
|
15 |
+
|
16 |
+
def face_mask(img_shape, landmark_list, dtype="uint8"):
|
17 |
+
height, width = img_shape[:2]
|
18 |
+
mask = np.ones((height, width, 1), dtype=dtype)
|
19 |
+
cv2.drawContours(
|
20 |
+
mask, np.int32([landmark_list[2:15]]), -1, color=(0), thickness=cv2.FILLED
|
21 |
+
)
|
22 |
+
|
23 |
+
return mask
|
24 |
+
|
25 |
+
|
26 |
+
def face_mask_jaw_box(img_shape, landmark_list, dtype="uint8", kernel_size=10):
|
27 |
+
nose = 33
|
28 |
+
jaw = 8
|
29 |
+
|
30 |
+
height, width = img_shape[:2]
|
31 |
+
mask = np.ones((height, width, 1), dtype=dtype)
|
32 |
+
combined_landmarks = np.concatenate((landmark_list[2:15], [landmark_list[33]]))
|
33 |
+
|
34 |
+
# Draw the combined contour on the mask
|
35 |
+
cv2.drawContours(
|
36 |
+
mask, [np.int32(combined_landmarks)], -1, color=(0), thickness=cv2.FILLED
|
37 |
+
)
|
38 |
+
|
39 |
+
inverted_mask = 1 - mask
|
40 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
41 |
+
mask = cv2.dilate(inverted_mask, kernel, iterations=1)
|
42 |
+
mask = np.expand_dims(
|
43 |
+
mask, axis=-1
|
44 |
+
) # Add a singleton dimension to match the number of channels
|
45 |
+
mask = 1 - mask
|
46 |
+
|
47 |
+
cut_h = landmark_list[nose][1]
|
48 |
+
|
49 |
+
far_left = int(np.argmin(landmark_list[:, 0]))
|
50 |
+
far_right = int(np.argmax(landmark_list[:, 0]))
|
51 |
+
left_up_point = np.int32([landmark_list[far_left][0], cut_h]) # 2
|
52 |
+
right_up_point = np.int32([landmark_list[far_right][0], cut_h]) # 15
|
53 |
+
height_landmarks = min(landmark_list[jaw, 1] + 20, height)
|
54 |
+
left_down_point = np.int32([landmark_list[far_left][0], height_landmarks])
|
55 |
+
right_down_point = np.int32([landmark_list[far_right][0], height_landmarks])
|
56 |
+
|
57 |
+
# print(cut_h, cut_h + 10, height_landmarks)
|
58 |
+
|
59 |
+
mask_box = [left_up_point, left_down_point, right_down_point, right_up_point]
|
60 |
+
|
61 |
+
return mask, mask_box
|
62 |
+
|
63 |
+
|
64 |
+
" Stretch the tight face mask - Countour from 2:15 but dilate, not good for extreme head poses "
|
65 |
+
|
66 |
+
|
67 |
+
def face_mask_stretch(img_shape, landmark_list, dtype="uint8", kernel_size=10):
|
68 |
+
height, width = img_shape[:2]
|
69 |
+
mask = np.ones((height, width, 1), dtype=dtype)
|
70 |
+
combined_landmarks = np.concatenate((landmark_list[2:15], [landmark_list[33]]))
|
71 |
+
|
72 |
+
# Draw the combined contour on the mask
|
73 |
+
cv2.drawContours(
|
74 |
+
mask, [np.int32(combined_landmarks)], -1, color=(0), thickness=cv2.FILLED
|
75 |
+
)
|
76 |
+
|
77 |
+
# cv2.drawContours(mask, np.int32([landmark_list[2:15]]), -1, color=(0), thickness=cv2.FILLED)
|
78 |
+
inverted_mask = 1 - mask
|
79 |
+
|
80 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
81 |
+
mask = cv2.dilate(inverted_mask, kernel, iterations=1)
|
82 |
+
mask = np.expand_dims(
|
83 |
+
mask, axis=-1
|
84 |
+
) # Add a singleton dimension to match the number of channels
|
85 |
+
mask = 1 - mask
|
86 |
+
|
87 |
+
return mask
|
88 |
+
|
89 |
+
|
90 |
+
" Small box around mouth - Use far left, far right points for extreme head poses, cut between nose and upper mouth point"
|
91 |
+
|
92 |
+
|
93 |
+
def face_mask_box_pose(img_shape, landmark_list, dtype="uint8"):
|
94 |
+
"""
|
95 |
+
When the head pose is different than frontal then the normal cropping with landmarks does not work correctly.
|
96 |
+
Crop using as height the middle nose point
|
97 |
+
Take the left/right corners using the far_left and far_right landmarks
|
98 |
+
TODO: Maybe it is better to add some more pixels to have a bigger mask, especially on large head poses
|
99 |
+
"""
|
100 |
+
|
101 |
+
height, width = img_shape[:2]
|
102 |
+
|
103 |
+
nose = 33
|
104 |
+
upper_lip = 51
|
105 |
+
jaw = 8
|
106 |
+
|
107 |
+
nose_point_h = landmark_list[nose, 1]
|
108 |
+
upper_lip_point = landmark_list[upper_lip, 1]
|
109 |
+
cut_h = (upper_lip_point - nose_point_h) / 2 + nose_point_h
|
110 |
+
|
111 |
+
# cut_h = landmark_list[nose][1]
|
112 |
+
|
113 |
+
mask = np.ones((height, width, 1), dtype=dtype)
|
114 |
+
|
115 |
+
far_left = int(np.argmin(landmark_list[:, 0]))
|
116 |
+
far_right = int(np.argmax(landmark_list[:, 0]))
|
117 |
+
|
118 |
+
left_up_point = np.int32([landmark_list[far_left][0], cut_h]) # 2
|
119 |
+
right_up_point = np.int32([landmark_list[far_right][0], cut_h]) # 15
|
120 |
+
|
121 |
+
height_landmarks = min(landmark_list[jaw, 1] + 20, height)
|
122 |
+
left_down_point = np.int32([landmark_list[far_left][0], height_landmarks])
|
123 |
+
right_down_point = np.int32([landmark_list[far_right][0], height_landmarks])
|
124 |
+
|
125 |
+
cv2.drawContours(
|
126 |
+
mask,
|
127 |
+
np.int32(
|
128 |
+
[
|
129 |
+
[
|
130 |
+
left_up_point,
|
131 |
+
left_down_point,
|
132 |
+
right_up_point,
|
133 |
+
right_down_point,
|
134 |
+
left_up_point,
|
135 |
+
right_up_point,
|
136 |
+
left_down_point,
|
137 |
+
right_down_point,
|
138 |
+
]
|
139 |
+
]
|
140 |
+
),
|
141 |
+
-1,
|
142 |
+
color=(0),
|
143 |
+
thickness=cv2.FILLED,
|
144 |
+
)
|
145 |
+
|
146 |
+
return mask
|
147 |
+
|
148 |
+
|
149 |
+
" Small box around mouth - Use far left, far right points for extreme head poses, cut from nose"
|
150 |
+
|
151 |
+
|
152 |
+
def face_mask_box_pose_nose(
|
153 |
+
img_shape,
|
154 |
+
landmark_list,
|
155 |
+
dtype="uint8",
|
156 |
+
get_box=False,
|
157 |
+
pixels_above_nose=None,
|
158 |
+
pixels_under_jaw=None,
|
159 |
+
):
|
160 |
+
height, width = img_shape[:2]
|
161 |
+
|
162 |
+
nose = 33
|
163 |
+
jaw = 8
|
164 |
+
|
165 |
+
cut_h = landmark_list[nose][1]
|
166 |
+
if pixels_above_nose is not None:
|
167 |
+
# this is only for inference to take a bigger mask and blend it back to the original frame
|
168 |
+
cut_h = cut_h - pixels_above_nose
|
169 |
+
|
170 |
+
mask = np.ones((height, width, 1), dtype=dtype)
|
171 |
+
|
172 |
+
far_left = int(np.argmin(landmark_list[:, 0]))
|
173 |
+
far_right = int(np.argmax(landmark_list[:, 0]))
|
174 |
+
|
175 |
+
left_up_point = np.int32([landmark_list[far_left][0], cut_h]) # 2
|
176 |
+
right_up_point = np.int32([landmark_list[far_right][0], cut_h]) # 15
|
177 |
+
|
178 |
+
height_landmarks = min(landmark_list[jaw, 1] + 20, height)
|
179 |
+
if pixels_under_jaw is not None:
|
180 |
+
height_landmarks = min(landmark_list[jaw, 1] + pixels_under_jaw, height)
|
181 |
+
left_down_point = np.int32([landmark_list[far_left][0], height_landmarks])
|
182 |
+
right_down_point = np.int32([landmark_list[far_right][0], height_landmarks])
|
183 |
+
|
184 |
+
cv2.drawContours(
|
185 |
+
mask,
|
186 |
+
np.int32(
|
187 |
+
[
|
188 |
+
[
|
189 |
+
left_up_point,
|
190 |
+
left_down_point,
|
191 |
+
right_up_point,
|
192 |
+
right_down_point,
|
193 |
+
left_up_point,
|
194 |
+
right_up_point,
|
195 |
+
left_down_point,
|
196 |
+
right_down_point,
|
197 |
+
]
|
198 |
+
]
|
199 |
+
),
|
200 |
+
-1,
|
201 |
+
color=(0),
|
202 |
+
thickness=cv2.FILLED,
|
203 |
+
)
|
204 |
+
|
205 |
+
if get_box:
|
206 |
+
mask_box = [left_up_point, left_down_point, right_down_point, right_up_point]
|
207 |
+
return mask, mask_box
|
208 |
+
else:
|
209 |
+
return mask
|
210 |
+
|
211 |
+
|
212 |
+
def face_mask_box_pose_big(
|
213 |
+
img_shape, landmark_list, dtype="uint8", cut_h=None, far_left=None, far_right=None
|
214 |
+
):
|
215 |
+
height, width = img_shape[:2]
|
216 |
+
mask = np.ones((height, width, 1), dtype=dtype)
|
217 |
+
nose = 33
|
218 |
+
nose_point_h = landmark_list[nose, 1]
|
219 |
+
if cut_h is None:
|
220 |
+
cut_h = nose_point_h
|
221 |
+
|
222 |
+
if far_right is None and far_left is None:
|
223 |
+
far_left = int(np.argmin(landmark_list[:, 0]))
|
224 |
+
far_right = int(np.argmax(landmark_list[:, 0]))
|
225 |
+
|
226 |
+
left_up_point = np.int32([landmark_list[far_left][0], cut_h])
|
227 |
+
left_down_point = np.int32([landmark_list[far_left][0], height])
|
228 |
+
|
229 |
+
right_up_point = np.int32([landmark_list[far_right][0], cut_h])
|
230 |
+
right_down_point = np.int32([landmark_list[far_right][0], height])
|
231 |
+
else:
|
232 |
+
left_up_point = np.int32([far_left, cut_h])
|
233 |
+
left_down_point = np.int32([far_left, height])
|
234 |
+
|
235 |
+
right_up_point = np.int32([far_right, cut_h])
|
236 |
+
right_down_point = np.int32([far_right, height])
|
237 |
+
|
238 |
+
cv2.drawContours(
|
239 |
+
mask,
|
240 |
+
np.int32(
|
241 |
+
[
|
242 |
+
[
|
243 |
+
left_up_point,
|
244 |
+
left_down_point,
|
245 |
+
right_up_point,
|
246 |
+
right_down_point,
|
247 |
+
left_up_point,
|
248 |
+
right_up_point,
|
249 |
+
left_down_point,
|
250 |
+
right_down_point,
|
251 |
+
]
|
252 |
+
]
|
253 |
+
),
|
254 |
+
-1,
|
255 |
+
color=(0),
|
256 |
+
thickness=cv2.FILLED,
|
257 |
+
)
|
258 |
+
|
259 |
+
return mask
|
260 |
+
|
261 |
+
|
262 |
+
def face_mask_box_pose_big_cover_nose(img_shape, landmark_list, dtype="uint8"):
|
263 |
+
height, width = img_shape[:2]
|
264 |
+
|
265 |
+
middle_nose_point = 29
|
266 |
+
|
267 |
+
cut_h = landmark_list[middle_nose_point, 1]
|
268 |
+
|
269 |
+
mask = np.ones((height, width, 1), dtype=dtype)
|
270 |
+
|
271 |
+
far_left = int(np.argmin(landmark_list[:, 0]))
|
272 |
+
far_right = int(np.argmax(landmark_list[:, 0]))
|
273 |
+
|
274 |
+
left_up_point = np.int32([landmark_list[far_left][0], cut_h])
|
275 |
+
left_down_point = np.int32([landmark_list[far_left][0], height])
|
276 |
+
|
277 |
+
right_up_point = np.int32([landmark_list[far_right][0], cut_h])
|
278 |
+
right_down_point = np.int32([landmark_list[far_right][0], height])
|
279 |
+
|
280 |
+
cv2.drawContours(
|
281 |
+
mask,
|
282 |
+
np.int32(
|
283 |
+
[
|
284 |
+
[
|
285 |
+
left_up_point,
|
286 |
+
left_down_point,
|
287 |
+
right_up_point,
|
288 |
+
right_down_point,
|
289 |
+
left_up_point,
|
290 |
+
right_up_point,
|
291 |
+
left_down_point,
|
292 |
+
right_down_point,
|
293 |
+
]
|
294 |
+
]
|
295 |
+
),
|
296 |
+
-1,
|
297 |
+
color=(0),
|
298 |
+
thickness=cv2.FILLED,
|
299 |
+
)
|
300 |
+
|
301 |
+
return mask
|
302 |
+
|
303 |
+
|
304 |
+
def face_mask_square(img_shape, landmark_list, dtype="uint8"):
|
305 |
+
height, width = img_shape[:2]
|
306 |
+
|
307 |
+
mask = np.ones((height, width, 1), dtype=dtype)
|
308 |
+
|
309 |
+
far_left = np.min(landmark_list[:, 0])
|
310 |
+
far_right = np.max(landmark_list[:, 1])
|
311 |
+
print("far_left {}, far_right {}".format(far_left, far_right))
|
312 |
+
|
313 |
+
left_p = 2
|
314 |
+
right_p = 14
|
315 |
+
|
316 |
+
print(
|
317 |
+
"left_p {}, right_p {}".format(
|
318 |
+
landmark_list[left_p][0], landmark_list[right_p][0]
|
319 |
+
)
|
320 |
+
)
|
321 |
+
|
322 |
+
cv2.drawContours(
|
323 |
+
mask,
|
324 |
+
np.int32(
|
325 |
+
[
|
326 |
+
[
|
327 |
+
landmark_list[left_p],
|
328 |
+
[landmark_list[left_p][0], height],
|
329 |
+
landmark_list[right_p],
|
330 |
+
[landmark_list[right_p][0], height],
|
331 |
+
landmark_list[left_p],
|
332 |
+
landmark_list[right_p],
|
333 |
+
[landmark_list[left_p][0], height],
|
334 |
+
[landmark_list[right_p][0], height],
|
335 |
+
]
|
336 |
+
]
|
337 |
+
),
|
338 |
+
-1,
|
339 |
+
color=(0),
|
340 |
+
thickness=cv2.FILLED,
|
341 |
+
)
|
342 |
+
|
343 |
+
return mask
|
344 |
+
|
345 |
+
|
346 |
+
" Used for half face "
|
347 |
+
|
348 |
+
|
349 |
+
def bbox2mask(img_shape, bbox, dtype="uint8"):
|
350 |
+
"""Generate mask in ndarray from bbox.
|
351 |
+
|
352 |
+
The returned mask has the shape of (h, w, 1). '1' indicates the
|
353 |
+
hole and '0' indicates the valid regions.
|
354 |
+
|
355 |
+
We prefer to use `uint8` as the data type of masks, which may be different
|
356 |
+
from other codes in the community.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
img_shape (tuple[int]): The size of the image.
|
360 |
+
bbox (tuple[int]): Configuration tuple, (top, left, height, width)
|
361 |
+
dtype (str): Indicate the data type of returned masks. Default: 'uint8'
|
362 |
+
|
363 |
+
Return:
|
364 |
+
numpy.ndarray: Mask in the shape of (h, w, 1).
|
365 |
+
"""
|
366 |
+
|
367 |
+
height, width = img_shape[:2]
|
368 |
+
|
369 |
+
mask = np.ones((height, width, 1), dtype=dtype)
|
370 |
+
mask[bbox[0] : bbox[0] + bbox[2], bbox[1] : bbox[1] + bbox[3], :] = 0.0
|
371 |
+
|
372 |
+
return mask
|
373 |
+
|
374 |
+
|
375 |
+
def face_mask_cheeks(img_shape, landmark_list, dtype="uint8"):
|
376 |
+
height, width = img_shape[:2]
|
377 |
+
mask = np.ones((height, width, 1), dtype=dtype)
|
378 |
+
|
379 |
+
middle_nose_point = 29
|
380 |
+
nose = 33
|
381 |
+
cut_h = int(landmark_list[middle_nose_point, 1])
|
382 |
+
|
383 |
+
far_left = int(np.argmin(landmark_list[:, 0]))
|
384 |
+
far_right = int(np.argmax(landmark_list[:, 0]))
|
385 |
+
|
386 |
+
left_up_point = np.int32([landmark_list[far_left][0], cut_h])
|
387 |
+
left_down_point = np.int32([landmark_list[far_left][0], height])
|
388 |
+
|
389 |
+
right_up_point = np.int32([landmark_list[far_right][0], cut_h])
|
390 |
+
right_down_point = np.int32([landmark_list[far_right][0], height])
|
391 |
+
|
392 |
+
cv2.drawContours(
|
393 |
+
mask,
|
394 |
+
np.int32(
|
395 |
+
[
|
396 |
+
[
|
397 |
+
left_up_point,
|
398 |
+
left_down_point,
|
399 |
+
right_up_point,
|
400 |
+
right_down_point,
|
401 |
+
left_up_point,
|
402 |
+
right_up_point,
|
403 |
+
left_down_point,
|
404 |
+
right_down_point,
|
405 |
+
]
|
406 |
+
]
|
407 |
+
),
|
408 |
+
-1,
|
409 |
+
color=(0),
|
410 |
+
thickness=cv2.FILLED,
|
411 |
+
)
|
412 |
+
|
413 |
+
# Calculate the bounding box coordinates for the nose
|
414 |
+
nose_jaw_dist = (
|
415 |
+
abs(landmark_list[2][0] - landmark_list[middle_nose_point][0]) * 0.10
|
416 |
+
) # 1, 15
|
417 |
+
# nose_right_dist = (landmark_list[middle_nose_point][0] - landmark_list[1][0]) * 0.10
|
418 |
+
# nose_left_dist = (landmark_list[15][0] - landmark_list[middle_nose_point][0]) * 0.10
|
419 |
+
#
|
420 |
+
|
421 |
+
nose_min_x = int(landmark_list[31][0] - nose_jaw_dist)
|
422 |
+
nose_max_x = int(landmark_list[35][0] + nose_jaw_dist)
|
423 |
+
# nose_min_x = int(landmark_list[31][0] - nose_right_dist)
|
424 |
+
# nose_max_x = int(landmark_list[35][0] + nose_left_dist)
|
425 |
+
nose_min_y = cut_h
|
426 |
+
nose_max_y = int(landmark_list[nose, 1])
|
427 |
+
|
428 |
+
# Clear the nose area from the mask using a rectangle
|
429 |
+
mask_nose = np.ones((height, width, 1), dtype=dtype)
|
430 |
+
cv2.rectangle(
|
431 |
+
mask_nose,
|
432 |
+
(nose_min_x, nose_min_y),
|
433 |
+
(nose_max_x, nose_max_y),
|
434 |
+
color=(0),
|
435 |
+
thickness=cv2.FILLED,
|
436 |
+
)
|
437 |
+
|
438 |
+
mask_nose = 1 - mask_nose
|
439 |
+
mask = mask + mask_nose
|
440 |
+
|
441 |
+
return mask
|
442 |
+
|
443 |
+
|
444 |
+
def face_mask_cheeks_batch(
|
445 |
+
img_shape, landmark_list, dtype="uint8", box_expand=0.0, show_nose=True
|
446 |
+
):
|
447 |
+
height, width = img_shape[:2]
|
448 |
+
|
449 |
+
# Handle both single and multiple landmarks
|
450 |
+
if len(landmark_list.shape) == 2:
|
451 |
+
landmark_list = landmark_list[None, ...] # Add batch dimension
|
452 |
+
num_frames = landmark_list.shape[0]
|
453 |
+
|
454 |
+
# Initialize masks for all frames
|
455 |
+
masks = np.ones((num_frames, height, width), dtype=dtype)
|
456 |
+
|
457 |
+
for i in range(num_frames):
|
458 |
+
landmarks = landmark_list[i]
|
459 |
+
middle_nose_point = 29
|
460 |
+
nose = 33
|
461 |
+
cut_h = int(landmarks[middle_nose_point, 1])
|
462 |
+
|
463 |
+
# Add height expansion
|
464 |
+
if box_expand > 0:
|
465 |
+
cut_h = max(0, cut_h - int(box_expand * height))
|
466 |
+
|
467 |
+
far_left = int(np.argmin(landmarks[:, 0]))
|
468 |
+
far_right = int(np.argmax(landmarks[:, 0]))
|
469 |
+
|
470 |
+
left_up_point = np.int32([landmarks[far_left][0], cut_h])
|
471 |
+
left_down_point = np.int32([landmarks[far_left][0], height])
|
472 |
+
|
473 |
+
right_up_point = np.int32([landmarks[far_right][0], cut_h])
|
474 |
+
right_down_point = np.int32([landmarks[far_right][0], height])
|
475 |
+
|
476 |
+
cv2.drawContours(
|
477 |
+
masks[i],
|
478 |
+
np.int32(
|
479 |
+
[
|
480 |
+
[
|
481 |
+
left_up_point,
|
482 |
+
left_down_point,
|
483 |
+
right_up_point,
|
484 |
+
right_down_point,
|
485 |
+
left_up_point,
|
486 |
+
right_up_point,
|
487 |
+
left_down_point,
|
488 |
+
right_down_point,
|
489 |
+
]
|
490 |
+
]
|
491 |
+
),
|
492 |
+
-1,
|
493 |
+
color=(0),
|
494 |
+
thickness=cv2.FILLED,
|
495 |
+
)
|
496 |
+
|
497 |
+
if show_nose:
|
498 |
+
# Calculate the bounding box coordinates for the nose
|
499 |
+
nose_jaw_dist = (
|
500 |
+
abs(landmarks[2][0] - landmarks[middle_nose_point][0]) * 0.10
|
501 |
+
) # 1, 15
|
502 |
+
|
503 |
+
nose_min_x = int(landmarks[31][0] - nose_jaw_dist)
|
504 |
+
nose_max_x = int(landmarks[35][0] + nose_jaw_dist)
|
505 |
+
nose_min_y = cut_h
|
506 |
+
nose_max_y = int(landmarks[nose, 1])
|
507 |
+
|
508 |
+
# Clear the nose area from the mask using a rectangle
|
509 |
+
mask_nose = np.ones((height, width), dtype=dtype)
|
510 |
+
cv2.rectangle(
|
511 |
+
mask_nose,
|
512 |
+
(nose_min_x, nose_min_y),
|
513 |
+
(nose_max_x, nose_max_y),
|
514 |
+
color=(0),
|
515 |
+
thickness=cv2.FILLED,
|
516 |
+
)
|
517 |
+
|
518 |
+
mask_nose = 1 - mask_nose
|
519 |
+
masks[i] = masks[i] + mask_nose
|
520 |
+
|
521 |
+
# If input was single frame, return single mask
|
522 |
+
if landmark_list.shape[0] == 1:
|
523 |
+
return masks[0]
|
524 |
+
|
525 |
+
return 1 - torch.from_numpy(masks)
|
sgm/data/video_datamodule_latent.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional
|
2 |
+
|
3 |
+
from pytorch_lightning import LightningDataModule
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
|
7 |
+
import sys
|
8 |
+
import pyrootutils
|
9 |
+
|
10 |
+
root = pyrootutils.setup_root(__file__, pythonpath=True)
|
11 |
+
sys.path.append(root)
|
12 |
+
from sgm.data.video_dataset_latent import VideoDataset
|
13 |
+
|
14 |
+
|
15 |
+
class VideoDataModule(LightningDataModule):
|
16 |
+
"""
|
17 |
+
A DataModule implements 5 key methods:
|
18 |
+
|
19 |
+
def prepare_data(self):
|
20 |
+
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
|
21 |
+
# download data, pre-process, split, save to disk, etc...
|
22 |
+
def setup(self, stage):
|
23 |
+
# things to do on every process in DDP
|
24 |
+
# load data, set variables, etc...
|
25 |
+
def train_dataloader(self):
|
26 |
+
# return train dataloader
|
27 |
+
def val_dataloader(self):
|
28 |
+
# return validation dataloader
|
29 |
+
def test_dataloader(self):
|
30 |
+
# return test dataloader
|
31 |
+
def teardown(self):
|
32 |
+
# called on every process in DDP
|
33 |
+
# clean up after fit or test
|
34 |
+
|
35 |
+
This allows you to share a full dataset without explaining how to download,
|
36 |
+
split, transform and process the data.
|
37 |
+
|
38 |
+
Read the docs:
|
39 |
+
https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
train: DictConfig,
|
45 |
+
validation: Optional[DictConfig] = None,
|
46 |
+
test: Optional[DictConfig] = None,
|
47 |
+
skip_val_loader: bool = False,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
# this line allows to access init params with 'self.hparams' attribute
|
52 |
+
# also ensures init params will be stored in ckpt
|
53 |
+
self.train_config = train
|
54 |
+
assert "datapipeline" in self.train_config and "loader" in self.train_config, (
|
55 |
+
"train config requires the fields `datapipeline` and `loader`"
|
56 |
+
)
|
57 |
+
|
58 |
+
self.val_config = validation
|
59 |
+
if not skip_val_loader:
|
60 |
+
if self.val_config is not None:
|
61 |
+
assert (
|
62 |
+
"datapipeline" in self.val_config and "loader" in self.val_config
|
63 |
+
), "validation config requires the fields `datapipeline` and `loader`"
|
64 |
+
else:
|
65 |
+
print(
|
66 |
+
"Warning: No Validation datapipeline defined, using that one from training"
|
67 |
+
)
|
68 |
+
self.val_config = train
|
69 |
+
|
70 |
+
self.test_config = test
|
71 |
+
if self.test_config is not None:
|
72 |
+
assert (
|
73 |
+
"datapipeline" in self.test_config and "loader" in self.test_config
|
74 |
+
), "test config requires the fields `datapipeline` and `loader`"
|
75 |
+
|
76 |
+
def setup(self, stage: Optional[str] = None):
|
77 |
+
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
78 |
+
|
79 |
+
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
|
80 |
+
careful not to execute things like random split twice!
|
81 |
+
"""
|
82 |
+
print("Preparing datasets")
|
83 |
+
|
84 |
+
self.train_datapipeline = VideoDataset(**self.train_config.datapipeline)
|
85 |
+
if self.val_config:
|
86 |
+
self.val_datapipeline = VideoDataset(**self.val_config.datapipeline)
|
87 |
+
if self.test_config:
|
88 |
+
self.test_datapipeline = VideoDataset(**self.test_config.datapipeline)
|
89 |
+
|
90 |
+
def train_dataloader(self):
|
91 |
+
return DataLoader(self.train_datapipeline, **self.train_config.loader)
|
92 |
+
|
93 |
+
def val_dataloader(self):
|
94 |
+
if self.val_datapipeline:
|
95 |
+
return DataLoader(self.val_datapipeline, **self.val_config.loader)
|
96 |
+
else:
|
97 |
+
return None
|
98 |
+
|
99 |
+
def test_dataloader(self):
|
100 |
+
if self.test_datapipeline:
|
101 |
+
return DataLoader(self.test_datapipeline, **self.test_config.loader)
|
102 |
+
else:
|
103 |
+
return None
|
104 |
+
|
105 |
+
def teardown(self, stage: Optional[str] = None):
|
106 |
+
"""Clean up after fit or test."""
|
107 |
+
pass
|
108 |
+
|
109 |
+
def state_dict(self):
|
110 |
+
"""Extra things to save to checkpoint."""
|
111 |
+
return {}
|
112 |
+
|
113 |
+
def load_state_dict(self, state_dict: Dict[str, Any]):
|
114 |
+
"""Things to do when loading checkpoint."""
|
115 |
+
pass
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
import hydra
|
120 |
+
import omegaconf
|
121 |
+
import pyrootutils
|
122 |
+
import cv2
|
123 |
+
|
124 |
+
root = pyrootutils.setup_root(__file__, pythonpath=True)
|
125 |
+
cfg = omegaconf.OmegaConf.load(
|
126 |
+
root / "configs" / "datamodule" / "image_datamodule.yaml"
|
127 |
+
)
|
128 |
+
# cfg.data_dir = str(root / "data")
|
129 |
+
data = hydra.utils.instantiate(cfg)
|
130 |
+
data.prepare_data()
|
131 |
+
data.setup()
|
132 |
+
print(data.data_train.__getitem__(0)[0].shape)
|
133 |
+
batch = next(iter(data.train_dataloader()))
|
134 |
+
identity, target = batch
|
135 |
+
image_identity = (identity[0].permute(1, 2, 0).numpy() + 1) / 2 * 255
|
136 |
+
image_other = (target[0].permute(1, 2, 0).numpy() + 1) / 2 * 255
|
137 |
+
cv2.imwrite("image_identity.png", image_identity[:, :, ::-1])
|
138 |
+
cv2.imwrite("image_other.png", image_other[:, :, ::-1])
|
sgm/data/video_dataset_latent.py
ADDED
@@ -0,0 +1,780 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
from functools import partial
|
4 |
+
from torch.utils.data import Dataset, WeightedRandomSampler
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch
|
7 |
+
import math
|
8 |
+
import decord
|
9 |
+
from einops import rearrange
|
10 |
+
from more_itertools import sliding_window
|
11 |
+
from omegaconf import ListConfig
|
12 |
+
import torchaudio
|
13 |
+
import soundfile as sf
|
14 |
+
from torchvision.transforms import RandomHorizontalFlip
|
15 |
+
from audiomentations import Compose, AddGaussianNoise, PitchShift
|
16 |
+
from safetensors.torch import load_file
|
17 |
+
from tqdm import tqdm
|
18 |
+
import cv2
|
19 |
+
from sgm.data.data_utils import (
|
20 |
+
create_masks_from_landmarks_full_size,
|
21 |
+
create_face_mask_from_landmarks,
|
22 |
+
create_masks_from_landmarks_box,
|
23 |
+
create_masks_from_landmarks_mouth,
|
24 |
+
)
|
25 |
+
from sgm.data.mask import face_mask_cheeks_batch
|
26 |
+
|
27 |
+
torchaudio.set_audio_backend("sox_io")
|
28 |
+
decord.bridge.set_bridge("torch")
|
29 |
+
|
30 |
+
|
31 |
+
def exists(x):
|
32 |
+
return x is not None
|
33 |
+
|
34 |
+
|
35 |
+
def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
|
36 |
+
len_file = audio.shape[-1]
|
37 |
+
|
38 |
+
if max_len_sec or max_len_raw:
|
39 |
+
max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
|
40 |
+
if len_file < int(max_len):
|
41 |
+
extened_wav = torch.nn.functional.pad(
|
42 |
+
audio, (0, int(max_len) - len_file), "constant"
|
43 |
+
)
|
44 |
+
else:
|
45 |
+
extened_wav = audio[:, : int(max_len)]
|
46 |
+
else:
|
47 |
+
extened_wav = audio
|
48 |
+
|
49 |
+
return extened_wav
|
50 |
+
|
51 |
+
|
52 |
+
# Similar to regular video dataset but trades flexibility for speed
|
53 |
+
class VideoDataset(Dataset):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
filelist,
|
57 |
+
resize_size=None,
|
58 |
+
audio_folder="Audio",
|
59 |
+
video_folder="CroppedVideos",
|
60 |
+
emotions_folder="emotions",
|
61 |
+
landmarks_folder=None,
|
62 |
+
audio_emb_folder=None,
|
63 |
+
video_extension=".avi",
|
64 |
+
audio_extension=".wav",
|
65 |
+
audio_rate=16000,
|
66 |
+
latent_folder=None,
|
67 |
+
audio_in_video=False,
|
68 |
+
fps=25,
|
69 |
+
num_frames=5,
|
70 |
+
need_cond=True,
|
71 |
+
step=1,
|
72 |
+
mode="prediction",
|
73 |
+
scale_audio=False,
|
74 |
+
augment=False,
|
75 |
+
augment_audio=False,
|
76 |
+
use_latent=False,
|
77 |
+
latent_type="stable",
|
78 |
+
latent_scale=1, # For backwards compatibility
|
79 |
+
from_audio_embedding=False,
|
80 |
+
load_all_possible_indexes=False,
|
81 |
+
audio_emb_type="wavlm",
|
82 |
+
cond_noise=[-3.0, 0.5],
|
83 |
+
motion_id=255.0,
|
84 |
+
data_mean=None,
|
85 |
+
data_std=None,
|
86 |
+
use_latent_condition=False,
|
87 |
+
skip_frames=0,
|
88 |
+
get_separate_id=False,
|
89 |
+
virtual_increase=1,
|
90 |
+
filter_by_length=False,
|
91 |
+
select_randomly=False,
|
92 |
+
balance_datasets=True,
|
93 |
+
use_emotions=False,
|
94 |
+
get_original_frames=False,
|
95 |
+
add_extra_audio_emb=False,
|
96 |
+
expand_box=0.0,
|
97 |
+
nose_index=28,
|
98 |
+
what_mask="full",
|
99 |
+
get_masks=False,
|
100 |
+
):
|
101 |
+
self.audio_folder = audio_folder
|
102 |
+
self.from_audio_embedding = from_audio_embedding
|
103 |
+
self.audio_emb_type = audio_emb_type
|
104 |
+
self.cond_noise = cond_noise
|
105 |
+
self.latent_condition = use_latent_condition
|
106 |
+
precomputed_latent = latent_type
|
107 |
+
self.audio_emb_folder = (
|
108 |
+
audio_emb_folder if audio_emb_folder is not None else audio_folder
|
109 |
+
)
|
110 |
+
self.skip_frames = skip_frames
|
111 |
+
self.get_separate_id = get_separate_id
|
112 |
+
self.fps = fps
|
113 |
+
self.virtual_increase = virtual_increase
|
114 |
+
self.select_randomly = select_randomly
|
115 |
+
self.use_emotions = use_emotions
|
116 |
+
self.emotions_folder = emotions_folder
|
117 |
+
self.get_original_frames = get_original_frames
|
118 |
+
self.add_extra_audio_emb = add_extra_audio_emb
|
119 |
+
self.expand_box = expand_box
|
120 |
+
self.nose_index = nose_index
|
121 |
+
self.landmarks_folder = landmarks_folder
|
122 |
+
self.what_mask = what_mask
|
123 |
+
self.get_masks = get_masks
|
124 |
+
|
125 |
+
assert not (exists(data_mean) ^ exists(data_std)), (
|
126 |
+
"Both data_mean and data_std should be provided"
|
127 |
+
)
|
128 |
+
|
129 |
+
if data_mean is not None:
|
130 |
+
data_mean = rearrange(torch.as_tensor(data_mean), "c -> c () () ()")
|
131 |
+
data_std = rearrange(torch.as_tensor(data_std), "c -> c () () ()")
|
132 |
+
self.data_mean = data_mean
|
133 |
+
self.data_std = data_std
|
134 |
+
self.motion_id = motion_id
|
135 |
+
self.latent_folder = (
|
136 |
+
latent_folder if latent_folder is not None else video_folder
|
137 |
+
)
|
138 |
+
self.audio_in_video = audio_in_video
|
139 |
+
|
140 |
+
self.filelist = []
|
141 |
+
self.audio_filelist = []
|
142 |
+
self.landmark_filelist = [] if get_masks else None
|
143 |
+
with open(filelist, "r") as files:
|
144 |
+
for f in files.readlines():
|
145 |
+
f = f.rstrip()
|
146 |
+
|
147 |
+
audio_path = f.replace(video_folder, audio_folder).replace(
|
148 |
+
video_extension, audio_extension
|
149 |
+
)
|
150 |
+
|
151 |
+
self.filelist += [f]
|
152 |
+
self.audio_filelist += [audio_path]
|
153 |
+
if self.get_masks:
|
154 |
+
landmark_path = f.replace(video_folder, landmarks_folder).replace(
|
155 |
+
video_extension, ".npy"
|
156 |
+
)
|
157 |
+
self.landmark_filelist += [landmark_path]
|
158 |
+
|
159 |
+
self.resize_size = resize_size
|
160 |
+
if use_latent and not precomputed_latent:
|
161 |
+
self.resize_size *= 4 if latent_type in ["stable", "ldm"] else 8
|
162 |
+
self.scale_audio = scale_audio
|
163 |
+
self.step = step
|
164 |
+
self.use_latent = use_latent
|
165 |
+
self.precomputed_latent = precomputed_latent
|
166 |
+
self.latent_type = latent_type
|
167 |
+
self.latent_scale = latent_scale
|
168 |
+
self.video_ext = video_extension
|
169 |
+
self.video_folder = video_folder
|
170 |
+
|
171 |
+
self.augment = augment
|
172 |
+
self.maybe_augment = RandomHorizontalFlip(p=0.5) if augment else lambda x: x
|
173 |
+
self.maybe_augment_audio = (
|
174 |
+
Compose(
|
175 |
+
[
|
176 |
+
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.002, p=0.25),
|
177 |
+
# TimeStretch(min_rate=0.8, max_rate=1.25, p=0.3),
|
178 |
+
PitchShift(min_semitones=-1, max_semitones=1, p=0.25),
|
179 |
+
# Shift(min_fraction=-0.5, max_fraction=0.5, p=0.333),
|
180 |
+
]
|
181 |
+
)
|
182 |
+
if augment_audio
|
183 |
+
else lambda x, sample_rate: x
|
184 |
+
)
|
185 |
+
self.maybe_augment_audio = partial(
|
186 |
+
self.maybe_augment_audio, sample_rate=audio_rate
|
187 |
+
)
|
188 |
+
|
189 |
+
self.mode = mode
|
190 |
+
if mode == "interpolation":
|
191 |
+
need_cond = False # Interpolation does not need condition as first and last frame becomes the condition
|
192 |
+
self.need_cond = need_cond # If need cond will extract one more frame than the number of frames
|
193 |
+
if get_separate_id:
|
194 |
+
self.need_cond = True
|
195 |
+
# It is used for the conditional model when the condition is not on the temporal dimension
|
196 |
+
num_frames = num_frames if not self.need_cond else num_frames + 1
|
197 |
+
|
198 |
+
vr = decord.VideoReader(self.filelist[0])
|
199 |
+
self.video_rate = math.ceil(vr.get_avg_fps())
|
200 |
+
print(f"Video rate: {self.video_rate}")
|
201 |
+
self.audio_rate = audio_rate
|
202 |
+
a2v_ratio = fps / float(self.audio_rate)
|
203 |
+
self.samples_per_frame = math.ceil(1 / a2v_ratio)
|
204 |
+
|
205 |
+
if get_separate_id:
|
206 |
+
assert mode == "prediction", (
|
207 |
+
"Separate identity frame is only supported for prediction mode"
|
208 |
+
)
|
209 |
+
# No need for extra frame if we are getting a separate identity frame
|
210 |
+
self.need_cond = True
|
211 |
+
num_frames -= 1
|
212 |
+
self.num_frames = num_frames
|
213 |
+
self.load_all_possible_indexes = load_all_possible_indexes
|
214 |
+
if load_all_possible_indexes:
|
215 |
+
self._indexes = self._get_indexes(
|
216 |
+
self.filelist, self.audio_filelist, self.landmark_filelist
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
if filter_by_length:
|
220 |
+
self._indexes = self.filter_by_length(
|
221 |
+
self.filelist, self.audio_filelist, self.landmark_filelist
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
if self.get_masks:
|
225 |
+
self._indexes = list(
|
226 |
+
zip(self.filelist, self.audio_filelist, self.landmark_filelist)
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
self._indexes = list(
|
230 |
+
zip(
|
231 |
+
self.filelist,
|
232 |
+
self.audio_filelist,
|
233 |
+
[None] * len(self.filelist),
|
234 |
+
)
|
235 |
+
)
|
236 |
+
|
237 |
+
self.balance_datasets = balance_datasets
|
238 |
+
if self.balance_datasets:
|
239 |
+
self.weights = self._calculate_weights()
|
240 |
+
self.sampler = WeightedRandomSampler(
|
241 |
+
self.weights, num_samples=len(self._indexes), replacement=True
|
242 |
+
)
|
243 |
+
|
244 |
+
def __len__(self):
|
245 |
+
return len(self._indexes) * self.virtual_increase
|
246 |
+
|
247 |
+
def _load_landmarks(self, filename, original_size, target_size, indexes):
|
248 |
+
landmarks = np.load(filename, allow_pickle=True)[indexes, :]
|
249 |
+
if self.what_mask == "full":
|
250 |
+
mask = create_masks_from_landmarks_full_size(
|
251 |
+
landmarks,
|
252 |
+
original_size[0],
|
253 |
+
original_size[1],
|
254 |
+
offset=self.expand_box,
|
255 |
+
nose_index=self.nose_index,
|
256 |
+
)
|
257 |
+
elif self.what_mask == "box":
|
258 |
+
mask = create_masks_from_landmarks_box(
|
259 |
+
landmarks,
|
260 |
+
(original_size[0], original_size[1]),
|
261 |
+
box_expand=self.expand_box,
|
262 |
+
nose_index=self.nose_index,
|
263 |
+
)
|
264 |
+
elif self.what_mask == "heart":
|
265 |
+
mask = face_mask_cheeks_batch(
|
266 |
+
original_size, landmarks, box_expand=0.0, show_nose=True
|
267 |
+
)
|
268 |
+
elif self.what_mask == "mouth":
|
269 |
+
mask = create_masks_from_landmarks_mouth(
|
270 |
+
landmarks,
|
271 |
+
(original_size[0], original_size[1]),
|
272 |
+
box_expand=0.01,
|
273 |
+
nose_index=self.nose_index,
|
274 |
+
)
|
275 |
+
else:
|
276 |
+
mask = create_face_mask_from_landmarks(
|
277 |
+
landmarks, original_size[0], original_size[1], mask_expand=0.05
|
278 |
+
)
|
279 |
+
# Interpolate the mask to the target size
|
280 |
+
mask = F.interpolate(
|
281 |
+
mask.unsqueeze(1).float(), size=target_size, mode="nearest"
|
282 |
+
)
|
283 |
+
|
284 |
+
return mask, landmarks
|
285 |
+
|
286 |
+
def get_emotions(self, video_file, video_indexes):
|
287 |
+
emotions_path = video_file.replace(
|
288 |
+
self.video_folder, self.emotions_folder
|
289 |
+
).replace(self.video_ext, ".pt")
|
290 |
+
emotions = torch.load(emotions_path)
|
291 |
+
return (
|
292 |
+
emotions["valence"][video_indexes],
|
293 |
+
emotions["arousal"][video_indexes],
|
294 |
+
emotions["labels"][video_indexes],
|
295 |
+
)
|
296 |
+
|
297 |
+
def get_frame_indices(self, total_video_frames, select_randomly=False, start_idx=0):
|
298 |
+
if select_randomly:
|
299 |
+
# Randomly select self.num_frames indices from the available range
|
300 |
+
available_indices = list(range(start_idx, total_video_frames))
|
301 |
+
if len(available_indices) < self.num_frames:
|
302 |
+
raise ValueError(
|
303 |
+
"Not enough frames in the video to sample with given parameters."
|
304 |
+
)
|
305 |
+
indexes = random.sample(available_indices, self.num_frames)
|
306 |
+
return sorted(indexes) # Sort to maintain temporal order
|
307 |
+
else:
|
308 |
+
# Calculate the maximum possible start index
|
309 |
+
max_start_idx = total_video_frames - (
|
310 |
+
(self.num_frames - 1) * (self.skip_frames + 1) + 1
|
311 |
+
)
|
312 |
+
|
313 |
+
# Generate a random start index
|
314 |
+
if max_start_idx > 0:
|
315 |
+
start_idx = np.random.randint(start_idx, max_start_idx)
|
316 |
+
else:
|
317 |
+
raise ValueError(
|
318 |
+
"Not enough frames in the video to sample with given parameters."
|
319 |
+
)
|
320 |
+
|
321 |
+
# Generate the indices
|
322 |
+
indexes = [
|
323 |
+
start_idx + i * (self.skip_frames + 1) for i in range(self.num_frames)
|
324 |
+
]
|
325 |
+
|
326 |
+
return indexes
|
327 |
+
|
328 |
+
def _load_audio(self, filename, max_len_sec, start=None, indexes=None):
|
329 |
+
audio, sr = sf.read(
|
330 |
+
filename,
|
331 |
+
start=math.ceil(start * self.audio_rate),
|
332 |
+
frames=math.ceil(self.audio_rate * max_len_sec),
|
333 |
+
always_2d=True,
|
334 |
+
) # e.g (16000, 1)
|
335 |
+
audio = audio.T # (1, 16000)
|
336 |
+
assert sr == self.audio_rate, (
|
337 |
+
f"Audio rate is {sr} but should be {self.audio_rate}"
|
338 |
+
)
|
339 |
+
audio = audio.mean(0, keepdims=True)
|
340 |
+
audio = self.maybe_augment_audio(audio)
|
341 |
+
audio = torch.from_numpy(audio).float()
|
342 |
+
# audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=self.audio_rate)
|
343 |
+
audio = trim_pad_audio(audio, self.audio_rate, max_len_sec=max_len_sec)
|
344 |
+
return audio[0]
|
345 |
+
|
346 |
+
def ensure_shape(self, tensors):
|
347 |
+
target_length = self.samples_per_frame
|
348 |
+
processed_tensors = []
|
349 |
+
for tensor in tensors:
|
350 |
+
current_length = tensor.shape[1]
|
351 |
+
diff = current_length - target_length
|
352 |
+
assert abs(diff) <= 5, (
|
353 |
+
f"Expected shape {target_length}, but got {current_length}"
|
354 |
+
)
|
355 |
+
if diff < 0:
|
356 |
+
# Calculate how much padding is needed
|
357 |
+
padding_needed = target_length - current_length
|
358 |
+
# Pad the tensor
|
359 |
+
padded_tensor = F.pad(tensor, (0, padding_needed))
|
360 |
+
processed_tensors.append(padded_tensor)
|
361 |
+
elif diff > 0:
|
362 |
+
# Trim the tensor
|
363 |
+
trimmed_tensor = tensor[:, :target_length]
|
364 |
+
processed_tensors.append(trimmed_tensor)
|
365 |
+
else:
|
366 |
+
# If it's already the correct size
|
367 |
+
processed_tensors.append(tensor)
|
368 |
+
return torch.cat(processed_tensors)
|
369 |
+
|
370 |
+
def normalize_latents(self, latents):
|
371 |
+
if self.data_mean is not None:
|
372 |
+
# Normalize latents to 0 mean and 0.5 std
|
373 |
+
latents = ((latents - self.data_mean) / self.data_std) * 0.5
|
374 |
+
return latents
|
375 |
+
|
376 |
+
def convert_indexes(self, indexes_25fps, fps_from=25, fps_to=60):
|
377 |
+
ratio = fps_to / fps_from
|
378 |
+
indexes_60fps = [int(index * ratio) for index in indexes_25fps]
|
379 |
+
return indexes_60fps
|
380 |
+
|
381 |
+
def _get_frames_and_audio(self, idx):
|
382 |
+
if self.load_all_possible_indexes:
|
383 |
+
indexes, video_file, audio_file, land_file = self._indexes[idx]
|
384 |
+
if self.audio_in_video:
|
385 |
+
vr = decord.AVReader(video_file, sample_rate=self.audio_rate)
|
386 |
+
else:
|
387 |
+
vr = decord.VideoReader(video_file)
|
388 |
+
len_video = len(vr)
|
389 |
+
if "AA_processed" in video_file or "1000actors_nsv" in video_file:
|
390 |
+
len_video *= 25 / 60
|
391 |
+
len_video = int(len_video)
|
392 |
+
else:
|
393 |
+
video_file, audio_file, land_file = self._indexes[idx]
|
394 |
+
if self.audio_in_video:
|
395 |
+
vr = decord.AVReader(video_file, sample_rate=self.audio_rate)
|
396 |
+
else:
|
397 |
+
vr = decord.VideoReader(video_file)
|
398 |
+
len_video = len(vr)
|
399 |
+
if "AA_processed" in video_file or "1000actors_nsv" in video_file:
|
400 |
+
len_video *= 25 / 60
|
401 |
+
len_video = int(len_video)
|
402 |
+
|
403 |
+
indexes = self.get_frame_indices(
|
404 |
+
len_video,
|
405 |
+
select_randomly=self.select_randomly,
|
406 |
+
start_idx=120 if "1000actors_nsv" in video_file else 0,
|
407 |
+
)
|
408 |
+
|
409 |
+
if self.get_separate_id:
|
410 |
+
id_idx = np.random.randint(0, len_video)
|
411 |
+
indexes.insert(0, id_idx)
|
412 |
+
|
413 |
+
if "AA_processed" in video_file or "1000actors_nsv" in video_file:
|
414 |
+
video_indexes = self.convert_indexes(indexes, fps_from=25, fps_to=60)
|
415 |
+
audio_file = audio_file.replace("_output_output", "")
|
416 |
+
if self.audio_emb_type == "wav2vec2" and "AA_processed" in video_file:
|
417 |
+
audio_path_extra = ".safetensors"
|
418 |
+
else:
|
419 |
+
audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors"
|
420 |
+
|
421 |
+
video_path_extra = f"_{self.latent_type}_512_latent.safetensors"
|
422 |
+
audio_path_extra_extra = (
|
423 |
+
".pt" if "AA_processed" in video_file else "_beats_emb.pt"
|
424 |
+
)
|
425 |
+
|
426 |
+
else:
|
427 |
+
video_indexes = indexes
|
428 |
+
audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors"
|
429 |
+
video_path_extra = f"_{self.latent_type}_512_latent.safetensors"
|
430 |
+
audio_path_extra_extra = "_beats_emb.pt"
|
431 |
+
|
432 |
+
emotions = None
|
433 |
+
if self.use_emotions:
|
434 |
+
emotions = self.get_emotions(video_file, video_indexes)
|
435 |
+
if self.get_separate_id:
|
436 |
+
emotions = (emotions[0][1:], emotions[1][1:], emotions[2][1:])
|
437 |
+
|
438 |
+
raw_audio = None
|
439 |
+
if self.audio_in_video:
|
440 |
+
raw_audio, frames_video = vr.get_batch(video_indexes)
|
441 |
+
raw_audio = rearrange(self.ensure_shape(raw_audio), "f s -> (f s)")
|
442 |
+
|
443 |
+
if self.use_latent and self.precomputed_latent:
|
444 |
+
latent_file = video_file.replace(self.video_ext, video_path_extra).replace(
|
445 |
+
self.video_folder, self.latent_folder
|
446 |
+
)
|
447 |
+
frames = load_file(latent_file)["latents"][video_indexes, :, :, :]
|
448 |
+
|
449 |
+
if frames.shape[-1] != 64:
|
450 |
+
print(f"Frames shape: {frames.shape}, video file: {video_file}")
|
451 |
+
|
452 |
+
frames = rearrange(frames, "t c h w -> c t h w") * self.latent_scale
|
453 |
+
frames = self.normalize_latents(frames)
|
454 |
+
else:
|
455 |
+
if self.audio_in_video:
|
456 |
+
frames = frames_video.permute(3, 0, 1, 2).float()
|
457 |
+
else:
|
458 |
+
frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float()
|
459 |
+
|
460 |
+
if raw_audio is None:
|
461 |
+
# Audio is not in video
|
462 |
+
raw_audio = self._load_audio(
|
463 |
+
audio_file,
|
464 |
+
max_len_sec=frames.shape[1] / self.fps,
|
465 |
+
start=indexes[0] / self.fps,
|
466 |
+
# indexes=indexes,
|
467 |
+
)
|
468 |
+
if not self.from_audio_embedding:
|
469 |
+
audio = raw_audio
|
470 |
+
audio_frames = rearrange(audio, "(f s) -> f s", s=self.samples_per_frame)
|
471 |
+
else:
|
472 |
+
audio = load_file(
|
473 |
+
audio_file.replace(self.audio_folder, self.audio_emb_folder).split(".")[
|
474 |
+
0
|
475 |
+
]
|
476 |
+
+ audio_path_extra
|
477 |
+
)["audio"]
|
478 |
+
audio_frames = audio[indexes, :]
|
479 |
+
if self.add_extra_audio_emb:
|
480 |
+
audio_extra = torch.load(
|
481 |
+
audio_file.replace(self.audio_folder, self.audio_emb_folder).split(
|
482 |
+
"."
|
483 |
+
)[0]
|
484 |
+
+ audio_path_extra_extra
|
485 |
+
)
|
486 |
+
audio_extra = audio_extra[indexes, :]
|
487 |
+
audio_frames = torch.cat([audio_frames, audio_extra], dim=-1)
|
488 |
+
|
489 |
+
audio_frames = (
|
490 |
+
audio_frames[1:] if self.need_cond else audio_frames
|
491 |
+
) # Remove audio of first frame
|
492 |
+
|
493 |
+
if self.get_original_frames:
|
494 |
+
original_frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float()
|
495 |
+
original_frames = self.scale_and_crop((original_frames / 255.0) * 2 - 1)
|
496 |
+
original_frames = (
|
497 |
+
original_frames[:, 1:] if self.need_cond else original_frames
|
498 |
+
)
|
499 |
+
else:
|
500 |
+
original_frames = None
|
501 |
+
|
502 |
+
if not self.use_latent or (self.use_latent and not self.precomputed_latent):
|
503 |
+
frames = self.scale_and_crop((frames / 255.0) * 2 - 1)
|
504 |
+
|
505 |
+
target = frames[:, 1:] if self.need_cond else frames
|
506 |
+
if self.mode == "prediction":
|
507 |
+
if self.use_latent:
|
508 |
+
if self.audio_in_video:
|
509 |
+
clean_cond = (
|
510 |
+
frames_video[0].unsqueeze(0).permute(3, 0, 1, 2).float()
|
511 |
+
)
|
512 |
+
else:
|
513 |
+
clean_cond = (
|
514 |
+
vr[video_indexes[0]].unsqueeze(0).permute(3, 0, 1, 2).float()
|
515 |
+
)
|
516 |
+
original_size = clean_cond.shape[-2:]
|
517 |
+
clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1).squeeze(
|
518 |
+
0
|
519 |
+
)
|
520 |
+
if self.latent_condition:
|
521 |
+
noisy_cond = frames[:, 0]
|
522 |
+
else:
|
523 |
+
noisy_cond = clean_cond
|
524 |
+
else:
|
525 |
+
clean_cond = frames[:, 0]
|
526 |
+
noisy_cond = clean_cond
|
527 |
+
elif self.mode == "interpolation":
|
528 |
+
if self.use_latent:
|
529 |
+
if self.audio_in_video:
|
530 |
+
clean_cond = frames_video[[0, -1]].permute(3, 0, 1, 2).float()
|
531 |
+
else:
|
532 |
+
clean_cond = (
|
533 |
+
vr.get_batch([video_indexes[0], video_indexes[-1]])
|
534 |
+
.permute(3, 0, 1, 2)
|
535 |
+
.float()
|
536 |
+
)
|
537 |
+
original_size = clean_cond.shape[-2:]
|
538 |
+
clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1)
|
539 |
+
if self.latent_condition:
|
540 |
+
noisy_cond = torch.stack([target[:, 0], target[:, -1]], dim=1)
|
541 |
+
else:
|
542 |
+
noisy_cond = clean_cond
|
543 |
+
else:
|
544 |
+
clean_cond = torch.stack([target[:, 0], target[:, -1]], dim=1)
|
545 |
+
noisy_cond = clean_cond
|
546 |
+
|
547 |
+
# Add noise to conditional frame
|
548 |
+
if self.cond_noise and isinstance(self.cond_noise, ListConfig):
|
549 |
+
cond_noise = (
|
550 |
+
self.cond_noise[0] + self.cond_noise[1] * torch.randn((1,))
|
551 |
+
).exp()
|
552 |
+
noisy_cond = noisy_cond + cond_noise * torch.randn_like(noisy_cond)
|
553 |
+
else:
|
554 |
+
noisy_cond = noisy_cond + self.cond_noise * torch.randn_like(noisy_cond)
|
555 |
+
cond_noise = self.cond_noise
|
556 |
+
|
557 |
+
if self.get_masks:
|
558 |
+
target_size = (
|
559 |
+
(self.resize_size, self.resize_size)
|
560 |
+
if not self.use_latent
|
561 |
+
else (self.resize_size // 8, self.resize_size // 8)
|
562 |
+
)
|
563 |
+
masks, landmarks = self._load_landmarks(
|
564 |
+
land_file, original_size, target_size, video_indexes
|
565 |
+
)
|
566 |
+
|
567 |
+
landmarks = None
|
568 |
+
masks = (
|
569 |
+
masks.permute(1, 0, 2, 3)[:, 1:]
|
570 |
+
if self.need_cond
|
571 |
+
else masks.permute(1, 0, 2, 3)
|
572 |
+
)
|
573 |
+
else:
|
574 |
+
masks = None
|
575 |
+
landmarks = None
|
576 |
+
|
577 |
+
return (
|
578 |
+
original_frames,
|
579 |
+
clean_cond,
|
580 |
+
noisy_cond,
|
581 |
+
target,
|
582 |
+
audio_frames,
|
583 |
+
raw_audio,
|
584 |
+
cond_noise,
|
585 |
+
emotions,
|
586 |
+
masks,
|
587 |
+
landmarks,
|
588 |
+
)
|
589 |
+
|
590 |
+
def filter_by_length(self, video_filelist, audio_filelist):
|
591 |
+
def with_opencv(filename):
|
592 |
+
video = cv2.VideoCapture(filename)
|
593 |
+
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
|
594 |
+
|
595 |
+
return int(frame_count)
|
596 |
+
|
597 |
+
filtered_video = []
|
598 |
+
filtered_audio = []
|
599 |
+
min_length = (self.num_frames - 1) * (self.skip_frames + 1) + 1
|
600 |
+
for vid_file, audio_file in tqdm(
|
601 |
+
zip(video_filelist, audio_filelist),
|
602 |
+
total=len(video_filelist),
|
603 |
+
desc="Filtering",
|
604 |
+
):
|
605 |
+
# vr = decord.VideoReader(vid_file)
|
606 |
+
|
607 |
+
len_video = with_opencv(vid_file)
|
608 |
+
# Short videos
|
609 |
+
if len_video < min_length:
|
610 |
+
continue
|
611 |
+
filtered_video.append(vid_file)
|
612 |
+
filtered_audio.append(audio_file)
|
613 |
+
print(f"New number of files: {len(filtered_video)}")
|
614 |
+
return filtered_video, filtered_audio
|
615 |
+
|
616 |
+
def _get_indexes(self, video_filelist, audio_filelist):
|
617 |
+
indexes = []
|
618 |
+
self.og_shape = None
|
619 |
+
for vid_file, audio_file in zip(video_filelist, audio_filelist):
|
620 |
+
vr = decord.VideoReader(vid_file)
|
621 |
+
if self.og_shape is None:
|
622 |
+
self.og_shape = vr[0].shape[-2]
|
623 |
+
len_video = len(vr)
|
624 |
+
# Short videos
|
625 |
+
if len_video < self.num_frames:
|
626 |
+
continue
|
627 |
+
else:
|
628 |
+
possible_indexes = list(
|
629 |
+
sliding_window(range(len_video), self.num_frames)
|
630 |
+
)[:: self.step]
|
631 |
+
possible_indexes = list(
|
632 |
+
map(lambda x: (x, vid_file, audio_file), possible_indexes)
|
633 |
+
)
|
634 |
+
indexes.extend(possible_indexes)
|
635 |
+
print("Indexes", len(indexes), "\n")
|
636 |
+
return indexes
|
637 |
+
|
638 |
+
def scale_and_crop(self, video):
|
639 |
+
h, w = video.shape[-2], video.shape[-1]
|
640 |
+
# scale shorter side to resolution
|
641 |
+
|
642 |
+
if self.resize_size is not None:
|
643 |
+
scale = self.resize_size / min(h, w)
|
644 |
+
if h < w:
|
645 |
+
target_size = (self.resize_size, math.ceil(w * scale))
|
646 |
+
else:
|
647 |
+
target_size = (math.ceil(h * scale), self.resize_size)
|
648 |
+
video = F.interpolate(
|
649 |
+
video,
|
650 |
+
size=target_size,
|
651 |
+
mode="bilinear",
|
652 |
+
align_corners=False,
|
653 |
+
antialias=True,
|
654 |
+
)
|
655 |
+
|
656 |
+
# center crop
|
657 |
+
h, w = video.shape[-2], video.shape[-1]
|
658 |
+
w_start = (w - self.resize_size) // 2
|
659 |
+
h_start = (h - self.resize_size) // 2
|
660 |
+
video = video[
|
661 |
+
:,
|
662 |
+
:,
|
663 |
+
h_start : h_start + self.resize_size,
|
664 |
+
w_start : w_start + self.resize_size,
|
665 |
+
]
|
666 |
+
return self.maybe_augment(video)
|
667 |
+
|
668 |
+
def _calculate_weights(self):
|
669 |
+
aa_processed_count = sum(
|
670 |
+
1
|
671 |
+
for item in self._indexes
|
672 |
+
if "AA_processed" in (item[1] if len(item) == 3 else item[0])
|
673 |
+
)
|
674 |
+
nsv_processed_count = sum(
|
675 |
+
1
|
676 |
+
for item in self._indexes
|
677 |
+
if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0])
|
678 |
+
)
|
679 |
+
other_count = len(self._indexes) - aa_processed_count - nsv_processed_count
|
680 |
+
|
681 |
+
aa_processed_weight = 1 / aa_processed_count if aa_processed_count > 0 else 0
|
682 |
+
nsv_processed_weight = 1 / nsv_processed_count if nsv_processed_count > 0 else 0
|
683 |
+
other_weight = 1 / other_count if other_count > 0 else 0
|
684 |
+
|
685 |
+
print(
|
686 |
+
f"AA processed count: {aa_processed_count}, NSV processed count: {nsv_processed_count}, other count: {other_count}"
|
687 |
+
)
|
688 |
+
print(f"AA processed weight: {aa_processed_weight}")
|
689 |
+
print(f"NSV processed weight: {nsv_processed_weight}")
|
690 |
+
print(f"Other weight: {other_weight}")
|
691 |
+
|
692 |
+
weights = [
|
693 |
+
aa_processed_weight
|
694 |
+
if "AA_processed" in (item[1] if len(item) == 3 else item[0])
|
695 |
+
else nsv_processed_weight
|
696 |
+
if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0])
|
697 |
+
else other_weight
|
698 |
+
for item in self._indexes
|
699 |
+
]
|
700 |
+
return weights
|
701 |
+
|
702 |
+
def __getitem__(self, idx):
|
703 |
+
if self.balance_datasets:
|
704 |
+
idx = self.sampler.__iter__().__next__()
|
705 |
+
|
706 |
+
try:
|
707 |
+
(
|
708 |
+
original_frames,
|
709 |
+
clean_cond,
|
710 |
+
noisy_cond,
|
711 |
+
target,
|
712 |
+
audio,
|
713 |
+
raw_audio,
|
714 |
+
cond_noise,
|
715 |
+
emotions,
|
716 |
+
masks,
|
717 |
+
landmarks,
|
718 |
+
) = self._get_frames_and_audio(idx % len(self._indexes))
|
719 |
+
except Exception as e:
|
720 |
+
print(f"Error with index {idx}: {e}")
|
721 |
+
return self.__getitem__(np.random.randint(0, len(self)))
|
722 |
+
out_data = {}
|
723 |
+
|
724 |
+
if original_frames is not None:
|
725 |
+
out_data["original_frames"] = original_frames
|
726 |
+
|
727 |
+
if audio is not None:
|
728 |
+
out_data["audio_emb"] = audio
|
729 |
+
out_data["raw_audio"] = raw_audio
|
730 |
+
|
731 |
+
if self.use_emotions:
|
732 |
+
out_data["valence"] = emotions[0]
|
733 |
+
out_data["arousal"] = emotions[1]
|
734 |
+
out_data["emo_labels"] = emotions[2]
|
735 |
+
if self.use_latent:
|
736 |
+
input_key = "latents"
|
737 |
+
else:
|
738 |
+
input_key = "frames"
|
739 |
+
out_data[input_key] = target
|
740 |
+
if noisy_cond is not None:
|
741 |
+
out_data["cond_frames"] = noisy_cond
|
742 |
+
out_data["cond_frames_without_noise"] = clean_cond
|
743 |
+
if cond_noise is not None:
|
744 |
+
out_data["cond_aug"] = cond_noise
|
745 |
+
|
746 |
+
if masks is not None:
|
747 |
+
out_data["masks"] = masks
|
748 |
+
out_data["gt"] = target
|
749 |
+
if landmarks is not None:
|
750 |
+
out_data["landmarks"] = landmarks
|
751 |
+
|
752 |
+
out_data["motion_bucket_id"] = torch.tensor([self.motion_id])
|
753 |
+
out_data["fps_id"] = torch.tensor([self.fps - 1])
|
754 |
+
out_data["num_video_frames"] = self.num_frames
|
755 |
+
out_data["image_only_indicator"] = torch.zeros(self.num_frames)
|
756 |
+
return out_data
|
757 |
+
|
758 |
+
|
759 |
+
if __name__ == "__main__":
|
760 |
+
import torchvision.transforms as transforms
|
761 |
+
import cv2
|
762 |
+
|
763 |
+
transform = transforms.Compose(transforms=[transforms.Resize((256, 256))])
|
764 |
+
dataset = VideoDataset(
|
765 |
+
"/vol/paramonos2/projects/antoni/datasets/mahnob/filelist_videos_val.txt",
|
766 |
+
transform=transform,
|
767 |
+
num_frames=25,
|
768 |
+
)
|
769 |
+
print(len(dataset))
|
770 |
+
idx = np.random.randint(0, len(dataset))
|
771 |
+
|
772 |
+
for i in range(10):
|
773 |
+
print(dataset[i][0].shape, dataset[i][1].shape)
|
774 |
+
|
775 |
+
image_identity = (dataset[idx][0].permute(1, 2, 0).numpy() + 1) / 2 * 255
|
776 |
+
image_other = (dataset[idx][1][:, -1].permute(1, 2, 0).numpy() + 1) / 2 * 255
|
777 |
+
cv2.imwrite("image_identity.png", image_identity[:, :, ::-1])
|
778 |
+
for i in range(25):
|
779 |
+
image = (dataset[idx][1][:, i].permute(1, 2, 0).numpy() + 1) / 2 * 255
|
780 |
+
cv2.imwrite(f"tmp_vid_dataset/image_{i}.png", image[:, :, ::-1])
|
sgm/inference/api.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from enum import Enum
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
|
8 |
+
from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
|
9 |
+
do_sample)
|
10 |
+
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
|
11 |
+
DPMPP2SAncestralSampler,
|
12 |
+
EulerAncestralSampler,
|
13 |
+
EulerEDMSampler,
|
14 |
+
HeunEDMSampler,
|
15 |
+
LinearMultistepSampler)
|
16 |
+
from sgm.util import load_model_from_config
|
17 |
+
|
18 |
+
|
19 |
+
class ModelArchitecture(str, Enum):
|
20 |
+
SD_2_1 = "stable-diffusion-v2-1"
|
21 |
+
SD_2_1_768 = "stable-diffusion-v2-1-768"
|
22 |
+
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
23 |
+
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
24 |
+
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
|
25 |
+
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
|
26 |
+
|
27 |
+
|
28 |
+
class Sampler(str, Enum):
|
29 |
+
EULER_EDM = "EulerEDMSampler"
|
30 |
+
HEUN_EDM = "HeunEDMSampler"
|
31 |
+
EULER_ANCESTRAL = "EulerAncestralSampler"
|
32 |
+
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
|
33 |
+
DPMPP2M = "DPMPP2MSampler"
|
34 |
+
LINEAR_MULTISTEP = "LinearMultistepSampler"
|
35 |
+
|
36 |
+
|
37 |
+
class Discretization(str, Enum):
|
38 |
+
LEGACY_DDPM = "LegacyDDPMDiscretization"
|
39 |
+
EDM = "EDMDiscretization"
|
40 |
+
|
41 |
+
|
42 |
+
class Guider(str, Enum):
|
43 |
+
VANILLA = "VanillaCFG"
|
44 |
+
IDENTITY = "IdentityGuider"
|
45 |
+
|
46 |
+
|
47 |
+
class Thresholder(str, Enum):
|
48 |
+
NONE = "None"
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class SamplingParams:
|
53 |
+
width: int = 1024
|
54 |
+
height: int = 1024
|
55 |
+
steps: int = 50
|
56 |
+
sampler: Sampler = Sampler.DPMPP2M
|
57 |
+
discretization: Discretization = Discretization.LEGACY_DDPM
|
58 |
+
guider: Guider = Guider.VANILLA
|
59 |
+
thresholder: Thresholder = Thresholder.NONE
|
60 |
+
scale: float = 6.0
|
61 |
+
aesthetic_score: float = 5.0
|
62 |
+
negative_aesthetic_score: float = 5.0
|
63 |
+
img2img_strength: float = 1.0
|
64 |
+
orig_width: int = 1024
|
65 |
+
orig_height: int = 1024
|
66 |
+
crop_coords_top: int = 0
|
67 |
+
crop_coords_left: int = 0
|
68 |
+
sigma_min: float = 0.0292
|
69 |
+
sigma_max: float = 14.6146
|
70 |
+
rho: float = 3.0
|
71 |
+
s_churn: float = 0.0
|
72 |
+
s_tmin: float = 0.0
|
73 |
+
s_tmax: float = 999.0
|
74 |
+
s_noise: float = 1.0
|
75 |
+
eta: float = 1.0
|
76 |
+
order: int = 4
|
77 |
+
|
78 |
+
|
79 |
+
@dataclass
|
80 |
+
class SamplingSpec:
|
81 |
+
width: int
|
82 |
+
height: int
|
83 |
+
channels: int
|
84 |
+
factor: int
|
85 |
+
is_legacy: bool
|
86 |
+
config: str
|
87 |
+
ckpt: str
|
88 |
+
is_guided: bool
|
89 |
+
|
90 |
+
|
91 |
+
model_specs = {
|
92 |
+
ModelArchitecture.SD_2_1: SamplingSpec(
|
93 |
+
height=512,
|
94 |
+
width=512,
|
95 |
+
channels=4,
|
96 |
+
factor=8,
|
97 |
+
is_legacy=True,
|
98 |
+
config="sd_2_1.yaml",
|
99 |
+
ckpt="v2-1_512-ema-pruned.safetensors",
|
100 |
+
is_guided=True,
|
101 |
+
),
|
102 |
+
ModelArchitecture.SD_2_1_768: SamplingSpec(
|
103 |
+
height=768,
|
104 |
+
width=768,
|
105 |
+
channels=4,
|
106 |
+
factor=8,
|
107 |
+
is_legacy=True,
|
108 |
+
config="sd_2_1_768.yaml",
|
109 |
+
ckpt="v2-1_768-ema-pruned.safetensors",
|
110 |
+
is_guided=True,
|
111 |
+
),
|
112 |
+
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
|
113 |
+
height=1024,
|
114 |
+
width=1024,
|
115 |
+
channels=4,
|
116 |
+
factor=8,
|
117 |
+
is_legacy=False,
|
118 |
+
config="sd_xl_base.yaml",
|
119 |
+
ckpt="sd_xl_base_0.9.safetensors",
|
120 |
+
is_guided=True,
|
121 |
+
),
|
122 |
+
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
|
123 |
+
height=1024,
|
124 |
+
width=1024,
|
125 |
+
channels=4,
|
126 |
+
factor=8,
|
127 |
+
is_legacy=True,
|
128 |
+
config="sd_xl_refiner.yaml",
|
129 |
+
ckpt="sd_xl_refiner_0.9.safetensors",
|
130 |
+
is_guided=True,
|
131 |
+
),
|
132 |
+
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
|
133 |
+
height=1024,
|
134 |
+
width=1024,
|
135 |
+
channels=4,
|
136 |
+
factor=8,
|
137 |
+
is_legacy=False,
|
138 |
+
config="sd_xl_base.yaml",
|
139 |
+
ckpt="sd_xl_base_1.0.safetensors",
|
140 |
+
is_guided=True,
|
141 |
+
),
|
142 |
+
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
|
143 |
+
height=1024,
|
144 |
+
width=1024,
|
145 |
+
channels=4,
|
146 |
+
factor=8,
|
147 |
+
is_legacy=True,
|
148 |
+
config="sd_xl_refiner.yaml",
|
149 |
+
ckpt="sd_xl_refiner_1.0.safetensors",
|
150 |
+
is_guided=True,
|
151 |
+
),
|
152 |
+
}
|
153 |
+
|
154 |
+
|
155 |
+
class SamplingPipeline:
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
model_id: ModelArchitecture,
|
159 |
+
model_path="checkpoints",
|
160 |
+
config_path="configs/inference",
|
161 |
+
device="cuda",
|
162 |
+
use_fp16=True,
|
163 |
+
) -> None:
|
164 |
+
if model_id not in model_specs:
|
165 |
+
raise ValueError(f"Model {model_id} not supported")
|
166 |
+
self.model_id = model_id
|
167 |
+
self.specs = model_specs[self.model_id]
|
168 |
+
self.config = str(pathlib.Path(config_path, self.specs.config))
|
169 |
+
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
|
170 |
+
self.device = device
|
171 |
+
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
172 |
+
|
173 |
+
def _load_model(self, device="cuda", use_fp16=True):
|
174 |
+
config = OmegaConf.load(self.config)
|
175 |
+
model = load_model_from_config(config, self.ckpt)
|
176 |
+
if model is None:
|
177 |
+
raise ValueError(f"Model {self.model_id} could not be loaded")
|
178 |
+
model.to(device)
|
179 |
+
if use_fp16:
|
180 |
+
model.conditioner.half()
|
181 |
+
model.model.half()
|
182 |
+
return model
|
183 |
+
|
184 |
+
def text_to_image(
|
185 |
+
self,
|
186 |
+
params: SamplingParams,
|
187 |
+
prompt: str,
|
188 |
+
negative_prompt: str = "",
|
189 |
+
samples: int = 1,
|
190 |
+
return_latents: bool = False,
|
191 |
+
):
|
192 |
+
sampler = get_sampler_config(params)
|
193 |
+
value_dict = asdict(params)
|
194 |
+
value_dict["prompt"] = prompt
|
195 |
+
value_dict["negative_prompt"] = negative_prompt
|
196 |
+
value_dict["target_width"] = params.width
|
197 |
+
value_dict["target_height"] = params.height
|
198 |
+
return do_sample(
|
199 |
+
self.model,
|
200 |
+
sampler,
|
201 |
+
value_dict,
|
202 |
+
samples,
|
203 |
+
params.height,
|
204 |
+
params.width,
|
205 |
+
self.specs.channels,
|
206 |
+
self.specs.factor,
|
207 |
+
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
208 |
+
return_latents=return_latents,
|
209 |
+
filter=None,
|
210 |
+
)
|
211 |
+
|
212 |
+
def image_to_image(
|
213 |
+
self,
|
214 |
+
params: SamplingParams,
|
215 |
+
image,
|
216 |
+
prompt: str,
|
217 |
+
negative_prompt: str = "",
|
218 |
+
samples: int = 1,
|
219 |
+
return_latents: bool = False,
|
220 |
+
):
|
221 |
+
sampler = get_sampler_config(params)
|
222 |
+
|
223 |
+
if params.img2img_strength < 1.0:
|
224 |
+
sampler.discretization = Img2ImgDiscretizationWrapper(
|
225 |
+
sampler.discretization,
|
226 |
+
strength=params.img2img_strength,
|
227 |
+
)
|
228 |
+
height, width = image.shape[2], image.shape[3]
|
229 |
+
value_dict = asdict(params)
|
230 |
+
value_dict["prompt"] = prompt
|
231 |
+
value_dict["negative_prompt"] = negative_prompt
|
232 |
+
value_dict["target_width"] = width
|
233 |
+
value_dict["target_height"] = height
|
234 |
+
return do_img2img(
|
235 |
+
image,
|
236 |
+
self.model,
|
237 |
+
sampler,
|
238 |
+
value_dict,
|
239 |
+
samples,
|
240 |
+
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
241 |
+
return_latents=return_latents,
|
242 |
+
filter=None,
|
243 |
+
)
|
244 |
+
|
245 |
+
def refiner(
|
246 |
+
self,
|
247 |
+
params: SamplingParams,
|
248 |
+
image,
|
249 |
+
prompt: str,
|
250 |
+
negative_prompt: Optional[str] = None,
|
251 |
+
samples: int = 1,
|
252 |
+
return_latents: bool = False,
|
253 |
+
):
|
254 |
+
sampler = get_sampler_config(params)
|
255 |
+
value_dict = {
|
256 |
+
"orig_width": image.shape[3] * 8,
|
257 |
+
"orig_height": image.shape[2] * 8,
|
258 |
+
"target_width": image.shape[3] * 8,
|
259 |
+
"target_height": image.shape[2] * 8,
|
260 |
+
"prompt": prompt,
|
261 |
+
"negative_prompt": negative_prompt,
|
262 |
+
"crop_coords_top": 0,
|
263 |
+
"crop_coords_left": 0,
|
264 |
+
"aesthetic_score": 6.0,
|
265 |
+
"negative_aesthetic_score": 2.5,
|
266 |
+
}
|
267 |
+
|
268 |
+
return do_img2img(
|
269 |
+
image,
|
270 |
+
self.model,
|
271 |
+
sampler,
|
272 |
+
value_dict,
|
273 |
+
samples,
|
274 |
+
skip_encode=True,
|
275 |
+
return_latents=return_latents,
|
276 |
+
filter=None,
|
277 |
+
)
|
278 |
+
|
279 |
+
|
280 |
+
def get_guider_config(params: SamplingParams):
|
281 |
+
if params.guider == Guider.IDENTITY:
|
282 |
+
guider_config = {
|
283 |
+
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
284 |
+
}
|
285 |
+
elif params.guider == Guider.VANILLA:
|
286 |
+
scale = params.scale
|
287 |
+
|
288 |
+
thresholder = params.thresholder
|
289 |
+
|
290 |
+
if thresholder == Thresholder.NONE:
|
291 |
+
dyn_thresh_config = {
|
292 |
+
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
293 |
+
}
|
294 |
+
else:
|
295 |
+
raise NotImplementedError
|
296 |
+
|
297 |
+
guider_config = {
|
298 |
+
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
299 |
+
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
300 |
+
}
|
301 |
+
else:
|
302 |
+
raise NotImplementedError
|
303 |
+
return guider_config
|
304 |
+
|
305 |
+
|
306 |
+
def get_discretization_config(params: SamplingParams):
|
307 |
+
if params.discretization == Discretization.LEGACY_DDPM:
|
308 |
+
discretization_config = {
|
309 |
+
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
310 |
+
}
|
311 |
+
elif params.discretization == Discretization.EDM:
|
312 |
+
discretization_config = {
|
313 |
+
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
314 |
+
"params": {
|
315 |
+
"sigma_min": params.sigma_min,
|
316 |
+
"sigma_max": params.sigma_max,
|
317 |
+
"rho": params.rho,
|
318 |
+
},
|
319 |
+
}
|
320 |
+
else:
|
321 |
+
raise ValueError(f"unknown discretization {params.discretization}")
|
322 |
+
return discretization_config
|
323 |
+
|
324 |
+
|
325 |
+
def get_sampler_config(params: SamplingParams):
|
326 |
+
discretization_config = get_discretization_config(params)
|
327 |
+
guider_config = get_guider_config(params)
|
328 |
+
sampler = None
|
329 |
+
if params.sampler == Sampler.EULER_EDM:
|
330 |
+
return EulerEDMSampler(
|
331 |
+
num_steps=params.steps,
|
332 |
+
discretization_config=discretization_config,
|
333 |
+
guider_config=guider_config,
|
334 |
+
s_churn=params.s_churn,
|
335 |
+
s_tmin=params.s_tmin,
|
336 |
+
s_tmax=params.s_tmax,
|
337 |
+
s_noise=params.s_noise,
|
338 |
+
verbose=True,
|
339 |
+
)
|
340 |
+
if params.sampler == Sampler.HEUN_EDM:
|
341 |
+
return HeunEDMSampler(
|
342 |
+
num_steps=params.steps,
|
343 |
+
discretization_config=discretization_config,
|
344 |
+
guider_config=guider_config,
|
345 |
+
s_churn=params.s_churn,
|
346 |
+
s_tmin=params.s_tmin,
|
347 |
+
s_tmax=params.s_tmax,
|
348 |
+
s_noise=params.s_noise,
|
349 |
+
verbose=True,
|
350 |
+
)
|
351 |
+
if params.sampler == Sampler.EULER_ANCESTRAL:
|
352 |
+
return EulerAncestralSampler(
|
353 |
+
num_steps=params.steps,
|
354 |
+
discretization_config=discretization_config,
|
355 |
+
guider_config=guider_config,
|
356 |
+
eta=params.eta,
|
357 |
+
s_noise=params.s_noise,
|
358 |
+
verbose=True,
|
359 |
+
)
|
360 |
+
if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
|
361 |
+
return DPMPP2SAncestralSampler(
|
362 |
+
num_steps=params.steps,
|
363 |
+
discretization_config=discretization_config,
|
364 |
+
guider_config=guider_config,
|
365 |
+
eta=params.eta,
|
366 |
+
s_noise=params.s_noise,
|
367 |
+
verbose=True,
|
368 |
+
)
|
369 |
+
if params.sampler == Sampler.DPMPP2M:
|
370 |
+
return DPMPP2MSampler(
|
371 |
+
num_steps=params.steps,
|
372 |
+
discretization_config=discretization_config,
|
373 |
+
guider_config=guider_config,
|
374 |
+
verbose=True,
|
375 |
+
)
|
376 |
+
if params.sampler == Sampler.LINEAR_MULTISTEP:
|
377 |
+
return LinearMultistepSampler(
|
378 |
+
num_steps=params.steps,
|
379 |
+
discretization_config=discretization_config,
|
380 |
+
guider_config=guider_config,
|
381 |
+
order=params.order,
|
382 |
+
verbose=True,
|
383 |
+
)
|
384 |
+
|
385 |
+
raise ValueError(f"unknown sampler {params.sampler}!")
|
sgm/inference/helpers.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from einops import rearrange
|
8 |
+
from imwatermark import WatermarkEncoder
|
9 |
+
from omegaconf import ListConfig
|
10 |
+
from PIL import Image
|
11 |
+
from torch import autocast
|
12 |
+
|
13 |
+
from sgm.util import append_dims
|
14 |
+
|
15 |
+
|
16 |
+
class WatermarkEmbedder:
|
17 |
+
def __init__(self, watermark):
|
18 |
+
self.watermark = watermark
|
19 |
+
self.num_bits = len(WATERMARK_BITS)
|
20 |
+
self.encoder = WatermarkEncoder()
|
21 |
+
self.encoder.set_watermark("bits", self.watermark)
|
22 |
+
|
23 |
+
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
24 |
+
"""
|
25 |
+
Adds a predefined watermark to the input image
|
26 |
+
|
27 |
+
Args:
|
28 |
+
image: ([N,] B, RGB, H, W) in range [0, 1]
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
same as input but watermarked
|
32 |
+
"""
|
33 |
+
squeeze = len(image.shape) == 4
|
34 |
+
if squeeze:
|
35 |
+
image = image[None, ...]
|
36 |
+
n = image.shape[0]
|
37 |
+
image_np = rearrange(
|
38 |
+
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
39 |
+
).numpy()[:, :, :, ::-1]
|
40 |
+
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
41 |
+
# watermarking libary expects input as cv2 BGR format
|
42 |
+
for k in range(image_np.shape[0]):
|
43 |
+
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
44 |
+
image = torch.from_numpy(
|
45 |
+
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
46 |
+
).to(image.device)
|
47 |
+
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
48 |
+
if squeeze:
|
49 |
+
image = image[0]
|
50 |
+
return image
|
51 |
+
|
52 |
+
|
53 |
+
# A fixed 48-bit message that was choosen at random
|
54 |
+
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
55 |
+
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
56 |
+
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
57 |
+
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
58 |
+
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
59 |
+
|
60 |
+
|
61 |
+
def get_unique_embedder_keys_from_conditioner(conditioner):
|
62 |
+
return list({x.input_key for x in conditioner.embedders})
|
63 |
+
|
64 |
+
|
65 |
+
def perform_save_locally(save_path, samples):
|
66 |
+
os.makedirs(os.path.join(save_path), exist_ok=True)
|
67 |
+
base_count = len(os.listdir(os.path.join(save_path)))
|
68 |
+
samples = embed_watermark(samples)
|
69 |
+
for sample in samples:
|
70 |
+
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
71 |
+
Image.fromarray(sample.astype(np.uint8)).save(
|
72 |
+
os.path.join(save_path, f"{base_count:09}.png")
|
73 |
+
)
|
74 |
+
base_count += 1
|
75 |
+
|
76 |
+
|
77 |
+
class Img2ImgDiscretizationWrapper:
|
78 |
+
"""
|
79 |
+
wraps a discretizer, and prunes the sigmas
|
80 |
+
params:
|
81 |
+
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, discretization, strength: float = 1.0):
|
85 |
+
self.discretization = discretization
|
86 |
+
self.strength = strength
|
87 |
+
assert 0.0 <= self.strength <= 1.0
|
88 |
+
|
89 |
+
def __call__(self, *args, **kwargs):
|
90 |
+
# sigmas start large first, and decrease then
|
91 |
+
sigmas = self.discretization(*args, **kwargs)
|
92 |
+
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
93 |
+
sigmas = torch.flip(sigmas, (0,))
|
94 |
+
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
95 |
+
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
96 |
+
sigmas = torch.flip(sigmas, (0,))
|
97 |
+
print(f"sigmas after pruning: ", sigmas)
|
98 |
+
return sigmas
|
99 |
+
|
100 |
+
|
101 |
+
def do_sample(
|
102 |
+
model,
|
103 |
+
sampler,
|
104 |
+
value_dict,
|
105 |
+
num_samples,
|
106 |
+
H,
|
107 |
+
W,
|
108 |
+
C,
|
109 |
+
F,
|
110 |
+
force_uc_zero_embeddings: Optional[List] = None,
|
111 |
+
batch2model_input: Optional[List] = None,
|
112 |
+
return_latents=False,
|
113 |
+
filter=None,
|
114 |
+
device="cuda",
|
115 |
+
):
|
116 |
+
if force_uc_zero_embeddings is None:
|
117 |
+
force_uc_zero_embeddings = []
|
118 |
+
if batch2model_input is None:
|
119 |
+
batch2model_input = []
|
120 |
+
|
121 |
+
with torch.no_grad():
|
122 |
+
with autocast(device) as precision_scope:
|
123 |
+
with model.ema_scope():
|
124 |
+
num_samples = [num_samples]
|
125 |
+
batch, batch_uc = get_batch(
|
126 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
127 |
+
value_dict,
|
128 |
+
num_samples,
|
129 |
+
)
|
130 |
+
for key in batch:
|
131 |
+
if isinstance(batch[key], torch.Tensor):
|
132 |
+
print(key, batch[key].shape)
|
133 |
+
elif isinstance(batch[key], list):
|
134 |
+
print(key, [len(l) for l in batch[key]])
|
135 |
+
else:
|
136 |
+
print(key, batch[key])
|
137 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
138 |
+
batch,
|
139 |
+
batch_uc=batch_uc,
|
140 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
141 |
+
)
|
142 |
+
|
143 |
+
for k in c:
|
144 |
+
if not k == "crossattn":
|
145 |
+
c[k], uc[k] = map(
|
146 |
+
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
|
147 |
+
)
|
148 |
+
|
149 |
+
additional_model_inputs = {}
|
150 |
+
for k in batch2model_input:
|
151 |
+
additional_model_inputs[k] = batch[k]
|
152 |
+
|
153 |
+
shape = (math.prod(num_samples), C, H // F, W // F)
|
154 |
+
randn = torch.randn(shape).to(device)
|
155 |
+
|
156 |
+
def denoiser(input, sigma, c):
|
157 |
+
return model.denoiser(
|
158 |
+
model.model, input, sigma, c, **additional_model_inputs
|
159 |
+
)
|
160 |
+
|
161 |
+
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
162 |
+
samples_x = model.decode_first_stage(samples_z)
|
163 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
164 |
+
|
165 |
+
if filter is not None:
|
166 |
+
samples = filter(samples)
|
167 |
+
|
168 |
+
if return_latents:
|
169 |
+
return samples, samples_z
|
170 |
+
return samples
|
171 |
+
|
172 |
+
|
173 |
+
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
174 |
+
# Hardcoded demo setups; might undergo some changes in the future
|
175 |
+
|
176 |
+
batch = {}
|
177 |
+
batch_uc = {}
|
178 |
+
|
179 |
+
for key in keys:
|
180 |
+
if key == "txt":
|
181 |
+
batch["txt"] = (
|
182 |
+
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
183 |
+
.reshape(N)
|
184 |
+
.tolist()
|
185 |
+
)
|
186 |
+
batch_uc["txt"] = (
|
187 |
+
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
188 |
+
.reshape(N)
|
189 |
+
.tolist()
|
190 |
+
)
|
191 |
+
elif key == "original_size_as_tuple":
|
192 |
+
batch["original_size_as_tuple"] = (
|
193 |
+
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
194 |
+
.to(device)
|
195 |
+
.repeat(*N, 1)
|
196 |
+
)
|
197 |
+
elif key == "crop_coords_top_left":
|
198 |
+
batch["crop_coords_top_left"] = (
|
199 |
+
torch.tensor(
|
200 |
+
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
201 |
+
)
|
202 |
+
.to(device)
|
203 |
+
.repeat(*N, 1)
|
204 |
+
)
|
205 |
+
elif key == "aesthetic_score":
|
206 |
+
batch["aesthetic_score"] = (
|
207 |
+
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
208 |
+
)
|
209 |
+
batch_uc["aesthetic_score"] = (
|
210 |
+
torch.tensor([value_dict["negative_aesthetic_score"]])
|
211 |
+
.to(device)
|
212 |
+
.repeat(*N, 1)
|
213 |
+
)
|
214 |
+
|
215 |
+
elif key == "target_size_as_tuple":
|
216 |
+
batch["target_size_as_tuple"] = (
|
217 |
+
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
218 |
+
.to(device)
|
219 |
+
.repeat(*N, 1)
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
batch[key] = value_dict[key]
|
223 |
+
|
224 |
+
for key in batch.keys():
|
225 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
226 |
+
batch_uc[key] = torch.clone(batch[key])
|
227 |
+
return batch, batch_uc
|
228 |
+
|
229 |
+
|
230 |
+
def get_input_image_tensor(image: Image.Image, device="cuda"):
|
231 |
+
w, h = image.size
|
232 |
+
print(f"loaded input image of size ({w}, {h})")
|
233 |
+
width, height = map(
|
234 |
+
lambda x: x - x % 64, (w, h)
|
235 |
+
) # resize to integer multiple of 64
|
236 |
+
image = image.resize((width, height))
|
237 |
+
image_array = np.array(image.convert("RGB"))
|
238 |
+
image_array = image_array[None].transpose(0, 3, 1, 2)
|
239 |
+
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
|
240 |
+
return image_tensor.to(device)
|
241 |
+
|
242 |
+
|
243 |
+
def do_img2img(
|
244 |
+
img,
|
245 |
+
model,
|
246 |
+
sampler,
|
247 |
+
value_dict,
|
248 |
+
num_samples,
|
249 |
+
force_uc_zero_embeddings=[],
|
250 |
+
additional_kwargs={},
|
251 |
+
offset_noise_level: float = 0.0,
|
252 |
+
return_latents=False,
|
253 |
+
skip_encode=False,
|
254 |
+
filter=None,
|
255 |
+
device="cuda",
|
256 |
+
):
|
257 |
+
with torch.no_grad():
|
258 |
+
with autocast(device) as precision_scope:
|
259 |
+
with model.ema_scope():
|
260 |
+
batch, batch_uc = get_batch(
|
261 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
262 |
+
value_dict,
|
263 |
+
[num_samples],
|
264 |
+
)
|
265 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
266 |
+
batch,
|
267 |
+
batch_uc=batch_uc,
|
268 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
269 |
+
)
|
270 |
+
|
271 |
+
for k in c:
|
272 |
+
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
|
273 |
+
|
274 |
+
for k in additional_kwargs:
|
275 |
+
c[k] = uc[k] = additional_kwargs[k]
|
276 |
+
if skip_encode:
|
277 |
+
z = img
|
278 |
+
else:
|
279 |
+
z = model.encode_first_stage(img)
|
280 |
+
noise = torch.randn_like(z)
|
281 |
+
sigmas = sampler.discretization(sampler.num_steps)
|
282 |
+
sigma = sigmas[0].to(z.device)
|
283 |
+
|
284 |
+
if offset_noise_level > 0.0:
|
285 |
+
noise = noise + offset_noise_level * append_dims(
|
286 |
+
torch.randn(z.shape[0], device=z.device), z.ndim
|
287 |
+
)
|
288 |
+
noised_z = z + noise * append_dims(sigma, z.ndim)
|
289 |
+
noised_z = noised_z / torch.sqrt(
|
290 |
+
1.0 + sigmas[0] ** 2.0
|
291 |
+
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
292 |
+
|
293 |
+
def denoiser(x, sigma, c):
|
294 |
+
return model.denoiser(model.model, x, sigma, c)
|
295 |
+
|
296 |
+
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
297 |
+
samples_x = model.decode_first_stage(samples_z)
|
298 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
299 |
+
|
300 |
+
if filter is not None:
|
301 |
+
samples = filter(samples)
|
302 |
+
|
303 |
+
if return_latents:
|
304 |
+
return samples, samples_z
|
305 |
+
return samples
|
sgm/lr_scheduler.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
warm_up_steps,
|
12 |
+
lr_min,
|
13 |
+
lr_max,
|
14 |
+
lr_start,
|
15 |
+
max_decay_steps,
|
16 |
+
verbosity_interval=0,
|
17 |
+
):
|
18 |
+
self.lr_warm_up_steps = warm_up_steps
|
19 |
+
self.lr_start = lr_start
|
20 |
+
self.lr_min = lr_min
|
21 |
+
self.lr_max = lr_max
|
22 |
+
self.lr_max_decay_steps = max_decay_steps
|
23 |
+
self.last_lr = 0.0
|
24 |
+
self.verbosity_interval = verbosity_interval
|
25 |
+
|
26 |
+
def schedule(self, n, **kwargs):
|
27 |
+
if self.verbosity_interval > 0:
|
28 |
+
if n % self.verbosity_interval == 0:
|
29 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
30 |
+
if n < self.lr_warm_up_steps:
|
31 |
+
lr = (
|
32 |
+
self.lr_max - self.lr_start
|
33 |
+
) / self.lr_warm_up_steps * n + self.lr_start
|
34 |
+
self.last_lr = lr
|
35 |
+
return lr
|
36 |
+
else:
|
37 |
+
t = (n - self.lr_warm_up_steps) / (
|
38 |
+
self.lr_max_decay_steps - self.lr_warm_up_steps
|
39 |
+
)
|
40 |
+
t = min(t, 1.0)
|
41 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
42 |
+
1 + np.cos(t * np.pi)
|
43 |
+
)
|
44 |
+
self.last_lr = lr
|
45 |
+
return lr
|
46 |
+
|
47 |
+
def __call__(self, n, **kwargs):
|
48 |
+
return self.schedule(n, **kwargs)
|
49 |
+
|
50 |
+
|
51 |
+
class LambdaWarmUpCosineScheduler2:
|
52 |
+
"""
|
53 |
+
supports repeated iterations, configurable via lists
|
54 |
+
note: use with a base_lr of 1.0.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
|
59 |
+
):
|
60 |
+
assert (
|
61 |
+
len(warm_up_steps)
|
62 |
+
== len(f_min)
|
63 |
+
== len(f_max)
|
64 |
+
== len(f_start)
|
65 |
+
== len(cycle_lengths)
|
66 |
+
)
|
67 |
+
self.lr_warm_up_steps = warm_up_steps
|
68 |
+
self.f_start = f_start
|
69 |
+
self.f_min = f_min
|
70 |
+
self.f_max = f_max
|
71 |
+
self.cycle_lengths = cycle_lengths
|
72 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
73 |
+
self.last_f = 0.0
|
74 |
+
self.verbosity_interval = verbosity_interval
|
75 |
+
|
76 |
+
def find_in_interval(self, n):
|
77 |
+
interval = 0
|
78 |
+
for cl in self.cum_cycles[1:]:
|
79 |
+
if n <= cl:
|
80 |
+
return interval
|
81 |
+
interval += 1
|
82 |
+
|
83 |
+
def schedule(self, n, **kwargs):
|
84 |
+
cycle = self.find_in_interval(n)
|
85 |
+
n = n - self.cum_cycles[cycle]
|
86 |
+
if self.verbosity_interval > 0:
|
87 |
+
if n % self.verbosity_interval == 0:
|
88 |
+
print(
|
89 |
+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
90 |
+
f"current cycle {cycle}"
|
91 |
+
)
|
92 |
+
if n < self.lr_warm_up_steps[cycle]:
|
93 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
94 |
+
cycle
|
95 |
+
] * n + self.f_start[cycle]
|
96 |
+
self.last_f = f
|
97 |
+
return f
|
98 |
+
else:
|
99 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (
|
100 |
+
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
101 |
+
)
|
102 |
+
t = min(t, 1.0)
|
103 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
104 |
+
1 + np.cos(t * np.pi)
|
105 |
+
)
|
106 |
+
self.last_f = f
|
107 |
+
return f
|
108 |
+
|
109 |
+
def __call__(self, n, **kwargs):
|
110 |
+
return self.schedule(n, **kwargs)
|
111 |
+
|
112 |
+
|
113 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
114 |
+
def schedule(self, n, **kwargs):
|
115 |
+
cycle = self.find_in_interval(n)
|
116 |
+
n = n - self.cum_cycles[cycle]
|
117 |
+
if self.verbosity_interval > 0:
|
118 |
+
if n % self.verbosity_interval == 0:
|
119 |
+
print(
|
120 |
+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
121 |
+
f"current cycle {cycle}"
|
122 |
+
)
|
123 |
+
|
124 |
+
if n < self.lr_warm_up_steps[cycle]:
|
125 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
126 |
+
cycle
|
127 |
+
] * n + self.f_start[cycle]
|
128 |
+
self.last_f = f
|
129 |
+
return f
|
130 |
+
else:
|
131 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
|
132 |
+
self.cycle_lengths[cycle] - n
|
133 |
+
) / (self.cycle_lengths[cycle])
|
134 |
+
self.last_f = f
|
135 |
+
return f
|
sgm/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .autoencoder import AutoencodingEngine
|
2 |
+
from .diffusion import DiffusionEngine
|
sgm/models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (335 Bytes). View file
|
|
sgm/models/__pycache__/autoencoder.cpython-311.pyc
ADDED
Binary file (35.8 kB). View file
|
|
sgm/models/__pycache__/diffusion.cpython-311.pyc
ADDED
Binary file (37.1 kB). View file
|
|
sgm/models/autoencoder.py
ADDED
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import re
|
4 |
+
from abc import abstractmethod
|
5 |
+
from contextlib import contextmanager
|
6 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange
|
12 |
+
from packaging import version
|
13 |
+
|
14 |
+
from ..modules.autoencoding.regularizers import AbstractRegularizer
|
15 |
+
from ..modules.ema import LitEma
|
16 |
+
from ..util import (default, get_nested_attribute, get_obj_from_str,
|
17 |
+
instantiate_from_config)
|
18 |
+
|
19 |
+
logpy = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class AbstractAutoencoder(pl.LightningModule):
|
23 |
+
"""
|
24 |
+
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
25 |
+
unCLIP models, etc. Hence, it is fairly general, and specific features
|
26 |
+
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
ema_decay: Union[None, float] = None,
|
32 |
+
monitor: Union[None, str] = None,
|
33 |
+
input_key: str = "jpg",
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.input_key = input_key
|
38 |
+
self.use_ema = ema_decay is not None
|
39 |
+
if monitor is not None:
|
40 |
+
self.monitor = monitor
|
41 |
+
|
42 |
+
if self.use_ema:
|
43 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
44 |
+
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
45 |
+
|
46 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
47 |
+
self.automatic_optimization = False
|
48 |
+
|
49 |
+
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
50 |
+
if ckpt is None:
|
51 |
+
return
|
52 |
+
if isinstance(ckpt, str):
|
53 |
+
ckpt = {
|
54 |
+
"target": "sgm.modules.checkpoint.CheckpointEngine",
|
55 |
+
"params": {"ckpt_path": ckpt},
|
56 |
+
}
|
57 |
+
engine = instantiate_from_config(ckpt)
|
58 |
+
engine(self)
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def get_input(self, batch) -> Any:
|
62 |
+
raise NotImplementedError()
|
63 |
+
|
64 |
+
def on_train_batch_end(self, *args, **kwargs):
|
65 |
+
# for EMA computation
|
66 |
+
if self.use_ema:
|
67 |
+
self.model_ema(self)
|
68 |
+
|
69 |
+
@contextmanager
|
70 |
+
def ema_scope(self, context=None):
|
71 |
+
if self.use_ema:
|
72 |
+
self.model_ema.store(self.parameters())
|
73 |
+
self.model_ema.copy_to(self)
|
74 |
+
if context is not None:
|
75 |
+
logpy.info(f"{context}: Switched to EMA weights")
|
76 |
+
try:
|
77 |
+
yield None
|
78 |
+
finally:
|
79 |
+
if self.use_ema:
|
80 |
+
self.model_ema.restore(self.parameters())
|
81 |
+
if context is not None:
|
82 |
+
logpy.info(f"{context}: Restored training weights")
|
83 |
+
|
84 |
+
@abstractmethod
|
85 |
+
def encode(self, *args, **kwargs) -> torch.Tensor:
|
86 |
+
raise NotImplementedError("encode()-method of abstract base class called")
|
87 |
+
|
88 |
+
@abstractmethod
|
89 |
+
def decode(self, *args, **kwargs) -> torch.Tensor:
|
90 |
+
raise NotImplementedError("decode()-method of abstract base class called")
|
91 |
+
|
92 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
93 |
+
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
94 |
+
return get_obj_from_str(cfg["target"])(
|
95 |
+
params, lr=lr, **cfg.get("params", dict())
|
96 |
+
)
|
97 |
+
|
98 |
+
def configure_optimizers(self) -> Any:
|
99 |
+
raise NotImplementedError()
|
100 |
+
|
101 |
+
|
102 |
+
class AutoencodingEngine(AbstractAutoencoder):
|
103 |
+
"""
|
104 |
+
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
105 |
+
(we also restore them explicitly as special cases for legacy reasons).
|
106 |
+
Regularizations such as KL or VQ are moved to the regularizer class.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
*args,
|
112 |
+
encoder_config: Dict,
|
113 |
+
decoder_config: Dict,
|
114 |
+
loss_config: Dict,
|
115 |
+
regularizer_config: Dict,
|
116 |
+
optimizer_config: Union[Dict, None] = None,
|
117 |
+
lr_g_factor: float = 1.0,
|
118 |
+
trainable_ae_params: Optional[List[List[str]]] = None,
|
119 |
+
ae_optimizer_args: Optional[List[dict]] = None,
|
120 |
+
trainable_disc_params: Optional[List[List[str]]] = None,
|
121 |
+
disc_optimizer_args: Optional[List[dict]] = None,
|
122 |
+
disc_start_iter: int = 0,
|
123 |
+
diff_boost_factor: float = 3.0,
|
124 |
+
ckpt_engine: Union[None, str, dict] = None,
|
125 |
+
ckpt_path: Optional[str] = None,
|
126 |
+
additional_decode_keys: Optional[List[str]] = None,
|
127 |
+
**kwargs,
|
128 |
+
):
|
129 |
+
super().__init__(*args, **kwargs)
|
130 |
+
self.automatic_optimization = False # pytorch lightning
|
131 |
+
|
132 |
+
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
133 |
+
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
134 |
+
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
|
135 |
+
self.regularization: AbstractRegularizer = instantiate_from_config(
|
136 |
+
regularizer_config
|
137 |
+
)
|
138 |
+
self.optimizer_config = default(
|
139 |
+
optimizer_config, {"target": "torch.optim.Adam"}
|
140 |
+
)
|
141 |
+
self.diff_boost_factor = diff_boost_factor
|
142 |
+
self.disc_start_iter = disc_start_iter
|
143 |
+
self.lr_g_factor = lr_g_factor
|
144 |
+
self.trainable_ae_params = trainable_ae_params
|
145 |
+
if self.trainable_ae_params is not None:
|
146 |
+
self.ae_optimizer_args = default(
|
147 |
+
ae_optimizer_args,
|
148 |
+
[{} for _ in range(len(self.trainable_ae_params))],
|
149 |
+
)
|
150 |
+
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
|
151 |
+
else:
|
152 |
+
self.ae_optimizer_args = [{}] # makes type consitent
|
153 |
+
|
154 |
+
self.trainable_disc_params = trainable_disc_params
|
155 |
+
if self.trainable_disc_params is not None:
|
156 |
+
self.disc_optimizer_args = default(
|
157 |
+
disc_optimizer_args,
|
158 |
+
[{} for _ in range(len(self.trainable_disc_params))],
|
159 |
+
)
|
160 |
+
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
|
161 |
+
else:
|
162 |
+
self.disc_optimizer_args = [{}] # makes type consitent
|
163 |
+
|
164 |
+
if ckpt_path is not None:
|
165 |
+
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
|
166 |
+
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
|
167 |
+
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
168 |
+
self.additional_decode_keys = set(default(additional_decode_keys, []))
|
169 |
+
|
170 |
+
def get_input(self, batch: Dict) -> torch.Tensor:
|
171 |
+
# assuming unified data format, dataloader returns a dict.
|
172 |
+
# image tensors should be scaled to -1 ... 1 and in channels-first
|
173 |
+
# format (e.g., bchw instead if bhwc)
|
174 |
+
return batch[self.input_key]
|
175 |
+
|
176 |
+
def get_autoencoder_params(self) -> list:
|
177 |
+
params = []
|
178 |
+
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
|
179 |
+
params += list(self.loss.get_trainable_autoencoder_parameters())
|
180 |
+
if hasattr(self.regularization, "get_trainable_parameters"):
|
181 |
+
params += list(self.regularization.get_trainable_parameters())
|
182 |
+
params = params + list(self.encoder.parameters())
|
183 |
+
params = params + list(self.decoder.parameters())
|
184 |
+
return params
|
185 |
+
|
186 |
+
def get_discriminator_params(self) -> list:
|
187 |
+
if hasattr(self.loss, "get_trainable_parameters"):
|
188 |
+
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
189 |
+
else:
|
190 |
+
params = []
|
191 |
+
return params
|
192 |
+
|
193 |
+
def get_last_layer(self):
|
194 |
+
return self.decoder.get_last_layer()
|
195 |
+
|
196 |
+
def encode(
|
197 |
+
self,
|
198 |
+
x: torch.Tensor,
|
199 |
+
return_reg_log: bool = False,
|
200 |
+
unregularized: bool = False,
|
201 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
202 |
+
z = self.encoder(x)
|
203 |
+
if unregularized:
|
204 |
+
return z, dict()
|
205 |
+
z, reg_log = self.regularization(z)
|
206 |
+
if return_reg_log:
|
207 |
+
return z, reg_log
|
208 |
+
return z
|
209 |
+
|
210 |
+
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
211 |
+
x = self.decoder(z, **kwargs)
|
212 |
+
return x
|
213 |
+
|
214 |
+
def forward(
|
215 |
+
self, x: torch.Tensor, **additional_decode_kwargs
|
216 |
+
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
217 |
+
z, reg_log = self.encode(x, return_reg_log=True)
|
218 |
+
dec = self.decode(z, **additional_decode_kwargs)
|
219 |
+
return z, dec, reg_log
|
220 |
+
|
221 |
+
def inner_training_step(
|
222 |
+
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
223 |
+
) -> torch.Tensor:
|
224 |
+
x = self.get_input(batch)
|
225 |
+
additional_decode_kwargs = {
|
226 |
+
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
227 |
+
}
|
228 |
+
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
229 |
+
if hasattr(self.loss, "forward_keys"):
|
230 |
+
extra_info = {
|
231 |
+
"z": z,
|
232 |
+
"optimizer_idx": optimizer_idx,
|
233 |
+
"global_step": self.global_step,
|
234 |
+
"last_layer": self.get_last_layer(),
|
235 |
+
"split": "train",
|
236 |
+
"regularization_log": regularization_log,
|
237 |
+
"autoencoder": self,
|
238 |
+
}
|
239 |
+
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
240 |
+
else:
|
241 |
+
extra_info = dict()
|
242 |
+
|
243 |
+
if optimizer_idx == 0:
|
244 |
+
# autoencode
|
245 |
+
out_loss = self.loss(x, xrec, **extra_info)
|
246 |
+
if isinstance(out_loss, tuple):
|
247 |
+
aeloss, log_dict_ae = out_loss
|
248 |
+
else:
|
249 |
+
# simple loss function
|
250 |
+
aeloss = out_loss
|
251 |
+
log_dict_ae = {"train/loss/rec": aeloss.detach()}
|
252 |
+
|
253 |
+
self.log_dict(
|
254 |
+
log_dict_ae,
|
255 |
+
prog_bar=False,
|
256 |
+
logger=True,
|
257 |
+
on_step=True,
|
258 |
+
on_epoch=True,
|
259 |
+
sync_dist=False,
|
260 |
+
)
|
261 |
+
self.log(
|
262 |
+
"loss",
|
263 |
+
aeloss.mean().detach(),
|
264 |
+
prog_bar=True,
|
265 |
+
logger=False,
|
266 |
+
on_epoch=False,
|
267 |
+
on_step=True,
|
268 |
+
)
|
269 |
+
return aeloss
|
270 |
+
elif optimizer_idx == 1:
|
271 |
+
# discriminator
|
272 |
+
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
273 |
+
# -> discriminator always needs to return a tuple
|
274 |
+
self.log_dict(
|
275 |
+
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
276 |
+
)
|
277 |
+
return discloss
|
278 |
+
else:
|
279 |
+
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
280 |
+
|
281 |
+
def training_step(self, batch: dict, batch_idx: int):
|
282 |
+
opts = self.optimizers()
|
283 |
+
if not isinstance(opts, list):
|
284 |
+
# Non-adversarial case
|
285 |
+
opts = [opts]
|
286 |
+
optimizer_idx = batch_idx % len(opts)
|
287 |
+
if self.global_step < self.disc_start_iter:
|
288 |
+
optimizer_idx = 0
|
289 |
+
opt = opts[optimizer_idx]
|
290 |
+
opt.zero_grad()
|
291 |
+
with opt.toggle_model():
|
292 |
+
loss = self.inner_training_step(
|
293 |
+
batch, batch_idx, optimizer_idx=optimizer_idx
|
294 |
+
)
|
295 |
+
self.manual_backward(loss)
|
296 |
+
opt.step()
|
297 |
+
|
298 |
+
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
|
299 |
+
log_dict = self._validation_step(batch, batch_idx)
|
300 |
+
with self.ema_scope():
|
301 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
302 |
+
log_dict.update(log_dict_ema)
|
303 |
+
return log_dict
|
304 |
+
|
305 |
+
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
|
306 |
+
x = self.get_input(batch)
|
307 |
+
|
308 |
+
z, xrec, regularization_log = self(x)
|
309 |
+
if hasattr(self.loss, "forward_keys"):
|
310 |
+
extra_info = {
|
311 |
+
"z": z,
|
312 |
+
"optimizer_idx": 0,
|
313 |
+
"global_step": self.global_step,
|
314 |
+
"last_layer": self.get_last_layer(),
|
315 |
+
"split": "val" + postfix,
|
316 |
+
"regularization_log": regularization_log,
|
317 |
+
"autoencoder": self,
|
318 |
+
}
|
319 |
+
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
320 |
+
else:
|
321 |
+
extra_info = dict()
|
322 |
+
out_loss = self.loss(x, xrec, **extra_info)
|
323 |
+
if isinstance(out_loss, tuple):
|
324 |
+
aeloss, log_dict_ae = out_loss
|
325 |
+
else:
|
326 |
+
# simple loss function
|
327 |
+
aeloss = out_loss
|
328 |
+
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
|
329 |
+
full_log_dict = log_dict_ae
|
330 |
+
|
331 |
+
if "optimizer_idx" in extra_info:
|
332 |
+
extra_info["optimizer_idx"] = 1
|
333 |
+
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
334 |
+
full_log_dict.update(log_dict_disc)
|
335 |
+
self.log(
|
336 |
+
f"val{postfix}/loss/rec",
|
337 |
+
log_dict_ae[f"val{postfix}/loss/rec"],
|
338 |
+
sync_dist=True,
|
339 |
+
)
|
340 |
+
self.log_dict(full_log_dict, sync_dist=True)
|
341 |
+
return full_log_dict
|
342 |
+
|
343 |
+
def get_param_groups(
|
344 |
+
self, parameter_names: List[List[str]], optimizer_args: List[dict]
|
345 |
+
) -> Tuple[List[Dict[str, Any]], int]:
|
346 |
+
groups = []
|
347 |
+
num_params = 0
|
348 |
+
for names, args in zip(parameter_names, optimizer_args):
|
349 |
+
params = []
|
350 |
+
for pattern_ in names:
|
351 |
+
pattern_params = []
|
352 |
+
pattern = re.compile(pattern_)
|
353 |
+
for p_name, param in self.named_parameters():
|
354 |
+
if re.match(pattern, p_name):
|
355 |
+
pattern_params.append(param)
|
356 |
+
num_params += param.numel()
|
357 |
+
if len(pattern_params) == 0:
|
358 |
+
logpy.warn(f"Did not find parameters for pattern {pattern_}")
|
359 |
+
params.extend(pattern_params)
|
360 |
+
groups.append({"params": params, **args})
|
361 |
+
return groups, num_params
|
362 |
+
|
363 |
+
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
|
364 |
+
if self.trainable_ae_params is None:
|
365 |
+
ae_params = self.get_autoencoder_params()
|
366 |
+
else:
|
367 |
+
ae_params, num_ae_params = self.get_param_groups(
|
368 |
+
self.trainable_ae_params, self.ae_optimizer_args
|
369 |
+
)
|
370 |
+
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
371 |
+
if self.trainable_disc_params is None:
|
372 |
+
disc_params = self.get_discriminator_params()
|
373 |
+
else:
|
374 |
+
disc_params, num_disc_params = self.get_param_groups(
|
375 |
+
self.trainable_disc_params, self.disc_optimizer_args
|
376 |
+
)
|
377 |
+
logpy.info(
|
378 |
+
f"Number of trainable discriminator parameters: {num_disc_params:,}"
|
379 |
+
)
|
380 |
+
opt_ae = self.instantiate_optimizer_from_config(
|
381 |
+
ae_params,
|
382 |
+
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
383 |
+
self.optimizer_config,
|
384 |
+
)
|
385 |
+
opts = [opt_ae]
|
386 |
+
if len(disc_params) > 0:
|
387 |
+
opt_disc = self.instantiate_optimizer_from_config(
|
388 |
+
disc_params, self.learning_rate, self.optimizer_config
|
389 |
+
)
|
390 |
+
opts.append(opt_disc)
|
391 |
+
|
392 |
+
return opts
|
393 |
+
|
394 |
+
@torch.no_grad()
|
395 |
+
def log_images(
|
396 |
+
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
397 |
+
) -> dict:
|
398 |
+
log = dict()
|
399 |
+
additional_decode_kwargs = {}
|
400 |
+
x = self.get_input(batch)
|
401 |
+
additional_decode_kwargs.update(
|
402 |
+
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
403 |
+
)
|
404 |
+
|
405 |
+
_, xrec, _ = self(x, **additional_decode_kwargs)
|
406 |
+
log["inputs"] = x
|
407 |
+
log["reconstructions"] = xrec
|
408 |
+
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
|
409 |
+
diff.clamp_(0, 1.0)
|
410 |
+
log["diff"] = 2.0 * diff - 1.0
|
411 |
+
# diff_boost shows location of small errors, by boosting their
|
412 |
+
# brightness.
|
413 |
+
log["diff_boost"] = (
|
414 |
+
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
415 |
+
)
|
416 |
+
if hasattr(self.loss, "log_images"):
|
417 |
+
log.update(self.loss.log_images(x, xrec))
|
418 |
+
with self.ema_scope():
|
419 |
+
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
|
420 |
+
log["reconstructions_ema"] = xrec_ema
|
421 |
+
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
422 |
+
diff_ema.clamp_(0, 1.0)
|
423 |
+
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
424 |
+
log["diff_boost_ema"] = (
|
425 |
+
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
426 |
+
)
|
427 |
+
if additional_log_kwargs:
|
428 |
+
additional_decode_kwargs.update(additional_log_kwargs)
|
429 |
+
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
430 |
+
log_str = "reconstructions-" + "-".join(
|
431 |
+
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
|
432 |
+
)
|
433 |
+
log[log_str] = xrec_add
|
434 |
+
return log
|
435 |
+
|
436 |
+
|
437 |
+
class AutoencodingEngineLegacy(AutoencodingEngine):
|
438 |
+
def __init__(self, embed_dim: int, **kwargs):
|
439 |
+
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
440 |
+
ddconfig = kwargs.pop("ddconfig")
|
441 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
442 |
+
ckpt_engine = kwargs.pop("ckpt_engine", None)
|
443 |
+
super().__init__(
|
444 |
+
encoder_config={
|
445 |
+
"target": "sgm.modules.diffusionmodules.model.Encoder",
|
446 |
+
"params": ddconfig,
|
447 |
+
},
|
448 |
+
decoder_config={
|
449 |
+
"target": "sgm.modules.diffusionmodules.model.Decoder",
|
450 |
+
"params": ddconfig,
|
451 |
+
},
|
452 |
+
**kwargs,
|
453 |
+
)
|
454 |
+
self.quant_conv = torch.nn.Conv2d(
|
455 |
+
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
456 |
+
(1 + ddconfig["double_z"]) * embed_dim,
|
457 |
+
1,
|
458 |
+
)
|
459 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
460 |
+
self.embed_dim = embed_dim
|
461 |
+
|
462 |
+
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
463 |
+
|
464 |
+
def get_autoencoder_params(self) -> list:
|
465 |
+
params = super().get_autoencoder_params()
|
466 |
+
return params
|
467 |
+
|
468 |
+
def encode(
|
469 |
+
self, x: torch.Tensor, return_reg_log: bool = False
|
470 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
471 |
+
if self.max_batch_size is None:
|
472 |
+
z = self.encoder(x)
|
473 |
+
z = self.quant_conv(z)
|
474 |
+
else:
|
475 |
+
N = x.shape[0]
|
476 |
+
bs = self.max_batch_size
|
477 |
+
n_batches = int(math.ceil(N / bs))
|
478 |
+
z = list()
|
479 |
+
for i_batch in range(n_batches):
|
480 |
+
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
481 |
+
z_batch = self.quant_conv(z_batch)
|
482 |
+
z.append(z_batch)
|
483 |
+
z = torch.cat(z, 0)
|
484 |
+
|
485 |
+
z, reg_log = self.regularization(z)
|
486 |
+
if return_reg_log:
|
487 |
+
return z, reg_log
|
488 |
+
return z
|
489 |
+
|
490 |
+
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
491 |
+
if self.max_batch_size is None:
|
492 |
+
dec = self.post_quant_conv(z)
|
493 |
+
dec = self.decoder(dec, **decoder_kwargs)
|
494 |
+
else:
|
495 |
+
N = z.shape[0]
|
496 |
+
bs = self.max_batch_size
|
497 |
+
n_batches = int(math.ceil(N / bs))
|
498 |
+
dec = list()
|
499 |
+
for i_batch in range(n_batches):
|
500 |
+
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
501 |
+
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
502 |
+
dec.append(dec_batch)
|
503 |
+
dec = torch.cat(dec, 0)
|
504 |
+
|
505 |
+
return dec
|
506 |
+
|
507 |
+
|
508 |
+
class AutoencoderKL(AutoencodingEngineLegacy):
|
509 |
+
def __init__(self, **kwargs):
|
510 |
+
if "lossconfig" in kwargs:
|
511 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
512 |
+
super().__init__(
|
513 |
+
regularizer_config={
|
514 |
+
"target": (
|
515 |
+
"sgm.modules.autoencoding.regularizers"
|
516 |
+
".DiagonalGaussianRegularizer"
|
517 |
+
)
|
518 |
+
},
|
519 |
+
**kwargs,
|
520 |
+
)
|
521 |
+
|
522 |
+
|
523 |
+
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
|
524 |
+
def __init__(
|
525 |
+
self,
|
526 |
+
embed_dim: int,
|
527 |
+
n_embed: int,
|
528 |
+
sane_index_shape: bool = False,
|
529 |
+
**kwargs,
|
530 |
+
):
|
531 |
+
if "lossconfig" in kwargs:
|
532 |
+
logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
|
533 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
534 |
+
super().__init__(
|
535 |
+
regularizer_config={
|
536 |
+
"target": (
|
537 |
+
"sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
|
538 |
+
),
|
539 |
+
"params": {
|
540 |
+
"n_e": n_embed,
|
541 |
+
"e_dim": embed_dim,
|
542 |
+
"sane_index_shape": sane_index_shape,
|
543 |
+
},
|
544 |
+
},
|
545 |
+
**kwargs,
|
546 |
+
)
|
547 |
+
|
548 |
+
|
549 |
+
class IdentityFirstStage(AbstractAutoencoder):
|
550 |
+
def __init__(self, *args, **kwargs):
|
551 |
+
super().__init__(*args, **kwargs)
|
552 |
+
|
553 |
+
def get_input(self, x: Any) -> Any:
|
554 |
+
return x
|
555 |
+
|
556 |
+
def encode(self, x: Any, *args, **kwargs) -> Any:
|
557 |
+
return x
|
558 |
+
|
559 |
+
def decode(self, x: Any, *args, **kwargs) -> Any:
|
560 |
+
return x
|
561 |
+
|
562 |
+
|
563 |
+
class AEIntegerWrapper(nn.Module):
|
564 |
+
def __init__(
|
565 |
+
self,
|
566 |
+
model: nn.Module,
|
567 |
+
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
|
568 |
+
regularization_key: str = "regularization",
|
569 |
+
encoder_kwargs: Optional[Dict[str, Any]] = None,
|
570 |
+
):
|
571 |
+
super().__init__()
|
572 |
+
self.model = model
|
573 |
+
assert hasattr(model, "encode") and hasattr(
|
574 |
+
model, "decode"
|
575 |
+
), "Need AE interface"
|
576 |
+
self.regularization = get_nested_attribute(model, regularization_key)
|
577 |
+
self.shape = shape
|
578 |
+
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
|
579 |
+
|
580 |
+
def encode(self, x) -> torch.Tensor:
|
581 |
+
assert (
|
582 |
+
not self.training
|
583 |
+
), f"{self.__class__.__name__} only supports inference currently"
|
584 |
+
_, log = self.model.encode(x, **self.encoder_kwargs)
|
585 |
+
assert isinstance(log, dict)
|
586 |
+
inds = log["min_encoding_indices"]
|
587 |
+
return rearrange(inds, "b ... -> b (...)")
|
588 |
+
|
589 |
+
def decode(
|
590 |
+
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
|
591 |
+
) -> torch.Tensor:
|
592 |
+
# expect inds shape (b, s) with s = h*w
|
593 |
+
shape = default(shape, self.shape) # Optional[(h, w)]
|
594 |
+
if shape is not None:
|
595 |
+
assert len(shape) == 2, f"Unhandeled shape {shape}"
|
596 |
+
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
|
597 |
+
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
|
598 |
+
h = rearrange(h, "b h w c -> b c h w")
|
599 |
+
return self.model.decode(h)
|
600 |
+
|
601 |
+
|
602 |
+
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
|
603 |
+
def __init__(self, **kwargs):
|
604 |
+
if "lossconfig" in kwargs:
|
605 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
606 |
+
super().__init__(
|
607 |
+
regularizer_config={
|
608 |
+
"target": (
|
609 |
+
"sgm.modules.autoencoding.regularizers"
|
610 |
+
".DiagonalGaussianRegularizer"
|
611 |
+
),
|
612 |
+
"params": {"sample": False},
|
613 |
+
},
|
614 |
+
**kwargs,
|
615 |
+
)
|
sgm/models/diffusion.py
ADDED
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
from contextlib import contextmanager
|
4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
5 |
+
import re
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
from omegaconf import ListConfig, OmegaConf
|
9 |
+
from safetensors.torch import load_file as load_safetensors
|
10 |
+
from torch.optim.lr_scheduler import LambdaLR
|
11 |
+
from einops import rearrange
|
12 |
+
from diffusers.models.attention_processor import IPAdapterAttnProcessor2_0
|
13 |
+
|
14 |
+
from ..modules import UNCONDITIONAL_CONFIG
|
15 |
+
from ..modules.autoencoding.temporal_ae import VideoDecoder
|
16 |
+
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
17 |
+
from ..modules.ema import LitEma
|
18 |
+
from ..util import (
|
19 |
+
default,
|
20 |
+
disabled_train,
|
21 |
+
get_obj_from_str,
|
22 |
+
instantiate_from_config,
|
23 |
+
log_txt_as_img,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class DiffusionEngine(pl.LightningModule):
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
network_config,
|
31 |
+
denoiser_config,
|
32 |
+
first_stage_config,
|
33 |
+
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
34 |
+
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
35 |
+
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
36 |
+
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
37 |
+
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
38 |
+
network_wrapper: Union[None, str, Dict, ListConfig, OmegaConf] = None,
|
39 |
+
ckpt_path: Union[None, str] = None,
|
40 |
+
remove_keys_from_weights: Union[None, List, Tuple] = None,
|
41 |
+
pattern_to_remove: Union[None, str] = None,
|
42 |
+
remove_keys_from_unet_weights: Union[None, List, Tuple] = None,
|
43 |
+
use_ema: bool = False,
|
44 |
+
ema_decay_rate: float = 0.9999,
|
45 |
+
scale_factor: float = 1.0,
|
46 |
+
disable_first_stage_autocast=False,
|
47 |
+
input_key: str = "jpg",
|
48 |
+
log_keys: Union[List, None] = None,
|
49 |
+
no_log_keys: Union[List, None] = None,
|
50 |
+
no_cond_log: bool = False,
|
51 |
+
compile_model: bool = False,
|
52 |
+
en_and_decode_n_samples_a_time: Optional[int] = None,
|
53 |
+
only_train_ipadapter: Optional[bool] = False,
|
54 |
+
to_unfreeze: Optional[List[str]] = [],
|
55 |
+
to_freeze: Optional[List[str]] = [],
|
56 |
+
separate_unet_ckpt: Optional[str] = None,
|
57 |
+
use_thunder: Optional[bool] = False,
|
58 |
+
is_dubbing: Optional[bool] = False,
|
59 |
+
bad_model_path: Optional[str] = None,
|
60 |
+
bad_model_config: Optional[Dict] = None,
|
61 |
+
):
|
62 |
+
super().__init__()
|
63 |
+
|
64 |
+
# self.automatic_optimization = False
|
65 |
+
self.log_keys = log_keys
|
66 |
+
self.no_log_keys = no_log_keys
|
67 |
+
self.input_key = input_key
|
68 |
+
self.is_dubbing = is_dubbing
|
69 |
+
self.optimizer_config = default(
|
70 |
+
optimizer_config, {"target": "torch.optim.AdamW"}
|
71 |
+
)
|
72 |
+
self.model = self.initialize_network(
|
73 |
+
network_config, network_wrapper, compile_model=compile_model
|
74 |
+
)
|
75 |
+
|
76 |
+
self.denoiser = instantiate_from_config(denoiser_config)
|
77 |
+
|
78 |
+
self.sampler = (
|
79 |
+
instantiate_from_config(sampler_config)
|
80 |
+
if sampler_config is not None
|
81 |
+
else None
|
82 |
+
)
|
83 |
+
self.is_guided = True
|
84 |
+
if (
|
85 |
+
self.sampler
|
86 |
+
and "IdentityGuider" in sampler_config["params"]["guider_config"]["target"]
|
87 |
+
):
|
88 |
+
self.is_guided = False
|
89 |
+
if self.sampler is not None:
|
90 |
+
config_guider = sampler_config["params"]["guider_config"]
|
91 |
+
sampler_config["params"]["guider_config"] = None
|
92 |
+
self.sampler_no_guidance = instantiate_from_config(sampler_config)
|
93 |
+
sampler_config["params"]["guider_config"] = config_guider
|
94 |
+
self.conditioner = instantiate_from_config(
|
95 |
+
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
96 |
+
)
|
97 |
+
self.scheduler_config = scheduler_config
|
98 |
+
self._init_first_stage(first_stage_config)
|
99 |
+
|
100 |
+
self.loss_fn = (
|
101 |
+
instantiate_from_config(loss_fn_config)
|
102 |
+
if loss_fn_config is not None
|
103 |
+
else None
|
104 |
+
)
|
105 |
+
|
106 |
+
self.use_ema = use_ema
|
107 |
+
if self.use_ema:
|
108 |
+
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
109 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
110 |
+
|
111 |
+
self.scale_factor = scale_factor
|
112 |
+
self.disable_first_stage_autocast = disable_first_stage_autocast
|
113 |
+
self.no_cond_log = no_cond_log
|
114 |
+
|
115 |
+
if ckpt_path is not None:
|
116 |
+
self.init_from_ckpt(
|
117 |
+
ckpt_path,
|
118 |
+
remove_keys_from_weights=remove_keys_from_weights,
|
119 |
+
pattern_to_remove=pattern_to_remove,
|
120 |
+
)
|
121 |
+
if separate_unet_ckpt is not None:
|
122 |
+
sd = torch.load(separate_unet_ckpt)["state_dict"]
|
123 |
+
if remove_keys_from_unet_weights is not None:
|
124 |
+
for k in list(sd.keys()):
|
125 |
+
for remove_key in remove_keys_from_unet_weights:
|
126 |
+
if remove_key in k:
|
127 |
+
del sd[k]
|
128 |
+
self.model.diffusion_model.load_state_dict(sd, strict=False)
|
129 |
+
|
130 |
+
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
131 |
+
print(
|
132 |
+
"Using",
|
133 |
+
self.en_and_decode_n_samples_a_time,
|
134 |
+
"samples at a time for encoding and decoding",
|
135 |
+
)
|
136 |
+
|
137 |
+
if to_freeze:
|
138 |
+
for name, p in self.model.diffusion_model.named_parameters():
|
139 |
+
for layer in to_freeze:
|
140 |
+
if layer[0] == "!":
|
141 |
+
if layer[1:] not in name:
|
142 |
+
# print("Freezing", name)
|
143 |
+
p.requires_grad = False
|
144 |
+
else:
|
145 |
+
if layer in name:
|
146 |
+
# print("Freezing", name)
|
147 |
+
p.requires_grad = False
|
148 |
+
# if "time_" in name:
|
149 |
+
# print("Freezing", name)
|
150 |
+
# p.requires_grad = False
|
151 |
+
|
152 |
+
if only_train_ipadapter:
|
153 |
+
# Freeze the model
|
154 |
+
for p in self.model.parameters():
|
155 |
+
p.requires_grad = False
|
156 |
+
# Unfreeze the adapter projection layer
|
157 |
+
for p in self.model.diffusion_model.encoder_hid_proj.parameters():
|
158 |
+
p.requires_grad = True
|
159 |
+
# Unfreeze the cross-attention layer
|
160 |
+
for att_layer in self.model.diffusion_model.attn_processors.values():
|
161 |
+
if isinstance(att_layer, IPAdapterAttnProcessor2_0):
|
162 |
+
for p in att_layer.parameters():
|
163 |
+
p.requires_grad = True
|
164 |
+
|
165 |
+
# for name, p in self.named_parameters():
|
166 |
+
# if p.requires_grad:
|
167 |
+
# print(name)
|
168 |
+
|
169 |
+
if to_unfreeze:
|
170 |
+
for name in to_unfreeze:
|
171 |
+
for p in getattr(self.model.diffusion_model, name).parameters():
|
172 |
+
p.requires_grad = True
|
173 |
+
|
174 |
+
if use_thunder:
|
175 |
+
import thunder
|
176 |
+
|
177 |
+
self.model.diffusion_model = thunder.jit(self.model.diffusion_model)
|
178 |
+
|
179 |
+
if "Karras" in denoiser_config.target:
|
180 |
+
assert bad_model_path is not None, (
|
181 |
+
"bad_model_path must be provided for KarrasGuidanceDenoiser"
|
182 |
+
)
|
183 |
+
karras_config = default(bad_model_config, network_config)
|
184 |
+
bad_model = self.initialize_network(
|
185 |
+
karras_config, network_wrapper, compile_model=compile_model
|
186 |
+
)
|
187 |
+
state_dict = self.load_bad_model_weights(bad_model_path)
|
188 |
+
bad_model.load_state_dict(state_dict)
|
189 |
+
self.denoiser.set_bad_network(bad_model)
|
190 |
+
|
191 |
+
def load_bad_model_weights(self, path: str) -> None:
|
192 |
+
print(f"Restoring bad model from {path}")
|
193 |
+
state_dict = torch.load(path, map_location="cpu")
|
194 |
+
new_dict = {}
|
195 |
+
for k, v in state_dict["module"].items():
|
196 |
+
if "learned_mask" in k:
|
197 |
+
new_dict[k.replace("_forward_module.", "").replace("model.", "")] = v
|
198 |
+
if "diffusion_model" in k:
|
199 |
+
new_dict["diffusion_model" + k.split("diffusion_model")[1]] = v
|
200 |
+
return new_dict
|
201 |
+
|
202 |
+
def initialize_network(self, network_config, network_wrapper, compile_model=False):
|
203 |
+
model = instantiate_from_config(network_config)
|
204 |
+
if isinstance(network_wrapper, str) or network_wrapper is None:
|
205 |
+
model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
206 |
+
model, compile_model=compile_model
|
207 |
+
)
|
208 |
+
else:
|
209 |
+
target = network_wrapper["target"]
|
210 |
+
params = network_wrapper.get("params", dict())
|
211 |
+
model = get_obj_from_str(target)(
|
212 |
+
model, compile_model=compile_model, **params
|
213 |
+
)
|
214 |
+
return model
|
215 |
+
|
216 |
+
def init_from_ckpt(
|
217 |
+
self,
|
218 |
+
path: str,
|
219 |
+
remove_keys_from_weights: Optional[Union[List, Tuple]] = None,
|
220 |
+
pattern_to_remove: str = None,
|
221 |
+
) -> None:
|
222 |
+
print(f"Restoring from {path}")
|
223 |
+
if path.endswith("ckpt"):
|
224 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
225 |
+
elif path.endswith("pt"):
|
226 |
+
sd = torch.load(path, map_location="cpu")["module"]
|
227 |
+
# Remove leading _forward_module from keys
|
228 |
+
sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
|
229 |
+
elif path.endswith("bin"):
|
230 |
+
sd = torch.load(path, map_location="cpu")
|
231 |
+
# Remove leading _forward_module from keys
|
232 |
+
sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
|
233 |
+
elif path.endswith("safetensors"):
|
234 |
+
sd = load_safetensors(path)
|
235 |
+
else:
|
236 |
+
raise NotImplementedError
|
237 |
+
|
238 |
+
print(f"Loaded state dict from {path} with {len(sd)} keys")
|
239 |
+
|
240 |
+
# if remove_keys_from_weights is not None:
|
241 |
+
# for k in list(sd.keys()):
|
242 |
+
# for remove_key in remove_keys_from_weights:
|
243 |
+
# if remove_key in k:
|
244 |
+
# del sd[k]
|
245 |
+
if pattern_to_remove is not None or remove_keys_from_weights is not None:
|
246 |
+
sd = self.remove_mismatched_keys(
|
247 |
+
sd, pattern_to_remove, remove_keys_from_weights
|
248 |
+
)
|
249 |
+
|
250 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
251 |
+
print(
|
252 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
253 |
+
)
|
254 |
+
if len(missing) > 0:
|
255 |
+
print(f"Missing Keys: {missing}")
|
256 |
+
if len(unexpected) > 0:
|
257 |
+
print(f"Unexpected Keys: {unexpected}")
|
258 |
+
|
259 |
+
def remove_mismatched_keys(self, state_dict, pattern=None, additional_keys=None):
|
260 |
+
"""Remove keys from the state dictionary based on a pattern and a list of additional specific keys."""
|
261 |
+
# Find keys that match the pattern
|
262 |
+
if pattern is not None:
|
263 |
+
mismatched_keys = [key for key in state_dict if re.search(pattern, key)]
|
264 |
+
else:
|
265 |
+
mismatched_keys = []
|
266 |
+
|
267 |
+
print(f"Removing {len(mismatched_keys)} keys based on pattern {pattern}")
|
268 |
+
print(mismatched_keys)
|
269 |
+
|
270 |
+
# Add specific keys to be removed
|
271 |
+
if additional_keys:
|
272 |
+
mismatched_keys.extend(
|
273 |
+
[key for key in additional_keys if key in state_dict]
|
274 |
+
)
|
275 |
+
|
276 |
+
# Remove all identified keys
|
277 |
+
for key in mismatched_keys:
|
278 |
+
if key in state_dict:
|
279 |
+
del state_dict[key]
|
280 |
+
return state_dict
|
281 |
+
|
282 |
+
def _init_first_stage(self, config):
|
283 |
+
model = instantiate_from_config(config).eval()
|
284 |
+
model.train = disabled_train
|
285 |
+
for param in model.parameters():
|
286 |
+
param.requires_grad = False
|
287 |
+
self.first_stage_model = model
|
288 |
+
if self.input_key == "latents":
|
289 |
+
# Remove encoder to save memory
|
290 |
+
self.first_stage_model.encoder = None
|
291 |
+
torch.cuda.empty_cache()
|
292 |
+
|
293 |
+
def get_input(self, batch):
|
294 |
+
# assuming unified data format, dataloader returns a dict.
|
295 |
+
# image tensors should be scaled to -1 ... 1 and in bchw format
|
296 |
+
return batch[self.input_key]
|
297 |
+
|
298 |
+
@torch.no_grad()
|
299 |
+
def decode_first_stage(self, z):
|
300 |
+
is_video = False
|
301 |
+
if len(z.shape) == 5:
|
302 |
+
is_video = True
|
303 |
+
T = z.shape[2]
|
304 |
+
z = rearrange(z, "b c t h w -> (b t) c h w")
|
305 |
+
|
306 |
+
z = 1.0 / self.scale_factor * z
|
307 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
308 |
+
|
309 |
+
n_rounds = math.ceil(z.shape[0] / n_samples)
|
310 |
+
all_out = []
|
311 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
312 |
+
for n in range(n_rounds):
|
313 |
+
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
314 |
+
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
315 |
+
else:
|
316 |
+
kwargs = {}
|
317 |
+
out = self.first_stage_model.decode(
|
318 |
+
z[n * n_samples : (n + 1) * n_samples], **kwargs
|
319 |
+
)
|
320 |
+
all_out.append(out)
|
321 |
+
out = torch.cat(all_out, dim=0)
|
322 |
+
if is_video:
|
323 |
+
out = rearrange(out, "(b t) c h w -> b c t h w", t=T)
|
324 |
+
torch.cuda.empty_cache()
|
325 |
+
return out
|
326 |
+
|
327 |
+
@torch.no_grad()
|
328 |
+
def encode_first_stage(self, x):
|
329 |
+
is_video = False
|
330 |
+
if len(x.shape) == 5:
|
331 |
+
is_video = True
|
332 |
+
T = x.shape[2]
|
333 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
334 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
335 |
+
n_rounds = math.ceil(x.shape[0] / n_samples)
|
336 |
+
all_out = []
|
337 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
338 |
+
for n in range(n_rounds):
|
339 |
+
out = self.first_stage_model.encode(
|
340 |
+
x[n * n_samples : (n + 1) * n_samples]
|
341 |
+
)
|
342 |
+
all_out.append(out)
|
343 |
+
z = torch.cat(all_out, dim=0)
|
344 |
+
z = self.scale_factor * z
|
345 |
+
if is_video:
|
346 |
+
z = rearrange(z, "(b t) c h w -> b c t h w", t=T)
|
347 |
+
return z
|
348 |
+
|
349 |
+
def forward(self, x, batch):
|
350 |
+
loss_dict = self.loss_fn(
|
351 |
+
self.model,
|
352 |
+
self.denoiser,
|
353 |
+
self.conditioner,
|
354 |
+
x,
|
355 |
+
batch,
|
356 |
+
self.first_stage_model,
|
357 |
+
)
|
358 |
+
# loss_mean = loss.mean()
|
359 |
+
for k in loss_dict:
|
360 |
+
loss_dict[k] = loss_dict[k].mean()
|
361 |
+
# loss_dict = {"loss": loss_mean}
|
362 |
+
return loss_dict["loss"], loss_dict
|
363 |
+
|
364 |
+
def shared_step(self, batch: Dict) -> Any:
|
365 |
+
x = self.get_input(batch)
|
366 |
+
if self.input_key != "latents":
|
367 |
+
x = self.encode_first_stage(x)
|
368 |
+
batch["global_step"] = self.global_step
|
369 |
+
loss, loss_dict = self(x, batch)
|
370 |
+
return loss, loss_dict
|
371 |
+
|
372 |
+
def training_step(self, batch, batch_idx):
|
373 |
+
loss, loss_dict = self.shared_step(batch)
|
374 |
+
# debugging_message = "Training step"
|
375 |
+
# print(f"RANK - {self.trainer.global_rank}: {debugging_message}")
|
376 |
+
|
377 |
+
self.log_dict(
|
378 |
+
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
379 |
+
)
|
380 |
+
|
381 |
+
self.log(
|
382 |
+
"global_step",
|
383 |
+
self.global_step,
|
384 |
+
prog_bar=True,
|
385 |
+
logger=True,
|
386 |
+
on_step=True,
|
387 |
+
on_epoch=False,
|
388 |
+
)
|
389 |
+
|
390 |
+
# debugging_message = "Training step - log"
|
391 |
+
# print(f"RANK - {self.trainer.global_rank}: {debugging_message}")
|
392 |
+
|
393 |
+
if self.scheduler_config is not None:
|
394 |
+
lr = self.optimizers().param_groups[0]["lr"]
|
395 |
+
self.log(
|
396 |
+
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
397 |
+
)
|
398 |
+
|
399 |
+
# # to prevent other processes from moving forward until all processes are in sync
|
400 |
+
# self.trainer.strategy.barrier()
|
401 |
+
|
402 |
+
return loss
|
403 |
+
|
404 |
+
# def validation_step(self, batch, batch_idx):
|
405 |
+
# # loss, loss_dict = self.shared_step(batch)
|
406 |
+
# # self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
407 |
+
# self.log(
|
408 |
+
# "global_step",
|
409 |
+
# self.global_step,
|
410 |
+
# prog_bar=True,
|
411 |
+
# logger=True,
|
412 |
+
# on_step=True,
|
413 |
+
# on_epoch=False,
|
414 |
+
# )
|
415 |
+
# return 0
|
416 |
+
|
417 |
+
# def on_train_epoch_start(self, *args, **kwargs):
|
418 |
+
# print(f"RANK - {self.trainer.global_rank}: on_train_epoch_start")
|
419 |
+
|
420 |
+
def on_train_start(self, *args, **kwargs):
|
421 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = str(self.trainer.global_rank)
|
422 |
+
# torch.cuda.set_device(self.trainer.global_rank)
|
423 |
+
# torch.cuda.empty_cache()
|
424 |
+
if self.sampler is None or self.loss_fn is None:
|
425 |
+
raise ValueError("Sampler and loss function need to be set for training.")
|
426 |
+
|
427 |
+
# def on_before_batch_transfer(self, batch, dataloader_idx):
|
428 |
+
# print(f"RANK - {self.trainer.global_rank}: on_before_batch_transfer - {dataloader_idx}")
|
429 |
+
# return batch
|
430 |
+
|
431 |
+
# def on_after_batch_transfer(self, batch, dataloader_idx):
|
432 |
+
# print(f"RANK - {self.trainer.global_rank}: on_after_batch_transfer - {dataloader_idx}")
|
433 |
+
# return batch
|
434 |
+
|
435 |
+
def on_train_batch_end(self, *args, **kwargs):
|
436 |
+
# print(f"RANK - {self.trainer.global_rank}: on_train_batch_end")
|
437 |
+
if self.use_ema:
|
438 |
+
self.model_ema(self.model)
|
439 |
+
|
440 |
+
@contextmanager
|
441 |
+
def ema_scope(self, context=None):
|
442 |
+
if self.use_ema:
|
443 |
+
self.model_ema.store(self.model.parameters())
|
444 |
+
self.model_ema.copy_to(self.model)
|
445 |
+
if context is not None:
|
446 |
+
print(f"{context}: Switched to EMA weights")
|
447 |
+
try:
|
448 |
+
yield None
|
449 |
+
finally:
|
450 |
+
if self.use_ema:
|
451 |
+
self.model_ema.restore(self.model.parameters())
|
452 |
+
if context is not None:
|
453 |
+
print(f"{context}: Restored training weights")
|
454 |
+
|
455 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
456 |
+
return get_obj_from_str(cfg["target"])(
|
457 |
+
params, lr=lr, **cfg.get("params", dict())
|
458 |
+
)
|
459 |
+
|
460 |
+
def configure_optimizers(self):
|
461 |
+
lr = self.learning_rate
|
462 |
+
params = list(self.model.parameters())
|
463 |
+
for embedder in self.conditioner.embedders:
|
464 |
+
if embedder.is_trainable:
|
465 |
+
params = params + list(embedder.parameters())
|
466 |
+
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
467 |
+
if self.scheduler_config is not None:
|
468 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
469 |
+
print("Setting up LambdaLR scheduler...")
|
470 |
+
scheduler = [
|
471 |
+
{
|
472 |
+
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
473 |
+
"interval": "step",
|
474 |
+
"frequency": 1,
|
475 |
+
}
|
476 |
+
]
|
477 |
+
return [opt], scheduler
|
478 |
+
return opt
|
479 |
+
|
480 |
+
@torch.no_grad()
|
481 |
+
def sample(
|
482 |
+
self,
|
483 |
+
cond: Dict,
|
484 |
+
uc: Union[Dict, None] = None,
|
485 |
+
batch_size: int = 16,
|
486 |
+
shape: Union[None, Tuple, List] = None,
|
487 |
+
**kwargs,
|
488 |
+
):
|
489 |
+
randn = torch.randn(batch_size, *shape).to(self.device)
|
490 |
+
|
491 |
+
denoiser = lambda input, sigma, c: self.denoiser(
|
492 |
+
self.model, input, sigma, c, **kwargs
|
493 |
+
)
|
494 |
+
samples = self.sampler(denoiser, randn, cond, uc=uc)
|
495 |
+
|
496 |
+
return samples
|
497 |
+
|
498 |
+
@torch.no_grad()
|
499 |
+
def sample_no_guider(
|
500 |
+
self,
|
501 |
+
cond: Dict,
|
502 |
+
uc: Union[Dict, None] = None,
|
503 |
+
batch_size: int = 16,
|
504 |
+
shape: Union[None, Tuple, List] = None,
|
505 |
+
**kwargs,
|
506 |
+
):
|
507 |
+
randn = torch.randn(batch_size, *shape).to(self.device)
|
508 |
+
|
509 |
+
denoiser = lambda input, sigma, c: self.denoiser(
|
510 |
+
self.model, input, sigma, c, **kwargs
|
511 |
+
)
|
512 |
+
samples = self.sampler_no_guidance(denoiser, randn, cond, uc=uc)
|
513 |
+
|
514 |
+
return samples
|
515 |
+
|
516 |
+
@torch.no_grad()
|
517 |
+
def log_conditionings(self, batch: Dict, n: int) -> Dict:
|
518 |
+
"""
|
519 |
+
Defines heuristics to log different conditionings.
|
520 |
+
These can be lists of strings (text-to-image), tensors, ints, ...
|
521 |
+
"""
|
522 |
+
image_h, image_w = batch[self.input_key].shape[-2:]
|
523 |
+
log = dict()
|
524 |
+
|
525 |
+
for embedder in self.conditioner.embedders:
|
526 |
+
if (
|
527 |
+
(self.log_keys is None) or (embedder.input_key in self.log_keys)
|
528 |
+
) and not self.no_cond_log:
|
529 |
+
if embedder.input_key in self.no_log_keys:
|
530 |
+
continue
|
531 |
+
x = batch[embedder.input_key][:n]
|
532 |
+
if isinstance(x, torch.Tensor):
|
533 |
+
if x.dim() == 1:
|
534 |
+
# class-conditional, convert integer to string
|
535 |
+
x = [str(x[i].item()) for i in range(x.shape[0])]
|
536 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
|
537 |
+
elif x.dim() == 2:
|
538 |
+
# size and crop cond and the like
|
539 |
+
x = [
|
540 |
+
"x".join([str(xx) for xx in x[i].tolist()])
|
541 |
+
for i in range(x.shape[0])
|
542 |
+
]
|
543 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
544 |
+
elif x.dim() == 4: # already an image
|
545 |
+
xc = x
|
546 |
+
elif x.dim() == 5:
|
547 |
+
xc = torch.cat([x[:, :, i] for i in range(x.shape[2])], dim=-1)
|
548 |
+
else:
|
549 |
+
print(x.shape, embedder.input_key)
|
550 |
+
raise NotImplementedError()
|
551 |
+
elif isinstance(x, (List, ListConfig)):
|
552 |
+
if isinstance(x[0], str):
|
553 |
+
# strings
|
554 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
555 |
+
else:
|
556 |
+
raise NotImplementedError()
|
557 |
+
else:
|
558 |
+
raise NotImplementedError()
|
559 |
+
log[embedder.input_key] = xc
|
560 |
+
return log
|
561 |
+
|
562 |
+
@torch.no_grad()
|
563 |
+
def log_images(
|
564 |
+
self,
|
565 |
+
batch: Dict,
|
566 |
+
N: int = 8,
|
567 |
+
sample: bool = True,
|
568 |
+
ucg_keys: List[str] = None,
|
569 |
+
**kwargs,
|
570 |
+
) -> Dict:
|
571 |
+
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
572 |
+
if ucg_keys:
|
573 |
+
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
574 |
+
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
575 |
+
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
576 |
+
)
|
577 |
+
else:
|
578 |
+
ucg_keys = conditioner_input_keys
|
579 |
+
log = dict()
|
580 |
+
|
581 |
+
x = self.get_input(batch)
|
582 |
+
|
583 |
+
c, uc = self.conditioner.get_unconditional_conditioning(
|
584 |
+
batch,
|
585 |
+
force_uc_zero_embeddings=ucg_keys
|
586 |
+
if len(self.conditioner.embedders) > 0
|
587 |
+
else [],
|
588 |
+
)
|
589 |
+
|
590 |
+
sampling_kwargs = {}
|
591 |
+
|
592 |
+
N = min(x.shape[0], N)
|
593 |
+
x = x.to(self.device)[:N]
|
594 |
+
if self.input_key != "latents":
|
595 |
+
log["inputs"] = x
|
596 |
+
z = self.encode_first_stage(x)
|
597 |
+
else:
|
598 |
+
z = x
|
599 |
+
log["reconstructions"] = self.decode_first_stage(z)
|
600 |
+
log.update(self.log_conditionings(batch, N))
|
601 |
+
|
602 |
+
for k in c:
|
603 |
+
if isinstance(c[k], torch.Tensor):
|
604 |
+
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
605 |
+
|
606 |
+
if sample:
|
607 |
+
with self.ema_scope("Plotting"):
|
608 |
+
samples = self.sample(
|
609 |
+
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
610 |
+
)
|
611 |
+
samples = self.decode_first_stage(samples)
|
612 |
+
|
613 |
+
log["samples"] = samples
|
614 |
+
|
615 |
+
with self.ema_scope("Plotting"):
|
616 |
+
samples = self.sample_no_guider(
|
617 |
+
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
618 |
+
)
|
619 |
+
samples = self.decode_first_stage(samples)
|
620 |
+
|
621 |
+
log["samples_no_guidance"] = samples
|
622 |
+
return log
|
623 |
+
|
624 |
+
@torch.no_grad()
|
625 |
+
def log_videos(
|
626 |
+
self,
|
627 |
+
batch: Dict,
|
628 |
+
N: int = 8,
|
629 |
+
sample: bool = True,
|
630 |
+
ucg_keys: List[str] = None,
|
631 |
+
**kwargs,
|
632 |
+
) -> Dict:
|
633 |
+
# conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
634 |
+
# if ucg_keys:
|
635 |
+
# assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
636 |
+
# "Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
637 |
+
# f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
638 |
+
# )
|
639 |
+
# else:
|
640 |
+
# ucg_keys = conditioner_input_keys
|
641 |
+
log = dict()
|
642 |
+
batch_uc = {}
|
643 |
+
|
644 |
+
x = self.get_input(batch)
|
645 |
+
num_frames = x.shape[2] # assuming bcthw format
|
646 |
+
|
647 |
+
for key in batch.keys():
|
648 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
649 |
+
batch_uc[key] = torch.clone(batch[key])
|
650 |
+
|
651 |
+
c, uc = self.conditioner.get_unconditional_conditioning(
|
652 |
+
batch,
|
653 |
+
batch_uc=batch_uc,
|
654 |
+
force_uc_zero_embeddings=ucg_keys
|
655 |
+
if ucg_keys is not None
|
656 |
+
else [
|
657 |
+
"cond_frames",
|
658 |
+
"cond_frames_without_noise",
|
659 |
+
],
|
660 |
+
)
|
661 |
+
|
662 |
+
# for k in ["crossattn", "concat"]:
|
663 |
+
# uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
|
664 |
+
# uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
|
665 |
+
# c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
|
666 |
+
# c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
|
667 |
+
|
668 |
+
sampling_kwargs = {}
|
669 |
+
|
670 |
+
N = min(x.shape[0], N)
|
671 |
+
x = x.to(self.device)[:N]
|
672 |
+
|
673 |
+
if self.input_key != "latents":
|
674 |
+
log["inputs"] = x
|
675 |
+
z = self.encode_first_stage(x)
|
676 |
+
else:
|
677 |
+
z = x
|
678 |
+
log["reconstructions"] = self.decode_first_stage(z)
|
679 |
+
log.update(self.log_conditionings(batch, N))
|
680 |
+
|
681 |
+
if c.get("masks", None) is not None:
|
682 |
+
# Create a mask reconstruction
|
683 |
+
masks = 1 - c["masks"]
|
684 |
+
t = masks.shape[2]
|
685 |
+
masks = rearrange(masks, "b c t h w -> (b t) c h w")
|
686 |
+
target_size = (
|
687 |
+
log["reconstructions"].shape[-2],
|
688 |
+
log["reconstructions"].shape[-1],
|
689 |
+
)
|
690 |
+
masks = torch.nn.functional.interpolate(
|
691 |
+
masks, size=target_size, mode="nearest"
|
692 |
+
)
|
693 |
+
masks = rearrange(masks, "(b t) c h w -> b c t h w", t=t)
|
694 |
+
log["mask_reconstructions"] = log["reconstructions"] * masks
|
695 |
+
|
696 |
+
for k in c:
|
697 |
+
if isinstance(c[k], torch.Tensor):
|
698 |
+
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
699 |
+
elif isinstance(c[k], list):
|
700 |
+
for i in range(len(c[k])):
|
701 |
+
c[k][i], uc[k][i] = map(
|
702 |
+
lambda y: y[k][i][:N].to(self.device), (c, uc)
|
703 |
+
)
|
704 |
+
|
705 |
+
if sample:
|
706 |
+
n = 2 if self.is_guided else 1
|
707 |
+
# if num_frames == 1:
|
708 |
+
# sampling_kwargs["image_only_indicator"] = torch.ones(n, num_frames).to(self.device)
|
709 |
+
# else:
|
710 |
+
sampling_kwargs["image_only_indicator"] = torch.zeros(n, num_frames).to(
|
711 |
+
self.device
|
712 |
+
)
|
713 |
+
sampling_kwargs["num_video_frames"] = batch["num_video_frames"]
|
714 |
+
|
715 |
+
with self.ema_scope("Plotting"):
|
716 |
+
samples = self.sample(
|
717 |
+
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
718 |
+
)
|
719 |
+
samples = self.decode_first_stage(samples)
|
720 |
+
if self.is_dubbing:
|
721 |
+
samples[:, :, :, : samples.shape[-2] // 2] = log["reconstructions"][
|
722 |
+
:, :, :, : samples.shape[-2] // 2
|
723 |
+
]
|
724 |
+
log["samples"] = samples
|
725 |
+
|
726 |
+
# Without guidance
|
727 |
+
# if num_frames == 1:
|
728 |
+
# sampling_kwargs["image_only_indicator"] = torch.ones(1, num_frames).to(self.device)
|
729 |
+
# else:
|
730 |
+
sampling_kwargs["image_only_indicator"] = torch.zeros(1, num_frames).to(
|
731 |
+
self.device
|
732 |
+
)
|
733 |
+
sampling_kwargs["num_video_frames"] = batch["num_video_frames"]
|
734 |
+
|
735 |
+
with self.ema_scope("Plotting"):
|
736 |
+
samples = self.sample_no_guider(
|
737 |
+
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
738 |
+
)
|
739 |
+
samples = self.decode_first_stage(samples)
|
740 |
+
if self.is_dubbing:
|
741 |
+
samples[:, :, :, : samples.shape[-2] // 2] = log["reconstructions"][
|
742 |
+
:, :, :, : samples.shape[-2] // 2
|
743 |
+
]
|
744 |
+
log["samples_no_guidance"] = samples
|
745 |
+
|
746 |
+
torch.cuda.empty_cache()
|
747 |
+
return log
|
sgm/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .encoders.modules import GeneralConditioner
|
2 |
+
|
3 |
+
UNCONDITIONAL_CONFIG = {
|
4 |
+
"target": "sgm.modules.GeneralConditioner",
|
5 |
+
"params": {"emb_models": []},
|
6 |
+
}
|
sgm/modules/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (400 Bytes). View file
|
|
sgm/modules/__pycache__/attention.cpython-311.pyc
ADDED
Binary file (39.1 kB). View file
|
|
sgm/modules/__pycache__/ema.cpython-311.pyc
ADDED
Binary file (5.87 kB). View file
|
|
sgm/modules/__pycache__/video_attention.cpython-311.pyc
ADDED
Binary file (14.2 kB). View file
|
|
sgm/modules/attention.py
ADDED
@@ -0,0 +1,889 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from inspect import isfunction
|
4 |
+
from typing import Any, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
from packaging import version
|
10 |
+
from torch import nn
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
|
13 |
+
logpy = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
16 |
+
SDP_IS_AVAILABLE = True
|
17 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
18 |
+
|
19 |
+
BACKEND_MAP = {
|
20 |
+
SDPBackend.MATH: {
|
21 |
+
"enable_math": True,
|
22 |
+
"enable_flash": False,
|
23 |
+
"enable_mem_efficient": False,
|
24 |
+
},
|
25 |
+
SDPBackend.FLASH_ATTENTION: {
|
26 |
+
"enable_math": False,
|
27 |
+
"enable_flash": True,
|
28 |
+
"enable_mem_efficient": False,
|
29 |
+
},
|
30 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
31 |
+
"enable_math": False,
|
32 |
+
"enable_flash": False,
|
33 |
+
"enable_mem_efficient": True,
|
34 |
+
},
|
35 |
+
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
36 |
+
}
|
37 |
+
else:
|
38 |
+
from contextlib import nullcontext
|
39 |
+
|
40 |
+
SDP_IS_AVAILABLE = False
|
41 |
+
sdp_kernel = nullcontext
|
42 |
+
BACKEND_MAP = {}
|
43 |
+
logpy.warn(
|
44 |
+
f"No SDP backend available, likely because you are running in pytorch "
|
45 |
+
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
46 |
+
f"You might want to consider upgrading."
|
47 |
+
)
|
48 |
+
|
49 |
+
try:
|
50 |
+
import xformers
|
51 |
+
import xformers.ops
|
52 |
+
|
53 |
+
XFORMERS_IS_AVAILABLE = True
|
54 |
+
except:
|
55 |
+
XFORMERS_IS_AVAILABLE = False
|
56 |
+
logpy.warn("no module 'xformers'. Processing without...")
|
57 |
+
|
58 |
+
# from .diffusionmodules.util import mixed_checkpoint as checkpoint
|
59 |
+
|
60 |
+
|
61 |
+
def exists(val):
|
62 |
+
return val is not None
|
63 |
+
|
64 |
+
|
65 |
+
def uniq(arr):
|
66 |
+
return {el: True for el in arr}.keys()
|
67 |
+
|
68 |
+
|
69 |
+
def default(val, d):
|
70 |
+
if exists(val):
|
71 |
+
return val
|
72 |
+
return d() if isfunction(d) else d
|
73 |
+
|
74 |
+
|
75 |
+
def max_neg_value(t):
|
76 |
+
return -torch.finfo(t.dtype).max
|
77 |
+
|
78 |
+
|
79 |
+
def init_(tensor):
|
80 |
+
dim = tensor.shape[-1]
|
81 |
+
std = 1 / math.sqrt(dim)
|
82 |
+
tensor.uniform_(-std, std)
|
83 |
+
return tensor
|
84 |
+
|
85 |
+
|
86 |
+
# feedforward
|
87 |
+
class GEGLU(nn.Module):
|
88 |
+
def __init__(self, dim_in, dim_out):
|
89 |
+
super().__init__()
|
90 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
94 |
+
return x * F.gelu(gate)
|
95 |
+
|
96 |
+
|
97 |
+
class FeedForward(nn.Module):
|
98 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
99 |
+
super().__init__()
|
100 |
+
inner_dim = int(dim * mult)
|
101 |
+
dim_out = default(dim_out, dim)
|
102 |
+
project_in = (
|
103 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
104 |
+
if not glu
|
105 |
+
else GEGLU(dim, inner_dim)
|
106 |
+
)
|
107 |
+
|
108 |
+
self.net = nn.Sequential(
|
109 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
return self.net(x)
|
114 |
+
|
115 |
+
|
116 |
+
def zero_module(module):
|
117 |
+
"""
|
118 |
+
Zero out the parameters of a module and return it.
|
119 |
+
"""
|
120 |
+
for p in module.parameters():
|
121 |
+
p.detach().zero_()
|
122 |
+
return module
|
123 |
+
|
124 |
+
|
125 |
+
def Normalize(in_channels):
|
126 |
+
return torch.nn.GroupNorm(
|
127 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
class LinearAttention(nn.Module):
|
132 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
133 |
+
super().__init__()
|
134 |
+
self.heads = heads
|
135 |
+
hidden_dim = dim_head * heads
|
136 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
137 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
b, c, h, w = x.shape
|
141 |
+
qkv = self.to_qkv(x)
|
142 |
+
q, k, v = rearrange(
|
143 |
+
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
144 |
+
)
|
145 |
+
k = k.softmax(dim=-1)
|
146 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
147 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
148 |
+
out = rearrange(
|
149 |
+
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
150 |
+
)
|
151 |
+
return self.to_out(out)
|
152 |
+
|
153 |
+
|
154 |
+
class SelfAttention(nn.Module):
|
155 |
+
ATTENTION_MODES = ("xformers", "torch", "math")
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
dim: int,
|
160 |
+
num_heads: int = 8,
|
161 |
+
qkv_bias: bool = False,
|
162 |
+
qk_scale: Optional[float] = None,
|
163 |
+
attn_drop: float = 0.0,
|
164 |
+
proj_drop: float = 0.0,
|
165 |
+
attn_mode: str = "xformers",
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
self.num_heads = num_heads
|
169 |
+
head_dim = dim // num_heads
|
170 |
+
self.scale = qk_scale or head_dim**-0.5
|
171 |
+
|
172 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
173 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
174 |
+
self.proj = nn.Linear(dim, dim)
|
175 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
176 |
+
assert attn_mode in self.ATTENTION_MODES
|
177 |
+
self.attn_mode = attn_mode
|
178 |
+
|
179 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
180 |
+
B, L, C = x.shape
|
181 |
+
|
182 |
+
qkv = self.qkv(x)
|
183 |
+
if self.attn_mode == "torch":
|
184 |
+
qkv = rearrange(
|
185 |
+
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
186 |
+
).float()
|
187 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
188 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
189 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
190 |
+
elif self.attn_mode == "xformers":
|
191 |
+
qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
192 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
193 |
+
x = xformers.ops.memory_efficient_attention(q, k, v)
|
194 |
+
x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
|
195 |
+
elif self.attn_mode == "math":
|
196 |
+
qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
197 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
198 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
199 |
+
attn = attn.softmax(dim=-1)
|
200 |
+
attn = self.attn_drop(attn)
|
201 |
+
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
202 |
+
else:
|
203 |
+
raise NotImplemented
|
204 |
+
|
205 |
+
x = self.proj(x)
|
206 |
+
x = self.proj_drop(x)
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class SpatialSelfAttention(nn.Module):
|
211 |
+
def __init__(self, in_channels):
|
212 |
+
super().__init__()
|
213 |
+
self.in_channels = in_channels
|
214 |
+
|
215 |
+
self.norm = Normalize(in_channels)
|
216 |
+
self.q = torch.nn.Conv2d(
|
217 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
218 |
+
)
|
219 |
+
self.k = torch.nn.Conv2d(
|
220 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
221 |
+
)
|
222 |
+
self.v = torch.nn.Conv2d(
|
223 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
224 |
+
)
|
225 |
+
self.proj_out = torch.nn.Conv2d(
|
226 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
227 |
+
)
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
h_ = x
|
231 |
+
h_ = self.norm(h_)
|
232 |
+
q = self.q(h_)
|
233 |
+
k = self.k(h_)
|
234 |
+
v = self.v(h_)
|
235 |
+
|
236 |
+
# compute attention
|
237 |
+
b, c, h, w = q.shape
|
238 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
239 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
240 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
241 |
+
|
242 |
+
w_ = w_ * (int(c) ** (-0.5))
|
243 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
244 |
+
|
245 |
+
# attend to values
|
246 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
247 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
248 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
249 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
250 |
+
h_ = self.proj_out(h_)
|
251 |
+
|
252 |
+
return x + h_
|
253 |
+
|
254 |
+
|
255 |
+
class CrossAttention(nn.Module):
|
256 |
+
def __init__(
|
257 |
+
self,
|
258 |
+
query_dim,
|
259 |
+
context_dim=None,
|
260 |
+
heads=8,
|
261 |
+
dim_head=64,
|
262 |
+
dropout=0.0,
|
263 |
+
backend=None,
|
264 |
+
**kwargs,
|
265 |
+
):
|
266 |
+
super().__init__()
|
267 |
+
inner_dim = dim_head * heads
|
268 |
+
context_dim = default(context_dim, query_dim)
|
269 |
+
|
270 |
+
self.scale = dim_head**-0.5
|
271 |
+
self.heads = heads
|
272 |
+
|
273 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
274 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
275 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
276 |
+
|
277 |
+
self.to_out = nn.Sequential(
|
278 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
279 |
+
)
|
280 |
+
self.backend = backend
|
281 |
+
|
282 |
+
def forward(
|
283 |
+
self,
|
284 |
+
x,
|
285 |
+
context=None,
|
286 |
+
mask=None,
|
287 |
+
additional_tokens=None,
|
288 |
+
n_times_crossframe_attn_in_self=0,
|
289 |
+
skip_attention=None,
|
290 |
+
**kwargs,
|
291 |
+
):
|
292 |
+
h = self.heads
|
293 |
+
|
294 |
+
if additional_tokens is not None:
|
295 |
+
# get the number of masked tokens at the beginning of the output sequence
|
296 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
297 |
+
# add additional token
|
298 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
299 |
+
|
300 |
+
# Ensure skip_attention is a BΓ1 boolean tensor
|
301 |
+
if skip_attention is None:
|
302 |
+
skip_attention = torch.zeros_like(x[:, :1], dtype=torch.bool)
|
303 |
+
|
304 |
+
assert isinstance(skip_attention, torch.Tensor)
|
305 |
+
assert skip_attention.shape[1] == 1 and skip_attention.dtype == torch.bool
|
306 |
+
|
307 |
+
# Split the batch into skip and non-skip parts
|
308 |
+
skip_indices = skip_attention.squeeze(1)
|
309 |
+
non_skip_indices = ~skip_indices
|
310 |
+
|
311 |
+
# Process skip attention samples
|
312 |
+
if skip_indices.any():
|
313 |
+
x_skip = x[skip_indices]
|
314 |
+
out_skip = self.to_v(x_skip)
|
315 |
+
out_skip = rearrange(out_skip, "b n (h d) -> b n (h d)", h=h)
|
316 |
+
|
317 |
+
# If all samples are skipped, return early
|
318 |
+
if not non_skip_indices.any():
|
319 |
+
if additional_tokens is not None:
|
320 |
+
out_skip = out_skip[:, n_tokens_to_mask:]
|
321 |
+
return self.to_out(out_skip)
|
322 |
+
|
323 |
+
# Process non-skip samples with attention
|
324 |
+
x_non_skip = x[non_skip_indices]
|
325 |
+
q = self.to_q(x_non_skip)
|
326 |
+
context = default(context, x_non_skip)
|
327 |
+
k = self.to_k(context)
|
328 |
+
v = self.to_v(context)
|
329 |
+
|
330 |
+
if n_times_crossframe_attn_in_self:
|
331 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
332 |
+
assert x_non_skip.shape[0] % n_times_crossframe_attn_in_self == 0
|
333 |
+
k = repeat(
|
334 |
+
k[::n_times_crossframe_attn_in_self],
|
335 |
+
"b ... -> (b n) ...",
|
336 |
+
n=n_times_crossframe_attn_in_self,
|
337 |
+
)
|
338 |
+
v = repeat(
|
339 |
+
v[::n_times_crossframe_attn_in_self],
|
340 |
+
"b ... -> (b n) ...",
|
341 |
+
n=n_times_crossframe_attn_in_self,
|
342 |
+
)
|
343 |
+
|
344 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
345 |
+
|
346 |
+
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
347 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
348 |
+
|
349 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
350 |
+
|
351 |
+
# Combine skip and non-skip results
|
352 |
+
combined_out = torch.zeros(
|
353 |
+
(x.shape[0], out.shape[1], out.shape[2]), dtype=out.dtype, device=out.device
|
354 |
+
)
|
355 |
+
combined_out[non_skip_indices] = out
|
356 |
+
if skip_indices.any():
|
357 |
+
combined_out[skip_indices] = out_skip
|
358 |
+
|
359 |
+
if additional_tokens is not None:
|
360 |
+
combined_out = combined_out[:, n_tokens_to_mask:]
|
361 |
+
return self.to_out(combined_out)
|
362 |
+
|
363 |
+
|
364 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
365 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
366 |
+
def __init__(
|
367 |
+
self,
|
368 |
+
query_dim,
|
369 |
+
context_dim=None,
|
370 |
+
heads=8,
|
371 |
+
dim_head=64,
|
372 |
+
dropout=0.0,
|
373 |
+
use_reference=False,
|
374 |
+
extra_linear=False,
|
375 |
+
**kwargs,
|
376 |
+
):
|
377 |
+
super().__init__()
|
378 |
+
logpy.debug(
|
379 |
+
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
|
380 |
+
f"context_dim is {context_dim} and using {heads} heads with a "
|
381 |
+
f"dimension of {dim_head}."
|
382 |
+
)
|
383 |
+
inner_dim = dim_head * heads
|
384 |
+
self.is_context = context_dim is not None
|
385 |
+
context_dim = default(context_dim, query_dim)
|
386 |
+
|
387 |
+
self.heads = heads
|
388 |
+
self.dim_head = dim_head
|
389 |
+
self.use_reference = use_reference and self.is_context
|
390 |
+
|
391 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
392 |
+
if not self.use_reference:
|
393 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
394 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
395 |
+
else:
|
396 |
+
if extra_linear:
|
397 |
+
self.to_k = nn.Linear(inner_dim, inner_dim, bias=False)
|
398 |
+
self.to_v = nn.Linear(inner_dim, inner_dim, bias=False)
|
399 |
+
self.extra_linear = extra_linear
|
400 |
+
|
401 |
+
self.to_out = nn.Sequential(
|
402 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
403 |
+
)
|
404 |
+
self.attention_op: Optional[Any] = None
|
405 |
+
|
406 |
+
def forward(
|
407 |
+
self,
|
408 |
+
x,
|
409 |
+
context=None,
|
410 |
+
mask=None,
|
411 |
+
additional_tokens=None,
|
412 |
+
n_times_crossframe_attn_in_self=0,
|
413 |
+
skip_attention=None,
|
414 |
+
):
|
415 |
+
if additional_tokens is not None:
|
416 |
+
# get the number of masked tokens at the beginning of the output sequence
|
417 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
418 |
+
# add additional token
|
419 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
420 |
+
|
421 |
+
# Ensure skip_attention is a BΓ1 boolean tensor
|
422 |
+
if skip_attention is None:
|
423 |
+
skip_attention = torch.zeros(x.shape[0], 1, dtype=torch.bool)
|
424 |
+
# print(x.shape)
|
425 |
+
# print(skip_attention)
|
426 |
+
# print(skip_attention.shape)
|
427 |
+
# print(any(skip_attention))
|
428 |
+
assert isinstance(skip_attention, torch.Tensor)
|
429 |
+
assert skip_attention.shape[1] == 1 and skip_attention.dtype == torch.bool
|
430 |
+
|
431 |
+
# Split the batch into skip and non-skip parts
|
432 |
+
skip_indices = skip_attention.squeeze(1)
|
433 |
+
non_skip_indices = ~skip_indices
|
434 |
+
|
435 |
+
# Process skip attention samples
|
436 |
+
if skip_indices.any():
|
437 |
+
x_skip = x[skip_indices]
|
438 |
+
out_skip = self.to_v(x_skip)
|
439 |
+
out_skip = (
|
440 |
+
out_skip.unsqueeze(0)
|
441 |
+
.reshape(-1, self.heads, out_skip.shape[1], self.dim_head)
|
442 |
+
.permute(0, 2, 1, 3)
|
443 |
+
.reshape(-1, out_skip.shape[1], self.heads * self.dim_head)
|
444 |
+
)
|
445 |
+
# If all samples are skipped, return early
|
446 |
+
if not non_skip_indices.any():
|
447 |
+
if additional_tokens is not None:
|
448 |
+
out_skip = out_skip[:, n_tokens_to_mask:]
|
449 |
+
return self.to_out(out_skip)
|
450 |
+
|
451 |
+
x_non_skip = x[non_skip_indices]
|
452 |
+
q = self.to_q(x_non_skip)
|
453 |
+
if not self.use_reference:
|
454 |
+
context = default(context, x_non_skip)
|
455 |
+
k = self.to_k(context)
|
456 |
+
v = self.to_v(context)
|
457 |
+
else:
|
458 |
+
# Reference has already correct shape
|
459 |
+
assert context is not None
|
460 |
+
if self.extra_linear:
|
461 |
+
k = self.to_k(context)
|
462 |
+
v = self.to_v(context)
|
463 |
+
k, v = context, context
|
464 |
+
|
465 |
+
if n_times_crossframe_attn_in_self:
|
466 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
467 |
+
assert x_non_skip.shape[0] % n_times_crossframe_attn_in_self == 0
|
468 |
+
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
|
469 |
+
k = repeat(
|
470 |
+
k[::n_times_crossframe_attn_in_self],
|
471 |
+
"b ... -> (b n) ...",
|
472 |
+
n=n_times_crossframe_attn_in_self,
|
473 |
+
)
|
474 |
+
v = repeat(
|
475 |
+
v[::n_times_crossframe_attn_in_self],
|
476 |
+
"b ... -> (b n) ...",
|
477 |
+
n=n_times_crossframe_attn_in_self,
|
478 |
+
)
|
479 |
+
|
480 |
+
b, _, _ = q.shape
|
481 |
+
q, k, v = map(
|
482 |
+
lambda t: t.unsqueeze(3)
|
483 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
484 |
+
.permute(0, 2, 1, 3)
|
485 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
486 |
+
.contiguous(),
|
487 |
+
(q, k, v),
|
488 |
+
)
|
489 |
+
if q.dtype != k.dtype:
|
490 |
+
k = k.to(q.dtype)
|
491 |
+
v = v.to(q.dtype)
|
492 |
+
|
493 |
+
# actually compute the attention, what we cannot get enough of
|
494 |
+
if version.parse(xformers.__version__) >= version.parse("0.0.21"):
|
495 |
+
# NOTE: workaround for
|
496 |
+
# https://github.com/facebookresearch/xformers/issues/845
|
497 |
+
max_bs = 32768
|
498 |
+
N = q.shape[0]
|
499 |
+
n_batches = math.ceil(N / max_bs)
|
500 |
+
out = list()
|
501 |
+
for i_batch in range(n_batches):
|
502 |
+
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
|
503 |
+
out.append(
|
504 |
+
xformers.ops.memory_efficient_attention(
|
505 |
+
q[batch],
|
506 |
+
k[batch],
|
507 |
+
v[batch],
|
508 |
+
attn_bias=None,
|
509 |
+
op=self.attention_op,
|
510 |
+
)
|
511 |
+
)
|
512 |
+
out = torch.cat(out, 0)
|
513 |
+
else:
|
514 |
+
out = xformers.ops.memory_efficient_attention(
|
515 |
+
q, k, v, attn_bias=None, op=self.attention_op
|
516 |
+
)
|
517 |
+
|
518 |
+
# TODO: Use this directly in the attention operation, as a bias
|
519 |
+
if exists(mask):
|
520 |
+
raise NotImplementedError
|
521 |
+
out = (
|
522 |
+
out.unsqueeze(0)
|
523 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
524 |
+
.permute(0, 2, 1, 3)
|
525 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
526 |
+
)
|
527 |
+
# Combine skip and non-skip results
|
528 |
+
combined_out = torch.zeros(
|
529 |
+
(x.shape[0], out.shape[1], out.shape[2]), dtype=out.dtype, device=out.device
|
530 |
+
)
|
531 |
+
combined_out[non_skip_indices] = out
|
532 |
+
if skip_indices.any():
|
533 |
+
combined_out[skip_indices] = out_skip
|
534 |
+
else:
|
535 |
+
combined_out = out
|
536 |
+
|
537 |
+
if additional_tokens is not None:
|
538 |
+
# remove additional token
|
539 |
+
combined_out = combined_out[:, n_tokens_to_mask:]
|
540 |
+
return self.to_out(combined_out)
|
541 |
+
|
542 |
+
|
543 |
+
class BasicTransformerBlock(nn.Module):
|
544 |
+
ATTENTION_MODES = {
|
545 |
+
"softmax": CrossAttention, # vanilla attention
|
546 |
+
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
|
547 |
+
}
|
548 |
+
|
549 |
+
def __init__(
|
550 |
+
self,
|
551 |
+
dim,
|
552 |
+
n_heads,
|
553 |
+
d_head,
|
554 |
+
dropout=0.0,
|
555 |
+
context_dim=None,
|
556 |
+
gated_ff=True,
|
557 |
+
checkpoint=True,
|
558 |
+
disable_self_attn=False,
|
559 |
+
attn_mode="softmax",
|
560 |
+
sdp_backend=None,
|
561 |
+
reference_to=None,
|
562 |
+
):
|
563 |
+
super().__init__()
|
564 |
+
assert attn_mode in self.ATTENTION_MODES
|
565 |
+
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
566 |
+
logpy.warn(
|
567 |
+
f"Attention mode '{attn_mode}' is not available. Falling "
|
568 |
+
f"back to native attention. This is not a problem in "
|
569 |
+
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
|
570 |
+
f"version {torch.__version__}."
|
571 |
+
)
|
572 |
+
attn_mode = "softmax"
|
573 |
+
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
574 |
+
logpy.warn(
|
575 |
+
"We do not support vanilla attention anymore, as it is too "
|
576 |
+
"expensive. Sorry."
|
577 |
+
)
|
578 |
+
if not XFORMERS_IS_AVAILABLE:
|
579 |
+
assert False, (
|
580 |
+
"Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
581 |
+
)
|
582 |
+
else:
|
583 |
+
logpy.info("Falling back to xformers efficient attention.")
|
584 |
+
attn_mode = "softmax-xformers"
|
585 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
586 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
587 |
+
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
588 |
+
else:
|
589 |
+
assert sdp_backend is None
|
590 |
+
self.disable_self_attn = disable_self_attn
|
591 |
+
extra_linear = (reference_to is not None) and ("extra" in reference_to)
|
592 |
+
if extra_linear:
|
593 |
+
reference_to = reference_to.replace("_extra", "")
|
594 |
+
assert reference_to in [None, "self", "cross"]
|
595 |
+
self.reference_to = reference_to
|
596 |
+
self.attn1 = attn_cls(
|
597 |
+
query_dim=dim,
|
598 |
+
heads=n_heads,
|
599 |
+
dim_head=d_head,
|
600 |
+
dropout=dropout,
|
601 |
+
context_dim=context_dim
|
602 |
+
if (self.disable_self_attn or reference_to == "self")
|
603 |
+
else None,
|
604 |
+
backend=sdp_backend,
|
605 |
+
use_reference=reference_to == "self",
|
606 |
+
extra_linear=extra_linear,
|
607 |
+
) # is a self-attention if not self.disable_self_attn
|
608 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
609 |
+
self.attn2 = attn_cls(
|
610 |
+
query_dim=dim,
|
611 |
+
context_dim=context_dim,
|
612 |
+
heads=n_heads,
|
613 |
+
dim_head=d_head,
|
614 |
+
dropout=dropout,
|
615 |
+
backend=sdp_backend,
|
616 |
+
use_reference=reference_to == "cross",
|
617 |
+
extra_linear=extra_linear,
|
618 |
+
) # is self-attn if context is none
|
619 |
+
self.norm1 = nn.LayerNorm(dim)
|
620 |
+
self.norm2 = nn.LayerNorm(dim)
|
621 |
+
self.norm3 = nn.LayerNorm(dim)
|
622 |
+
self.checkpoint = checkpoint
|
623 |
+
if self.checkpoint:
|
624 |
+
logpy.debug(f"{self.__class__.__name__} is using checkpointing")
|
625 |
+
|
626 |
+
def forward(
|
627 |
+
self,
|
628 |
+
x,
|
629 |
+
context=None,
|
630 |
+
reference_context=None,
|
631 |
+
additional_tokens=None,
|
632 |
+
n_times_crossframe_attn_in_self=0,
|
633 |
+
skip_attention=None,
|
634 |
+
):
|
635 |
+
kwargs = {"x": x}
|
636 |
+
|
637 |
+
if context is not None:
|
638 |
+
kwargs.update({"context": context})
|
639 |
+
|
640 |
+
if reference_context is not None:
|
641 |
+
kwargs.update({"reference_context": reference_context})
|
642 |
+
|
643 |
+
if additional_tokens is not None:
|
644 |
+
kwargs.update({"additional_tokens": additional_tokens})
|
645 |
+
|
646 |
+
if n_times_crossframe_attn_in_self:
|
647 |
+
kwargs.update(
|
648 |
+
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
|
649 |
+
)
|
650 |
+
|
651 |
+
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
652 |
+
if self.checkpoint:
|
653 |
+
# inputs = {"x": x, "context": context}
|
654 |
+
return checkpoint(
|
655 |
+
self._forward, x, context, reference_context, None, 0, skip_attention
|
656 |
+
)
|
657 |
+
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
658 |
+
else:
|
659 |
+
return self._forward(**kwargs)
|
660 |
+
|
661 |
+
def _forward(
|
662 |
+
self,
|
663 |
+
x,
|
664 |
+
context=None,
|
665 |
+
reference_context=None,
|
666 |
+
additional_tokens=None,
|
667 |
+
n_times_crossframe_attn_in_self=0,
|
668 |
+
skip_attention=None,
|
669 |
+
):
|
670 |
+
self_context = reference_context if self.reference_to == "self" else context
|
671 |
+
# print(self.reference_to)
|
672 |
+
# print("context: ", context.shape if context is not None else None)
|
673 |
+
# print("reference_context: ", reference_context.shape if reference_context is not None else None)
|
674 |
+
# print("x: ", x.shape)
|
675 |
+
|
676 |
+
x = (
|
677 |
+
self.attn1(
|
678 |
+
self.norm1(x),
|
679 |
+
context=self_context
|
680 |
+
if (self.disable_self_attn or self.reference_to == "self")
|
681 |
+
else None,
|
682 |
+
additional_tokens=additional_tokens,
|
683 |
+
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
684 |
+
if not self.disable_self_attn
|
685 |
+
else 0,
|
686 |
+
skip_attention=skip_attention,
|
687 |
+
)
|
688 |
+
+ x
|
689 |
+
)
|
690 |
+
cross_context = reference_context if self.reference_to == "cross" else context
|
691 |
+
x = (
|
692 |
+
self.attn2(
|
693 |
+
self.norm2(x),
|
694 |
+
context=cross_context,
|
695 |
+
additional_tokens=additional_tokens,
|
696 |
+
)
|
697 |
+
+ x
|
698 |
+
)
|
699 |
+
x = self.ff(self.norm3(x)) + x
|
700 |
+
return x
|
701 |
+
|
702 |
+
|
703 |
+
class BasicTransformerSingleLayerBlock(nn.Module):
|
704 |
+
ATTENTION_MODES = {
|
705 |
+
"softmax": CrossAttention, # vanilla attention
|
706 |
+
"softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
|
707 |
+
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
708 |
+
}
|
709 |
+
|
710 |
+
def __init__(
|
711 |
+
self,
|
712 |
+
dim,
|
713 |
+
n_heads,
|
714 |
+
d_head,
|
715 |
+
dropout=0.0,
|
716 |
+
context_dim=None,
|
717 |
+
gated_ff=True,
|
718 |
+
checkpoint=True,
|
719 |
+
attn_mode="softmax",
|
720 |
+
):
|
721 |
+
super().__init__()
|
722 |
+
assert attn_mode in self.ATTENTION_MODES
|
723 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
724 |
+
self.attn1 = attn_cls(
|
725 |
+
query_dim=dim,
|
726 |
+
heads=n_heads,
|
727 |
+
dim_head=d_head,
|
728 |
+
dropout=dropout,
|
729 |
+
context_dim=context_dim,
|
730 |
+
)
|
731 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
732 |
+
self.norm1 = nn.LayerNorm(dim)
|
733 |
+
self.norm2 = nn.LayerNorm(dim)
|
734 |
+
self.checkpoint = checkpoint
|
735 |
+
|
736 |
+
def forward(self, x, context=None):
|
737 |
+
# inputs = {"x": x, "context": context}
|
738 |
+
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
739 |
+
return checkpoint(self._forward, x, context)
|
740 |
+
|
741 |
+
def _forward(self, x, context=None):
|
742 |
+
x = self.attn1(self.norm1(x), context=context) + x
|
743 |
+
x = self.ff(self.norm2(x)) + x
|
744 |
+
return x
|
745 |
+
|
746 |
+
|
747 |
+
class SpatialTransformer(nn.Module):
|
748 |
+
"""
|
749 |
+
Transformer block for image-like data.
|
750 |
+
First, project the input (aka embedding)
|
751 |
+
and reshape to b, t, d.
|
752 |
+
Then apply standard transformer action.
|
753 |
+
Finally, reshape to image
|
754 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
755 |
+
"""
|
756 |
+
|
757 |
+
def __init__(
|
758 |
+
self,
|
759 |
+
in_channels,
|
760 |
+
n_heads,
|
761 |
+
d_head,
|
762 |
+
depth=1,
|
763 |
+
dropout=0.0,
|
764 |
+
context_dim=None,
|
765 |
+
disable_self_attn=False,
|
766 |
+
use_linear=False,
|
767 |
+
attn_type="softmax",
|
768 |
+
use_checkpoint=True,
|
769 |
+
# sdp_backend=SDPBackend.FLASH_ATTENTION
|
770 |
+
sdp_backend=None,
|
771 |
+
reference_to=None,
|
772 |
+
):
|
773 |
+
super().__init__()
|
774 |
+
logpy.debug(
|
775 |
+
f"constructing {self.__class__.__name__} of depth {depth} w/ "
|
776 |
+
f"{in_channels} channels and {n_heads} heads."
|
777 |
+
)
|
778 |
+
|
779 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
780 |
+
context_dim = [context_dim]
|
781 |
+
if exists(context_dim) and isinstance(context_dim, list):
|
782 |
+
if depth != len(context_dim):
|
783 |
+
logpy.warn(
|
784 |
+
f"{self.__class__.__name__}: Found context dims "
|
785 |
+
f"{context_dim} of depth {len(context_dim)}, which does not "
|
786 |
+
f"match the specified 'depth' of {depth}. Setting context_dim "
|
787 |
+
f"to {depth * [context_dim[0]]} now."
|
788 |
+
)
|
789 |
+
# depth does not match context dims.
|
790 |
+
assert all(map(lambda x: x == context_dim[0], context_dim)), (
|
791 |
+
"need homogenous context_dim to match depth automatically"
|
792 |
+
)
|
793 |
+
context_dim = depth * [context_dim[0]]
|
794 |
+
elif context_dim is None:
|
795 |
+
context_dim = [None] * depth
|
796 |
+
self.in_channels = in_channels
|
797 |
+
inner_dim = n_heads * d_head
|
798 |
+
self.norm = Normalize(in_channels)
|
799 |
+
if not use_linear:
|
800 |
+
self.proj_in = nn.Conv2d(
|
801 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
802 |
+
)
|
803 |
+
else:
|
804 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
805 |
+
|
806 |
+
self.transformer_blocks = nn.ModuleList(
|
807 |
+
[
|
808 |
+
BasicTransformerBlock(
|
809 |
+
inner_dim,
|
810 |
+
n_heads,
|
811 |
+
d_head,
|
812 |
+
dropout=dropout,
|
813 |
+
context_dim=context_dim[d],
|
814 |
+
disable_self_attn=disable_self_attn,
|
815 |
+
attn_mode=attn_type,
|
816 |
+
checkpoint=use_checkpoint,
|
817 |
+
sdp_backend=sdp_backend,
|
818 |
+
reference_to=reference_to,
|
819 |
+
)
|
820 |
+
for d in range(depth)
|
821 |
+
]
|
822 |
+
)
|
823 |
+
if not use_linear:
|
824 |
+
self.proj_out = zero_module(
|
825 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
826 |
+
)
|
827 |
+
else:
|
828 |
+
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
829 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
830 |
+
self.use_linear = use_linear
|
831 |
+
|
832 |
+
def forward(self, x, context=None, skip_attention=None):
|
833 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
834 |
+
if not isinstance(context, list):
|
835 |
+
context = [context]
|
836 |
+
b, c, h, w = x.shape
|
837 |
+
x_in = x
|
838 |
+
x = self.norm(x)
|
839 |
+
if not self.use_linear:
|
840 |
+
x = self.proj_in(x)
|
841 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
842 |
+
if self.use_linear:
|
843 |
+
x = self.proj_in(x)
|
844 |
+
for i, block in enumerate(self.transformer_blocks):
|
845 |
+
if i > 0 and len(context) == 1:
|
846 |
+
i = 0 # use same context for each block
|
847 |
+
x = block(x, context=context[i], skip_attention=skip_attention)
|
848 |
+
if self.use_linear:
|
849 |
+
x = self.proj_out(x)
|
850 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
851 |
+
if not self.use_linear:
|
852 |
+
x = self.proj_out(x)
|
853 |
+
return x + x_in
|
854 |
+
|
855 |
+
|
856 |
+
class SimpleTransformer(nn.Module):
|
857 |
+
def __init__(
|
858 |
+
self,
|
859 |
+
dim: int,
|
860 |
+
depth: int,
|
861 |
+
heads: int,
|
862 |
+
dim_head: int,
|
863 |
+
context_dim: Optional[int] = None,
|
864 |
+
dropout: float = 0.0,
|
865 |
+
checkpoint: bool = True,
|
866 |
+
):
|
867 |
+
super().__init__()
|
868 |
+
self.layers = nn.ModuleList([])
|
869 |
+
for _ in range(depth):
|
870 |
+
self.layers.append(
|
871 |
+
BasicTransformerBlock(
|
872 |
+
dim,
|
873 |
+
heads,
|
874 |
+
dim_head,
|
875 |
+
dropout=dropout,
|
876 |
+
context_dim=context_dim,
|
877 |
+
attn_mode="softmax-xformers",
|
878 |
+
checkpoint=checkpoint,
|
879 |
+
)
|
880 |
+
)
|
881 |
+
|
882 |
+
def forward(
|
883 |
+
self,
|
884 |
+
x: torch.Tensor,
|
885 |
+
context: Optional[torch.Tensor] = None,
|
886 |
+
) -> torch.Tensor:
|
887 |
+
for layer in self.layers:
|
888 |
+
x = layer(x, context)
|
889 |
+
return x
|