arlaz commited on
Commit
9bb001a
·
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: OnlyFlow
3
+ emoji: 🐢
4
+ colorFrom: pink
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.16.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: 'Optical flow based motion conditioned video generation'
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import imageio
4
+ import numpy as np
5
+ import torch
6
+ import random
7
+
8
+ import spaces
9
+
10
+ import gradio as gr
11
+
12
+ import torchvision
13
+ import torchvision.transforms as T
14
+ from einops import rearrange
15
+ from huggingface_hub import hf_hub_download
16
+ from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
17
+ from torchvision.utils import flow_to_image
18
+
19
+ from diffusers import AutoencoderKL, MotionAdapter, UNet2DConditionModel
20
+ from diffusers import DDIMScheduler
21
+ from transformers import CLIPTextModel, CLIPTokenizer
22
+
23
+ from onlyflow.models.flow_adaptor import FlowEncoder, FlowAdaptor
24
+ from onlyflow.models.unet import UNetMotionModel
25
+ from onlyflow.pipelines.pipeline_animation_long import FlowCtrlPipeline
26
+ from tools.optical_flow import get_optical_flow
27
+
28
+
29
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
30
+ videos = rearrange(videos, "b c t h w -> t b c h w")
31
+ outputs = []
32
+ for x in videos:
33
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
34
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
35
+ if rescale:
36
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
37
+ x = (x * 255).numpy().astype(np.uint8)
38
+ outputs.append(x)
39
+
40
+ os.makedirs(os.path.dirname(path), exist_ok=True)
41
+ imageio.mimsave(path, outputs, fps=fps)
42
+
43
+ css = """
44
+ .toolbutton {
45
+ margin-buttom: 0em 0em 0em 0em;
46
+ max-width: 2.5em;
47
+ min-width: 2.5em !important;
48
+ height: 2.5em;
49
+ }
50
+ """
51
+
52
+
53
+ class AnimateController:
54
+ def __init__(self):
55
+
56
+ # config dirs
57
+ self.basedir = os.getcwd()
58
+ self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
59
+ self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
60
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
61
+ self.savedir = os.path.join(self.basedir, "samples")
62
+ os.makedirs(self.savedir, exist_ok=True)
63
+
64
+
65
+ ckpt_path = hf_hub_download('obvious-research/onlyflow', 'weights_fp16.ckpt')
66
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
67
+ self.flow_encoder_state_dict = ckpt['flow_encoder_state_dict']
68
+ self.attention_processor_state_dict = ckpt['attention_processor_state_dict']
69
+
70
+ self.tokenizer = None
71
+ self.text_encoder = None
72
+ self.vae = None
73
+ self.unet = None
74
+ self.motion_adapter = None
75
+
76
+ def update_base_model(self, base_model_id, progress=gr.Progress()):
77
+
78
+ progress(0, desc="Starting...")
79
+
80
+ self.tokenizer = CLIPTokenizer.from_pretrained(base_model_id, subfolder="tokenizer")
81
+ self.text_encoder = CLIPTextModel.from_pretrained(base_model_id, subfolder="text_encoder")
82
+ self.vae = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae")
83
+ self.unet = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet")
84
+
85
+ return base_model_id
86
+
87
+ def update_motion_module(self, motion_module_id, progress=gr.Progress()):
88
+ self.motion_adapter = MotionAdapter.from_pretrained(motion_module_id)
89
+
90
+ def animate(
91
+ self,
92
+ id_base_model,
93
+ id_motion_module,
94
+ prompt_textbox_positive,
95
+ prompt_textbox_negative,
96
+ seed_textbox,
97
+ input_video,
98
+ height,
99
+ width,
100
+ flow_scale,
101
+ cfg,
102
+ diffusion_steps,
103
+ temporal_ds,
104
+ ctx_stride
105
+ ):
106
+ #if any([x is None for x in [self.tokenizer, self.text_encoder, self.vae, self.unet, self.motion_adapter]]) or isinstance(self.unet, str):
107
+ self.update_base_model(id_base_model)
108
+ self.update_motion_module(id_motion_module)
109
+
110
+ self.unet = UNetMotionModel.from_unet2d(
111
+ self.unet,
112
+ motion_adapter=self.motion_adapter
113
+ )
114
+
115
+ self.raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).eval()
116
+
117
+ self.flow_encoder = FlowEncoder(
118
+ downscale_factor=8,
119
+ channels=[320, 640, 1280, 1280],
120
+ nums_rb=2,
121
+ ksize=1,
122
+ sk=True,
123
+ use_conv=False,
124
+ compression_factor=1,
125
+ temporal_attention_nhead=8,
126
+ positional_embeddings="sinusoidal",
127
+ num_positional_embeddings=16,
128
+ checkpointing=False
129
+ ).eval()
130
+
131
+ self.vae.requires_grad_(False)
132
+ self.text_encoder.requires_grad_(False)
133
+ self.unet.requires_grad_(False)
134
+ self.raft.requires_grad_(False)
135
+ self.flow_encoder.requires_grad_(False)
136
+
137
+ self.unet.set_all_attn(
138
+ flow_channels=[320, 640, 1280, 1280],
139
+ add_spatial=False,
140
+ add_temporal=True,
141
+ encoder_only=False,
142
+ query_condition=True,
143
+ key_value_condition=True,
144
+ flow_scale=1.0,
145
+ )
146
+
147
+ self.flow_adaptor = FlowAdaptor(self.unet, self.flow_encoder).eval()
148
+
149
+ # load the flow encoder weights
150
+ pose_enc_m, pose_enc_u = self.flow_adaptor.flow_encoder.load_state_dict(
151
+ self.flow_encoder_state_dict,
152
+ strict=False
153
+ )
154
+ assert len(pose_enc_m) == 0 and len(pose_enc_u) == 0
155
+
156
+ # load the attention processor weights
157
+ _, attention_processor_u = self.flow_adaptor.unet.load_state_dict(
158
+ self.attention_processor_state_dict,
159
+ strict=False
160
+ )
161
+ assert len(attention_processor_u) == 0
162
+
163
+ pipeline = FlowCtrlPipeline(
164
+ vae=self.vae,
165
+ text_encoder=self.text_encoder,
166
+ tokenizer=self.tokenizer,
167
+ unet=self.unet,
168
+ motion_adapter=self.motion_adapter,
169
+ flow_encoder=self.flow_encoder,
170
+ scheduler=DDIMScheduler.from_pretrained(id_base_model, subfolder="scheduler"),
171
+ )
172
+
173
+ if int(seed_textbox) > 0:
174
+ seed = int(seed_textbox)
175
+ else:
176
+ seed = random.randint(1, int(1e16))
177
+
178
+ return animate_diffusion(seed, pipeline, self.raft, input_video, prompt_textbox_positive, prompt_textbox_negative, width, height, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride)
179
+
180
+ @spaces.GPU(duration=150)
181
+ def animate_diffusion(seed, pipeline, raft_model, base_video, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, flow_scale, cfg, diffusion_steps, temporal_ds, context_stride):
182
+ savedir = './samples'
183
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
184
+ generator = torch.Generator(device="cpu")
185
+ generator.manual_seed(seed)
186
+
187
+ raft_model = raft_model.to(device)
188
+ pipeline = pipeline.to(device)
189
+
190
+ pixel_values = torchvision.io.read_video(base_video, output_format="TCHW", pts_unit='sec')[0][::temporal_ds]
191
+ print("Video loaded, shape:", pixel_values.shape)
192
+ if width_slider/height_slider > pixel_values.shape[3]/pixel_values.shape[2]:
193
+ print("Resizing video to fit width cause input video is not wide enough")
194
+ temp_height = int(width_slider * pixel_values.shape[2]/pixel_values.shape[3])
195
+ temp_width = width_slider
196
+ else:
197
+ print("Resizing video to fit height cause input video is not tall enough")
198
+ temp_height = height_slider
199
+ temp_width = int(height_slider * pixel_values.shape[3]/pixel_values.shape[2])
200
+ print("Resizing video to:", temp_height, temp_width)
201
+ pixel_values = T.Resize((temp_height, temp_width))(pixel_values)
202
+ pixel_values = T.CenterCrop((height_slider, width_slider))(pixel_values)
203
+ pixel_values = T.ConvertImageDtype(torch.float32)(pixel_values)[None, ...].contiguous().to(device)
204
+
205
+ save_sample_path_input = os.path.join(savedir, f"input.mp4")
206
+ pixel_values_save = pixel_values[0] * 255
207
+ pixel_values_save = pixel_values_save.cpu()
208
+ pixel_values_save = torch.permute(pixel_values_save, (0, 2, 3, 1))
209
+ torchvision.io.write_video(save_sample_path_input, pixel_values_save, fps=8)
210
+ del pixel_values_save
211
+
212
+ print("Video loaded, shape:", pixel_values.shape)
213
+ flow = get_optical_flow(
214
+ raft_model,
215
+ (pixel_values * 2) - 1,
216
+ pixel_values.shape[1] - 1,
217
+ encode_chunk_size=16,
218
+ ).to('cpu')
219
+
220
+ sample_flow = (flow_to_image(rearrange(flow[0], "c f h w -> f c h w"))) # N, 3, H, W
221
+ save_sample_path_flow = os.path.join(savedir, f"flow.mp4")
222
+ sample_flow = (sample_flow).cpu().to(torch.uint8).permute(0, 2, 3, 1)
223
+ torchvision.io.write_video(save_sample_path_flow, sample_flow, fps=8)
224
+ del sample_flow
225
+
226
+ original_flow_shape = flow.shape
227
+ print("Optical flow computed, shape:", flow.shape)
228
+ if flow.shape[2] < 16:
229
+ print("Video is too short, padding to 16 frames")
230
+ video_length = 16
231
+ n = 16 - flow.shape[2]
232
+ # create a tensor containing the last frame optical flow repeated n times
233
+ to_add = flow[:, :, -1].unsqueeze(2).expand(-1, -1, n, -1, -1)
234
+ flow = torch.cat([flow, to_add], dim=2).to(device)
235
+ elif flow.shape[2] > 16:
236
+ print("Video is too long, enabling windowing")
237
+ print("Enabling model CPU offload")
238
+ pipeline.enable_model_cpu_offload()
239
+ print("Enabling VAE slicing")
240
+ pipeline.enable_vae_slicing()
241
+ print("Enabling VAE tiling")
242
+ pipeline.enable_vae_tiling()
243
+
244
+ print("Enabling free noise")
245
+ pipeline.enable_free_noise(
246
+ context_length=16,
247
+ context_stride=context_stride,
248
+ )
249
+
250
+ import math
251
+
252
+ def find_divisors(n: int):
253
+ """
254
+ Return sorted list of all positive divisors of n.
255
+ Uses a sqrt(n) approach for efficiency.
256
+ """
257
+ divs = set()
258
+ limit = int(math.isqrt(n))
259
+ for i in range(1, limit + 1):
260
+ if n % i == 0:
261
+ divs.add(i)
262
+ divs.add(n // i)
263
+ return sorted(divs)
264
+
265
+ def multiples_in_range(k: int, min_val: int, max_val: int):
266
+ """
267
+ Return all multiples of k within [min_val, max_val].
268
+ """
269
+ if k == 0:
270
+ return []
271
+
272
+ # First multiple of k >= min_val
273
+ start = ((min_val + k - 1) // k) * k
274
+ # Last multiple of k <= max_val
275
+ end = (max_val // k) * k
276
+
277
+ return list(range(start, end + 1, k)) if start <= end else []
278
+
279
+ def adjust_video_length(original_length: int,
280
+ context_stride: int,
281
+ chunk_size: int,
282
+ temporal_split_size: int) -> int:
283
+ """
284
+ Find the minimal video_length >= original_length satisfying:
285
+ 1) (video_length - 16) is divisible by context_stride.
286
+ 2) EITHER (2*video_length) is divisible by temporal_split_size
287
+ OR (2*video_length) is divisible by chunk_size
288
+ (when 2*video_length is not multiple of temporal_split_size).
289
+ """
290
+
291
+ # We start at least at 16 (though in practice original_length likely > 16)
292
+ candidate = max(original_length, 16)
293
+
294
+ # We want (candidate - 16) % context_stride == 0
295
+ # so let n be the multiple to step.
296
+ # n is how many times we add `context_stride` beyond 16.
297
+ # This ensures (candidate - 16) is a multiple of context_stride.
298
+ # Then we check the second condition, else keep stepping.
299
+
300
+ # If candidate < 16, bump it to 16
301
+ if candidate < 16:
302
+ candidate = 16
303
+
304
+ # Make sure we jump to the correct "starting multiple" of context_stride
305
+ offset = (candidate - 16) % context_stride
306
+ if offset != 0:
307
+ candidate += (context_stride - offset) # jump to the next multiple
308
+
309
+ while True:
310
+ # Condition: (candidate - 16) is multiple of context_stride (already enforced by stepping)
311
+ # Check second part:
312
+ # - if (2*candidate) % temporal_split_size == 0, we are good
313
+ # - else we require (2*candidate) % chunk_size == 0
314
+ twoL = 2 * candidate
315
+ if (twoL % temporal_split_size == 0) or (twoL % chunk_size == 0):
316
+ return candidate
317
+
318
+ # Go to next valid candidate
319
+ candidate += context_stride
320
+
321
+ def find_valid_configs(original_video_length: int,
322
+ width: int,
323
+ height: int,
324
+ context_stride: int):
325
+ """
326
+ Generate all valid tuples (chunk_size, spatial_split_size, temporal_split_size, video_length)
327
+ subject to the constraints:
328
+ 1) chunk_size divides temporal_split_size
329
+ 2) chunk_size divides spatial_split_size
330
+ 3) chunk_size divides (2 * (width//64) * (height//64))
331
+ 4) if (2*video_length) % temporal_split_size != 0, then chunk_size divides (2*video_length)
332
+ 5) context_stride divides (video_length - 16)
333
+ 6) 128 <= spatial_split_size <= 512
334
+ 7) 1 <= temporal_split_size <= 32
335
+ 8) 1 <= chunk_size <= 16
336
+
337
+ We allow increasing original_video_length minimally if needed to satisfy constraints #4 and #5.
338
+ """
339
+
340
+ factor = 2 * (width // 64) * (height // 64)
341
+
342
+ # 1) find all possible chunk_size as divisors of factor, in [1..16]
343
+ possible_chunks = [d for d in find_divisors(factor) if 1 <= d <= 32]
344
+
345
+ # For storing results
346
+ valid_tuples = []
347
+
348
+ for chunk_size in possible_chunks:
349
+ # 2) generate all spatial_split_size in [128..512] that are multiples of chunk_size
350
+ spatial_splits = multiples_in_range(chunk_size, 480, 512)
351
+
352
+ # 3) generate all temporal_split_size in [1..32] that are multiples of chunk_size
353
+ temporal_splits = multiples_in_range(chunk_size, 1, 32)
354
+
355
+ for ssp in spatial_splits:
356
+ for tsp in temporal_splits:
357
+ # 4) & 5) Adjust video_length minimally to satisfy constraints
358
+ final_length = adjust_video_length(original_video_length,
359
+ context_stride,
360
+ chunk_size,
361
+ tsp)
362
+ # Now we have a valid (chunk_size, ssp, tsp, final_length)
363
+ valid_tuples.append((chunk_size, ssp, tsp, final_length))
364
+
365
+ return valid_tuples
366
+
367
+ def find_pareto_optimal(configs):
368
+ """
369
+ Given a list of tuples (chunk_size, spatial_split_size, temporal_split_size, video_length),
370
+ return the Pareto-optimal subset under the criteria:
371
+ - chunk_size: larger is better
372
+ - spatial_split_size: larger is better
373
+ - temporal_split_size: larger is better
374
+ - video_length: smaller is better
375
+ """
376
+
377
+ def dominates(A, B):
378
+ cA, sA, tA, lA = A
379
+ cB, sB, tB, lB = B
380
+
381
+ # A dominates B if:
382
+ # cA >= cB, sA >= sB, tA >= tB, and lA <= lB
383
+ # AND at least one of these is a strict inequality.
384
+
385
+ better_or_equal = (cA >= cB) and (tA >= tB) and (lA <= lB)
386
+ strictly_better = (cA > cB) or (tA > tB) or (lA < lB)
387
+
388
+ return better_or_equal and strictly_better
389
+
390
+ pareto = []
391
+ for i, cfg_i in enumerate(configs):
392
+ # Check if cfg_i is dominated by any cfg_j
393
+ is_dominated = False
394
+ for j, cfg_j in enumerate(configs):
395
+ if i == j:
396
+ continue
397
+ if dominates(cfg_j, cfg_i):
398
+ is_dominated = True
399
+ break
400
+ if not is_dominated:
401
+ pareto.append(cfg_i)
402
+
403
+ return pareto
404
+
405
+ print("Finding valid configurations...")
406
+ valid_configs = find_valid_configs(
407
+ original_video_length=flow.shape[2],
408
+ width=width_slider,
409
+ height=height_slider,
410
+ context_stride=context_stride
411
+ )
412
+
413
+ print("Found", len(valid_configs), "valid configurations")
414
+ print("Finding Pareto-optimal configurations...")
415
+ pareto_optimal = find_pareto_optimal(valid_configs)
416
+
417
+ print("Found", pareto_optimal)
418
+
419
+ criteria = lambda cs, sss, tss, vl: cs + tss - 3 * int(abs(flow.shape[2] - vl) / 10)
420
+ pareto_optimal.sort(key=lambda x: criteria(*x), reverse=True)
421
+
422
+ print("Found sorted", pareto_optimal)
423
+
424
+ solution = pareto_optimal[0]
425
+ chunk_size, spatial_split_size, temporal_split_size, video_length = solution
426
+
427
+ n = video_length - original_flow_shape[2]
428
+ to_add = flow[:, :, -1].unsqueeze(2).expand(-1, -1, n, -1, -1)
429
+ flow = torch.cat([flow, to_add], dim=2)
430
+
431
+ pipeline.enable_free_noise_split_inference(
432
+ temporal_split_size=temporal_split_size,
433
+ spatial_split_size=spatial_split_size
434
+ )
435
+ pipeline.unet.enable_forward_chunking(chunk_size)
436
+
437
+ print("Chunking enabled with chunk size:", chunk_size)
438
+ print("Temporal split size:", temporal_split_size)
439
+ print("Spatial split size:", spatial_split_size)
440
+ print("Context stride:", context_stride)
441
+ print("Temporal downscale:", temporal_ds)
442
+ print("Video length:", video_length)
443
+ print("Flow shape:", flow.shape)
444
+ else:
445
+ print("Video is just right, no padding or windowing needed")
446
+ flow = flow.to(device)
447
+ video_length = flow.shape[2]
448
+
449
+ sample_vid = pipeline(
450
+ prompt_textbox,
451
+ negative_prompt=negative_prompt_textbox,
452
+ optical_flow=flow,
453
+ num_inference_steps=diffusion_steps,
454
+ guidance_scale=cfg,
455
+ width=width_slider,
456
+ height=height_slider,
457
+ num_frames=video_length,
458
+ val_scale_factor_temporal=flow_scale,
459
+ generator=generator,
460
+ ).frames[0]
461
+
462
+ del flow
463
+ if device == "cuda":
464
+ torch.cuda.synchronize()
465
+ torch.cuda.empty_cache()
466
+
467
+ save_sample_path_video = os.path.join(savedir, f"sample.mp4")
468
+ sample_vid = sample_vid[:original_flow_shape[2]] * 255.
469
+ sample_vid = sample_vid.cpu().numpy()
470
+ sample_vid = np.transpose(sample_vid, axes=(0, 2, 3, 1))
471
+ torchvision.io.write_video(save_sample_path_video, sample_vid, fps=8)
472
+
473
+ return gr.Video(value=save_sample_path_flow), gr.Video(value=save_sample_path_video)
474
+
475
+ controller = AnimateController()
476
+
477
+
478
+ def find_closest_ratio(target_ratio):
479
+ width_list = list(reversed(range(256, 1025, 64)))
480
+ height_list = list(reversed(range(256, 1025, 64)))
481
+ ratio_list = [(h, w, w/h) for h in height_list for w in width_list]
482
+ ratio_list.sort(key=lambda x: abs(x[2] - target_ratio))
483
+ ratio_list = list(filter(lambda x: x[2] == ratio_list[0][2], ratio_list))
484
+ ratio_list.sort(key=lambda x: abs(x[0]*x[1] - 512*512))
485
+ return ratio_list[0][:2]
486
+
487
+
488
+ def find_dimension(video):
489
+ import av
490
+ container = av.open(open(video, 'rb'))
491
+ height, width = container.streams.video[0].height, container.streams.video[0].width
492
+ target_ratio = width / height
493
+ return find_closest_ratio(target_ratio)
494
+
495
+
496
+ def ui():
497
+ with gr.Blocks(css=css) as demo:
498
+ gr.Markdown(
499
+ """
500
+ # <p style="text-align:center;">OnlyFlow: Optical Flow based Motion Conditioning for Video Diffusion Models</p>
501
+ Mathis Koroglu, Hugo Caselles-Dupré, Guillaume Jeanneret Sanmiguel, Matthieu Cord<br>
502
+ [Arxiv Report](https://arxiv.org/abs/2411.10501) | [Project Page](https://obvious-research.github.io/onlyflow/) | [Github](https://github.com/obvious-research/onlyflow/)
503
+ """
504
+ )
505
+ gr.Markdown(
506
+ """
507
+ ### Quick Start:
508
+
509
+ 1. Select desired `Base Model`.
510
+ 2. Select `Motion Module`. We recommend trying guoyww/animatediff-motion-adapter-v1-5-3 for the best results.
511
+ 3. Provide `Positive Prompt` and `Negative Prompt`. You are encouraged to refer to each model's webpage on HuggingFace Hub or CivitAI to learn how to write prompts for them.
512
+ 4. Upload a video to extract optical flow from.
513
+ 5. Select a 'Flow Scale' to modulate the input video optical flow conditioning.
514
+ 6. Select a 'CFG' and 'Diffusion Steps' to control the quality of the generated video and prompt adherence.
515
+ 7. Select a 'Temporal Downsample' to reduce the number of frames in the input video.
516
+ 8. If you want to use a custom dimension, check the `Custom Dimension` box and adjust the `Width` and `Height` sliders.
517
+ 9. If the video is too long, you can adjust the generation window offset with the `Context Stride` slider.
518
+ 10. Click `Generate`, wait for ~1/3 min, and enjoy the result!
519
+
520
+ If you have any error concerning GPU limits, please try again later when your ZeroGPU quota is reset, or try with a shorter video.
521
+ Otherwise, you can also duplicate this space and select a custom GPU plan.
522
+ """
523
+ )
524
+ with gr.Row():
525
+ with gr.Column():
526
+
527
+ gr.Markdown("# INPUTS")
528
+
529
+ with gr.Row(equal_height=True, show_progress=True):
530
+ base_model = gr.Dropdown(
531
+ label="Select or type a base model id",
532
+ choices=[
533
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
534
+ "digiplay/Photon_v1",
535
+ ],
536
+ interactive=True,
537
+ scale=4,
538
+ allow_custom_value=True,
539
+ show_label=True
540
+ )
541
+ base_model_btn = gr.Button(value="Update", scale=1, size='lg')
542
+ with gr.Row(equal_height=True, show_progress=True):
543
+ motion_module = gr.Dropdown(
544
+ label="Select or type a motion module id",
545
+ choices=[
546
+ "guoyww/animatediff-motion-adapter-v1-5-3",
547
+ "guoyww/animatediff-motion-adapter-v1-5-2"
548
+ ],
549
+ interactive=True,
550
+ scale=4
551
+ )
552
+ motion_module_btn = gr.Button(value="Update", scale=1, size='lg')
553
+
554
+ base_model_btn.click(fn=controller.update_base_model, inputs=[base_model])
555
+ motion_module_btn.click(fn=controller.update_motion_module, inputs=[motion_module])
556
+
557
+ prompt_textbox_positive = gr.Textbox(label="Positive Prompt", lines=3)
558
+ prompt_textbox_negative = gr.Textbox(label="Negative Prompt", lines=2, value="worst quality, low quality, nsfw, logo")
559
+
560
+ flow_scale = gr.Slider(label="Flow Scale", value=1.0, minimum=0, maximum=2, step=0.025)
561
+ diffusion_steps = gr.Slider(label="Diffusion Steps", value=25, minimum=0, maximum=100, step=1)
562
+ cfg = gr.Slider(label="CFG", value=7.5, minimum=0, maximum=30, step=0.1)
563
+
564
+ temporal_ds = gr.Slider(label="Temporal Downsample", value=1, minimum=1, maximum=30, step=1)
565
+
566
+ input_video = gr.Video(label="Input Video", interactive=True)
567
+ ctx_stride = gr.State(12)
568
+
569
+ with gr.Accordion("Advanced", open=False):
570
+ use_custom_dim = gr.Checkbox(label="Custom Dimension", value=False)
571
+
572
+ with gr.Row(equal_height=True):
573
+
574
+ height, width = gr.State(512), gr.State(512)
575
+
576
+ @gr.render(inputs=[use_custom_dim, input_video])
577
+ def render_custom_dim(use_custom_dim, input_video):
578
+ if input_video is not None:
579
+ loc_height, loc_width = find_dimension(input_video)
580
+ else:
581
+ loc_height, loc_width = 512, 512
582
+ slider_width = gr.Slider(label="Width", value=loc_width, minimum=256, maximum=1024,
583
+ step=64, visible=use_custom_dim)
584
+ slider_height = gr.Slider(label="Height", value=loc_height, minimum=256, maximum=1024,
585
+ step=64, visible=use_custom_dim)
586
+
587
+ slider_width.change(lambda x: x, inputs=[slider_width], outputs=[width])
588
+ slider_height.change(lambda x: x, inputs=[slider_height], outputs=[height])
589
+
590
+
591
+ with gr.Row():
592
+ @gr.render(inputs=input_video)
593
+ def render_ctx_stride(input_video):
594
+ if input_video is not None:
595
+ video = open(input_video, 'rb')
596
+ import av
597
+ container = av.open(video)
598
+ num_frames = container.streams.video[0].frames
599
+ if num_frames > 17:
600
+ stride_slider = gr.Slider(label="Context Stride", value=12, minimum=1, maximum=16, step=1)
601
+ stride_slider.input(lambda x: x, inputs=[stride_slider], outputs=[ctx_stride])
602
+ if num_frames > 32:
603
+ gr.Warning(f"Video is long ({num_frames} frames), consider using a shorter video, increasing the context stride, or selecting a custom GPU plan.")
604
+ elif num_frames > 64:
605
+ raise gr.Error(f"Video is too long ({num_frames} frames), please use a shorter video, increase the context stride, or select a custom GPU plan. The current parameters won't allow generation on ZeroGPU.")
606
+
607
+ with gr.Row(equal_height=True):
608
+ seed_textbox = gr.Textbox(label="Seed", value='-1')
609
+
610
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
611
+ seed_button.click(
612
+ fn=lambda: random.randint(1, int(1e16)),
613
+ inputs=[],
614
+ outputs=[seed_textbox]
615
+ )
616
+
617
+ with gr.Row():
618
+ clear_btn = gr.ClearButton(value="Clear & Reset", size='lg', variant='secondary', scale=1)
619
+ generate_button = gr.Button(value="Generate", variant='primary', scale=2, size='lg')
620
+
621
+ clear_btn.add([base_model, motion_module, input_video, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, use_custom_dim, ctx_stride])
622
+
623
+ with gr.Column():
624
+
625
+ gr.Markdown("# OUTPUTS")
626
+
627
+ result_optical_flow = gr.Video(label="Optical Flow", interactive=False)
628
+ result_video = gr.Video(label="Generated Animation", interactive=False)
629
+
630
+ inputs = [base_model, motion_module, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, input_video, height, width, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride]
631
+ outputs = [result_optical_flow, result_video]
632
+
633
+ generate_button.click(fn=controller.animate, inputs=inputs, outputs=outputs)
634
+
635
+ return demo
636
+
637
+
638
+ if __name__ == "__main__":
639
+ demo = ui()
640
+ demo.queue(max_size=20)
641
+ demo.launch()
onlyflow/data/dataset_idx.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from io import BytesIO
3
+
4
+ import torch
5
+ import torchvision
6
+ import torchvision.transforms.v2 as transforms
7
+ import wids
8
+ from torch.utils.data import DataLoader
9
+
10
+
11
+ def _video_shortener(video_tensor, length, generator=None):
12
+ start = torch.randint(0, video_tensor.shape[0] - length, (1,), generator=generator)
13
+ return video_tensor[start:start + length]
14
+
15
+
16
+ def select_video_extract(length=16, generator=None):
17
+ return functools.partial(_video_shortener, length=length, generator=generator)
18
+
19
+
20
+ def my_collate_fn(batch):
21
+ videos = torch.stack([sample[0] for sample in batch])
22
+ txts = [sample[1] for sample in batch]
23
+
24
+ return videos, txts
25
+
26
+
27
+ class WebVidDataset(wids.ShardListDataset):
28
+
29
+ def __init__(self, shards, cache_dir, video_length=16, video_size=256, video_length_offset=1, val=False, seed=42,
30
+ **kwargs):
31
+
32
+ self.val = val
33
+ self.generator = torch.Generator()
34
+ self.generator.manual_seed(seed)
35
+ self.generator_init_state = self.generator.get_state()
36
+ super().__init__(shards, cache_dir=cache_dir, keep=True, **kwargs)
37
+
38
+ if isinstance(video_size, int):
39
+ video_size = (video_size, video_size)
40
+
41
+ self.video_size = video_size
42
+
43
+ for size in video_size:
44
+ if size % 8 != 0:
45
+ raise ValueError("video_size must be divisible by 8")
46
+
47
+ self.transform = transforms.Compose(
48
+ [
49
+ select_video_extract(length=video_length + video_length_offset, generator=self.generator),
50
+ transforms.Resize(size=video_size),
51
+ transforms.RandomCrop(size=video_size) if not self.val else transforms.CenterCrop(size=video_size),
52
+ transforms.RandomHorizontalFlip() if not self.val else transforms.Identity(),
53
+ ]
54
+ )
55
+
56
+ self.add_transform(self._make_sample)
57
+
58
+ def _make_sample(self, sample):
59
+ if self.val:
60
+ self.generator.set_state(self.generator_init_state)
61
+ video = torchvision.io.read_video(BytesIO(sample[".mp4"].read()), output_format="TCHW", pts_unit='sec')[0]
62
+ label = sample[".txt"]
63
+ return self.transform(video), label
64
+
65
+
66
+ if __name__ == "__main__":
67
+
68
+ dataset = WebVidDataset(
69
+ tar_index=0,
70
+ root_path='/users/Etu9/3711799/onlyflow/data/webvid/desc.json',
71
+ video_length=16,
72
+ video_size=256,
73
+ video_length_offset=0,
74
+ )
75
+
76
+ sampler = wids.DistributedChunkedSampler(dataset, chunksize=1000, shuffle=True)
77
+ dataloader = DataLoader(
78
+ dataset,
79
+ collate_fn=my_collate_fn,
80
+ batch_size=4,
81
+ sampler=sampler,
82
+ num_workers=4
83
+ )
84
+
85
+ for i, (images, labels) in enumerate(dataloader):
86
+ print(i, images.shape, labels)
87
+ if i > 10:
88
+ break
onlyflow/data/dataset_itr.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ from io import BytesIO
4
+
5
+ import torch
6
+ import torchvision
7
+ import torchvision.transforms.v2 as transforms
8
+ import webdataset as wds
9
+
10
+
11
+ def _video_shortener(video_tensor, length):
12
+ start = torch.randint(0, video_tensor.shape[0] - length, (1,))
13
+ return video_tensor[start:start + length]
14
+
15
+
16
+ def select_video_extract(length=16):
17
+ return functools.partial(_video_shortener, length=length)
18
+
19
+
20
+ def my_collate_fn(batch):
21
+ output = {}
22
+ for key in batch[0].keys():
23
+ if key == 'video':
24
+ output[key] = torch.stack([sample[key] for sample in batch])
25
+ else:
26
+ output[key] = [sample[key] for sample in batch]
27
+
28
+ return output
29
+
30
+
31
+ def map_mp4(sample):
32
+ return torchvision.io.read_video(BytesIO(sample), output_format="TCHW", pts_unit='sec')[0]
33
+
34
+
35
+ def map_txt(sample):
36
+ return sample.decode("utf-8")
37
+
38
+
39
+ class WebVidDataset(wds.DataPipeline):
40
+ def __init__(self, batch_size, tar_index, root_path, video_length=16, video_size=256, video_length_offset=0,
41
+ horizontal_flip=True, seed=None):
42
+
43
+ self.dataset_full_path = os.path.join(root_path, f'webvid-uw-{{{tar_index}}}.tar')
44
+
45
+ if isinstance(video_size, int):
46
+ video_size = (video_size, video_size)
47
+
48
+ for size in video_size:
49
+ if size % 8 != 0:
50
+ raise ValueError("video_size must be divisible by 8")
51
+
52
+ self.pipeline = [
53
+ wds.SimpleShardList('file:' + str(self.dataset_full_path), seed=seed),
54
+ wds.shuffle(50),
55
+ wds.split_by_node,
56
+ wds.tarfile_to_samples(),
57
+ wds.shuffle(100),
58
+ wds.split_by_worker,
59
+ wds.map_dict(
60
+ mp4=map_mp4,
61
+ txt=map_txt,
62
+ ),
63
+ wds.map_dict(
64
+ mp4=transforms.Compose(
65
+ [
66
+ select_video_extract(length=video_length + video_length_offset),
67
+ transforms.Resize(size=video_size),
68
+ transforms.RandomCrop(size=video_size),
69
+ transforms.RandomHorizontalFlip() if horizontal_flip else transforms.Identity,
70
+ ]
71
+ )
72
+ ),
73
+ wds.rename_keys(video="mp4", text='txt', keep_unselected=True),
74
+ wds.batched(batch_size, collation_fn=my_collate_fn, partial=True)
75
+ ]
76
+
77
+ super().__init__(self.pipeline)
78
+
79
+ self.batch_size = batch_size
80
+ self.video_length = video_length
81
+ self.video_size = video_size
onlyflow/models/attention.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ from diffusers.models.attention import GatedSelfAttentionDense, FeedForward, _chunked_feed_forward
18
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
19
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero
20
+ from diffusers.utils import logging
21
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
22
+ from torch import nn
23
+
24
+ from onlyflow.models.attention_processor import Attention
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ @maybe_allow_in_graph
30
+ class BasicTransformerBlock(nn.Module):
31
+ r"""
32
+ A basic Transformer block.
33
+
34
+ Parameters:
35
+ dim (`int`): The number of channels in the input and output.
36
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
37
+ attention_head_dim (`int`): The number of channels in each head.
38
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
39
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
40
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
41
+ num_embeds_ada_norm (:
42
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
43
+ attention_bias (:
44
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
45
+ only_cross_attention (`bool`, *optional*):
46
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
47
+ double_self_attention (`bool`, *optional*):
48
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
49
+ upcast_attention (`bool`, *optional*):
50
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
51
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
52
+ Whether to use learnable elementwise affine parameters for normalization.
53
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
54
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
55
+ final_dropout (`bool` *optional*, defaults to False):
56
+ Whether to apply a final dropout after the last feed-forward layer.
57
+ attention_type (`str`, *optional*, defaults to `"default"`):
58
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
59
+ positional_embeddings (`str`, *optional*, defaults to `None`):
60
+ The type of positional embeddings to apply to.
61
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
62
+ The maximum number of positional embeddings to apply.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ dim: int,
68
+ num_attention_heads: int,
69
+ attention_head_dim: int,
70
+ dropout=0.0,
71
+ cross_attention_dim: Optional[int] = None,
72
+ activation_fn: str = "geglu",
73
+ num_embeds_ada_norm: Optional[int] = None,
74
+ attention_bias: bool = False,
75
+ only_cross_attention: bool = False,
76
+ double_self_attention: bool = False,
77
+ upcast_attention: bool = False,
78
+ norm_elementwise_affine: bool = True,
79
+ norm_type: str = "layer_norm",
80
+ # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
81
+ norm_eps: float = 1e-5,
82
+ final_dropout: bool = False,
83
+ attention_type: str = "default",
84
+ positional_embeddings: Optional[str] = None,
85
+ num_positional_embeddings: Optional[int] = None,
86
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
87
+ ada_norm_bias: Optional[int] = None,
88
+ ff_inner_dim: Optional[int] = None,
89
+ ff_bias: bool = True,
90
+ attention_out_bias: bool = True,
91
+ ):
92
+ super().__init__()
93
+ self.dim = dim
94
+ self.num_attention_heads = num_attention_heads
95
+ self.attention_head_dim = attention_head_dim
96
+ self.dropout = dropout
97
+ self.cross_attention_dim = cross_attention_dim
98
+ self.activation_fn = activation_fn
99
+ self.attention_bias = attention_bias
100
+ self.double_self_attention = double_self_attention
101
+ self.norm_elementwise_affine = norm_elementwise_affine
102
+ self.positional_embeddings = positional_embeddings
103
+ self.num_positional_embeddings = num_positional_embeddings
104
+ self.only_cross_attention = only_cross_attention
105
+
106
+ # We keep these boolean flags for backward-compatibility.
107
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
108
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
109
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
110
+ self.use_layer_norm = norm_type == "layer_norm"
111
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
112
+
113
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
114
+ raise ValueError(
115
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
116
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
117
+ )
118
+
119
+ self.norm_type = norm_type
120
+ self.num_embeds_ada_norm = num_embeds_ada_norm
121
+
122
+ if positional_embeddings and (num_positional_embeddings is None):
123
+ raise ValueError(
124
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
125
+ )
126
+
127
+ if positional_embeddings == "sinusoidal":
128
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
129
+ else:
130
+ self.pos_embed = None
131
+
132
+ # Define 3 blocks. Each block has its own normalization layer.
133
+ # 1. Self-Attn
134
+ if norm_type == "ada_norm":
135
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
136
+ elif norm_type == "ada_norm_zero":
137
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
138
+ elif norm_type == "ada_norm_continuous":
139
+ self.norm1 = AdaLayerNormContinuous(
140
+ dim,
141
+ ada_norm_continous_conditioning_embedding_dim,
142
+ norm_elementwise_affine,
143
+ norm_eps,
144
+ ada_norm_bias,
145
+ "rms_norm",
146
+ )
147
+ else:
148
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
149
+
150
+ self.attn1 = Attention(
151
+ query_dim=dim,
152
+ heads=num_attention_heads,
153
+ dim_head=attention_head_dim,
154
+ dropout=dropout,
155
+ bias=attention_bias,
156
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
157
+ upcast_attention=upcast_attention,
158
+ out_bias=attention_out_bias,
159
+ )
160
+
161
+ # 2. Cross-Attn
162
+ if cross_attention_dim is not None or double_self_attention:
163
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
164
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
165
+ # the second cross attention block.
166
+ if norm_type == "ada_norm":
167
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
168
+ elif norm_type == "ada_norm_continuous":
169
+ self.norm2 = AdaLayerNormContinuous(
170
+ dim,
171
+ ada_norm_continous_conditioning_embedding_dim,
172
+ norm_elementwise_affine,
173
+ norm_eps,
174
+ ada_norm_bias,
175
+ "rms_norm",
176
+ )
177
+ else:
178
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
179
+
180
+ self.attn2 = Attention(
181
+ query_dim=dim,
182
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
183
+ heads=num_attention_heads,
184
+ dim_head=attention_head_dim,
185
+ dropout=dropout,
186
+ bias=attention_bias,
187
+ upcast_attention=upcast_attention,
188
+ out_bias=attention_out_bias,
189
+ ) # is self-attn if encoder_hidden_states is none
190
+ else:
191
+ if norm_type == "ada_norm_single": # For Latte
192
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
193
+ else:
194
+ self.norm2 = None
195
+ self.attn2 = None
196
+
197
+ # 3. Feed-forward
198
+ if norm_type == "ada_norm_continuous":
199
+ self.norm3 = AdaLayerNormContinuous(
200
+ dim,
201
+ ada_norm_continous_conditioning_embedding_dim,
202
+ norm_elementwise_affine,
203
+ norm_eps,
204
+ ada_norm_bias,
205
+ "layer_norm",
206
+ )
207
+
208
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
209
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
210
+ elif norm_type == "layer_norm_i2vgen":
211
+ self.norm3 = None
212
+
213
+ self.ff = FeedForward(
214
+ dim,
215
+ dropout=dropout,
216
+ activation_fn=activation_fn,
217
+ final_dropout=final_dropout,
218
+ inner_dim=ff_inner_dim,
219
+ bias=ff_bias,
220
+ )
221
+
222
+ # 4. Fuser
223
+ if attention_type == "gated" or attention_type == "gated-text-image":
224
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
225
+
226
+ # 5. Scale-shift for PixArt-Alpha.
227
+ if norm_type == "ada_norm_single":
228
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim ** 0.5)
229
+
230
+ # let chunk size default to None
231
+ self._chunk_size = None
232
+ self._chunk_dim = 0
233
+
234
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
235
+ # Sets chunk feed-forward
236
+ self._chunk_size = chunk_size
237
+ self._chunk_dim = dim
238
+
239
+ def forward(
240
+ self,
241
+ hidden_states: torch.Tensor,
242
+ attention_mask: Optional[torch.Tensor] = None,
243
+ encoder_hidden_states: Optional[torch.Tensor] = None,
244
+ encoder_attention_mask: Optional[torch.Tensor] = None,
245
+ timestep: Optional[torch.LongTensor] = None,
246
+ cross_attention_kwargs: Dict[str, Any] = None,
247
+ class_labels: Optional[torch.LongTensor] = None,
248
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
249
+ ) -> torch.Tensor:
250
+ if cross_attention_kwargs is not None:
251
+ if cross_attention_kwargs.get("scale", None) is not None:
252
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
253
+
254
+ # Notice that normalization is always applied before the real computation in the following blocks.
255
+ # 0. Self-Attention
256
+ batch_size = hidden_states.shape[0]
257
+
258
+ if self.norm_type == "ada_norm":
259
+ norm_hidden_states = self.norm1(hidden_states, timestep)
260
+ elif self.norm_type == "ada_norm_zero":
261
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
262
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
263
+ )
264
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
265
+ norm_hidden_states = self.norm1(hidden_states)
266
+ elif self.norm_type == "ada_norm_continuous":
267
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
268
+ elif self.norm_type == "ada_norm_single":
269
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
270
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
271
+ ).chunk(6, dim=1)
272
+ norm_hidden_states = self.norm1(hidden_states)
273
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
274
+ else:
275
+ raise ValueError("Incorrect norm used")
276
+
277
+ if self.pos_embed is not None:
278
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
279
+
280
+ # 1. Prepare GLIGEN inputs
281
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
282
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
283
+
284
+ attn_output = self.attn1(
285
+ hidden_states=norm_hidden_states,
286
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
287
+ attention_mask=attention_mask,
288
+ **cross_attention_kwargs,
289
+ )
290
+
291
+ if self.norm_type == "ada_norm_zero":
292
+ attn_output = gate_msa.unsqueeze(1) * attn_output
293
+ elif self.norm_type == "ada_norm_single":
294
+ attn_output = gate_msa * attn_output
295
+
296
+ hidden_states = attn_output + hidden_states
297
+ if hidden_states.ndim == 4:
298
+ hidden_states = hidden_states.squeeze(1)
299
+
300
+ # 1.2 GLIGEN Control
301
+ if gligen_kwargs is not None:
302
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
303
+
304
+ # 3. Cross-Attention
305
+ if self.attn2 is not None:
306
+ if self.norm_type == "ada_norm":
307
+ norm_hidden_states = self.norm2(hidden_states, timestep)
308
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
309
+ norm_hidden_states = self.norm2(hidden_states)
310
+ elif self.norm_type == "ada_norm_single":
311
+ # For PixArt norm2 isn't applied here:
312
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
313
+ norm_hidden_states = hidden_states
314
+ elif self.norm_type == "ada_norm_continuous":
315
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
316
+ else:
317
+ raise ValueError("Incorrect norm")
318
+
319
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
320
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
321
+
322
+ attn_output = self.attn2(
323
+ hidden_states=norm_hidden_states,
324
+ encoder_hidden_states=encoder_hidden_states,
325
+ attention_mask=encoder_attention_mask,
326
+ **cross_attention_kwargs,
327
+ )
328
+ hidden_states = attn_output + hidden_states
329
+
330
+ # 4. Feed-forward
331
+ # i2vgen doesn't have this norm 🤷‍♂️
332
+ if self.norm_type == "ada_norm_continuous":
333
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
334
+ elif not self.norm_type == "ada_norm_single":
335
+ norm_hidden_states = self.norm3(hidden_states)
336
+
337
+ if self.norm_type == "ada_norm_zero":
338
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
339
+
340
+ if self.norm_type == "ada_norm_single":
341
+ norm_hidden_states = self.norm2(hidden_states)
342
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
343
+
344
+ if self._chunk_size is not None:
345
+ # "feed_forward_chunk_size" can be used to save memory
346
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
347
+ else:
348
+ ff_output = self.ff(norm_hidden_states)
349
+
350
+ if self.norm_type == "ada_norm_zero":
351
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
352
+ elif self.norm_type == "ada_norm_single":
353
+ ff_output = gate_mlp * ff_output
354
+
355
+ hidden_states = ff_output + hidden_states
356
+ if hidden_states.ndim == 4:
357
+ hidden_states = hidden_states.squeeze(1)
358
+
359
+ return hidden_states
onlyflow/models/attention_processor.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import logging
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.nn.init as init
9
+ from diffusers.models.attention_processor import Attention as AttentionBase
10
+ from diffusers.models.attention_processor import AttnProcessor2_0 as AttnProcessor2_0_Base, SpatialNorm, AttnProcessor
11
+ from diffusers.models.attention_processor import IPAdapterAttnProcessor2_0 as IPAdapterAttnProcessor2_0_Base
12
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @maybe_allow_in_graph
18
+ class Attention(AttentionBase):
19
+ r"""
20
+ A cross attention layer.
21
+
22
+ Parameters:
23
+ query_dim (`int`):
24
+ The number of channels in the query.
25
+ cross_attention_dim (`int`, *optional*):
26
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
27
+ heads (`int`, *optional*, defaults to 8):
28
+ The number of heads to use for multi-head attention.
29
+ kv_heads (`int`, *optional*, defaults to `None`):
30
+ The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
31
+ `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
32
+ Query Attention (MQA) otherwise GQA is used.
33
+ dim_head (`int`, *optional*, defaults to 64):
34
+ The number of channels in each head.
35
+ dropout (`float`, *optional*, defaults to 0.0):
36
+ The dropout probability to use.
37
+ bias (`bool`, *optional*, defaults to False):
38
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
39
+ upcast_attention (`bool`, *optional*, defaults to False):
40
+ Set to `True` to upcast the attention computation to `float32`.
41
+ upcast_softmax (`bool`, *optional*, defaults to False):
42
+ Set to `True` to upcast the softmax computation to `float32`.
43
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
44
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
45
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
46
+ The number of groups to use for the group norm in the cross attention.
47
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
48
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
49
+ norm_num_groups (`int`, *optional*, defaults to `None`):
50
+ The number of groups to use for the group norm in the attention.
51
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
52
+ The number of channels to use for the spatial normalization.
53
+ out_bias (`bool`, *optional*, defaults to `True`):
54
+ Set to `True` to use a bias in the output linear layer.
55
+ scale_qk (`bool`, *optional*, defaults to `True`):
56
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
57
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
58
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
59
+ `added_kv_proj_dim` is not `None`.
60
+ eps (`float`, *optional*, defaults to 1e-5):
61
+ An additional value added to the denominator in group normalization that is used for numerical stability.
62
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
63
+ A factor to rescale the output by dividing it with this value.
64
+ residual_connection (`bool`, *optional*, defaults to `False`):
65
+ Set to `True` to add the residual connection to the output.
66
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
67
+ Set to `True` if the attention block is loaded from a deprecated state dict.
68
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
69
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
70
+ `AttnProcessor` otherwise.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ query_dim: int,
76
+ cross_attention_dim: Optional[int] = None,
77
+ heads: int = 8,
78
+ kv_heads: Optional[int] = None,
79
+ dim_head: int = 64,
80
+ dropout: float = 0.0,
81
+ bias: bool = False,
82
+ upcast_attention: bool = False,
83
+ upcast_softmax: bool = False,
84
+ cross_attention_norm: Optional[str] = None,
85
+ cross_attention_norm_num_groups: int = 32,
86
+ qk_norm: Optional[str] = None,
87
+ added_kv_proj_dim: Optional[int] = None,
88
+ added_proj_bias: Optional[bool] = True,
89
+ norm_num_groups: Optional[int] = None,
90
+ spatial_norm_dim: Optional[int] = None,
91
+ out_bias: bool = True,
92
+ scale_qk: bool = True,
93
+ only_cross_attention: bool = False,
94
+ eps: float = 1e-5,
95
+ rescale_output_factor: float = 1.0,
96
+ residual_connection: bool = False,
97
+ _from_deprecated_attn_block: bool = False,
98
+ processor: Optional["AttnProcessor"] = None,
99
+ out_dim: int = None,
100
+ context_pre_only=None,
101
+ pre_only=False,
102
+ ):
103
+ nn.Module.__init__(self)
104
+
105
+ # To prevent circular import.
106
+ from diffusers.models.normalization import FP32LayerNorm, RMSNorm
107
+
108
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
109
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
110
+ self.query_dim = query_dim
111
+ self.use_bias = bias
112
+ self.is_cross_attention = cross_attention_dim is not None
113
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
114
+ self.upcast_attention = upcast_attention
115
+ self.upcast_softmax = upcast_softmax
116
+ self.rescale_output_factor = rescale_output_factor
117
+ self.residual_connection = residual_connection
118
+ self.dropout = dropout
119
+ self.fused_projections = False
120
+ self.out_dim = out_dim if out_dim is not None else query_dim
121
+ self.context_pre_only = context_pre_only
122
+ self.pre_only = pre_only
123
+
124
+ # we make use of this private variable to know whether this class is loaded
125
+ # with an deprecated state dict so that we can convert it on the fly
126
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
127
+
128
+ self.scale_qk = scale_qk
129
+ self.scale = dim_head ** -0.5 if self.scale_qk else 1.0
130
+
131
+ self.heads = out_dim // dim_head if out_dim is not None else heads
132
+ # for slice_size > 0 the attention score computation
133
+ # is split across the batch axis to save memory
134
+ # You can set slice_size with `set_attention_slice`
135
+ self.sliceable_head_dim = heads
136
+
137
+ self.added_kv_proj_dim = added_kv_proj_dim
138
+ self.only_cross_attention = only_cross_attention
139
+
140
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
141
+ raise ValueError(
142
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
143
+ )
144
+
145
+ if norm_num_groups is not None:
146
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
147
+ else:
148
+ self.group_norm = None
149
+
150
+ if spatial_norm_dim is not None:
151
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
152
+ else:
153
+ self.spatial_norm = None
154
+
155
+ if qk_norm is None:
156
+ self.norm_q = None
157
+ self.norm_k = None
158
+ elif qk_norm == "layer_norm":
159
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps)
160
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps)
161
+ elif qk_norm == "fp32_layer_norm":
162
+ self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
163
+ self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
164
+ elif qk_norm == "layer_norm_across_heads":
165
+ # Lumina applys qk norm across all heads
166
+ self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
167
+ self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
168
+ elif qk_norm == "rms_norm":
169
+ self.norm_q = RMSNorm(dim_head, eps=eps)
170
+ self.norm_k = RMSNorm(dim_head, eps=eps)
171
+ else:
172
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
173
+
174
+ if cross_attention_norm is None:
175
+ self.norm_cross = None
176
+ elif cross_attention_norm == "layer_norm":
177
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
178
+ elif cross_attention_norm == "group_norm":
179
+ if self.added_kv_proj_dim is not None:
180
+ # The given `encoder_hidden_states` are initially of shape
181
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
182
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
183
+ # before the projection, so we need to use `added_kv_proj_dim` as
184
+ # the number of channels for the group norm.
185
+ norm_cross_num_channels = added_kv_proj_dim
186
+ else:
187
+ norm_cross_num_channels = self.cross_attention_dim
188
+
189
+ self.norm_cross = nn.GroupNorm(
190
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
191
+ )
192
+ else:
193
+ raise ValueError(
194
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
195
+ )
196
+
197
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
198
+
199
+ if not self.only_cross_attention:
200
+ # only relevant for the `AddedKVProcessor` classes
201
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
202
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
203
+ else:
204
+ self.to_k = None
205
+ self.to_v = None
206
+
207
+ self.added_proj_bias = added_proj_bias
208
+ if self.added_kv_proj_dim is not None:
209
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
210
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
211
+ if self.context_pre_only is not None:
212
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
213
+
214
+ if not self.pre_only:
215
+ self.to_out = nn.ModuleList([])
216
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
217
+ self.to_out.append(nn.Dropout(dropout))
218
+
219
+ if self.context_pre_only is not None and not self.context_pre_only:
220
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
221
+
222
+ if qk_norm is not None and added_kv_proj_dim is not None:
223
+ if qk_norm == "fp32_layer_norm":
224
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
225
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
226
+ elif qk_norm == "rms_norm":
227
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
228
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
229
+ else:
230
+ self.norm_added_q = None
231
+ self.norm_added_k = None
232
+
233
+ # set attention processor
234
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
235
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
236
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
237
+ if processor is None:
238
+ processor = (
239
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
240
+ )
241
+ self.set_processor(processor)
242
+
243
+ def forward(
244
+ self,
245
+ hidden_states: torch.Tensor,
246
+ encoder_hidden_states: Optional[torch.Tensor] = None,
247
+ attention_mask: Optional[torch.Tensor] = None,
248
+ **cross_attention_kwargs,
249
+ ) -> torch.Tensor:
250
+ r"""
251
+ The forward method of the `Attention` class.
252
+
253
+ Args:
254
+ hidden_states (`torch.Tensor`):
255
+ The hidden states of the query.
256
+ encoder_hidden_states (`torch.Tensor`, *optional*):
257
+ The hidden states of the encoder.
258
+ attention_mask (`torch.Tensor`, *optional*):
259
+ The attention mask to use. If `None`, no mask is applied.
260
+ **cross_attention_kwargs:
261
+ Additional keyword arguments to pass along to the cross attention.
262
+
263
+ Returns:
264
+ `torch.Tensor`: The output of the attention layer.
265
+ """
266
+ # The `Attention` class can call different attention processors / attention functions
267
+ # here we simply pass along all tensors to the selected processor class
268
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
269
+
270
+ return self.processor(
271
+ self,
272
+ hidden_states=hidden_states,
273
+ encoder_hidden_states=encoder_hidden_states,
274
+ attention_mask=attention_mask,
275
+ **cross_attention_kwargs,
276
+ )
277
+
278
+
279
+ class AttnProcessor2_0(AttnProcessor2_0_Base):
280
+ def __call__(
281
+ self,
282
+ attn: Attention,
283
+ hidden_states: torch.Tensor,
284
+ encoder_hidden_states: Optional[torch.Tensor] = None,
285
+ attention_mask: Optional[torch.Tensor] = None,
286
+ temb: Optional[torch.Tensor] = None,
287
+ flow_feature: Optional[torch.Tensor] = None,
288
+ flow_scale: Optional[float] = None,
289
+ *args,
290
+ **kwargs,
291
+ ) -> torch.Tensor:
292
+
293
+ old_attn = attn.scale
294
+ attn.scale *= kwargs.get("attn_scale", 1.0)
295
+
296
+ output = super().__call__(
297
+ attn,
298
+ hidden_states,
299
+ encoder_hidden_states=encoder_hidden_states,
300
+ attention_mask=attention_mask,
301
+ temb=temb,
302
+ *args,
303
+ **kwargs,
304
+ )
305
+
306
+ attn.scale = old_attn
307
+ return output
308
+
309
+ class IPAdapterAttnProcessor2_0(IPAdapterAttnProcessor2_0_Base):
310
+ def __call__(
311
+ self,
312
+ attn: Attention,
313
+ hidden_states: torch.Tensor,
314
+ encoder_hidden_states: Optional[torch.Tensor] = None,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ temb: Optional[torch.Tensor] = None,
317
+ scale: float = 1.0,
318
+ ip_adapter_masks: Optional[torch.Tensor] = None,
319
+ flow_feature: Optional[torch.Tensor] = None,
320
+ flow_scale: Optional[float] = None,
321
+ *args,
322
+ **kwargs,
323
+ ) -> torch.Tensor:
324
+ return super().__call__(
325
+ attn=attn,
326
+ hidden_states=hidden_states,
327
+ encoder_hidden_states=encoder_hidden_states,
328
+ attention_mask=attention_mask,
329
+ temb=temb,
330
+ scale=scale,
331
+ ip_adapter_masks=ip_adapter_masks,
332
+ )
333
+
334
+
335
+ class FlowAdaptorAttnProcessor(nn.Module):
336
+ def __init__(self,
337
+ type: str,
338
+ hidden_size, # dimension of hidden state
339
+ flow_feature_dim=None, # dimension of the pose feature
340
+ cross_attention_dim=None, # dimension of the text embedding
341
+ query_condition=False,
342
+ key_value_condition=False,
343
+ flow_scale=1.0
344
+ ):
345
+ super().__init__()
346
+
347
+ self.type = type
348
+ self.hidden_size = hidden_size
349
+ self.flow_feature_dim = flow_feature_dim
350
+ self.cross_attention_dim = cross_attention_dim
351
+ self.flow_scale = flow_scale
352
+ self.query_condition = query_condition
353
+ self.key_value_condition = key_value_condition
354
+ assert hidden_size == flow_feature_dim
355
+ if self.query_condition and self.key_value_condition:
356
+ self.qkv_merge = nn.Linear(hidden_size, hidden_size)
357
+ init.zeros_(self.qkv_merge.weight)
358
+ init.zeros_(self.qkv_merge.bias)
359
+ elif self.query_condition:
360
+ self.q_merge = nn.Linear(hidden_size, hidden_size)
361
+ init.zeros_(self.q_merge.weight)
362
+ init.zeros_(self.q_merge.bias)
363
+ else:
364
+ self.kv_merge = nn.Linear(hidden_size, hidden_size)
365
+ init.zeros_(self.kv_merge.weight)
366
+ init.zeros_(self.kv_merge.bias)
367
+
368
+ def forward(self,
369
+ attn: Attention,
370
+ hidden_states,
371
+ flow_feature,
372
+ encoder_hidden_states=None,
373
+ attention_mask=None,
374
+ temb=None,
375
+ flow_scale=None,
376
+ *args,
377
+ **kwargs,
378
+ ):
379
+ assert flow_feature is not None
380
+ flow_embedding_scale = (flow_scale if flow_scale is not None else self.flow_scale)
381
+
382
+ residual = hidden_states
383
+ if attn.spatial_norm is not None:
384
+ hidden_states = attn.spatial_norm(hidden_states, temb)
385
+
386
+ if self.query_condition and self.key_value_condition:
387
+ assert encoder_hidden_states is None
388
+
389
+ if encoder_hidden_states is None:
390
+ encoder_hidden_states = hidden_states
391
+
392
+ batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
393
+
394
+ if attention_mask is not None:
395
+ attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
396
+ # scaled_dot_product_attention expects attention_mask shape to be
397
+ # (batch, heads, source_length, target_length)
398
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
399
+
400
+ if attn.group_norm is not None:
401
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
402
+
403
+ if attn.norm_cross:
404
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
405
+
406
+ if self.query_condition and self.key_value_condition: # only self attention
407
+ query_hidden_state = self.qkv_merge(hidden_states + flow_feature) * flow_embedding_scale + hidden_states
408
+ key_value_hidden_state = query_hidden_state
409
+ elif self.query_condition:
410
+ query_hidden_state = self.q_merge(hidden_states + flow_feature) * flow_embedding_scale + hidden_states
411
+ key_value_hidden_state = encoder_hidden_states
412
+ else:
413
+ key_value_hidden_state = self.kv_merge(
414
+ encoder_hidden_states + flow_feature) * flow_embedding_scale + encoder_hidden_states
415
+ query_hidden_state = hidden_states
416
+
417
+ # original attention
418
+ key = attn.to_k(key_value_hidden_state)
419
+ value = attn.to_v(key_value_hidden_state)
420
+ query = attn.to_q(query_hidden_state)
421
+
422
+ inner_dim = key.shape[-1]
423
+ head_dim = inner_dim // attn.heads
424
+
425
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
426
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
427
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
428
+
429
+ if attn.norm_q is not None:
430
+ query = attn.norm_q(query)
431
+ if attn.norm_k is not None:
432
+ key = attn.norm_k(key)
433
+
434
+ hidden_states = F.scaled_dot_product_attention(
435
+ query, key, value,
436
+ attn_mask=attention_mask,
437
+ dropout_p=0.0,
438
+ is_causal=False,
439
+ scale=attn.scale * kwargs.get("attn_scale_flow", 1.0),
440
+ )
441
+
442
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
443
+ hidden_states = hidden_states.to(query.dtype)
444
+
445
+ # linear proj
446
+ hidden_states = attn.to_out[0](hidden_states)
447
+
448
+ # dropout
449
+ hidden_states = attn.to_out[1](hidden_states)
450
+
451
+ if attn.residual_connection:
452
+ hidden_states = hidden_states + residual
453
+
454
+ hidden_states = hidden_states / attn.rescale_output_factor
455
+
456
+ return hidden_states
onlyflow/models/flow_adaptor.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from torch.utils import checkpoint
7
+
8
+ from onlyflow.models.attention import BasicTransformerBlock
9
+
10
+
11
+ def get_parameter_dtype(parameter: torch.nn.Module):
12
+ params = tuple(parameter.parameters())
13
+ if len(params) > 0:
14
+ return params[0].dtype
15
+
16
+ buffers = tuple(parameter.buffers())
17
+ if len(buffers) > 0:
18
+ return buffers[0].dtype
19
+
20
+
21
+ def conv_nd(dims, *args, **kwargs):
22
+ """
23
+ Create a 1D, 2D, or 3D convolution module.
24
+ """
25
+ if dims == 1:
26
+ return nn.Conv1d(*args, **kwargs)
27
+ elif dims == 2:
28
+ return nn.Conv2d(*args, **kwargs)
29
+ elif dims == 3:
30
+ return nn.Conv3d(*args, **kwargs)
31
+ raise ValueError(f"unsupported dimensions: {dims}")
32
+
33
+
34
+ def avg_pool_nd(dims, *args, **kwargs):
35
+ """
36
+ Create a 1D, 2D, or 3D average pooling module.
37
+ """
38
+ if dims == 1:
39
+ return nn.AvgPool1d(*args, **kwargs)
40
+ elif dims == 2:
41
+ return nn.AvgPool2d(*args, **kwargs)
42
+ elif dims == 3:
43
+ return nn.AvgPool3d(*args, **kwargs)
44
+ raise ValueError(f"unsupported dimensions: {dims}")
45
+
46
+
47
+ class FlowAdaptor(nn.Module):
48
+ def __init__(self, unet, flow_encoder, ckpt_act=True):
49
+ super().__init__()
50
+ self.unet = unet
51
+ self.flow_encoder = flow_encoder
52
+ self.ckpt_act = ckpt_act
53
+
54
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, flow_embedding):
55
+ assert flow_embedding.ndim == 5
56
+ bs = flow_embedding.shape[0] # b c f h w
57
+ flow_embedding_features = self.flow_encoder(flow_embedding) # flow_embedding b f c h w
58
+ flow_embedding_features = [rearrange(x, '(b f) c h w -> b c f h w', b=bs)
59
+ for x in flow_embedding_features]
60
+
61
+ added_cond_kwargs = {'flow_embedding_features': flow_embedding_features}
62
+
63
+ noise_pred = self.unet(noisy_latents,
64
+ timesteps,
65
+ encoder_hidden_states,
66
+ added_cond_kwargs=added_cond_kwargs,
67
+ )
68
+
69
+ return noise_pred.sample
70
+
71
+
72
+ class Downsample(nn.Module):
73
+ """
74
+ A downsampling layer with an optional convolution.
75
+ :param channels: channels in the inputs and outputs.
76
+ :param use_conv: a bool determining if a convolution is applied.
77
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
78
+ downsampling occurs in the inner-two dimensions.
79
+ """
80
+
81
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
82
+ super().__init__()
83
+ self.channels = channels
84
+ self.out_channels = out_channels or channels
85
+ self.use_conv = use_conv
86
+ self.dims = dims
87
+ stride = 2 if dims != 3 else (1, 2, 2)
88
+ if use_conv:
89
+ self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
90
+ else:
91
+ assert self.channels == self.out_channels
92
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
93
+
94
+ def forward(self, x):
95
+ assert x.shape[1] == self.channels
96
+ return self.op(x)
97
+
98
+
99
+ class ResnetBlock(nn.Module):
100
+
101
+ def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
102
+ super().__init__()
103
+ ps = ksize // 2
104
+ if in_c != out_c or sk == False:
105
+ self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
106
+ else:
107
+ self.in_conv = None
108
+ self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
109
+ self.act = nn.ReLU()
110
+ self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
111
+ if not sk:
112
+ self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
113
+ else:
114
+ self.skep = None
115
+
116
+ self.down = down
117
+ if self.down:
118
+ self.down_opt = Downsample(in_c, use_conv=use_conv)
119
+
120
+ def forward(self, x):
121
+ if self.down:
122
+ x = self.down_opt(x)
123
+ if self.in_conv is not None: # edit
124
+ x = self.in_conv(x)
125
+
126
+ h = self.block1(x)
127
+ h = self.act(h)
128
+ h = self.block2(h)
129
+ if self.skep is not None:
130
+ return h + self.skep(x)
131
+ else:
132
+ return h + x
133
+
134
+
135
+ class PositionalEncoding(nn.Module):
136
+ def __init__(
137
+ self,
138
+ d_model,
139
+ dropout=0.,
140
+ max_len=32,
141
+ ):
142
+ super().__init__()
143
+ self.dropout = nn.Dropout(p=dropout)
144
+ position = torch.arange(max_len).unsqueeze(1)
145
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
146
+ pe = torch.zeros(1, max_len, d_model)
147
+ pe[0, :, 0::2, ...] = torch.sin(position * div_term)
148
+ pe[0, :, 1::2, ...] = torch.cos(position * div_term)
149
+ pe.unsqueeze_(-1).unsqueeze_(-1)
150
+ self.register_buffer('pe', pe)
151
+
152
+ def forward(self, x):
153
+ x = x + self.pe[:, :x.size(1), ...]
154
+ return self.dropout(x)
155
+
156
+
157
+ class FlowEncoder(nn.Module):
158
+
159
+ def __init__(self,
160
+ downscale_factor,
161
+ channels=None,
162
+ nums_rb=3,
163
+ ksize=3,
164
+ sk=False,
165
+ use_conv=True,
166
+ compression_factor=1,
167
+ temporal_attention_nhead=8,
168
+ positional_embeddings=None,
169
+ num_positional_embeddings=16,
170
+ rescale_output_factor=1.0,
171
+ checkpointing=False):
172
+ super(FlowEncoder, self).__init__()
173
+ if channels is None:
174
+ channels = [320, 640, 1280, 1280]
175
+
176
+ self.checkpointing = checkpointing
177
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
178
+ self.channels = channels
179
+ self.nums_rb = nums_rb
180
+ self.encoder_down_conv_blocks = nn.ModuleList()
181
+ self.encoder_down_attention_blocks = nn.ModuleList()
182
+ for i in range(len(channels)):
183
+ conv_layers = nn.ModuleList()
184
+ temporal_attention_layers = nn.ModuleList()
185
+ for j in range(nums_rb):
186
+ if j == 0 and i != 0:
187
+ in_dim = channels[i - 1]
188
+ out_dim = int(channels[i] / compression_factor)
189
+ conv_layer = ResnetBlock(in_dim, out_dim, down=True, ksize=ksize, sk=sk, use_conv=use_conv)
190
+ elif j == 0:
191
+ in_dim = channels[0]
192
+ out_dim = int(channels[i] / compression_factor)
193
+ conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
194
+ elif j == nums_rb - 1:
195
+ in_dim = channels[i] / compression_factor
196
+ out_dim = channels[i]
197
+ conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
198
+ else:
199
+ in_dim = int(channels[i] / compression_factor)
200
+ out_dim = int(channels[i] / compression_factor)
201
+ conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
202
+ temporal_attention_layer = BasicTransformerBlock(
203
+ dim=out_dim,
204
+ num_attention_heads=temporal_attention_nhead,
205
+ attention_head_dim=int(out_dim / temporal_attention_nhead),
206
+ dropout=0.0,
207
+ positional_embeddings=positional_embeddings,
208
+ num_positional_embeddings=num_positional_embeddings
209
+ )
210
+ conv_layers.append(conv_layer)
211
+ temporal_attention_layers.append(temporal_attention_layer)
212
+ self.encoder_down_conv_blocks.append(conv_layers)
213
+ self.encoder_down_attention_blocks.append(temporal_attention_layers)
214
+
215
+ self.encoder_conv_in = nn.Conv2d(2 * (downscale_factor ** 2), channels[0], 3, 1, 1)
216
+
217
+ @property
218
+ def dtype(self) -> torch.dtype:
219
+ """
220
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
221
+ """
222
+ return get_parameter_dtype(self)
223
+
224
+ def forward(self, x):
225
+ # unshuffle
226
+ bs = x.shape[0]
227
+ x = rearrange(x, "b c f h w -> (b f) c h w")
228
+ x = self.unshuffle(x)
229
+ # extract features
230
+ features = []
231
+ x = self.encoder_conv_in(x)
232
+ for i, (res_block, attention_block) in enumerate(
233
+ zip(self.encoder_down_conv_blocks, self.encoder_down_attention_blocks)):
234
+ for j, (res_layer, attention_layer) in enumerate(zip(res_block, attention_block)):
235
+ if self.checkpointing:
236
+ x = checkpoint.checkpoint(res_layer, x, use_reentrant=False)
237
+ else:
238
+ x = res_layer(x)
239
+ h, w = x.shape[-2:]
240
+ x = rearrange(x, '(b f) c h w -> (b h w) f c', b=bs)
241
+ if self.checkpointing:
242
+ x = checkpoint.checkpoint(attention_layer, x, use_reentrant=False)
243
+ else:
244
+ x = attention_layer(x)
245
+ x = rearrange(x, '(b h w) f c -> (b f) c h w', h=h, w=w)
246
+ features.append(x)
247
+ return features
onlyflow/models/transformer_2d.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from diffusers.configuration_utils import LegacyConfigMixin, register_to_config
19
+ from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
20
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
21
+ from diffusers.models.modeling_utils import LegacyModelMixin
22
+ from diffusers.models.normalization import AdaLayerNormSingle
23
+ from diffusers.utils import deprecate, is_torch_version, logging
24
+ from torch import nn
25
+
26
+ from onlyflow.models.attention import BasicTransformerBlock
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ class Transformer2DModelOutput(Transformer2DModelOutput):
32
+ def __init__(self, *args, **kwargs):
33
+ deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."
34
+ deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
35
+ super().__init__(*args, **kwargs)
36
+
37
+
38
+ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
39
+ """
40
+ A 2D Transformer model for image-like data.
41
+
42
+ Parameters:
43
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
44
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
45
+ in_channels (`int`, *optional*):
46
+ The number of channels in the input and output (specify if the input is **continuous**).
47
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
48
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
49
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
50
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
51
+ This is fixed during training since it is used to learn a number of position embeddings.
52
+ num_vector_embeds (`int`, *optional*):
53
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
54
+ Includes the class for the masked latent pixel.
55
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
56
+ num_embeds_ada_norm ( `int`, *optional*):
57
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
58
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
59
+ added to the hidden states.
60
+
61
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
62
+ attention_bias (`bool`, *optional*):
63
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
64
+ """
65
+
66
+ _supports_gradient_checkpointing = True
67
+ _no_split_modules = ["BasicTransformerBlock"]
68
+
69
+ @register_to_config
70
+ def __init__(
71
+ self,
72
+ num_attention_heads: int = 16,
73
+ attention_head_dim: int = 88,
74
+ in_channels: Optional[int] = None,
75
+ out_channels: Optional[int] = None,
76
+ num_layers: int = 1,
77
+ dropout: float = 0.0,
78
+ norm_num_groups: int = 32,
79
+ cross_attention_dim: Optional[int] = None,
80
+ attention_bias: bool = False,
81
+ sample_size: Optional[int] = None,
82
+ num_vector_embeds: Optional[int] = None,
83
+ patch_size: Optional[int] = None,
84
+ activation_fn: str = "geglu",
85
+ num_embeds_ada_norm: Optional[int] = None,
86
+ use_linear_projection: bool = False,
87
+ only_cross_attention: bool = False,
88
+ double_self_attention: bool = False,
89
+ upcast_attention: bool = False,
90
+ norm_type: str = "layer_norm",
91
+ # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
92
+ norm_elementwise_affine: bool = True,
93
+ norm_eps: float = 1e-5,
94
+ attention_type: str = "default",
95
+ caption_channels: int = None,
96
+ interpolation_scale: float = None,
97
+ use_additional_conditions: Optional[bool] = None,
98
+ ):
99
+ super().__init__()
100
+
101
+ # Validate inputs.
102
+ if patch_size is not None:
103
+ if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
104
+ raise NotImplementedError(
105
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
106
+ )
107
+ elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None:
108
+ raise ValueError(
109
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
110
+ )
111
+
112
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
113
+ # Define whether input is continuous or discrete depending on configuration
114
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
115
+ self.is_input_vectorized = num_vector_embeds is not None
116
+ self.is_input_patches = in_channels is not None and patch_size is not None
117
+
118
+ if self.is_input_continuous and self.is_input_vectorized:
119
+ raise ValueError(
120
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
121
+ " sure that either `in_channels` or `num_vector_embeds` is None."
122
+ )
123
+ elif self.is_input_vectorized and self.is_input_patches:
124
+ raise ValueError(
125
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
126
+ " sure that either `num_vector_embeds` or `num_patches` is None."
127
+ )
128
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
129
+ raise ValueError(
130
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
131
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
132
+ )
133
+
134
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
135
+ deprecation_message = (
136
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
137
+ " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
138
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
139
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
140
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
141
+ )
142
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
143
+ norm_type = "ada_norm"
144
+
145
+ # Set some common variables used across the board.
146
+ self.use_linear_projection = use_linear_projection
147
+ self.interpolation_scale = interpolation_scale
148
+ self.caption_channels = caption_channels
149
+ self.num_attention_heads = num_attention_heads
150
+ self.attention_head_dim = attention_head_dim
151
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
152
+ self.in_channels = in_channels
153
+ self.out_channels = in_channels if out_channels is None else out_channels
154
+ self.gradient_checkpointing = False
155
+
156
+ if use_additional_conditions is None:
157
+ if norm_type == "ada_norm_single" and sample_size == 128:
158
+ use_additional_conditions = True
159
+ else:
160
+ use_additional_conditions = False
161
+ self.use_additional_conditions = use_additional_conditions
162
+
163
+ # 2. Initialize the right blocks.
164
+ # These functions follow a common structure:
165
+ # a. Initialize the input blocks. b. Initialize the transformer blocks.
166
+ # c. Initialize the output blocks and other projection blocks when necessary.
167
+ if self.is_input_continuous:
168
+ self._init_continuous_input(norm_type=norm_type)
169
+ elif self.is_input_vectorized:
170
+ self._init_vectorized_inputs(norm_type=norm_type)
171
+ elif self.is_input_patches:
172
+ self._init_patched_inputs(norm_type=norm_type)
173
+
174
+ def _init_continuous_input(self, norm_type):
175
+ self.norm = torch.nn.GroupNorm(
176
+ num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
177
+ )
178
+ if self.use_linear_projection:
179
+ self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
180
+ else:
181
+ self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)
182
+
183
+ self.transformer_blocks = nn.ModuleList(
184
+ [
185
+ BasicTransformerBlock(
186
+ self.inner_dim,
187
+ self.config.num_attention_heads,
188
+ self.config.attention_head_dim,
189
+ dropout=self.config.dropout,
190
+ cross_attention_dim=self.config.cross_attention_dim,
191
+ activation_fn=self.config.activation_fn,
192
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
193
+ attention_bias=self.config.attention_bias,
194
+ only_cross_attention=self.config.only_cross_attention,
195
+ double_self_attention=self.config.double_self_attention,
196
+ upcast_attention=self.config.upcast_attention,
197
+ norm_type=norm_type,
198
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
199
+ norm_eps=self.config.norm_eps,
200
+ attention_type=self.config.attention_type,
201
+ )
202
+ for _ in range(self.config.num_layers)
203
+ ]
204
+ )
205
+
206
+ if self.use_linear_projection:
207
+ self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels)
208
+ else:
209
+ self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0)
210
+
211
+ def _init_vectorized_inputs(self, norm_type):
212
+ assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
213
+ assert (
214
+ self.config.num_vector_embeds is not None
215
+ ), "Transformer2DModel over discrete input must provide num_embed"
216
+
217
+ self.height = self.config.sample_size
218
+ self.width = self.config.sample_size
219
+ self.num_latent_pixels = self.height * self.width
220
+
221
+ self.latent_image_embedding = ImagePositionalEmbeddings(
222
+ num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width
223
+ )
224
+
225
+ self.transformer_blocks = nn.ModuleList(
226
+ [
227
+ BasicTransformerBlock(
228
+ self.inner_dim,
229
+ self.config.num_attention_heads,
230
+ self.config.attention_head_dim,
231
+ dropout=self.config.dropout,
232
+ cross_attention_dim=self.config.cross_attention_dim,
233
+ activation_fn=self.config.activation_fn,
234
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
235
+ attention_bias=self.config.attention_bias,
236
+ only_cross_attention=self.config.only_cross_attention,
237
+ double_self_attention=self.config.double_self_attention,
238
+ upcast_attention=self.config.upcast_attention,
239
+ norm_type=norm_type,
240
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
241
+ norm_eps=self.config.norm_eps,
242
+ attention_type=self.config.attention_type,
243
+ )
244
+ for _ in range(self.config.num_layers)
245
+ ]
246
+ )
247
+
248
+ self.norm_out = nn.LayerNorm(self.inner_dim)
249
+ self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1)
250
+
251
+ def _init_patched_inputs(self, norm_type):
252
+ assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
253
+
254
+ self.height = self.config.sample_size
255
+ self.width = self.config.sample_size
256
+
257
+ self.patch_size = self.config.patch_size
258
+ interpolation_scale = (
259
+ self.config.interpolation_scale
260
+ if self.config.interpolation_scale is not None
261
+ else max(self.config.sample_size // 64, 1)
262
+ )
263
+ self.pos_embed = PatchEmbed(
264
+ height=self.config.sample_size,
265
+ width=self.config.sample_size,
266
+ patch_size=self.config.patch_size,
267
+ in_channels=self.in_channels,
268
+ embed_dim=self.inner_dim,
269
+ interpolation_scale=interpolation_scale,
270
+ )
271
+
272
+ self.transformer_blocks = nn.ModuleList(
273
+ [
274
+ BasicTransformerBlock(
275
+ self.inner_dim,
276
+ self.config.num_attention_heads,
277
+ self.config.attention_head_dim,
278
+ dropout=self.config.dropout,
279
+ cross_attention_dim=self.config.cross_attention_dim,
280
+ activation_fn=self.config.activation_fn,
281
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
282
+ attention_bias=self.config.attention_bias,
283
+ only_cross_attention=self.config.only_cross_attention,
284
+ double_self_attention=self.config.double_self_attention,
285
+ upcast_attention=self.config.upcast_attention,
286
+ norm_type=norm_type,
287
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
288
+ norm_eps=self.config.norm_eps,
289
+ attention_type=self.config.attention_type,
290
+ )
291
+ for _ in range(self.config.num_layers)
292
+ ]
293
+ )
294
+
295
+ if self.config.norm_type != "ada_norm_single":
296
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
297
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
298
+ self.proj_out_2 = nn.Linear(
299
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
300
+ )
301
+ elif self.config.norm_type == "ada_norm_single":
302
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
303
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim ** 0.5)
304
+ self.proj_out = nn.Linear(
305
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
306
+ )
307
+
308
+ # PixArt-Alpha blocks.
309
+ self.adaln_single = None
310
+ if self.config.norm_type == "ada_norm_single":
311
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
312
+ # additional conditions until we find better name
313
+ self.adaln_single = AdaLayerNormSingle(
314
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
315
+ )
316
+
317
+ self.caption_projection = None
318
+ if self.caption_channels is not None:
319
+ self.caption_projection = PixArtAlphaTextProjection(
320
+ in_features=self.caption_channels, hidden_size=self.inner_dim
321
+ )
322
+
323
+ def _set_gradient_checkpointing(self, module, value=False):
324
+ if hasattr(module, "gradient_checkpointing"):
325
+ module.gradient_checkpointing = value
326
+
327
+ def forward(
328
+ self,
329
+ hidden_states: torch.Tensor,
330
+ encoder_hidden_states: Optional[torch.Tensor] = None,
331
+ timestep: Optional[torch.LongTensor] = None,
332
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
333
+ class_labels: Optional[torch.LongTensor] = None,
334
+ cross_attention_kwargs: Dict[str, Any] = None,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ encoder_attention_mask: Optional[torch.Tensor] = None,
337
+ return_dict: bool = True,
338
+ ):
339
+ """
340
+ The [`Transformer2DModel`] forward method.
341
+
342
+ Args:
343
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
344
+ Input `hidden_states`.
345
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
346
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
347
+ self-attention.
348
+ timestep ( `torch.LongTensor`, *optional*):
349
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
350
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
351
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
352
+ `AdaLayerZeroNorm`.
353
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
354
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
355
+ `self.processor` in
356
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
357
+ attention_mask ( `torch.Tensor`, *optional*):
358
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
359
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
360
+ negative values to the attention scores corresponding to "discard" tokens.
361
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
362
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
363
+
364
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
365
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
366
+
367
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
368
+ above. This bias will be added to the cross-attention scores.
369
+ return_dict (`bool`, *optional*, defaults to `True`):
370
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
371
+ tuple.
372
+
373
+ Returns:
374
+ If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned,
375
+ otherwise a `tuple` where the first element is the sample tensor.
376
+ """
377
+ if cross_attention_kwargs is not None:
378
+ if cross_attention_kwargs.get("scale", None) is not None:
379
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
380
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
381
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
382
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
383
+ # expects mask of shape:
384
+ # [batch, key_tokens]
385
+ # adds singleton query_tokens dimension:
386
+ # [batch, 1, key_tokens]
387
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
388
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
389
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
390
+ if attention_mask is not None and attention_mask.ndim == 2:
391
+ # assume that mask is expressed as:
392
+ # (1 = keep, 0 = discard)
393
+ # convert mask into a bias that can be added to attention scores:
394
+ # (keep = +0, discard = -10000.0)
395
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
396
+ attention_mask = attention_mask.unsqueeze(1)
397
+
398
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
399
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
400
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
401
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
402
+
403
+ # 1. Input
404
+ if self.is_input_continuous:
405
+ batch_size, _, height, width = hidden_states.shape
406
+ residual = hidden_states
407
+ hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
408
+ elif self.is_input_vectorized:
409
+ hidden_states = self.latent_image_embedding(hidden_states)
410
+ elif self.is_input_patches:
411
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
412
+ hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
413
+ hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
414
+ )
415
+
416
+ # 2. Blocks
417
+ for block in self.transformer_blocks:
418
+ if self.training and self.gradient_checkpointing:
419
+
420
+ def create_custom_forward(module, return_dict=None):
421
+ def custom_forward(*inputs):
422
+ if return_dict is not None:
423
+ return module(*inputs, return_dict=return_dict)
424
+ else:
425
+ return module(*inputs)
426
+
427
+ return custom_forward
428
+
429
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
430
+ hidden_states = torch.utils.checkpoint.checkpoint(
431
+ create_custom_forward(block),
432
+ hidden_states,
433
+ attention_mask,
434
+ encoder_hidden_states,
435
+ encoder_attention_mask,
436
+ timestep,
437
+ cross_attention_kwargs,
438
+ class_labels,
439
+ **ckpt_kwargs,
440
+ )
441
+ else:
442
+ hidden_states = block(
443
+ hidden_states=hidden_states,
444
+ attention_mask=attention_mask,
445
+ encoder_hidden_states=encoder_hidden_states,
446
+ encoder_attention_mask=encoder_attention_mask,
447
+ timestep=timestep,
448
+ cross_attention_kwargs=cross_attention_kwargs,
449
+ class_labels=class_labels,
450
+ )
451
+
452
+ # 3. Output
453
+ if self.is_input_continuous:
454
+ output = self._get_output_for_continuous_inputs(
455
+ hidden_states=hidden_states,
456
+ residual=residual,
457
+ batch_size=batch_size,
458
+ height=height,
459
+ width=width,
460
+ inner_dim=inner_dim,
461
+ )
462
+ elif self.is_input_vectorized:
463
+ output = self._get_output_for_vectorized_inputs(hidden_states)
464
+ elif self.is_input_patches:
465
+ output = self._get_output_for_patched_inputs(
466
+ hidden_states=hidden_states,
467
+ timestep=timestep,
468
+ class_labels=class_labels,
469
+ embedded_timestep=embedded_timestep,
470
+ height=height,
471
+ width=width,
472
+ )
473
+
474
+ if not return_dict:
475
+ return (output,)
476
+
477
+ return Transformer2DModelOutput(sample=output)
478
+
479
+ def _operate_on_continuous_inputs(self, hidden_states):
480
+ batch, _, height, width = hidden_states.shape
481
+ hidden_states = self.norm(hidden_states)
482
+
483
+ if not self.use_linear_projection:
484
+ hidden_states = self.proj_in(hidden_states)
485
+ inner_dim = hidden_states.shape[1]
486
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
487
+ else:
488
+ inner_dim = hidden_states.shape[1]
489
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
490
+ hidden_states = self.proj_in(hidden_states)
491
+
492
+ return hidden_states, inner_dim
493
+
494
+ def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
495
+ batch_size = hidden_states.shape[0]
496
+ hidden_states = self.pos_embed(hidden_states)
497
+ embedded_timestep = None
498
+
499
+ if self.adaln_single is not None:
500
+ if self.use_additional_conditions and added_cond_kwargs is None:
501
+ raise ValueError(
502
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
503
+ )
504
+ timestep, embedded_timestep = self.adaln_single(
505
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
506
+ )
507
+
508
+ if self.caption_projection is not None:
509
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
510
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
511
+
512
+ return hidden_states, encoder_hidden_states, timestep, embedded_timestep
513
+
514
+ def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
515
+ if not self.use_linear_projection:
516
+ hidden_states = (
517
+ hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
518
+ )
519
+ hidden_states = self.proj_out(hidden_states)
520
+ else:
521
+ hidden_states = self.proj_out(hidden_states)
522
+ hidden_states = (
523
+ hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
524
+ )
525
+
526
+ output = hidden_states + residual
527
+ return output
528
+
529
+ def _get_output_for_vectorized_inputs(self, hidden_states):
530
+ hidden_states = self.norm_out(hidden_states)
531
+ logits = self.out(hidden_states)
532
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
533
+ logits = logits.permute(0, 2, 1)
534
+ # log(p(x_0))
535
+ output = F.log_softmax(logits.double(), dim=1).float()
536
+ return output
537
+
538
+ def _get_output_for_patched_inputs(
539
+ self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
540
+ ):
541
+ if self.config.norm_type != "ada_norm_single":
542
+ conditioning = self.transformer_blocks[0].norm1.emb(
543
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
544
+ )
545
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
546
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
547
+ hidden_states = self.proj_out_2(hidden_states)
548
+ elif self.config.norm_type == "ada_norm_single":
549
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
550
+ hidden_states = self.norm_out(hidden_states)
551
+ # Modulation
552
+ hidden_states = hidden_states * (1 + scale) + shift
553
+ hidden_states = self.proj_out(hidden_states)
554
+ hidden_states = hidden_states.squeeze(1)
555
+
556
+ # unpatchify
557
+ if self.adaln_single is None:
558
+ height = width = int(hidden_states.shape[1] ** 0.5)
559
+ hidden_states = hidden_states.reshape(
560
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
561
+ )
562
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
563
+ output = hidden_states.reshape(
564
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
565
+ )
566
+ return output
onlyflow/models/unet.py ADDED
The diff for this file is too large to render. See raw diff
 
onlyflow/pipelines/pipeline_animation.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+
4
+ # TODO: rebase on diffusers/pipelines/animatediff/pipeline_animatediff.py
5
+
6
+ import copy
7
+ from dataclasses import dataclass
8
+ from typing import Callable, Optional, Dict, Any
9
+ from typing import List, Union
10
+
11
+ import PIL.Image
12
+ import numpy as np
13
+ import torch
14
+ from diffusers import AnimateDiffPipeline
15
+ from diffusers.image_processor import PipelineImageInput
16
+ from diffusers.models import AutoencoderKL
17
+ from diffusers.pipelines.animatediff import AnimateDiffPipelineOutput
18
+ from diffusers.pipelines.animatediff.pipeline_animatediff import EXAMPLE_DOC_STRING
19
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
20
+ from diffusers.schedulers import (
21
+ DDIMScheduler,
22
+ DPMSolverMultistepScheduler,
23
+ EulerAncestralDiscreteScheduler,
24
+ EulerDiscreteScheduler,
25
+ LMSDiscreteScheduler,
26
+ PNDMScheduler,
27
+ )
28
+ from diffusers.utils import BaseOutput
29
+ from diffusers.utils import deprecate, logging, replace_example_docstring
30
+ from einops import rearrange
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+
33
+ from onlyflow.models.flow_adaptor import FlowEncoder
34
+ from onlyflow.models.unet import UNetMotionModel
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ @dataclass
40
+ class AnimateDiffPipelineOutput(BaseOutput):
41
+ frames_no_flow: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
42
+ frames_flow: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
43
+
44
+
45
+ class FlowCtrlPipeline(AnimateDiffPipeline, DiffusionPipeline):
46
+ _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
47
+
48
+ def __init__(self,
49
+ vae: AutoencoderKL,
50
+ text_encoder: CLIPTextModel,
51
+ tokenizer: CLIPTokenizer,
52
+ unet: UNetMotionModel,
53
+ scheduler: Union[
54
+ DDIMScheduler,
55
+ PNDMScheduler,
56
+ LMSDiscreteScheduler,
57
+ EulerDiscreteScheduler,
58
+ EulerAncestralDiscreteScheduler,
59
+ DPMSolverMultistepScheduler],
60
+ flow_encoder: FlowEncoder,
61
+ feature_extractor=None,
62
+ image_encoder=None,
63
+ motion_adapter=None,
64
+ ):
65
+
66
+ super().__init__(
67
+ vae=vae,
68
+ text_encoder=text_encoder,
69
+ tokenizer=tokenizer,
70
+ unet=unet,
71
+ motion_adapter=motion_adapter,
72
+ scheduler=scheduler,
73
+ feature_extractor=feature_extractor,
74
+ image_encoder=image_encoder,
75
+ )
76
+
77
+ # deepcopy the scheduler
78
+ self.scheduler_no_flow = copy.deepcopy(scheduler)
79
+
80
+ self.unet = unet
81
+
82
+ self.register_modules(
83
+ flow_encoder=flow_encoder
84
+ )
85
+
86
+ @torch.no_grad()
87
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
88
+ def __call__(
89
+ self,
90
+ prompt: Union[str, List[str]] = None,
91
+ flow_embedding: torch.FloatTensor = None,
92
+
93
+ num_frames: Optional[int] = 16,
94
+ height: Optional[int] = None,
95
+ width: Optional[int] = None,
96
+
97
+ num_inference_steps: int = 50,
98
+ guidance_scale: float = 7.5,
99
+ negative_prompt: Optional[Union[str, List[str]]] = None,
100
+ eta: float = 0.0,
101
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
102
+ latents: Optional[torch.Tensor] = None,
103
+
104
+ prompt_embeds: Optional[torch.Tensor] = None,
105
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
106
+ ip_adapter_image: Optional[PipelineImageInput] = None,
107
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
108
+
109
+ output_type: Optional[str] = "pt",
110
+ return_dict: bool = True,
111
+
112
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
113
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
114
+
115
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
116
+ motion_cross_attention_kwargs: Optional[Dict[str, Any]] = None,
117
+
118
+ clip_skip: Optional[int] = None,
119
+ decode_chunk_size: int = 16,
120
+
121
+ val_scale_factor_spatial: float = 1.,
122
+ val_scale_factor_temporal: float = 1.,
123
+
124
+ generate_no_flow: bool = False,
125
+
126
+ **kwargs,
127
+ ):
128
+ r"""
129
+ The call function to the pipeline for generation.
130
+
131
+ Args:
132
+ prompt (`str` or `List[str]`, *optional*):
133
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
134
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
135
+ The height in pixels of the generated video.
136
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
137
+ The width in pixels of the generated video.
138
+ num_frames (`int`, *optional*, defaults to 16):
139
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
140
+ amounts to 2 seconds of video.
141
+ num_inference_steps (`int`, *optional*, defaults to 50):
142
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
143
+ expense of slower inference.
144
+ guidance_scale (`float`, *optional*, defaults to 7.5):
145
+ A higher guidance scale value encourages the model to generate images closely linked to the text
146
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
147
+ negative_prompt (`str` or `List[str]`, *optional*):
148
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
149
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
150
+ eta (`float`, *optional*, defaults to 0.0):
151
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
152
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
153
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
154
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
155
+ generation deterministic.
156
+ latents (`torch.Tensor`, *optional*):
157
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
158
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
159
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
160
+ `(batch_size, num_channel, num_frames, height, width)`.
161
+ prompt_embeds (`torch.Tensor`, *optional*):
162
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
163
+ provided, text embeddings are generated from the `prompt` input argument.
164
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
165
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
166
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
167
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
168
+ Optional image input to work with IP Adapters.
169
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
170
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
171
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
172
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
173
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
174
+ output_type (`str`, *optional*, defaults to `"pil"`):
175
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
176
+ return_dict (`bool`, *optional*, defaults to `True`):
177
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
178
+ of a plain tuple.
179
+ cross_attention_kwargs (`dict`, *optional*):
180
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
181
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
182
+ clip_skip (`int`, *optional*):
183
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
184
+ the output of the pre-final layer will be used for computing the prompt embeddings.
185
+ callback_on_step_end (`Callable`, *optional*):
186
+ A function that calls at the end of each denoising steps during the inference. The function is called
187
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
188
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
189
+ `callback_on_step_end_tensor_inputs`.
190
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
191
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
192
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
193
+ `._callback_tensor_inputs` attribute of your pipeline class.
194
+ decode_chunk_size (`int`, defaults to `16`):
195
+ The number of frames to decode at a time when calling `decode_latents` method.
196
+
197
+ Examples:
198
+
199
+ Returns:
200
+ [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
201
+ If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
202
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
203
+ """
204
+
205
+ callback = kwargs.pop("callback", None)
206
+ callback_steps = kwargs.pop("callback_steps", None)
207
+
208
+ if callback is not None:
209
+ deprecate(
210
+ "callback",
211
+ "1.0.0",
212
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
213
+ )
214
+ if callback_steps is not None:
215
+ deprecate(
216
+ "callback_steps",
217
+ "1.0.0",
218
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
219
+ )
220
+
221
+ # 0. Default height and width to unet
222
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
223
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
224
+
225
+ num_videos_per_prompt = 1
226
+
227
+ # 1. Check inputs. Raise error if not correct
228
+ self.check_inputs(
229
+ prompt,
230
+ height,
231
+ width,
232
+ callback_steps,
233
+ negative_prompt,
234
+ prompt_embeds,
235
+ negative_prompt_embeds,
236
+ ip_adapter_image,
237
+ ip_adapter_image_embeds,
238
+ callback_on_step_end_tensor_inputs,
239
+ )
240
+
241
+ self._guidance_scale = guidance_scale
242
+ self._clip_skip = clip_skip
243
+ self._cross_attention_kwargs = cross_attention_kwargs
244
+
245
+ # 2. Define call parameters
246
+ if prompt is not None and isinstance(prompt, str):
247
+ batch_size = 1
248
+ elif prompt is not None and isinstance(prompt, list):
249
+ batch_size = len(prompt)
250
+ else:
251
+ batch_size = prompt_embeds.shape[0]
252
+
253
+ device = self.unet.device
254
+
255
+ # 3. Encode input prompt
256
+ text_encoder_lora_scale = (
257
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
258
+ )
259
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
260
+ prompt,
261
+ device,
262
+ num_videos_per_prompt,
263
+ self.do_classifier_free_guidance,
264
+ negative_prompt,
265
+ prompt_embeds=prompt_embeds,
266
+ negative_prompt_embeds=negative_prompt_embeds,
267
+ lora_scale=text_encoder_lora_scale,
268
+ clip_skip=self.clip_skip,
269
+ )
270
+ # For classifier free guidance, we need to do two forward passes.
271
+ # Here we concatenate the unconditional and text embeddings into a single batch
272
+ # to avoid doing two forward passes
273
+ if self.do_classifier_free_guidance:
274
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
275
+
276
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
277
+ image_embeds = self.prepare_ip_adapter_image_embeds(
278
+ ip_adapter_image,
279
+ ip_adapter_image_embeds,
280
+ device,
281
+ batch_size * num_videos_per_prompt,
282
+ self.do_classifier_free_guidance,
283
+ )
284
+
285
+ # 4. Prepare timesteps
286
+ single_model_length = num_frames
287
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
288
+ timesteps = self.scheduler.timesteps
289
+
290
+ # 5. Prepare latent variables
291
+ num_channels_latents = self.unet.config.in_channels
292
+ latents = self.prepare_latents(
293
+ batch_size * num_videos_per_prompt,
294
+ num_channels_latents,
295
+ num_frames,
296
+ height,
297
+ width,
298
+ prompt_embeds.dtype,
299
+ device,
300
+ generator,
301
+ latents,
302
+ )
303
+
304
+ if generate_no_flow:
305
+ latents_no_flow = latents.clone()
306
+
307
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
308
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
309
+ if isinstance(flow_embedding, list):
310
+ assert all([x.ndim == 5 for x in flow_embedding])
311
+ bs = flow_embedding[0].shape[0]
312
+ flow_embedding_features = []
313
+ for pe in flow_embedding:
314
+ flow_embedding_feature = self.flow_encoder(pe)
315
+ flow_embedding_feature = [rearrange(x, '(b f) c h w -> b c f h w', b=bs) for x in
316
+ flow_embedding_feature]
317
+ flow_embedding_features.append(flow_embedding_feature)
318
+ else:
319
+ bs = flow_embedding.shape[0]
320
+ assert flow_embedding.ndim == 5
321
+ flow_embedding_features = self.flow_encoder(flow_embedding) # bf, c, h, w
322
+ flow_embedding_features = [rearrange(x, '(b f) c h w -> b c f h w', b=bs)
323
+ for x in flow_embedding_features]
324
+
325
+ # 7. Add image embeds for IP-Adapter
326
+ added_cond_kwargs = {
327
+ "image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None
328
+
329
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
330
+ for free_init_iter in range(num_free_init_iters):
331
+ if self.free_init_enabled:
332
+ latents, timesteps = self._apply_free_init(
333
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
334
+ )
335
+ if generate_no_flow:
336
+ latents_no_flow = latents.clone()
337
+
338
+ self._num_timesteps = len(timesteps)
339
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
340
+ if isinstance(flow_embedding_features[0], list):
341
+ flow_embedding_features = [[torch.cat([x, x], dim=0) for x in flow_embedding_feature]
342
+ for flow_embedding_feature in flow_embedding_features] \
343
+ if self.do_classifier_free_guidance else flow_embedding_features
344
+ else:
345
+ flow_embedding_features = [torch.cat([x, x], dim=0) for x in flow_embedding_features] \
346
+ if self.do_classifier_free_guidance else flow_embedding_features # [2b c f h w]
347
+
348
+ # 8. Denoising loop
349
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
350
+ for i, t in enumerate(timesteps):
351
+
352
+ # expand the latents if we are doing classifier free guidance
353
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
354
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
355
+
356
+ if added_cond_kwargs is not None:
357
+ added_cond_kwargs.update({"flow_embedding_features": flow_embedding_features})
358
+ else:
359
+ added_cond_kwargs = {"flow_embedding_features": flow_embedding_features}
360
+
361
+ if cross_attention_kwargs is not None:
362
+ cross_attention_kwargs.update({"flow_scale": val_scale_factor_spatial})
363
+ else:
364
+ cross_attention_kwargs = {"flow_scale": val_scale_factor_spatial}
365
+
366
+ if motion_cross_attention_kwargs is not None:
367
+ motion_cross_attention_kwargs.update({"flow_scale": val_scale_factor_temporal})
368
+ else:
369
+ motion_cross_attention_kwargs = {"flow_scale": val_scale_factor_temporal}
370
+
371
+ # predict the noise residual
372
+ noise_pred = self.unet(
373
+ latent_model_input,
374
+ t,
375
+ encoder_hidden_states=prompt_embeds,
376
+ cross_attention_kwargs=cross_attention_kwargs,
377
+ motion_cross_attention_kwargs=motion_cross_attention_kwargs,
378
+ added_cond_kwargs=added_cond_kwargs,
379
+ ).sample
380
+
381
+ del latent_model_input
382
+
383
+ # perform guidance
384
+ if self.do_classifier_free_guidance:
385
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
386
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
387
+ del noise_pred_uncond, noise_pred_text
388
+
389
+ # compute the previous noisy sample x_t -> x_t-1
390
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
391
+ del noise_pred
392
+
393
+ if callback_on_step_end is not None:
394
+ callback_kwargs = {}
395
+ for k in callback_on_step_end_tensor_inputs:
396
+ callback_kwargs[k] = locals()[k]
397
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
398
+
399
+ latents = callback_outputs.pop("latents", latents)
400
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
401
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
402
+
403
+ # call the callback, if provided
404
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
405
+ progress_bar.update()
406
+ if callback is not None and i % callback_steps == 0:
407
+ callback(i, t, latents)
408
+
409
+ # 8. Denoising loop
410
+ if generate_no_flow:
411
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
412
+ for i, t in enumerate(timesteps):
413
+
414
+ # expand the latents if we are doing classifier free guidance
415
+ latent_model_input_no_flow = torch.cat(
416
+ [latents_no_flow] * 2) if self.do_classifier_free_guidance else latents_no_flow
417
+ latent_model_input_no_flow = self.scheduler.scale_model_input(latent_model_input_no_flow, t)
418
+
419
+ if added_cond_kwargs is not None:
420
+ added_cond_kwargs.update({"flow_embedding_features": flow_embedding_features})
421
+ else:
422
+ added_cond_kwargs = {"flow_embedding_features": flow_embedding_features}
423
+
424
+ if cross_attention_kwargs is not None:
425
+ cross_attention_kwargs.update({"flow_scale": 0.})
426
+ else:
427
+ cross_attention_kwargs = {"flow_scale": 0.}
428
+
429
+ if motion_cross_attention_kwargs is not None:
430
+ motion_cross_attention_kwargs.update({"flow_scale": 0.})
431
+ else:
432
+ motion_cross_attention_kwargs = {"flow_scale": 0.}
433
+
434
+ noise_pred_no_flow = self.unet(
435
+ latent_model_input_no_flow,
436
+ t,
437
+ encoder_hidden_states=prompt_embeds,
438
+ cross_attention_kwargs=cross_attention_kwargs,
439
+ motion_cross_attention_kwargs=motion_cross_attention_kwargs,
440
+ added_cond_kwargs=added_cond_kwargs,
441
+ ).sample
442
+
443
+ del latent_model_input_no_flow
444
+
445
+ # perform guidance
446
+ if self.do_classifier_free_guidance:
447
+ noise_pred_no_flow_uncond, noise_pred_no_flow_text = noise_pred_no_flow.chunk(2)
448
+ noise_pred_no_flow = noise_pred_no_flow_uncond + guidance_scale * (
449
+ noise_pred_no_flow_text - noise_pred_no_flow_uncond)
450
+ del noise_pred_no_flow_uncond, noise_pred_no_flow_text
451
+
452
+ # compute the previous noisy sample x_t -> x_t-1
453
+ latents_no_flow = self.scheduler.step(noise_pred_no_flow, t, latents_no_flow,
454
+ **extra_step_kwargs).prev_sample
455
+ del noise_pred_no_flow
456
+
457
+ if callback_on_step_end is not None:
458
+ callback_kwargs = {}
459
+ for k in callback_on_step_end_tensor_inputs:
460
+ callback_kwargs[k] = locals()[k]
461
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
462
+
463
+ latents = callback_outputs.pop("latents", latents)
464
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
465
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds",
466
+ negative_prompt_embeds)
467
+
468
+ # call the callback, if provided
469
+ if i == len(timesteps) - 1 or (
470
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
471
+ progress_bar.update()
472
+ if callback is not None and i % callback_steps == 0:
473
+ callback(i, t, latents)
474
+
475
+ # 9. Post processing
476
+ if output_type == "latent":
477
+ video = latents
478
+ if generate_no_flow:
479
+ video_no_flow = latents_no_flow
480
+ else:
481
+ video_tensor = self.decode_latents(latents, decode_chunk_size)
482
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
483
+
484
+ if generate_no_flow:
485
+ video_tensor_no_flow = self.decode_latents(latents_no_flow, decode_chunk_size)
486
+ video_no_flow = self.video_processor.postprocess_video(video=video_tensor_no_flow,
487
+ output_type=output_type)
488
+
489
+ # 10. Offload all models
490
+ self.maybe_free_model_hooks()
491
+
492
+ video_no_flow = None if not generate_no_flow else video_no_flow
493
+
494
+ if not return_dict:
495
+ return (video, video_no_flow)
496
+
497
+ return AnimateDiffPipelineOutput(frames_flow=video, frames_no_flow=video_no_flow)
onlyflow/pipelines/pipeline_animation_long.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+
4
+ # TODO: rebase on diffusers/pipelines/animatediff/pipeline_animatediff.py
5
+
6
+ import copy
7
+ import gc
8
+ from dataclasses import dataclass
9
+ from typing import Callable, Optional, Dict, Any, Tuple
10
+ from typing import List, Union
11
+
12
+ import PIL.Image
13
+ import numpy as np
14
+ import torch
15
+ from diffusers import AnimateDiffPipeline
16
+ from diffusers.image_processor import PipelineImageInput
17
+ from diffusers.models import AutoencoderKL
18
+ from diffusers.models.attention import FreeNoiseTransformerBlock
19
+ from diffusers.pipelines.animatediff.pipeline_animatediff import EXAMPLE_DOC_STRING
20
+ from diffusers.pipelines.free_noise_utils import AnimateDiffFreeNoiseMixin, SplitInferenceModule
21
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
22
+ from diffusers.schedulers import (
23
+ DDIMScheduler,
24
+ DPMSolverMultistepScheduler,
25
+ EulerAncestralDiscreteScheduler,
26
+ EulerDiscreteScheduler,
27
+ LMSDiscreteScheduler,
28
+ PNDMScheduler,
29
+ )
30
+ from diffusers.utils import BaseOutput
31
+ from diffusers.utils import deprecate, logging, replace_example_docstring
32
+ from einops import rearrange
33
+ from transformers import CLIPTextModel, CLIPTokenizer
34
+
35
+ from onlyflow.models.flow_adaptor import FlowEncoder
36
+ from onlyflow.models.unet import UNetMotionModel, AnimateDiffTransformer3D, \
37
+ CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, CrossAttnUpBlockMotion
38
+ from ..models.attention import BasicTransformerBlock
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ @dataclass
43
+ class FlowCtrlPipelineOutput(BaseOutput):
44
+ r"""
45
+ Output class for AnimateDiff pipelines.
46
+
47
+ Args:
48
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
49
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
50
+ denoised
51
+ PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
52
+ `(batch_size, num_frames, channels, height, width)`
53
+ """
54
+
55
+ frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
56
+
57
+
58
+ class FlowCtrlPipeline(AnimateDiffPipeline):
59
+ model_cpu_offload_seq = "text_encoder->flow_encoder->image_encoder->unet->vae"
60
+ _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
61
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
62
+
63
+ def __init__(self,
64
+ vae: AutoencoderKL,
65
+ text_encoder: CLIPTextModel,
66
+ tokenizer: CLIPTokenizer,
67
+ unet: UNetMotionModel,
68
+ scheduler: Union[
69
+ DDIMScheduler,
70
+ PNDMScheduler,
71
+ LMSDiscreteScheduler,
72
+ EulerDiscreteScheduler,
73
+ EulerAncestralDiscreteScheduler,
74
+ DPMSolverMultistepScheduler],
75
+ flow_encoder: FlowEncoder,
76
+ feature_extractor=None,
77
+ image_encoder=None,
78
+ motion_adapter=None,
79
+ ):
80
+
81
+ super().__init__(
82
+ vae=vae,
83
+ text_encoder=text_encoder,
84
+ tokenizer=tokenizer,
85
+ unet=unet,
86
+ motion_adapter=motion_adapter,
87
+ scheduler=scheduler,
88
+ feature_extractor=feature_extractor,
89
+ image_encoder=image_encoder,
90
+ )
91
+ self.register_modules(
92
+ flow_encoder=flow_encoder
93
+ )
94
+
95
+ def _enable_split_inference_motion_modules_(
96
+ self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int
97
+ ) -> None:
98
+ for motion_module in motion_modules:
99
+ motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"])
100
+
101
+ for i in range(len(motion_module.transformer_blocks)):
102
+ motion_module.transformer_blocks[i] = SplitInferenceModule(
103
+ motion_module.transformer_blocks[i],
104
+ spatial_split_size,
105
+ 0,
106
+ ["hidden_states", "encoder_hidden_states", "cross_attention_kwargs"],
107
+ )
108
+
109
+ motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"])
110
+
111
+
112
+ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, CrossAttnUpBlockMotion]):
113
+ r"""Helper function to enable FreeNoise in transformer blocks."""
114
+
115
+ for motion_module in block.motion_modules:
116
+ num_transformer_blocks = len(motion_module.transformer_blocks)
117
+
118
+ for i in range(num_transformer_blocks):
119
+ if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
120
+ motion_module.transformer_blocks[i].set_free_noise_properties(
121
+ self._free_noise_context_length,
122
+ self._free_noise_context_stride,
123
+ self._free_noise_weighting_scheme,
124
+ )
125
+ else:
126
+ basic_transfomer_block = motion_module.transformer_blocks[i]
127
+
128
+ motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock(
129
+ dim=basic_transfomer_block.dim,
130
+ num_attention_heads=basic_transfomer_block.num_attention_heads,
131
+ attention_head_dim=basic_transfomer_block.attention_head_dim,
132
+ dropout=basic_transfomer_block.dropout,
133
+ cross_attention_dim=basic_transfomer_block.cross_attention_dim,
134
+ activation_fn=basic_transfomer_block.activation_fn,
135
+ attention_bias=basic_transfomer_block.attention_bias,
136
+ only_cross_attention=basic_transfomer_block.only_cross_attention,
137
+ double_self_attention=basic_transfomer_block.double_self_attention,
138
+ positional_embeddings=basic_transfomer_block.positional_embeddings,
139
+ num_positional_embeddings=basic_transfomer_block.num_positional_embeddings,
140
+ context_length=self._free_noise_context_length,
141
+ context_stride=self._free_noise_context_stride,
142
+ weighting_scheme=self._free_noise_weighting_scheme,
143
+ ).to(device=self._execution_device, dtype=self.dtype)
144
+
145
+ # here i need to copy the attention processor from the basic transformer block to the free noise transformer block
146
+ motion_module.transformer_blocks[i].attn1 = basic_transfomer_block.attn1
147
+ motion_module.transformer_blocks[i].attn2 = basic_transfomer_block.attn2
148
+
149
+ motion_module.transformer_blocks[i].load_state_dict(
150
+ basic_transfomer_block.state_dict(), strict=True
151
+ )
152
+ motion_module.transformer_blocks[i].set_chunk_feed_forward(
153
+ basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim
154
+ )
155
+
156
+ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, CrossAttnUpBlockMotion]):
157
+ r"""Helper function to disable FreeNoise in transformer blocks."""
158
+
159
+ for motion_module in block.motion_modules:
160
+ num_transformer_blocks = len(motion_module.transformer_blocks)
161
+
162
+ for i in range(num_transformer_blocks):
163
+ if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
164
+ free_noise_transfomer_block = motion_module.transformer_blocks[i]
165
+
166
+ motion_module.transformer_blocks[i] = BasicTransformerBlock(
167
+ dim=free_noise_transfomer_block.dim,
168
+ num_attention_heads=free_noise_transfomer_block.num_attention_heads,
169
+ attention_head_dim=free_noise_transfomer_block.attention_head_dim,
170
+ dropout=free_noise_transfomer_block.dropout,
171
+ cross_attention_dim=free_noise_transfomer_block.cross_attention_dim,
172
+ activation_fn=free_noise_transfomer_block.activation_fn,
173
+ attention_bias=free_noise_transfomer_block.attention_bias,
174
+ only_cross_attention=free_noise_transfomer_block.only_cross_attention,
175
+ double_self_attention=free_noise_transfomer_block.double_self_attention,
176
+ positional_embeddings=free_noise_transfomer_block.positional_embeddings,
177
+ num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings,
178
+ ).to(device=self._execution_device, dtype=self.dtype)
179
+
180
+ motion_module.transformer_blocks[i].load_state_dict(
181
+ free_noise_transfomer_block.state_dict(), strict=True
182
+ )
183
+ motion_module.transformer_blocks[i].set_chunk_feed_forward(
184
+ free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim
185
+ )
186
+
187
+
188
+ @torch.no_grad()
189
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
190
+ def __call__(
191
+ self,
192
+ prompt: Union[str, List[str]] = None,
193
+ optical_flow: torch.FloatTensor = None,
194
+
195
+ num_frames: Optional[int] = 16,
196
+ height: Optional[int] = None,
197
+ width: Optional[int] = None,
198
+
199
+ num_inference_steps: int = 50,
200
+ guidance_scale: float = 7.5,
201
+ negative_prompt: Optional[Union[str, List[str]]] = None,
202
+ eta: float = 0.0,
203
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
204
+ latents: Optional[torch.Tensor] = None,
205
+
206
+ prompt_embeds: Optional[torch.Tensor] = None,
207
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
208
+ ip_adapter_image: Optional[PipelineImageInput] = None,
209
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
210
+
211
+ output_type: Optional[str] = "pt",
212
+ return_dict: bool = True,
213
+
214
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
215
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
216
+
217
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
218
+ motion_cross_attention_kwargs: Optional[Dict[str, Any]] = None,
219
+
220
+ clip_skip: Optional[int] = None,
221
+ decode_chunk_size: int = 16,
222
+
223
+ val_scale_factor_spatial: float = 0.,
224
+ val_scale_factor_temporal: float = 0.,
225
+
226
+ **kwargs,
227
+ ):
228
+ r"""
229
+ The call function to the pipeline for generation.
230
+
231
+ Args:
232
+ prompt (`str` or `List[str]`, *optional*):
233
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
234
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
235
+ The height in pixels of the generated video.
236
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
237
+ The width in pixels of the generated video.
238
+ num_frames (`int`, *optional*, defaults to 16):
239
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
240
+ amounts to 2 seconds of video.
241
+ num_inference_steps (`int`, *optional*, defaults to 50):
242
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
243
+ expense of slower inference.
244
+ guidance_scale (`float`, *optional*, defaults to 7.5):
245
+ A higher guidance scale value encourages the model to generate images closely linked to the text
246
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
247
+ negative_prompt (`str` or `List[str]`, *optional*):
248
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
249
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
250
+ eta (`float`, *optional*, defaults to 0.0):
251
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
252
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
253
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
254
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
255
+ generation deterministic.
256
+ latents (`torch.Tensor`, *optional*):
257
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
258
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
259
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
260
+ `(batch_size, num_channel, num_frames, height, width)`.
261
+ prompt_embeds (`torch.Tensor`, *optional*):
262
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
263
+ provided, text embeddings are generated from the `prompt` input argument.
264
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
265
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
266
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
267
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
268
+ Optional image input to work with IP Adapters.
269
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
270
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
271
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
272
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
273
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
274
+ output_type (`str`, *optional*, defaults to `"pil"`):
275
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
276
+ return_dict (`bool`, *optional*, defaults to `True`):
277
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
278
+ of a plain tuple.
279
+ cross_attention_kwargs (`dict`, *optional*):
280
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
281
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
282
+ clip_skip (`int`, *optional*):
283
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
284
+ the output of the pre-final layer will be used for computing the prompt embeddings.
285
+ callback_on_step_end (`Callable`, *optional*):
286
+ A function that calls at the end of each denoising steps during the inference. The function is called
287
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
288
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
289
+ `callback_on_step_end_tensor_inputs`.
290
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
291
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
292
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
293
+ `._callback_tensor_inputs` attribute of your pipeline class.
294
+ decode_chunk_size (`int`, defaults to `16`):
295
+ The number of frames to decode at a time when calling `decode_latents` method.
296
+
297
+ Examples:
298
+
299
+ Returns:
300
+ [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
301
+ If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
302
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
303
+ """
304
+
305
+ callback = kwargs.pop("callback", None)
306
+ callback_steps = kwargs.pop("callback_steps", None)
307
+
308
+ if callback is not None:
309
+ deprecate(
310
+ "callback",
311
+ "1.0.0",
312
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
313
+ )
314
+ if callback_steps is not None:
315
+ deprecate(
316
+ "callback_steps",
317
+ "1.0.0",
318
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
319
+ )
320
+
321
+ # 0. Default height and width to unet
322
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
323
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
324
+
325
+ num_videos_per_prompt = 1
326
+
327
+ # 1. Check inputs. Raise error if not correct
328
+ self.check_inputs(
329
+ prompt,
330
+ height,
331
+ width,
332
+ callback_steps,
333
+ negative_prompt,
334
+ prompt_embeds,
335
+ negative_prompt_embeds,
336
+ ip_adapter_image,
337
+ ip_adapter_image_embeds,
338
+ callback_on_step_end_tensor_inputs,
339
+ )
340
+
341
+ self._guidance_scale = guidance_scale
342
+ self._clip_skip = clip_skip
343
+ self._cross_attention_kwargs = cross_attention_kwargs
344
+ self._interrupt = False
345
+
346
+ # 2. Define call parameters
347
+ if prompt is not None and isinstance(prompt, (str, dict)):
348
+ batch_size = 1
349
+ elif prompt is not None and isinstance(prompt, list):
350
+ batch_size = len(prompt)
351
+ else:
352
+ batch_size = prompt_embeds.shape[0]
353
+
354
+ device = self._execution_device
355
+
356
+ # 3. Encode input prompt
357
+ text_encoder_lora_scale = (
358
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
359
+ )
360
+ if self.free_noise_enabled:
361
+ prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
362
+ prompt=prompt,
363
+ num_frames=num_frames,
364
+ device=device,
365
+ num_videos_per_prompt=num_videos_per_prompt,
366
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
367
+ negative_prompt=negative_prompt,
368
+ prompt_embeds=prompt_embeds,
369
+ negative_prompt_embeds=negative_prompt_embeds,
370
+ lora_scale=text_encoder_lora_scale,
371
+ clip_skip=self.clip_skip,
372
+ )
373
+ else:
374
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
375
+ prompt,
376
+ device,
377
+ num_videos_per_prompt,
378
+ self.do_classifier_free_guidance,
379
+ negative_prompt,
380
+ prompt_embeds=prompt_embeds,
381
+ negative_prompt_embeds=negative_prompt_embeds,
382
+ lora_scale=text_encoder_lora_scale,
383
+ clip_skip=self.clip_skip,
384
+ )
385
+
386
+ # For classifier free guidance, we need to do two forward passes.
387
+ # Here we concatenate the unconditional and text embeddings into a single batch
388
+ # to avoid doing two forward passes
389
+ if self.do_classifier_free_guidance:
390
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
391
+
392
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
393
+
394
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
395
+ image_embeds = self.prepare_ip_adapter_image_embeds(
396
+ ip_adapter_image,
397
+ ip_adapter_image_embeds,
398
+ device,
399
+ batch_size * num_videos_per_prompt,
400
+ self.do_classifier_free_guidance,
401
+ )
402
+
403
+ # 4. Prepare timesteps
404
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
405
+ timesteps = self.scheduler.timesteps
406
+
407
+ # 5. Prepare latent variables
408
+ num_channels_latents = self.unet.config.in_channels
409
+ latents = self.prepare_latents(
410
+ batch_size * num_videos_per_prompt,
411
+ num_channels_latents,
412
+ num_frames,
413
+ height,
414
+ width,
415
+ prompt_embeds.dtype,
416
+ device,
417
+ generator,
418
+ latents,
419
+ )
420
+
421
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
422
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
423
+
424
+ if torch.cuda.is_available():
425
+ torch.cuda.empty_cache()
426
+ torch.cuda.reset_peak_memory_stats()
427
+ torch.cuda.synchronize()
428
+ assert optical_flow.ndim == 5
429
+ bs = optical_flow.shape[0]
430
+ if self.free_noise_enabled:
431
+ length = optical_flow.shape[2]
432
+ flow_embedding_features = [
433
+ torch.zeros((bs, length, *test_size.shape[1:]), device=self._execution_device)
434
+ for test_size in self.flow_encoder(optical_flow[:,:,:16].to(self._execution_device))
435
+ ]
436
+ weight_factor = torch.zeros(length, device=self._execution_device)
437
+ for star_idx in range(0, length, self._free_noise_context_stride):
438
+ weight_factor[star_idx:star_idx + self._free_noise_context_length] += 1.0
439
+ infe = self.flow_encoder(optical_flow[:,:,star_idx:star_idx + self._free_noise_context_length].to(self._execution_device))
440
+ for flow_emb, infe_sub in zip(flow_embedding_features, infe):
441
+ flow_emb[:,star_idx:star_idx + self._free_noise_context_length] += rearrange(infe_sub, '(b f) c h w -> b f c h w', b=bs).to(self._execution_device)
442
+
443
+ flow_embedding_features = [flow_emb / weight_factor[None,:,None,None,None] for flow_emb in flow_embedding_features]
444
+ flow_embedding_features = [rearrange(x, 'b f c h w -> b c f h w') for x in flow_embedding_features]
445
+ else:
446
+ flow_embedding_features = self.flow_encoder(optical_flow.to(self._execution_device)) # input b c f h w into bf, c, h, w
447
+ flow_embedding_features = [rearrange(x, '(b f) c h w -> b c f h w', b=bs).to(self._execution_device)
448
+ for x in flow_embedding_features]
449
+
450
+ del optical_flow
451
+ gc.collect()
452
+ if torch.cuda.is_available():
453
+ torch.cuda.empty_cache()
454
+ torch.cuda.reset_peak_memory_stats()
455
+ torch.cuda.synchronize()
456
+
457
+ # 7. Add image embeds for IP-Adapter
458
+ added_cond_kwargs = (
459
+ {"image_embeds": image_embeds}
460
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
461
+ else None
462
+ )
463
+
464
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
465
+ for free_init_iter in range(num_free_init_iters):
466
+ if self.free_init_enabled:
467
+ latents, timesteps = self._apply_free_init(
468
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
469
+ )
470
+
471
+ self._num_timesteps = len(timesteps)
472
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
473
+
474
+ if isinstance(flow_embedding_features[0], list):
475
+ flow_embedding_features = [[torch.cat([x, x], dim=0) for x in flow_embedding_feature]
476
+ for flow_embedding_feature in flow_embedding_features] \
477
+ if self.do_classifier_free_guidance else flow_embedding_features
478
+ else:
479
+ flow_embedding_features = [torch.cat([x, x], dim=0) for x in flow_embedding_features] \
480
+ if self.do_classifier_free_guidance else flow_embedding_features # [2b c f h w]
481
+
482
+ # 8. Denoising loop
483
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
484
+ for i, t in enumerate(timesteps):
485
+ if self.interrupt:
486
+ continue
487
+
488
+ # expand the latents if we are doing classifier free guidance
489
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
490
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
491
+
492
+ if added_cond_kwargs is not None:
493
+ added_cond_kwargs.update({"flow_embedding_features": flow_embedding_features})
494
+ else:
495
+ added_cond_kwargs = {"flow_embedding_features": flow_embedding_features}
496
+
497
+ if cross_attention_kwargs is not None:
498
+ cross_attention_kwargs.update({"flow_scale": val_scale_factor_spatial})
499
+ else:
500
+ cross_attention_kwargs = {"flow_scale": val_scale_factor_spatial}
501
+
502
+ if motion_cross_attention_kwargs is not None:
503
+ motion_cross_attention_kwargs.update({"flow_scale": val_scale_factor_temporal})
504
+ else:
505
+ motion_cross_attention_kwargs = {"flow_scale": val_scale_factor_temporal}
506
+
507
+ # predict the noise residual
508
+
509
+ noise_pred = self.unet(
510
+ latent_model_input,
511
+ t,
512
+ encoder_hidden_states=prompt_embeds,
513
+ cross_attention_kwargs=cross_attention_kwargs,
514
+ motion_cross_attention_kwargs=motion_cross_attention_kwargs,
515
+ added_cond_kwargs=added_cond_kwargs,
516
+ ).sample
517
+
518
+ # perform guidance
519
+ if self.do_classifier_free_guidance:
520
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
521
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
522
+
523
+ # compute the previous noisy sample x_t -> x_t-1
524
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
525
+
526
+ if callback_on_step_end is not None:
527
+ callback_kwargs = {}
528
+ for k in callback_on_step_end_tensor_inputs:
529
+ callback_kwargs[k] = locals()[k]
530
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
531
+
532
+ latents = callback_outputs.pop("latents", latents)
533
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
534
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
535
+
536
+ # call the callback, if provided
537
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
538
+ progress_bar.update()
539
+ if callback is not None and i % callback_steps == 0:
540
+ callback(i, t, latents)
541
+
542
+ # 9. Post processing
543
+ if output_type == "latent":
544
+ video = latents
545
+ else:
546
+ video_tensor = self.decode_latents(latents, decode_chunk_size)
547
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
548
+
549
+ # 10. Offload all models
550
+ self.maybe_free_model_hooks()
551
+
552
+ if not return_dict:
553
+ return (video,)
554
+
555
+ return FlowCtrlPipelineOutput(frames=video)
onlyflow/utils/util.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import functools
3
+ import importlib
4
+ import io
5
+ import logging
6
+ import os
7
+ import sys
8
+
9
+ import imageio
10
+ import numpy as np
11
+ import torch
12
+ from termcolor import colored
13
+
14
+
15
+ def instantiate_from_config(config, **additional_kwargs):
16
+ if not "target" in config:
17
+ if config == '__is_first_stage__':
18
+ return None
19
+ elif config == "__is_unconditional__":
20
+ return None
21
+ raise KeyError("Expected key `target` to instantiate.")
22
+
23
+ additional_kwargs.update(config.get("kwargs", dict()))
24
+ return get_obj_from_str(config["target"])(**additional_kwargs)
25
+
26
+
27
+ def get_obj_from_str(string, reload=False):
28
+ module, cls = string.rsplit(".", 1)
29
+ if reload:
30
+ module_imp = importlib.import_module(module)
31
+ importlib.reload(module_imp)
32
+ return getattr(importlib.import_module(module, package=None), cls)
33
+
34
+
35
+ def get_video(videos: torch.Tensor, path: str, rescale=False, fps=8):
36
+ if rescale:
37
+ videos = (videos + 1.0) / 2.0 # -1,1 -> 0,1
38
+ videos = (videos * 255).numpy().astype(np.uint8)
39
+ videos = np.transpose(videos, axes=(1, 2, 3, 0))
40
+
41
+ binary_object = io.BytesIO()
42
+
43
+ imageio.mimsave(binary_object, list(videos), fps=fps, format='gif')
44
+
45
+ return binary_object
46
+
47
+
48
+ # Logger utils are copied from detectron2
49
+ class _ColorfulFormatter(logging.Formatter):
50
+ def __init__(self, *args, **kwargs):
51
+ self._root_name = kwargs.pop("root_name") + "."
52
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
53
+ if len(self._abbrev_name):
54
+ self._abbrev_name = self._abbrev_name + "."
55
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
56
+
57
+ def formatMessage(self, record):
58
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
59
+ log = super(_ColorfulFormatter, self).formatMessage(record)
60
+ if record.levelno == logging.WARNING:
61
+ prefix = colored("WARNING", "red", attrs=["blink"])
62
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
63
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
64
+ else:
65
+ return log
66
+ return prefix + " " + log
67
+
68
+
69
+ # cache the opened file object, so that different calls to `setup_logger`
70
+ # with the same file name can safely write to the same file.
71
+ @functools.lru_cache(maxsize=None)
72
+ def _cached_log_stream(filename):
73
+ # use 1K buffer if writing to cloud storage
74
+ io = open(filename, "a", buffering=1024 if "://" in filename else -1)
75
+ atexit.register(io.close)
76
+ return io
77
+
78
+
79
+ @functools.lru_cache()
80
+ def setup_logger(output, distributed_rank, color=True, name='AnimateDiff', abbrev_name=None):
81
+ logger = logging.getLogger(name)
82
+ logger.setLevel(logging.DEBUG)
83
+ logger.propagate = False
84
+
85
+ if abbrev_name is None:
86
+ abbrev_name = 'AD'
87
+ plain_formatter = logging.Formatter(
88
+ "[%(asctime)s] %(name)s:%(lineno)d %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
89
+ )
90
+
91
+ # stdout logging: master only
92
+ if distributed_rank == 0:
93
+ ch = logging.StreamHandler(stream=sys.stdout)
94
+ ch.setLevel(logging.DEBUG)
95
+ if color:
96
+ formatter = _ColorfulFormatter(
97
+ colored("[%(asctime)s %(name)s:%(lineno)d]: ", "green") + "%(message)s",
98
+ datefmt="%m/%d %H:%M:%S",
99
+ root_name=name,
100
+ abbrev_name=str(abbrev_name),
101
+ )
102
+ else:
103
+ formatter = plain_formatter
104
+ ch.setFormatter(formatter)
105
+ logger.addHandler(ch)
106
+
107
+ # file logging: all workers
108
+ if output is not None:
109
+ filename = os.path.join(output, "ranks_logs", f"log.{distributed_rank}.txt")
110
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
111
+ fh = logging.StreamHandler(_cached_log_stream(filename))
112
+ fh.setLevel(logging.DEBUG)
113
+ fh.setFormatter(plain_formatter)
114
+ logger.addHandler(fh)
115
+
116
+ return logger
117
+
118
+
119
+ def format_time(elapsed_time):
120
+ # Time thresholds
121
+ minute = 60
122
+ hour = 60 * minute
123
+ day = 24 * hour
124
+
125
+ days, remainder = divmod(elapsed_time, day)
126
+ hours, remainder = divmod(remainder, hour)
127
+ minutes, seconds = divmod(remainder, minute)
128
+
129
+ formatted_time = ""
130
+
131
+ if days > 0:
132
+ formatted_time += f"{int(days)} days "
133
+ if hours > 0:
134
+ formatted_time += f"{int(hours)} hours "
135
+ if minutes > 0:
136
+ formatted_time += f"{int(minutes)} minutes "
137
+ if seconds > 0:
138
+ formatted_time += f"{seconds:.2f} seconds"
139
+
140
+ return formatted_time.strip()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ diffusers
4
+ transformers
5
+ accelerate
6
+ git+https://github.com/obvious-research/diffusers.git
7
+ numpy
8
+ einops
9
+ imageio
10
+ omegaconf
11
+ av==12.0.0
tools/optical_flow.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+ @torch.no_grad()
5
+ def get_optical_flow(raft_model, pixel_values, video_length, encode_chunk_size=48, num_flow_updates=14):
6
+ imgs_1 = pixel_values[:, :-1]
7
+ imgs_2 = pixel_values[:, 1:]
8
+ imgs_1 = rearrange(imgs_1, "b f c h w -> (b f) c h w")
9
+ imgs_2 = rearrange(imgs_2, "b f c h w -> (b f) c h w")
10
+
11
+ flow_embedding = []
12
+
13
+ for i in range(0, imgs_1.shape[0], encode_chunk_size):
14
+ imgs_1_chunk = imgs_1[i:i + encode_chunk_size]
15
+ imgs_2_chunk = imgs_2[i:i + encode_chunk_size]
16
+ flow_embedding_chunk = raft_model(imgs_1_chunk, imgs_2_chunk, num_flow_updates)[-1]
17
+ flow_embedding.append(flow_embedding_chunk)
18
+
19
+ flow_embedding = torch.cat(flow_embedding).contiguous()
20
+ flow_embedding = rearrange(flow_embedding, "(b f) c h w -> b c f h w", f=video_length)
21
+
22
+ return flow_embedding