multimodalart HF staff commited on
Commit
f0533a5
1 Parent(s): a3f1b31

Upload 33 files

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Yang Jin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,165 @@
1
- ---
2
- title: Pyramid Video
3
- emoji: 💻
4
- colorFrom: purple
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # ⚡️Pyramid Flow⚡️
4
+
5
+ [[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page ✨]](https://pyramid-flow.github.io) [[Model 🤗]](https://huggingface.co/rain1011/pyramid-flow-sd3)
6
+
7
+ </div>
8
+
9
+ This is the official repository for Pyramid Flow, a training-efficient **Autoregressive Video Generation** method based on **Flow Matching**. By training only on **open-source datasets**, it can generate high-quality 10-second videos at 768p resolution and 24 FPS, and naturally supports image-to-video generation.
10
+
11
+ <table class="center" border="0" style="width: 100%; text-align: left;">
12
+ <tr>
13
+ <th>10s, 768p, 24fps</th>
14
+ <th>5s, 768p, 24fps</th>
15
+ <th>Image-to-video</th>
16
+ </tr>
17
+ <tr>
18
+ <td><video src="https://github.com/user-attachments/assets/9935da83-ae56-4672-8747-0f46e90f7b2b" autoplay muted loop playsinline></video></td>
19
+ <td><video src="https://github.com/user-attachments/assets/3412848b-64db-4d9e-8dbf-11403f6d02c5" autoplay muted loop playsinline></video></td>
20
+ <td><video src="https://github.com/user-attachments/assets/3bd7251f-7b2c-4bee-951d-656fdb45f427" autoplay muted loop playsinline></video></td>
21
+ </tr>
22
+ </table>
23
+
24
+ ## News
25
+
26
+ * `COMING SOON` ⚡️⚡️⚡️ Training code for both the Video VAE and DiT; New model checkpoints trained from scratch.
27
+
28
+ > We are training Pyramid Flow from scratch to fix human structure issues related to the currently adopted SD3 initialization and hope to release it in the next few days.
29
+ * `2024.10.10` 🚀🚀🚀 We release the [technical report](https://arxiv.org/abs/2410.05954), [project page](https://pyramid-flow.github.io) and [model checkpoint](https://huggingface.co/rain1011/pyramid-flow-sd3) of Pyramid Flow.
30
+
31
+ ## Introduction
32
+
33
+ ![motivation](assets/motivation.jpg)
34
+
35
+ Existing video diffusion models operate at full resolution, spending a lot of computation on very noisy latents. By contrast, our method harnesses the flexibility of flow matching ([Lipman et al., 2023](https://openreview.net/forum?id=PqvMRDCJT9t); [Liu et al., 2023](https://openreview.net/forum?id=XVjTT1nw5z); [Albergo & Vanden-Eijnden, 2023](https://openreview.net/forum?id=li7qeBbCR1t)) to interpolate between latents of different resolutions and noise levels, allowing for simultaneous generation and decompression of visual content with better computational efficiency. The entire framework is end-to-end optimized with a single DiT ([Peebles & Xie, 2023](http://openaccess.thecvf.com/content/ICCV2023/html/Peebles_Scalable_Diffusion_Models_with_Transformers_ICCV_2023_paper.html)), generating high-quality 10-second videos at 768p resolution and 24 FPS within 20.7k A100 GPU training hours.
36
+
37
+ ## Usage
38
+
39
+ You can directly download the model from [Huggingface](https://huggingface.co/rain1011/pyramid-flow-sd3). We provide both model checkpoints for 768p and 384p video generation. The 384p checkpoint supports 5-second video generation at 24FPS, while the 768p checkpoint supports up to 10-second video generation at 24FPS.
40
+
41
+ ```python
42
+ from huggingface_hub import snapshot_download
43
+
44
+ model_path = 'PATH' # The local directory to save downloaded checkpoint
45
+ snapshot_download("rain1011/pyramid-flow-sd3", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
46
+ ```
47
+
48
+
49
+ To use our model, please follow the inference code in `video_generation_demo.ipynb` at [this link](https://github.com/jy0205/Pyramid-Flow/blob/main/video_generation_demo.ipynb). We further simplify it into the following two-step procedure. First, load the downloaded model:
50
+
51
+ ```python
52
+ import torch
53
+ from PIL import Image
54
+ from pyramid_dit import PyramidDiTForVideoGeneration
55
+ from diffusers.utils import load_image, export_to_video
56
+
57
+ torch.cuda.set_device(0)
58
+ model_dtype, torch_dtype = 'bf16', torch.bfloat16 # Use bf16, fp16 or fp32
59
+
60
+ model = PyramidDiTForVideoGeneration(
61
+ 'PATH', # The downloaded checkpoint dir
62
+ model_dtype,
63
+ model_variant='diffusion_transformer_768p', # 'diffusion_transformer_384p'
64
+ )
65
+
66
+ model.vae.to("cuda")
67
+ model.dit.to("cuda")
68
+ model.text_encoder.to("cuda")
69
+ model.vae.enable_tiling()
70
+ ```
71
+
72
+ Then, you can try text-to-video generation on your own prompts:
73
+
74
+ ```python
75
+ prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"
76
+
77
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
78
+ frames = model.generate(
79
+ prompt=prompt,
80
+ num_inference_steps=[20, 20, 20],
81
+ video_num_inference_steps=[10, 10, 10],
82
+ height=768,
83
+ width=1280,
84
+ temp=16, # temp=16: 5s, temp=31: 10s
85
+ guidance_scale=9.0, # The guidance for the first frame
86
+ video_guidance_scale=5.0, # The guidance for the other video latent
87
+ output_type="pil",
88
+ save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
89
+ )
90
+
91
+ export_to_video(frames, "./text_to_video_sample.mp4", fps=24)
92
+ ```
93
+
94
+ As an autoregressive model, our model also supports (text conditioned) image-to-video generation:
95
+
96
+ ```python
97
+ image = Image.open('assets/the_great_wall.jpg').convert("RGB").resize((1280, 768))
98
+ prompt = "FPV flying over the Great Wall"
99
+
100
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
101
+ frames = model.generate_i2v(
102
+ prompt=prompt,
103
+ input_image=image,
104
+ num_inference_steps=[10, 10, 10],
105
+ temp=16,
106
+ video_guidance_scale=4.0,
107
+ output_type="pil",
108
+ save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
109
+ )
110
+
111
+ export_to_video(frames, "./image_to_video_sample.mp4", fps=24)
112
+ ```
113
+
114
+ Usage tips:
115
+
116
+ * The `guidance_scale` parameter controls the visual quality. We suggest using a guidance within [7, 9] for the 768p checkpoint during text-to-video generation, and 7 for the 384p checkpoint.
117
+ * The `video_guidance_scale` parameter controls the motion. A larger value increases the dynamic degree and mitigates the autoregressive generation degradation, while a smaller value stabilizes the video.
118
+ * For 10-second video generation, we recommend using a guidance scale of 7 and a video guidance scale of 5.
119
+
120
+ ## Gallery
121
+
122
+ The following video examples are generated at 5s, 768p, 24fps. For more results, please visit our [project page](https://pyramid-flow.github.io).
123
+
124
+ <table class="center" border="0" style="width: 100%; text-align: left;">
125
+ <tr>
126
+ <td><video src="https://github.com/user-attachments/assets/5b44a57e-fa08-4554-84a2-2c7a99f2b343" autoplay muted loop playsinline></video></td>
127
+ <td><video src="https://github.com/user-attachments/assets/5afd5970-de72-40e2-900d-a20d18308e8e" autoplay muted loop playsinline></video></td>
128
+ </tr>
129
+ <tr>
130
+ <td><video src="https://github.com/user-attachments/assets/1d44daf8-017f-40e9-bf18-1e19c0a8983b" autoplay muted loop playsinline></video></td>
131
+ <td><video src="https://github.com/user-attachments/assets/7f5dd901-b7d7-48cc-b67a-3c5f9e1546d2" autoplay muted loop playsinline></video></td>
132
+ </tr>
133
+ </table>
134
+
135
+ ## Comparison
136
+
137
+ On VBench ([Huang et al., 2024](https://huggingface.co/spaces/Vchitect/VBench_Leaderboard)), our method surpasses all the compared open-source baselines. Even with only public video data, it achieves comparable performance to commercial models like Kling ([Kuaishou, 2024](https://kling.kuaishou.com/en)) and Gen-3 Alpha ([Runway, 2024](https://runwayml.com/research/introducing-gen-3-alpha)), especially in the quality score (84.74 vs. 84.11 of Gen-3) and motion smoothness.
138
+
139
+ ![vbench](assets/vbench.jpg)
140
+
141
+ We conduct an additional user study with 20+ participants. As can be seen, our method is preferred over open-source models such as [Open-Sora](https://github.com/hpcaitech/Open-Sora) and [CogVideoX-2B](https://github.com/THUDM/CogVideo) especially in terms of motion smoothness.
142
+
143
+ ![user_study](assets/user_study.jpg)
144
+
145
+ ## Acknowledgement
146
+
147
+ We are grateful for the following awesome projects when implementing Pyramid Flow:
148
+
149
+ * [SD3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium) and [Flux 1.0](https://huggingface.co/black-forest-labs/FLUX.1-dev): State-of-the-art image generation models based on flow matching.
150
+ * [Diffusion Forcing](https://boyuan.space/diffusion-forcing) and [GameNGen](https://gamengen.github.io): Next-token prediction meets full-sequence diffusion.
151
+ * [WebVid-10M](https://github.com/m-bain/webvid), [OpenVid-1M](https://github.com/NJU-PCALab/OpenVid-1M) and [Open-Sora Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan): Large-scale datasets for text-to-video generation.
152
+ * [CogVideoX](https://github.com/THUDM/CogVideo): An open-source text-to-video generation model that shares many training details.
153
+ * [Video-LLaMA2](https://github.com/DAMO-NLP-SG/VideoLLaMA2): An open-source video LLM for our video recaptioning.
154
+
155
+ ## Citation
156
+
157
+ Consider giving this repository a star and cite Pyramid Flow in your publications if it helps your research.
158
+ ```
159
+ @article{jin2024pyramidal,
160
+ title={Pyramidal Flow Matching for Efficient Video Generative Modeling},
161
+ author={Jin, Yang and Sun, Zhicheng and Li, Ningyuan and Xu, Kun and Xu, Kun and Jiang, Hao and Zhuang, Nan and Huang, Quzhe and Song, Yang and Mu, Yadong and Lin, Zhouchen},
162
+ jounal={arXiv preprint arXiv:2410.05954},
163
+ year={2024}
164
+ }
165
+ ```
assets/motivation.jpg ADDED
assets/the_great_wall.jpg ADDED
assets/user_study.jpg ADDED
assets/vbench.jpg ADDED
diffusion_schedulers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .scheduling_cosine_ddpm import DDPMCosineScheduler
2
+ from .scheduling_flow_matching import PyramidFlowMatchEulerDiscreteScheduler
diffusion_schedulers/scheduling_cosine_ddpm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+
12
+
13
+ @dataclass
14
+ class DDPMSchedulerOutput(BaseOutput):
15
+ """
16
+ Output class for the scheduler's step function output.
17
+
18
+ Args:
19
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
20
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
21
+ denoising loop.
22
+ """
23
+
24
+ prev_sample: torch.Tensor
25
+
26
+
27
+ class DDPMCosineScheduler(SchedulerMixin, ConfigMixin):
28
+
29
+ @register_to_config
30
+ def __init__(
31
+ self,
32
+ scaler: float = 1.0,
33
+ s: float = 0.008,
34
+ ):
35
+ self.scaler = scaler
36
+ self.s = torch.tensor([s])
37
+ self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
38
+
39
+ # standard deviation of the initial noise distribution
40
+ self.init_noise_sigma = 1.0
41
+
42
+ def _alpha_cumprod(self, t, device):
43
+ if self.scaler > 1:
44
+ t = 1 - (1 - t) ** self.scaler
45
+ elif self.scaler < 1:
46
+ t = t**self.scaler
47
+ alpha_cumprod = torch.cos(
48
+ (t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5
49
+ ) ** 2 / self._init_alpha_cumprod.to(device)
50
+ return alpha_cumprod.clamp(0.0001, 0.9999)
51
+
52
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
53
+ """
54
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
55
+ current timestep.
56
+
57
+ Args:
58
+ sample (`torch.Tensor`): input sample
59
+ timestep (`int`, optional): current timestep
60
+
61
+ Returns:
62
+ `torch.Tensor`: scaled input sample
63
+ """
64
+ return sample
65
+
66
+ def set_timesteps(
67
+ self,
68
+ num_inference_steps: int = None,
69
+ timesteps: Optional[List[int]] = None,
70
+ device: Union[str, torch.device] = None,
71
+ ):
72
+ """
73
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
74
+
75
+ Args:
76
+ num_inference_steps (`Dict[float, int]`):
77
+ the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
78
+ `timesteps` must be `None`.
79
+ device (`str` or `torch.device`, optional):
80
+ the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10}
81
+ """
82
+ if timesteps is None:
83
+ timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device)
84
+ if not isinstance(timesteps, torch.Tensor):
85
+ timesteps = torch.Tensor(timesteps).to(device)
86
+ self.timesteps = timesteps
87
+
88
+ def step(
89
+ self,
90
+ model_output: torch.Tensor,
91
+ timestep: int,
92
+ sample: torch.Tensor,
93
+ generator=None,
94
+ return_dict: bool = True,
95
+ ) -> Union[DDPMSchedulerOutput, Tuple]:
96
+ dtype = model_output.dtype
97
+ device = model_output.device
98
+ t = timestep
99
+
100
+ prev_t = self.previous_timestep(t)
101
+
102
+ alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]])
103
+ alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
104
+ alpha = alpha_cumprod / alpha_cumprod_prev
105
+
106
+ mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt())
107
+
108
+ std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
109
+ std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise
110
+ pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
111
+
112
+ if not return_dict:
113
+ return (pred.to(dtype),)
114
+
115
+ return DDPMSchedulerOutput(prev_sample=pred.to(dtype))
116
+
117
+ def add_noise(
118
+ self,
119
+ original_samples: torch.Tensor,
120
+ noise: torch.Tensor,
121
+ timesteps: torch.Tensor,
122
+ ) -> torch.Tensor:
123
+ device = original_samples.device
124
+ dtype = original_samples.dtype
125
+ alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
126
+ timesteps.size(0), *[1 for _ in original_samples.shape[1:]]
127
+ )
128
+ noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise
129
+ return noisy_samples.to(dtype=dtype)
130
+
131
+ def __len__(self):
132
+ return self.config.num_train_timesteps
133
+
134
+ def previous_timestep(self, timestep):
135
+ index = (self.timesteps - timestep[0]).abs().argmin().item()
136
+ prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0])
137
+ return prev_t
diffusion_schedulers/scheduling_flow_matching.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union, List
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput, logging
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+ from IPython import embed
12
+
13
+
14
+ @dataclass
15
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
16
+ """
17
+ Output class for the scheduler's `step` function output.
18
+
19
+ Args:
20
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
21
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
22
+ denoising loop.
23
+ """
24
+
25
+ prev_sample: torch.FloatTensor
26
+
27
+
28
+ class PyramidFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
29
+ """
30
+ Euler scheduler.
31
+
32
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
33
+ methods the library implements for all schedulers such as loading and saving.
34
+
35
+ Args:
36
+ num_train_timesteps (`int`, defaults to 1000):
37
+ The number of diffusion steps to train the model.
38
+ timestep_spacing (`str`, defaults to `"linspace"`):
39
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
40
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
41
+ shift (`float`, defaults to 1.0):
42
+ The shift value for the timestep schedule.
43
+ """
44
+
45
+ _compatibles = []
46
+ order = 1
47
+
48
+ @register_to_config
49
+ def __init__(
50
+ self,
51
+ num_train_timesteps: int = 1000,
52
+ shift: float = 1.0, # Following Stable diffusion 3,
53
+ stages: int = 3,
54
+ stage_range: List = [0, 1/3, 2/3, 1],
55
+ gamma: float = 1/3,
56
+ ):
57
+
58
+ self.timestep_ratios = {} # The timestep ratio for each stage
59
+ self.timesteps_per_stage = {} # The detailed timesteps per stage
60
+ self.sigmas_per_stage = {}
61
+ self.start_sigmas = {}
62
+ self.end_sigmas = {}
63
+ self.ori_start_sigmas = {}
64
+
65
+ # self.init_sigmas()
66
+ self.init_sigmas_for_each_stage()
67
+ self.sigma_min = self.sigmas[-1].item()
68
+ self.sigma_max = self.sigmas[0].item()
69
+ self.gamma = gamma
70
+
71
+ def init_sigmas(self):
72
+ """
73
+ initialize the global timesteps and sigmas
74
+ """
75
+ num_train_timesteps = self.config.num_train_timesteps
76
+ shift = self.config.shift
77
+
78
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
79
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
80
+
81
+ sigmas = timesteps / num_train_timesteps
82
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
83
+
84
+ self.timesteps = sigmas * num_train_timesteps
85
+
86
+ self._step_index = None
87
+ self._begin_index = None
88
+
89
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
90
+
91
+ def init_sigmas_for_each_stage(self):
92
+ """
93
+ Init the timesteps for each stage
94
+ """
95
+ self.init_sigmas()
96
+
97
+ stage_distance = []
98
+ stages = self.config.stages
99
+ training_steps = self.config.num_train_timesteps
100
+ stage_range = self.config.stage_range
101
+
102
+ # Init the start and end point of each stage
103
+ for i_s in range(stages):
104
+ # To decide the start and ends point
105
+ start_indice = int(stage_range[i_s] * training_steps)
106
+ start_indice = max(start_indice, 0)
107
+ end_indice = int(stage_range[i_s+1] * training_steps)
108
+ end_indice = min(end_indice, training_steps)
109
+ start_sigma = self.sigmas[start_indice].item()
110
+ end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
111
+ self.ori_start_sigmas[i_s] = start_sigma
112
+
113
+ if i_s != 0:
114
+ ori_sigma = 1 - start_sigma
115
+ gamma = self.config.gamma
116
+ corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
117
+ # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
118
+ start_sigma = 1 - corrected_sigma
119
+
120
+ stage_distance.append(start_sigma - end_sigma)
121
+ self.start_sigmas[i_s] = start_sigma
122
+ self.end_sigmas[i_s] = end_sigma
123
+
124
+ # Determine the ratio of each stage according to flow length
125
+ tot_distance = sum(stage_distance)
126
+ for i_s in range(stages):
127
+ if i_s == 0:
128
+ start_ratio = 0.0
129
+ else:
130
+ start_ratio = sum(stage_distance[:i_s]) / tot_distance
131
+ if i_s == stages - 1:
132
+ end_ratio = 1.0
133
+ else:
134
+ end_ratio = sum(stage_distance[:i_s+1]) / tot_distance
135
+
136
+ self.timestep_ratios[i_s] = (start_ratio, end_ratio)
137
+
138
+ # Determine the timesteps and sigmas for each stage
139
+ for i_s in range(stages):
140
+ timestep_ratio = self.timestep_ratios[i_s]
141
+ timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
142
+ timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
143
+ timesteps = np.linspace(
144
+ timestep_max, timestep_min, training_steps + 1,
145
+ )
146
+ self.timesteps_per_stage[i_s] = torch.from_numpy(timesteps[:-1])
147
+ stage_sigmas = np.linspace(
148
+ 1, 0, training_steps + 1,
149
+ )
150
+ self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
151
+
152
+ @property
153
+ def step_index(self):
154
+ """
155
+ The index counter for current timestep. It will increase 1 after each scheduler step.
156
+ """
157
+ return self._step_index
158
+
159
+ @property
160
+ def begin_index(self):
161
+ """
162
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
163
+ """
164
+ return self._begin_index
165
+
166
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
167
+ def set_begin_index(self, begin_index: int = 0):
168
+ """
169
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
170
+
171
+ Args:
172
+ begin_index (`int`):
173
+ The begin index for the scheduler.
174
+ """
175
+ self._begin_index = begin_index
176
+
177
+ def _sigma_to_t(self, sigma):
178
+ return sigma * self.config.num_train_timesteps
179
+
180
+ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None):
181
+ """
182
+ Setting the timesteps and sigmas for each stage
183
+ """
184
+ self.num_inference_steps = num_inference_steps
185
+ training_steps = self.config.num_train_timesteps
186
+ self.init_sigmas()
187
+
188
+ stage_timesteps = self.timesteps_per_stage[stage_index]
189
+ timestep_max = stage_timesteps[0].item()
190
+ timestep_min = stage_timesteps[-1].item()
191
+
192
+ timesteps = np.linspace(
193
+ timestep_max, timestep_min, num_inference_steps,
194
+ )
195
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
196
+
197
+ stage_sigmas = self.sigmas_per_stage[stage_index]
198
+ sigma_max = stage_sigmas[0].item()
199
+ sigma_min = stage_sigmas[-1].item()
200
+
201
+ ratios = np.linspace(
202
+ sigma_max, sigma_min, num_inference_steps
203
+ )
204
+ sigmas = torch.from_numpy(ratios).to(device=device)
205
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
206
+
207
+ self._step_index = None
208
+
209
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
210
+ if schedule_timesteps is None:
211
+ schedule_timesteps = self.timesteps
212
+
213
+ indices = (schedule_timesteps == timestep).nonzero()
214
+
215
+ # The sigma index that is taken for the **very** first `step`
216
+ # is always the second index (or the last index if there is only 1)
217
+ # This way we can ensure we don't accidentally skip a sigma in
218
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
219
+ pos = 1 if len(indices) > 1 else 0
220
+
221
+ return indices[pos].item()
222
+
223
+ def _init_step_index(self, timestep):
224
+ if self.begin_index is None:
225
+ if isinstance(timestep, torch.Tensor):
226
+ timestep = timestep.to(self.timesteps.device)
227
+ self._step_index = self.index_for_timestep(timestep)
228
+ else:
229
+ self._step_index = self._begin_index
230
+
231
+ def step(
232
+ self,
233
+ model_output: torch.FloatTensor,
234
+ timestep: Union[float, torch.FloatTensor],
235
+ sample: torch.FloatTensor,
236
+ generator: Optional[torch.Generator] = None,
237
+ return_dict: bool = True,
238
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
239
+ """
240
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
241
+ process from the learned model outputs (most often the predicted noise).
242
+
243
+ Args:
244
+ model_output (`torch.FloatTensor`):
245
+ The direct output from learned diffusion model.
246
+ timestep (`float`):
247
+ The current discrete timestep in the diffusion chain.
248
+ sample (`torch.FloatTensor`):
249
+ A current instance of a sample created by the diffusion process.
250
+ generator (`torch.Generator`, *optional*):
251
+ A random number generator.
252
+ return_dict (`bool`):
253
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
254
+ tuple.
255
+
256
+ Returns:
257
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
258
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
259
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
260
+ """
261
+
262
+ if (
263
+ isinstance(timestep, int)
264
+ or isinstance(timestep, torch.IntTensor)
265
+ or isinstance(timestep, torch.LongTensor)
266
+ ):
267
+ raise ValueError(
268
+ (
269
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
270
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
271
+ " one of the `scheduler.timesteps` as a timestep."
272
+ ),
273
+ )
274
+
275
+ if self.step_index is None:
276
+ self._step_index = 0
277
+
278
+ # Upcast to avoid precision issues when computing prev_sample
279
+ sample = sample.to(torch.float32)
280
+
281
+ sigma = self.sigmas[self.step_index]
282
+ sigma_next = self.sigmas[self.step_index + 1]
283
+
284
+ prev_sample = sample + (sigma_next - sigma) * model_output
285
+
286
+ # Cast sample back to model compatible dtype
287
+ prev_sample = prev_sample.to(model_output.dtype)
288
+
289
+ # upon completion increase step index by one
290
+ self._step_index += 1
291
+
292
+ if not return_dict:
293
+ return (prev_sample,)
294
+
295
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
296
+
297
+ def __len__(self):
298
+ return self.config.num_train_timesteps
pyramid_dit/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .modeling_pyramid_mmdit import PyramidDiffusionMMDiT
2
+ from .pyramid_dit_for_video_gen_pipeline import PyramidDiTForVideoGeneration
3
+ from .modeling_text_encoder import SD3TextEncoderWithMask
pyramid_dit/modeling_embedding.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import math
7
+
8
+ from diffusers.models.activations import get_activation
9
+ from einops import rearrange
10
+
11
+
12
+ def get_1d_sincos_pos_embed(
13
+ embed_dim, num_frames, cls_token=False, extra_tokens=0,
14
+ ):
15
+ t = np.arange(num_frames, dtype=np.float32)
16
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, t) # (T, D)
17
+ if cls_token and extra_tokens > 0:
18
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
19
+ return pos_embed
20
+
21
+
22
+ def get_2d_sincos_pos_embed(
23
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
24
+ ):
25
+ """
26
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
27
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
28
+ """
29
+ if isinstance(grid_size, int):
30
+ grid_size = (grid_size, grid_size)
31
+
32
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
33
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
34
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
35
+ grid = np.stack(grid, axis=0)
36
+
37
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
38
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
39
+ if cls_token and extra_tokens > 0:
40
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
41
+ return pos_embed
42
+
43
+
44
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
45
+ if embed_dim % 2 != 0:
46
+ raise ValueError("embed_dim must be divisible by 2")
47
+
48
+ # use half of dimensions to encode grid_h
49
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
50
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
51
+
52
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
53
+ return emb
54
+
55
+
56
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
57
+ """
58
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
59
+ """
60
+ if embed_dim % 2 != 0:
61
+ raise ValueError("embed_dim must be divisible by 2")
62
+
63
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
64
+ omega /= embed_dim / 2.0
65
+ omega = 1.0 / 10000**omega # (D/2,)
66
+
67
+ pos = pos.reshape(-1) # (M,)
68
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
69
+
70
+ emb_sin = np.sin(out) # (M, D/2)
71
+ emb_cos = np.cos(out) # (M, D/2)
72
+
73
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
74
+ return emb
75
+
76
+
77
+ def get_timestep_embedding(
78
+ timesteps: torch.Tensor,
79
+ embedding_dim: int,
80
+ flip_sin_to_cos: bool = False,
81
+ downscale_freq_shift: float = 1,
82
+ scale: float = 1,
83
+ max_period: int = 10000,
84
+ ):
85
+ """
86
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
87
+ :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional.
88
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
89
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
90
+ """
91
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
92
+
93
+ half_dim = embedding_dim // 2
94
+ exponent = -math.log(max_period) * torch.arange(
95
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
96
+ )
97
+ exponent = exponent / (half_dim - downscale_freq_shift)
98
+
99
+ emb = torch.exp(exponent)
100
+ emb = timesteps[:, None].float() * emb[None, :]
101
+
102
+ # scale embeddings
103
+ emb = scale * emb
104
+
105
+ # concat sine and cosine embeddings
106
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
107
+
108
+ # flip sine and cosine embeddings
109
+ if flip_sin_to_cos:
110
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
111
+
112
+ # zero pad
113
+ if embedding_dim % 2 == 1:
114
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
115
+ return emb
116
+
117
+
118
+ class Timesteps(nn.Module):
119
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
120
+ super().__init__()
121
+ self.num_channels = num_channels
122
+ self.flip_sin_to_cos = flip_sin_to_cos
123
+ self.downscale_freq_shift = downscale_freq_shift
124
+
125
+ def forward(self, timesteps):
126
+ t_emb = get_timestep_embedding(
127
+ timesteps,
128
+ self.num_channels,
129
+ flip_sin_to_cos=self.flip_sin_to_cos,
130
+ downscale_freq_shift=self.downscale_freq_shift,
131
+ )
132
+ return t_emb
133
+
134
+
135
+ class TimestepEmbedding(nn.Module):
136
+ def __init__(
137
+ self,
138
+ in_channels: int,
139
+ time_embed_dim: int,
140
+ act_fn: str = "silu",
141
+ out_dim: int = None,
142
+ post_act_fn: Optional[str] = None,
143
+ sample_proj_bias=True,
144
+ ):
145
+ super().__init__()
146
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
147
+ self.act = get_activation(act_fn)
148
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, sample_proj_bias)
149
+
150
+ def forward(self, sample):
151
+ sample = self.linear_1(sample)
152
+ sample = self.act(sample)
153
+ sample = self.linear_2(sample)
154
+ return sample
155
+
156
+
157
+ class TextProjection(nn.Module):
158
+ def __init__(self, in_features, hidden_size, act_fn="silu"):
159
+ super().__init__()
160
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
161
+ self.act_1 = get_activation(act_fn)
162
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
163
+
164
+ def forward(self, caption):
165
+ hidden_states = self.linear_1(caption)
166
+ hidden_states = self.act_1(hidden_states)
167
+ hidden_states = self.linear_2(hidden_states)
168
+ return hidden_states
169
+
170
+
171
+ class CombinedTimestepConditionEmbeddings(nn.Module):
172
+ def __init__(self, embedding_dim, pooled_projection_dim):
173
+ super().__init__()
174
+
175
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
176
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
177
+ self.text_embedder = TextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
178
+
179
+ def forward(self, timestep, pooled_projection):
180
+ timesteps_proj = self.time_proj(timestep)
181
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
182
+ pooled_projections = self.text_embedder(pooled_projection)
183
+ conditioning = timesteps_emb + pooled_projections
184
+ return conditioning
185
+
186
+
187
+ class CombinedTimestepEmbeddings(nn.Module):
188
+ def __init__(self, embedding_dim):
189
+ super().__init__()
190
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
191
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
192
+
193
+ def forward(self, timestep):
194
+ timesteps_proj = self.time_proj(timestep)
195
+ timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
196
+ return timesteps_emb
197
+
198
+
199
+ class PatchEmbed3D(nn.Module):
200
+ """Support the 3D Tensor input"""
201
+
202
+ def __init__(
203
+ self,
204
+ height=128,
205
+ width=128,
206
+ patch_size=2,
207
+ in_channels=16,
208
+ embed_dim=1536,
209
+ layer_norm=False,
210
+ bias=True,
211
+ interpolation_scale=1,
212
+ pos_embed_type="sincos",
213
+ temp_pos_embed_type='rope',
214
+ pos_embed_max_size=192, # For SD3 cropping
215
+ max_num_frames=64,
216
+ add_temp_pos_embed=False,
217
+ interp_condition_pos=False,
218
+ ):
219
+ super().__init__()
220
+
221
+ num_patches = (height // patch_size) * (width // patch_size)
222
+ self.layer_norm = layer_norm
223
+ self.pos_embed_max_size = pos_embed_max_size
224
+
225
+ self.proj = nn.Conv2d(
226
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
227
+ )
228
+ if layer_norm:
229
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
230
+ else:
231
+ self.norm = None
232
+
233
+ self.patch_size = patch_size
234
+ self.height, self.width = height // patch_size, width // patch_size
235
+ self.base_size = height // patch_size
236
+ self.interpolation_scale = interpolation_scale
237
+ self.add_temp_pos_embed = add_temp_pos_embed
238
+
239
+ # Calculate positional embeddings based on max size or default
240
+ if pos_embed_max_size:
241
+ grid_size = pos_embed_max_size
242
+ else:
243
+ grid_size = int(num_patches**0.5)
244
+
245
+ if pos_embed_type is None:
246
+ self.pos_embed = None
247
+
248
+ elif pos_embed_type == "sincos":
249
+ pos_embed = get_2d_sincos_pos_embed(
250
+ embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
251
+ )
252
+ persistent = True if pos_embed_max_size else False
253
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
254
+
255
+ if add_temp_pos_embed and temp_pos_embed_type == 'sincos':
256
+ time_pos_embed = get_1d_sincos_pos_embed(embed_dim, max_num_frames)
257
+ self.register_buffer("temp_pos_embed", torch.from_numpy(time_pos_embed).float().unsqueeze(0), persistent=True)
258
+
259
+ elif pos_embed_type == "rope":
260
+ print("Using the rotary position embedding")
261
+
262
+ else:
263
+ raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
264
+
265
+ self.pos_embed_type = pos_embed_type
266
+ self.temp_pos_embed_type = temp_pos_embed_type
267
+ self.interp_condition_pos = interp_condition_pos
268
+
269
+ def cropped_pos_embed(self, height, width, ori_height, ori_width):
270
+ """Crops positional embeddings for SD3 compatibility."""
271
+ if self.pos_embed_max_size is None:
272
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
273
+
274
+ height = height // self.patch_size
275
+ width = width // self.patch_size
276
+ ori_height = ori_height // self.patch_size
277
+ ori_width = ori_width // self.patch_size
278
+
279
+ assert ori_height >= height, "The ori_height needs >= height"
280
+ assert ori_width >= width, "The ori_width needs >= width"
281
+
282
+ if height > self.pos_embed_max_size:
283
+ raise ValueError(
284
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
285
+ )
286
+ if width > self.pos_embed_max_size:
287
+ raise ValueError(
288
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
289
+ )
290
+
291
+ if self.interp_condition_pos:
292
+ top = (self.pos_embed_max_size - ori_height) // 2
293
+ left = (self.pos_embed_max_size - ori_width) // 2
294
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
295
+ spatial_pos_embed = spatial_pos_embed[:, top : top + ori_height, left : left + ori_width, :] # [b h w c]
296
+ if ori_height != height or ori_width != width:
297
+ spatial_pos_embed = spatial_pos_embed.permute(0, 3, 1, 2)
298
+ spatial_pos_embed = torch.nn.functional.interpolate(spatial_pos_embed, size=(height, width), mode='bilinear')
299
+ spatial_pos_embed = spatial_pos_embed.permute(0, 2, 3, 1)
300
+ else:
301
+ top = (self.pos_embed_max_size - height) // 2
302
+ left = (self.pos_embed_max_size - width) // 2
303
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
304
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
305
+
306
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
307
+
308
+ return spatial_pos_embed
309
+
310
+ def forward_func(self, latent, time_index=0, ori_height=None, ori_width=None):
311
+ if self.pos_embed_max_size is not None:
312
+ height, width = latent.shape[-2:]
313
+ else:
314
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
315
+
316
+ bs = latent.shape[0]
317
+ temp = latent.shape[2]
318
+
319
+ latent = rearrange(latent, 'b c t h w -> (b t) c h w')
320
+ latent = self.proj(latent)
321
+ latent = latent.flatten(2).transpose(1, 2) # (BT)CHW -> (BT)NC
322
+
323
+ if self.layer_norm:
324
+ latent = self.norm(latent)
325
+
326
+ if self.pos_embed_type == 'sincos':
327
+ # Spatial position embedding, Interpolate or crop positional embeddings as needed
328
+ if self.pos_embed_max_size:
329
+ pos_embed = self.cropped_pos_embed(height, width, ori_height, ori_width)
330
+ else:
331
+ raise NotImplementedError("Not implemented sincos pos embed without sd3 max pos crop")
332
+ if self.height != height or self.width != width:
333
+ pos_embed = get_2d_sincos_pos_embed(
334
+ embed_dim=self.pos_embed.shape[-1],
335
+ grid_size=(height, width),
336
+ base_size=self.base_size,
337
+ interpolation_scale=self.interpolation_scale,
338
+ )
339
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
340
+ else:
341
+ pos_embed = self.pos_embed
342
+
343
+ if self.add_temp_pos_embed and self.temp_pos_embed_type == 'sincos':
344
+ latent_dtype = latent.dtype
345
+ latent = latent + pos_embed
346
+ latent = rearrange(latent, '(b t) n c -> (b n) t c', t=temp)
347
+ latent = latent + self.temp_pos_embed[:, time_index:time_index + temp, :]
348
+ latent = latent.to(latent_dtype)
349
+ latent = rearrange(latent, '(b n) t c -> b t n c', b=bs)
350
+ else:
351
+ latent = (latent + pos_embed).to(latent.dtype)
352
+ latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp)
353
+
354
+ else:
355
+ assert self.pos_embed_type == "rope", "Only supporting the sincos and rope embedding"
356
+ latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp)
357
+
358
+ return latent
359
+
360
+ def forward(self, latent):
361
+ """
362
+ Arguments:
363
+ past_condition_latents (Torch.FloatTensor): The past latent during the generation
364
+ flatten_input (bool): True indicate flatten the latent into 1D sequence
365
+ """
366
+
367
+ if isinstance(latent, list):
368
+ output_list = []
369
+
370
+ for latent_ in latent:
371
+ if not isinstance(latent_, list):
372
+ latent_ = [latent_]
373
+
374
+ output_latent = []
375
+ time_index = 0
376
+ ori_height, ori_width = latent_[-1].shape[-2:]
377
+ for each_latent in latent_:
378
+ hidden_state = self.forward_func(each_latent, time_index=time_index, ori_height=ori_height, ori_width=ori_width)
379
+ time_index += each_latent.shape[2]
380
+ hidden_state = rearrange(hidden_state, "b t n c -> b (t n) c")
381
+ output_latent.append(hidden_state)
382
+
383
+ output_latent = torch.cat(output_latent, dim=1)
384
+ output_list.append(output_latent)
385
+
386
+ return output_list
387
+ else:
388
+ hidden_states = self.forward_func(latent)
389
+ hidden_states = rearrange(hidden_states, "b t n c -> b (t n) c")
390
+ return hidden_states
pyramid_dit/modeling_mmdit_block.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, List
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
7
+
8
+ try:
9
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
10
+ from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
+ except:
13
+ flash_attn_func = None
14
+ flash_attn_qkvpacked_func = None
15
+ flash_attn_varlen_func = None
16
+ print("Please install flash attention")
17
+
18
+ from trainer_misc import (
19
+ is_sequence_parallel_initialized,
20
+ get_sequence_parallel_group,
21
+ get_sequence_parallel_world_size,
22
+ all_to_all,
23
+ )
24
+
25
+ from .modeling_normalization import AdaLayerNormZero, AdaLayerNormContinuous, RMSNorm
26
+
27
+
28
+ class FeedForward(nn.Module):
29
+ r"""
30
+ A feed-forward layer.
31
+
32
+ Parameters:
33
+ dim (`int`): The number of channels in the input.
34
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
35
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
36
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
37
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
38
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
39
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
40
+ """
41
+ def __init__(
42
+ self,
43
+ dim: int,
44
+ dim_out: Optional[int] = None,
45
+ mult: int = 4,
46
+ dropout: float = 0.0,
47
+ activation_fn: str = "geglu",
48
+ final_dropout: bool = False,
49
+ inner_dim=None,
50
+ bias: bool = True,
51
+ ):
52
+ super().__init__()
53
+ if inner_dim is None:
54
+ inner_dim = int(dim * mult)
55
+ dim_out = dim_out if dim_out is not None else dim
56
+
57
+ if activation_fn == "gelu":
58
+ act_fn = GELU(dim, inner_dim, bias=bias)
59
+ if activation_fn == "gelu-approximate":
60
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
61
+ elif activation_fn == "geglu":
62
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
63
+ elif activation_fn == "geglu-approximate":
64
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
65
+
66
+ self.net = nn.ModuleList([])
67
+ # project in
68
+ self.net.append(act_fn)
69
+ # project dropout
70
+ self.net.append(nn.Dropout(dropout))
71
+ # project out
72
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
73
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
74
+ if final_dropout:
75
+ self.net.append(nn.Dropout(dropout))
76
+
77
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
78
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
79
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
80
+ deprecate("scale", "1.0.0", deprecation_message)
81
+ for module in self.net:
82
+ hidden_states = module(hidden_states)
83
+ return hidden_states
84
+
85
+
86
+ class VarlenFlashSelfAttentionWithT5Mask:
87
+
88
+ def __init__(self):
89
+ pass
90
+
91
+ def apply_rope(self, xq, xk, freqs_cis):
92
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
93
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
94
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
95
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
96
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
97
+
98
+ def __call__(
99
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
100
+ heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
101
+ ):
102
+ assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
103
+
104
+ batch_size = query.shape[0]
105
+ output_hidden = torch.zeros_like(query)
106
+ output_encoder_hidden = torch.zeros_like(encoder_query)
107
+ encoder_length = encoder_query.shape[1]
108
+
109
+ qkv_list = []
110
+ num_stages = len(hidden_length)
111
+
112
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
113
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
114
+
115
+ i_sum = 0
116
+ for i_p, length in enumerate(hidden_length):
117
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
118
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
119
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
120
+
121
+ if image_rotary_emb is not None:
122
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
123
+
124
+ indices = encoder_attention_mask[i_p]['indices']
125
+ qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
126
+ i_sum += length
127
+
128
+ token_lengths = [x_.shape[0] for x_ in qkv_list]
129
+ qkv = torch.cat(qkv_list, dim=0)
130
+ query, key, value = qkv.unbind(1)
131
+
132
+ cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
133
+ max_seqlen_q = cu_seqlens.max().item()
134
+ max_seqlen_k = max_seqlen_q
135
+ cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
136
+ cu_seqlens_k = cu_seqlens_q.clone()
137
+
138
+ output = flash_attn_varlen_func(
139
+ query,
140
+ key,
141
+ value,
142
+ cu_seqlens_q=cu_seqlens_q,
143
+ cu_seqlens_k=cu_seqlens_k,
144
+ max_seqlen_q=max_seqlen_q,
145
+ max_seqlen_k=max_seqlen_k,
146
+ dropout_p=0.0,
147
+ causal=False,
148
+ softmax_scale=scale,
149
+ )
150
+
151
+ # To merge the tokens
152
+ i_sum = 0;token_sum = 0
153
+ for i_p, length in enumerate(hidden_length):
154
+ tot_token_num = token_lengths[i_p]
155
+ stage_output = output[token_sum : token_sum + tot_token_num]
156
+ stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length)
157
+ stage_encoder_hidden_output = stage_output[:, :encoder_length]
158
+ stage_hidden_output = stage_output[:, encoder_length:]
159
+ output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
160
+ output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
161
+ token_sum += tot_token_num
162
+ i_sum += length
163
+
164
+ output_hidden = output_hidden.flatten(2, 3)
165
+ output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
166
+
167
+ return output_hidden, output_encoder_hidden
168
+
169
+
170
+ class SequenceParallelVarlenFlashSelfAttentionWithT5Mask:
171
+
172
+ def __init__(self):
173
+ pass
174
+
175
+ def apply_rope(self, xq, xk, freqs_cis):
176
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
177
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
178
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
179
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
180
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
181
+
182
+ def __call__(
183
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
184
+ heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
185
+ ):
186
+ assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
187
+
188
+ batch_size = query.shape[0]
189
+ qkv_list = []
190
+ num_stages = len(hidden_length)
191
+
192
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
193
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
194
+
195
+ # To sync the encoder query, key and values
196
+ sp_group = get_sequence_parallel_group()
197
+ sp_group_size = get_sequence_parallel_world_size()
198
+ encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
199
+
200
+ output_hidden = torch.zeros_like(qkv[:,:,0])
201
+ output_encoder_hidden = torch.zeros_like(encoder_qkv[:,:,0])
202
+ encoder_length = encoder_qkv.shape[1]
203
+
204
+ i_sum = 0
205
+ for i_p, length in enumerate(hidden_length):
206
+ # get the query, key, value from padding sequence
207
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
208
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
209
+ qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
210
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, pad_seq, 3, nhead, dim]
211
+
212
+ if image_rotary_emb is not None:
213
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
214
+
215
+ indices = encoder_attention_mask[i_p]['indices']
216
+ qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
217
+ i_sum += length
218
+
219
+ token_lengths = [x_.shape[0] for x_ in qkv_list]
220
+ qkv = torch.cat(qkv_list, dim=0)
221
+ query, key, value = qkv.unbind(1)
222
+
223
+ cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
224
+ max_seqlen_q = cu_seqlens.max().item()
225
+ max_seqlen_k = max_seqlen_q
226
+ cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
227
+ cu_seqlens_k = cu_seqlens_q.clone()
228
+
229
+ output = flash_attn_varlen_func(
230
+ query,
231
+ key,
232
+ value,
233
+ cu_seqlens_q=cu_seqlens_q,
234
+ cu_seqlens_k=cu_seqlens_k,
235
+ max_seqlen_q=max_seqlen_q,
236
+ max_seqlen_k=max_seqlen_k,
237
+ dropout_p=0.0,
238
+ causal=False,
239
+ softmax_scale=scale,
240
+ )
241
+
242
+ # To merge the tokens
243
+ i_sum = 0;token_sum = 0
244
+ for i_p, length in enumerate(hidden_length):
245
+ tot_token_num = token_lengths[i_p]
246
+ stage_output = output[token_sum : token_sum + tot_token_num]
247
+ stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length * sp_group_size)
248
+ stage_encoder_hidden_output = stage_output[:, :encoder_length]
249
+ stage_hidden_output = stage_output[:, encoder_length:]
250
+ stage_hidden_output = all_to_all(stage_hidden_output, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
251
+ output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
252
+ output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
253
+ token_sum += tot_token_num
254
+ i_sum += length
255
+
256
+ output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
257
+ output_hidden = output_hidden.flatten(2, 3)
258
+ output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
259
+
260
+ return output_hidden, output_encoder_hidden
261
+
262
+
263
+ class VarlenSelfAttentionWithT5Mask:
264
+
265
+ """
266
+ For chunk stage attention without using flash attention
267
+ """
268
+
269
+ def __init__(self):
270
+ pass
271
+
272
+ def apply_rope(self, xq, xk, freqs_cis):
273
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
274
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
275
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
276
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
277
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
278
+
279
+ def __call__(
280
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
281
+ heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
282
+ ):
283
+ assert attention_mask is not None, "The attention mask needed to be set"
284
+
285
+ encoder_length = encoder_query.shape[1]
286
+ num_stages = len(hidden_length)
287
+
288
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
289
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
290
+
291
+ i_sum = 0
292
+ output_encoder_hidden_list = []
293
+ output_hidden_list = []
294
+
295
+ for i_p, length in enumerate(hidden_length):
296
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
297
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
298
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
299
+
300
+ if image_rotary_emb is not None:
301
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
302
+
303
+ query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
304
+ query = query.transpose(1, 2)
305
+ key = key.transpose(1, 2)
306
+ value = value.transpose(1, 2)
307
+
308
+ # with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):
309
+ stage_hidden_states = F.scaled_dot_product_attention(
310
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
311
+ )
312
+ stage_hidden_states = stage_hidden_states.transpose(1, 2).flatten(2, 3) # [bs, tot_seq, dim]
313
+
314
+ output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
315
+ output_hidden_list.append(stage_hidden_states[:, encoder_length:])
316
+ i_sum += length
317
+
318
+ output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) # [b n s d]
319
+ output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s d -> (b n) s d')
320
+ output_hidden = torch.cat(output_hidden_list, dim=1)
321
+
322
+ return output_hidden, output_encoder_hidden
323
+
324
+
325
+ class SequenceParallelVarlenSelfAttentionWithT5Mask:
326
+ """
327
+ For chunk stage attention without using flash attention
328
+ """
329
+
330
+ def __init__(self):
331
+ pass
332
+
333
+ def apply_rope(self, xq, xk, freqs_cis):
334
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
335
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
336
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
337
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
338
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
339
+
340
+ def __call__(
341
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
342
+ heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
343
+ ):
344
+ assert attention_mask is not None, "The attention mask needed to be set"
345
+
346
+ num_stages = len(hidden_length)
347
+
348
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
349
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
350
+
351
+ # To sync the encoder query, key and values
352
+ sp_group = get_sequence_parallel_group()
353
+ sp_group_size = get_sequence_parallel_world_size()
354
+ encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
355
+ encoder_length = encoder_qkv.shape[1]
356
+
357
+ i_sum = 0
358
+ output_encoder_hidden_list = []
359
+ output_hidden_list = []
360
+
361
+ for i_p, length in enumerate(hidden_length):
362
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
363
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
364
+ qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
365
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
366
+
367
+ if image_rotary_emb is not None:
368
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
369
+
370
+ query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
371
+ query = query.transpose(1, 2)
372
+ key = key.transpose(1, 2)
373
+ value = value.transpose(1, 2)
374
+
375
+ stage_hidden_states = F.scaled_dot_product_attention(
376
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
377
+ )
378
+ stage_hidden_states = stage_hidden_states.transpose(1, 2) # [bs, tot_seq, nhead, dim]
379
+
380
+ output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
381
+
382
+ output_hidden = stage_hidden_states[:, encoder_length:]
383
+ output_hidden = all_to_all(output_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
384
+ output_hidden_list.append(output_hidden)
385
+
386
+ i_sum += length
387
+
388
+ output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) # [b n s nhead d]
389
+ output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s h d -> (b n) s h d')
390
+ output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
391
+ output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
392
+ output_hidden = torch.cat(output_hidden_list, dim=1).flatten(2, 3)
393
+
394
+ return output_hidden, output_encoder_hidden
395
+
396
+
397
+ class JointAttention(nn.Module):
398
+
399
+ def __init__(
400
+ self,
401
+ query_dim: int,
402
+ cross_attention_dim: Optional[int] = None,
403
+ heads: int = 8,
404
+ dim_head: int = 64,
405
+ dropout: float = 0.0,
406
+ bias: bool = False,
407
+ qk_norm: Optional[str] = None,
408
+ added_kv_proj_dim: Optional[int] = None,
409
+ out_bias: bool = True,
410
+ eps: float = 1e-5,
411
+ out_dim: int = None,
412
+ context_pre_only=None,
413
+ use_flash_attn=True,
414
+ ):
415
+ """
416
+ Fixing the QKNorm, following the flux, norm the head dimension
417
+ """
418
+ super().__init__()
419
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
420
+ self.query_dim = query_dim
421
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
422
+ self.use_bias = bias
423
+ self.dropout = dropout
424
+
425
+ self.out_dim = out_dim if out_dim is not None else query_dim
426
+ self.context_pre_only = context_pre_only
427
+
428
+ self.scale = dim_head**-0.5
429
+ self.heads = out_dim // dim_head if out_dim is not None else heads
430
+ self.added_kv_proj_dim = added_kv_proj_dim
431
+
432
+ if qk_norm is None:
433
+ self.norm_q = None
434
+ self.norm_k = None
435
+ elif qk_norm == "layer_norm":
436
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps)
437
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps)
438
+ elif qk_norm == 'rms_norm':
439
+ self.norm_q = RMSNorm(dim_head, eps=eps)
440
+ self.norm_k = RMSNorm(dim_head, eps=eps)
441
+ else:
442
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
443
+
444
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
445
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
446
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
447
+
448
+ if self.added_kv_proj_dim is not None:
449
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
450
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
451
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
452
+
453
+ if qk_norm is None:
454
+ self.norm_add_q = None
455
+ self.norm_add_k = None
456
+ elif qk_norm == "layer_norm":
457
+ self.norm_add_q = nn.LayerNorm(dim_head, eps=eps)
458
+ self.norm_add_k = nn.LayerNorm(dim_head, eps=eps)
459
+ elif qk_norm == 'rms_norm':
460
+ self.norm_add_q = RMSNorm(dim_head, eps=eps)
461
+ self.norm_add_k = RMSNorm(dim_head, eps=eps)
462
+ else:
463
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
464
+
465
+ self.to_out = nn.ModuleList([])
466
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
467
+ self.to_out.append(nn.Dropout(dropout))
468
+
469
+ if not self.context_pre_only:
470
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
471
+
472
+ self.use_flash_attn = use_flash_attn
473
+
474
+ if flash_attn_func is None:
475
+ self.use_flash_attn = False
476
+
477
+ # print(f"Using flash-attention: {self.use_flash_attn}")
478
+ if self.use_flash_attn:
479
+ if is_sequence_parallel_initialized():
480
+ self.var_flash_attn = SequenceParallelVarlenFlashSelfAttentionWithT5Mask()
481
+ else:
482
+ self.var_flash_attn = VarlenFlashSelfAttentionWithT5Mask()
483
+ else:
484
+ if is_sequence_parallel_initialized():
485
+ self.var_len_attn = SequenceParallelVarlenSelfAttentionWithT5Mask()
486
+ else:
487
+ self.var_len_attn = VarlenSelfAttentionWithT5Mask()
488
+
489
+
490
+ def forward(
491
+ self,
492
+ hidden_states: torch.FloatTensor,
493
+ encoder_hidden_states: torch.FloatTensor = None,
494
+ encoder_attention_mask: torch.FloatTensor = None,
495
+ attention_mask: torch.FloatTensor = None, # [B, L, S]
496
+ hidden_length: torch.Tensor = None,
497
+ image_rotary_emb: torch.Tensor = None,
498
+ **kwargs,
499
+ ) -> torch.FloatTensor:
500
+ # This function is only used during training
501
+ # `sample` projections.
502
+ query = self.to_q(hidden_states)
503
+ key = self.to_k(hidden_states)
504
+ value = self.to_v(hidden_states)
505
+
506
+ inner_dim = key.shape[-1]
507
+ head_dim = inner_dim // self.heads
508
+
509
+ query = query.view(query.shape[0], -1, self.heads, head_dim)
510
+ key = key.view(key.shape[0], -1, self.heads, head_dim)
511
+ value = value.view(value.shape[0], -1, self.heads, head_dim)
512
+
513
+ if self.norm_q is not None:
514
+ query = self.norm_q(query)
515
+
516
+ if self.norm_k is not None:
517
+ key = self.norm_k(key)
518
+
519
+ # `context` projections.
520
+ encoder_hidden_states_query_proj = self.add_q_proj(encoder_hidden_states)
521
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
522
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
523
+
524
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
525
+ encoder_hidden_states_query_proj.shape[0], -1, self.heads, head_dim
526
+ )
527
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
528
+ encoder_hidden_states_key_proj.shape[0], -1, self.heads, head_dim
529
+ )
530
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
531
+ encoder_hidden_states_value_proj.shape[0], -1, self.heads, head_dim
532
+ )
533
+
534
+ if self.norm_add_q is not None:
535
+ encoder_hidden_states_query_proj = self.norm_add_q(encoder_hidden_states_query_proj)
536
+
537
+ if self.norm_add_k is not None:
538
+ encoder_hidden_states_key_proj = self.norm_add_k(encoder_hidden_states_key_proj)
539
+
540
+ # To cat the hidden and encoder hidden, perform attention compuataion, and then split
541
+ if self.use_flash_attn:
542
+ hidden_states, encoder_hidden_states = self.var_flash_attn(
543
+ query, key, value,
544
+ encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
545
+ encoder_hidden_states_value_proj, self.heads, self.scale, hidden_length,
546
+ image_rotary_emb, encoder_attention_mask,
547
+ )
548
+ else:
549
+ hidden_states, encoder_hidden_states = self.var_len_attn(
550
+ query, key, value,
551
+ encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
552
+ encoder_hidden_states_value_proj, self.heads, self.scale, hidden_length,
553
+ image_rotary_emb, attention_mask,
554
+ )
555
+
556
+ # linear proj
557
+ hidden_states = self.to_out[0](hidden_states)
558
+ # dropout
559
+ hidden_states = self.to_out[1](hidden_states)
560
+ if not self.context_pre_only:
561
+ encoder_hidden_states = self.to_add_out(encoder_hidden_states)
562
+
563
+ return hidden_states, encoder_hidden_states
564
+
565
+
566
+ class JointTransformerBlock(nn.Module):
567
+ r"""
568
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
569
+
570
+ Reference: https://arxiv.org/abs/2403.03206
571
+
572
+ Parameters:
573
+ dim (`int`): The number of channels in the input and output.
574
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
575
+ attention_head_dim (`int`): The number of channels in each head.
576
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
577
+ processing of `context` conditions.
578
+ """
579
+
580
+ def __init__(
581
+ self, dim, num_attention_heads, attention_head_dim, qk_norm=None,
582
+ context_pre_only=False, use_flash_attn=True,
583
+ ):
584
+ super().__init__()
585
+
586
+ self.context_pre_only = context_pre_only
587
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
588
+
589
+ self.norm1 = AdaLayerNormZero(dim)
590
+
591
+ if context_norm_type == "ada_norm_continous":
592
+ self.norm1_context = AdaLayerNormContinuous(
593
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
594
+ )
595
+ elif context_norm_type == "ada_norm_zero":
596
+ self.norm1_context = AdaLayerNormZero(dim)
597
+ else:
598
+ raise ValueError(
599
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
600
+ )
601
+
602
+ self.attn = JointAttention(
603
+ query_dim=dim,
604
+ cross_attention_dim=None,
605
+ added_kv_proj_dim=dim,
606
+ dim_head=attention_head_dim // num_attention_heads,
607
+ heads=num_attention_heads,
608
+ out_dim=attention_head_dim,
609
+ qk_norm=qk_norm,
610
+ context_pre_only=context_pre_only,
611
+ bias=True,
612
+ use_flash_attn=use_flash_attn,
613
+ )
614
+
615
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
616
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
617
+
618
+ if not context_pre_only:
619
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
620
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
621
+ else:
622
+ self.norm2_context = None
623
+ self.ff_context = None
624
+
625
+ def forward(
626
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor,
627
+ encoder_attention_mask: torch.FloatTensor, temb: torch.FloatTensor,
628
+ attention_mask: torch.FloatTensor = None, hidden_length: List = None,
629
+ image_rotary_emb: torch.FloatTensor = None,
630
+ ):
631
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb, hidden_length=hidden_length)
632
+
633
+ if self.context_pre_only:
634
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
635
+ else:
636
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
637
+ encoder_hidden_states, emb=temb,
638
+ )
639
+
640
+ # Attention
641
+ attn_output, context_attn_output = self.attn(
642
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
643
+ encoder_attention_mask=encoder_attention_mask, attention_mask=attention_mask,
644
+ hidden_length=hidden_length, image_rotary_emb=image_rotary_emb,
645
+ )
646
+
647
+ # Process attention outputs for the `hidden_states`.
648
+ attn_output = gate_msa * attn_output
649
+ hidden_states = hidden_states + attn_output
650
+
651
+ norm_hidden_states = self.norm2(hidden_states)
652
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
653
+
654
+ ff_output = self.ff(norm_hidden_states)
655
+ ff_output = gate_mlp * ff_output
656
+
657
+ hidden_states = hidden_states + ff_output
658
+
659
+ # Process attention outputs for the `encoder_hidden_states`.
660
+ if self.context_pre_only:
661
+ encoder_hidden_states = None
662
+ else:
663
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
664
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
665
+
666
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
667
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
668
+
669
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
670
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
671
+
672
+ return encoder_hidden_states, hidden_states
pyramid_dit/modeling_normalization.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from diffusers.utils import is_torch_version
9
+
10
+
11
+ if is_torch_version(">=", "2.1.0"):
12
+ LayerNorm = nn.LayerNorm
13
+ else:
14
+ # Has optional bias parameter compared to torch layer norm
15
+ # TODO: replace with torch layernorm once min required torch version >= 2.1
16
+ class LayerNorm(nn.Module):
17
+ def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
18
+ super().__init__()
19
+
20
+ self.eps = eps
21
+
22
+ if isinstance(dim, numbers.Integral):
23
+ dim = (dim,)
24
+
25
+ self.dim = torch.Size(dim)
26
+
27
+ if elementwise_affine:
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+ self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
30
+ else:
31
+ self.weight = None
32
+ self.bias = None
33
+
34
+ def forward(self, input):
35
+ return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
36
+
37
+
38
+ class RMSNorm(nn.Module):
39
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True):
40
+ super().__init__()
41
+
42
+ self.eps = eps
43
+
44
+ if isinstance(dim, numbers.Integral):
45
+ dim = (dim,)
46
+
47
+ self.dim = torch.Size(dim)
48
+
49
+ if elementwise_affine:
50
+ self.weight = nn.Parameter(torch.ones(dim))
51
+ else:
52
+ self.weight = None
53
+
54
+ def forward(self, hidden_states):
55
+ input_dtype = hidden_states.dtype
56
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
57
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
58
+
59
+ if self.weight is not None:
60
+ # convert into half-precision if necessary
61
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
62
+ hidden_states = hidden_states.to(self.weight.dtype)
63
+ hidden_states = hidden_states * self.weight
64
+
65
+ hidden_states = hidden_states.to(input_dtype)
66
+
67
+ return hidden_states
68
+
69
+
70
+ class AdaLayerNormContinuous(nn.Module):
71
+ def __init__(
72
+ self,
73
+ embedding_dim: int,
74
+ conditioning_embedding_dim: int,
75
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
76
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
77
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
78
+ # However, this is how it was implemented in the original code, and it's rather likely you should
79
+ # set `elementwise_affine` to False.
80
+ elementwise_affine=True,
81
+ eps=1e-5,
82
+ bias=True,
83
+ norm_type="layer_norm",
84
+ ):
85
+ super().__init__()
86
+ self.silu = nn.SiLU()
87
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
88
+ if norm_type == "layer_norm":
89
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
90
+ elif norm_type == "rms_norm":
91
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
92
+ else:
93
+ raise ValueError(f"unknown norm_type {norm_type}")
94
+
95
+ def forward_with_pad(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
96
+ assert hidden_length is not None
97
+
98
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
99
+ batch_emb = torch.zeros_like(x).repeat(1, 1, 2)
100
+
101
+ i_sum = 0
102
+ num_stages = len(hidden_length)
103
+ for i_p, length in enumerate(hidden_length):
104
+ batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
105
+ i_sum += length
106
+
107
+ batch_scale, batch_shift = torch.chunk(batch_emb, 2, dim=2)
108
+ x = self.norm(x) * (1 + batch_scale) + batch_shift
109
+ return x
110
+
111
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
112
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
113
+ if hidden_length is not None:
114
+ return self.forward_with_pad(x, conditioning_embedding, hidden_length)
115
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
116
+ scale, shift = torch.chunk(emb, 2, dim=1)
117
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
118
+ return x
119
+
120
+
121
+ class AdaLayerNormZero(nn.Module):
122
+ r"""
123
+ Norm layer adaptive layer norm zero (adaLN-Zero).
124
+
125
+ Parameters:
126
+ embedding_dim (`int`): The size of each embedding vector.
127
+ num_embeddings (`int`): The size of the embeddings dictionary.
128
+ """
129
+
130
+ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
131
+ super().__init__()
132
+ self.emb = None
133
+ self.silu = nn.SiLU()
134
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
135
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
136
+
137
+ def forward_with_pad(
138
+ self,
139
+ x: torch.Tensor,
140
+ timestep: Optional[torch.Tensor] = None,
141
+ class_labels: Optional[torch.LongTensor] = None,
142
+ hidden_dtype: Optional[torch.dtype] = None,
143
+ emb: Optional[torch.Tensor] = None,
144
+ hidden_length: Optional[torch.Tensor] = None,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
146
+ # x: [bs, seq_len, dim]
147
+ if self.emb is not None:
148
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
149
+
150
+ emb = self.linear(self.silu(emb))
151
+ batch_emb = torch.zeros_like(x).repeat(1, 1, 6)
152
+
153
+ i_sum = 0
154
+ num_stages = len(hidden_length)
155
+ for i_p, length in enumerate(hidden_length):
156
+ batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
157
+ i_sum += length
158
+
159
+ batch_shift_msa, batch_scale_msa, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp = batch_emb.chunk(6, dim=2)
160
+ x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa
161
+ return x, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp
162
+
163
+ def forward(
164
+ self,
165
+ x: torch.Tensor,
166
+ timestep: Optional[torch.Tensor] = None,
167
+ class_labels: Optional[torch.LongTensor] = None,
168
+ hidden_dtype: Optional[torch.dtype] = None,
169
+ emb: Optional[torch.Tensor] = None,
170
+ hidden_length: Optional[torch.Tensor] = None,
171
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
172
+ if hidden_length is not None:
173
+ return self.forward_with_pad(x, timestep, class_labels, hidden_dtype, emb, hidden_length)
174
+ if self.emb is not None:
175
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
176
+ emb = self.linear(self.silu(emb))
177
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
178
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
179
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
pyramid_dit/modeling_pyramid_mmdit.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ import torch.nn.functional as F
5
+
6
+ from einops import rearrange
7
+ from diffusers.utils.torch_utils import randn_tensor
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.utils import is_torch_version
11
+ from typing import Any, Callable, Dict, List, Optional, Union
12
+ from tqdm import tqdm
13
+
14
+ from .modeling_embedding import PatchEmbed3D, CombinedTimestepConditionEmbeddings
15
+ from .modeling_normalization import AdaLayerNormContinuous
16
+ from .modeling_mmdit_block import JointTransformerBlock
17
+
18
+ from trainer_misc import (
19
+ is_sequence_parallel_initialized,
20
+ get_sequence_parallel_group,
21
+ get_sequence_parallel_world_size,
22
+ get_sequence_parallel_rank,
23
+ all_to_all,
24
+ )
25
+
26
+ from IPython import embed
27
+
28
+
29
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
30
+ assert dim % 2 == 0, "The dimension must be even."
31
+
32
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
33
+ omega = 1.0 / (theta**scale)
34
+
35
+ batch_size, seq_length = pos.shape
36
+ out = torch.einsum("...n,d->...nd", pos, omega)
37
+ cos_out = torch.cos(out)
38
+ sin_out = torch.sin(out)
39
+
40
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
41
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
42
+ return out.float()
43
+
44
+
45
+ class EmbedNDRoPE(nn.Module):
46
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
47
+ super().__init__()
48
+ self.dim = dim
49
+ self.theta = theta
50
+ self.axes_dim = axes_dim
51
+
52
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
53
+ n_axes = ids.shape[-1]
54
+ emb = torch.cat(
55
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
56
+ dim=-3,
57
+ )
58
+ return emb.unsqueeze(2)
59
+
60
+
61
+ class PyramidDiffusionMMDiT(ModelMixin, ConfigMixin):
62
+ _supports_gradient_checkpointing = True
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ sample_size: int = 128,
68
+ patch_size: int = 2,
69
+ in_channels: int = 16,
70
+ num_layers: int = 24,
71
+ attention_head_dim: int = 64,
72
+ num_attention_heads: int = 24,
73
+ caption_projection_dim: int = 1152,
74
+ pooled_projection_dim: int = 2048,
75
+ pos_embed_max_size: int = 192,
76
+ max_num_frames: int = 200,
77
+ qk_norm: str = 'rms_norm',
78
+ pos_embed_type: str = 'rope',
79
+ temp_pos_embed_type: str = 'sincos',
80
+ joint_attention_dim: int = 4096,
81
+ use_gradient_checkpointing: bool = False,
82
+ use_flash_attn: bool = True,
83
+ use_temporal_causal: bool = False,
84
+ use_t5_mask: bool = False,
85
+ add_temp_pos_embed: bool = False,
86
+ interp_condition_pos: bool = False,
87
+ ):
88
+ super().__init__()
89
+
90
+ self.out_channels = in_channels
91
+ self.inner_dim = num_attention_heads * attention_head_dim
92
+ assert temp_pos_embed_type in ['rope', 'sincos']
93
+
94
+ # The input latent embeder, using the name pos_embed to remain the same with SD#
95
+ self.pos_embed = PatchEmbed3D(
96
+ height=sample_size,
97
+ width=sample_size,
98
+ patch_size=patch_size,
99
+ in_channels=in_channels,
100
+ embed_dim=self.inner_dim,
101
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
102
+ max_num_frames=max_num_frames,
103
+ pos_embed_type=pos_embed_type,
104
+ temp_pos_embed_type=temp_pos_embed_type,
105
+ add_temp_pos_embed=add_temp_pos_embed,
106
+ interp_condition_pos=interp_condition_pos,
107
+ )
108
+
109
+ # The RoPE EMbedding
110
+ if pos_embed_type == 'rope':
111
+ self.rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[16, 24, 24])
112
+ else:
113
+ self.rope_embed = None
114
+
115
+ if temp_pos_embed_type == 'rope':
116
+ self.temp_rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[attention_head_dim])
117
+ else:
118
+ self.temp_rope_embed = None
119
+
120
+ self.time_text_embed = CombinedTimestepConditionEmbeddings(
121
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim,
122
+ )
123
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
124
+
125
+ self.transformer_blocks = nn.ModuleList(
126
+ [
127
+ JointTransformerBlock(
128
+ dim=self.inner_dim,
129
+ num_attention_heads=num_attention_heads,
130
+ attention_head_dim=self.inner_dim,
131
+ qk_norm=qk_norm,
132
+ context_pre_only=i == num_layers - 1,
133
+ use_flash_attn=use_flash_attn,
134
+ )
135
+ for i in range(num_layers)
136
+ ]
137
+ )
138
+
139
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
140
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
141
+ self.gradient_checkpointing = use_gradient_checkpointing
142
+ self.patch_size = patch_size
143
+ self.use_flash_attn = use_flash_attn
144
+ self.use_temporal_causal = use_temporal_causal
145
+ self.pos_embed_type = pos_embed_type
146
+ self.temp_pos_embed_type = temp_pos_embed_type
147
+ self.add_temp_pos_embed = add_temp_pos_embed
148
+
149
+ if self.use_temporal_causal:
150
+ print("Using temporal causal attention")
151
+ assert self.use_flash_attn is False, "The flash attention does not support temporal causal"
152
+
153
+ if interp_condition_pos:
154
+ print("We interp the position embedding of condition latents")
155
+
156
+ # init weights
157
+ self.initialize_weights()
158
+
159
+ def initialize_weights(self):
160
+ # Initialize transformer layers:
161
+ def _basic_init(module):
162
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
163
+ torch.nn.init.xavier_uniform_(module.weight)
164
+ if module.bias is not None:
165
+ nn.init.constant_(module.bias, 0)
166
+ self.apply(_basic_init)
167
+
168
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
169
+ w = self.pos_embed.proj.weight.data
170
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
171
+ nn.init.constant_(self.pos_embed.proj.bias, 0)
172
+
173
+ # Initialize all the conditioning to normal init
174
+ nn.init.normal_(self.time_text_embed.timestep_embedder.linear_1.weight, std=0.02)
175
+ nn.init.normal_(self.time_text_embed.timestep_embedder.linear_2.weight, std=0.02)
176
+ nn.init.normal_(self.time_text_embed.text_embedder.linear_1.weight, std=0.02)
177
+ nn.init.normal_(self.time_text_embed.text_embedder.linear_2.weight, std=0.02)
178
+ nn.init.normal_(self.context_embedder.weight, std=0.02)
179
+
180
+ # Zero-out adaLN modulation layers in DiT blocks:
181
+ for block in self.transformer_blocks:
182
+ nn.init.constant_(block.norm1.linear.weight, 0)
183
+ nn.init.constant_(block.norm1.linear.bias, 0)
184
+ nn.init.constant_(block.norm1_context.linear.weight, 0)
185
+ nn.init.constant_(block.norm1_context.linear.bias, 0)
186
+
187
+ # Zero-out output layers:
188
+ nn.init.constant_(self.norm_out.linear.weight, 0)
189
+ nn.init.constant_(self.norm_out.linear.bias, 0)
190
+ nn.init.constant_(self.proj_out.weight, 0)
191
+ nn.init.constant_(self.proj_out.bias, 0)
192
+
193
+ @torch.no_grad()
194
+ def _prepare_latent_image_ids(self, batch_size, temp, height, width, device):
195
+ latent_image_ids = torch.zeros(temp, height, width, 3)
196
+ latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None]
197
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[None, :, None]
198
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, None, :]
199
+
200
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
201
+ latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')
202
+ return latent_image_ids.to(device=device)
203
+
204
+ @torch.no_grad()
205
+ def _prepare_pyramid_latent_image_ids(self, batch_size, temp_list, height_list, width_list, device):
206
+ base_width = width_list[-1]; base_height = height_list[-1]
207
+ assert base_width == max(width_list)
208
+ assert base_height == max(height_list)
209
+
210
+ image_ids_list = []
211
+ for temp, height, width in zip(temp_list, height_list, width_list):
212
+ latent_image_ids = torch.zeros(temp, height, width, 3)
213
+
214
+ if height != base_height:
215
+ height_pos = F.interpolate(torch.arange(base_height)[None, None, :].float(), height, mode='linear').squeeze(0, 1)
216
+ else:
217
+ height_pos = torch.arange(base_height).float()
218
+ if width != base_width:
219
+ width_pos = F.interpolate(torch.arange(base_width)[None, None, :].float(), width, mode='linear').squeeze(0, 1)
220
+ else:
221
+ width_pos = torch.arange(base_width).float()
222
+
223
+ latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None]
224
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + height_pos[None, :, None]
225
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + width_pos[None, None, :]
226
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
227
+ latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c').to(device)
228
+ image_ids_list.append(latent_image_ids)
229
+
230
+ return image_ids_list
231
+
232
+ @torch.no_grad()
233
+ def _prepare_temporal_rope_ids(self, batch_size, temp, height, width, device, start_time_stamp=0):
234
+ latent_image_ids = torch.zeros(temp, height, width, 1)
235
+ latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(start_time_stamp, start_time_stamp + temp)[:, None, None]
236
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
237
+ latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')
238
+ return latent_image_ids.to(device=device)
239
+
240
+ @torch.no_grad()
241
+ def _prepare_pyramid_temporal_rope_ids(self, sample, batch_size, device):
242
+ image_ids_list = []
243
+
244
+ for i_b, sample_ in enumerate(sample):
245
+ if not isinstance(sample_, list):
246
+ sample_ = [sample_]
247
+
248
+ cur_image_ids = []
249
+ start_time_stamp = 0
250
+
251
+ for clip_ in sample_:
252
+ _, _, temp, height, width = clip_.shape
253
+ height = height // self.patch_size
254
+ width = width // self.patch_size
255
+ cur_image_ids.append(self._prepare_temporal_rope_ids(batch_size, temp, height, width, device, start_time_stamp=start_time_stamp))
256
+ start_time_stamp += temp
257
+
258
+ cur_image_ids = torch.cat(cur_image_ids, dim=1)
259
+ image_ids_list.append(cur_image_ids)
260
+
261
+ return image_ids_list
262
+
263
+ def merge_input(self, sample, encoder_hidden_length, encoder_attention_mask):
264
+ """
265
+ Merge the input video with different resolutions into one sequence
266
+ Sample: From low resolution to high resolution
267
+ """
268
+ if isinstance(sample[0], list):
269
+ device = sample[0][-1].device
270
+ pad_batch_size = sample[0][-1].shape[0]
271
+ else:
272
+ device = sample[0].device
273
+ pad_batch_size = sample[0].shape[0]
274
+
275
+ num_stages = len(sample)
276
+ height_list = [];width_list = [];temp_list = []
277
+ trainable_token_list = []
278
+
279
+ for i_b, sample_ in enumerate(sample):
280
+ if isinstance(sample_, list):
281
+ sample_ = sample_[-1]
282
+ _, _, temp, height, width = sample_.shape
283
+ height = height // self.patch_size
284
+ width = width // self.patch_size
285
+ temp_list.append(temp)
286
+ height_list.append(height)
287
+ width_list.append(width)
288
+ trainable_token_list.append(height * width * temp)
289
+
290
+ # prepare the RoPE embedding if needed
291
+ if self.pos_embed_type == 'rope':
292
+ # TODO: support the 3D Rope for video
293
+ raise NotImplementedError("Not compatible with video generation now")
294
+ text_ids = torch.zeros(pad_batch_size, encoder_hidden_length, 3).to(device=device)
295
+ image_ids_list = self._prepare_pyramid_latent_image_ids(pad_batch_size, temp_list, height_list, width_list, device)
296
+ input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
297
+ image_rotary_emb = [self.rope_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2]
298
+ else:
299
+ if self.temp_pos_embed_type == 'rope' and self.add_temp_pos_embed:
300
+ image_ids_list = self._prepare_pyramid_temporal_rope_ids(sample, pad_batch_size, device)
301
+ text_ids = torch.zeros(pad_batch_size, encoder_attention_mask.shape[1], 1).to(device=device)
302
+ input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
303
+ image_rotary_emb = [self.temp_rope_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2]
304
+
305
+ if is_sequence_parallel_initialized():
306
+ sp_group = get_sequence_parallel_group()
307
+ sp_group_size = get_sequence_parallel_world_size()
308
+ image_rotary_emb = [all_to_all(x_.repeat(1, 1, sp_group_size, 1, 1, 1), sp_group, sp_group_size, scatter_dim=2, gather_dim=0) for x_ in image_rotary_emb]
309
+ input_ids_list = [all_to_all(input_ids.repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0) for input_ids in input_ids_list]
310
+
311
+ else:
312
+ image_rotary_emb = None
313
+
314
+ hidden_states = self.pos_embed(sample) # hidden states is a list of [b c t h w] b = real_b // num_stages
315
+ hidden_length = []
316
+
317
+ for i_b in range(num_stages):
318
+ hidden_length.append(hidden_states[i_b].shape[1])
319
+
320
+ # prepare the attention mask
321
+ if self.use_flash_attn:
322
+ attention_mask = None
323
+ indices_list = []
324
+ for i_p, length in enumerate(hidden_length):
325
+ pad_attention_mask = torch.ones((pad_batch_size, length), dtype=encoder_attention_mask.dtype).to(device)
326
+ pad_attention_mask = torch.cat([encoder_attention_mask[i_p::num_stages], pad_attention_mask], dim=1)
327
+
328
+ if is_sequence_parallel_initialized():
329
+ sp_group = get_sequence_parallel_group()
330
+ sp_group_size = get_sequence_parallel_world_size()
331
+ pad_attention_mask = all_to_all(pad_attention_mask.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0)
332
+ pad_attention_mask = pad_attention_mask.squeeze(2)
333
+
334
+ seqlens_in_batch = pad_attention_mask.sum(dim=-1, dtype=torch.int32)
335
+ indices = torch.nonzero(pad_attention_mask.flatten(), as_tuple=False).flatten()
336
+
337
+ indices_list.append(
338
+ {
339
+ 'indices': indices,
340
+ 'seqlens_in_batch': seqlens_in_batch,
341
+ }
342
+ )
343
+ encoder_attention_mask = indices_list
344
+ else:
345
+ assert encoder_attention_mask.shape[1] == encoder_hidden_length
346
+ real_batch_size = encoder_attention_mask.shape[0]
347
+ # prepare text ids
348
+ text_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, encoder_hidden_length)
349
+ text_ids = text_ids.to(device)
350
+ text_ids[encoder_attention_mask == 0] = 0
351
+
352
+ # prepare image ids
353
+ image_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, max(hidden_length))
354
+ image_ids = image_ids.to(device)
355
+ image_ids_list = []
356
+ for i_p, length in enumerate(hidden_length):
357
+ image_ids_list.append(image_ids[i_p::num_stages][:, :length])
358
+
359
+ if is_sequence_parallel_initialized():
360
+ sp_group = get_sequence_parallel_group()
361
+ sp_group_size = get_sequence_parallel_world_size()
362
+ text_ids = all_to_all(text_ids.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0).squeeze(2)
363
+ image_ids_list = [all_to_all(image_ids_.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0).squeeze(2) for image_ids_ in image_ids_list]
364
+
365
+ attention_mask = []
366
+ for i_p in range(len(hidden_length)):
367
+ image_ids = image_ids_list[i_p]
368
+ token_ids = torch.cat([text_ids[i_p::num_stages], image_ids], dim=1)
369
+ stage_attention_mask = rearrange(token_ids, 'b i -> b 1 i 1') == rearrange(token_ids, 'b j -> b 1 1 j') # [bs, 1, q_len, k_len]
370
+ if self.use_temporal_causal:
371
+ input_order_ids = input_ids_list[i_p].squeeze(2)
372
+ temporal_causal_mask = rearrange(input_order_ids, 'b i -> b 1 i 1') >= rearrange(input_order_ids, 'b j -> b 1 1 j')
373
+ stage_attention_mask = stage_attention_mask & temporal_causal_mask
374
+ attention_mask.append(stage_attention_mask)
375
+
376
+ return hidden_states, hidden_length, temp_list, height_list, width_list, trainable_token_list, encoder_attention_mask, attention_mask, image_rotary_emb
377
+
378
+ def split_output(self, batch_hidden_states, hidden_length, temps, heights, widths, trainable_token_list):
379
+ # To split the hidden states
380
+ batch_size = batch_hidden_states.shape[0]
381
+ output_hidden_list = []
382
+ batch_hidden_states = torch.split(batch_hidden_states, hidden_length, dim=1)
383
+
384
+ if is_sequence_parallel_initialized():
385
+ sp_group_size = get_sequence_parallel_world_size()
386
+ batch_size = batch_size // sp_group_size
387
+
388
+ for i_p, length in enumerate(hidden_length):
389
+ width, height, temp = widths[i_p], heights[i_p], temps[i_p]
390
+ trainable_token_num = trainable_token_list[i_p]
391
+ hidden_states = batch_hidden_states[i_p]
392
+
393
+ if is_sequence_parallel_initialized():
394
+ sp_group = get_sequence_parallel_group()
395
+ sp_group_size = get_sequence_parallel_world_size()
396
+ hidden_states = all_to_all(hidden_states, sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
397
+
398
+ # only the trainable token are taking part in loss computation
399
+ hidden_states = hidden_states[:, -trainable_token_num:]
400
+
401
+ # unpatchify
402
+ hidden_states = hidden_states.reshape(
403
+ shape=(batch_size, temp, height, width, self.patch_size, self.patch_size, self.out_channels)
404
+ )
405
+ hidden_states = rearrange(hidden_states, "b t h w p1 p2 c -> b t (h p1) (w p2) c")
406
+ hidden_states = rearrange(hidden_states, "b t h w c -> b c t h w")
407
+ output_hidden_list.append(hidden_states)
408
+
409
+ return output_hidden_list
410
+
411
+ def forward(
412
+ self,
413
+ sample: torch.FloatTensor, # [num_stages]
414
+ encoder_hidden_states: torch.FloatTensor = None,
415
+ encoder_attention_mask: torch.FloatTensor = None,
416
+ pooled_projections: torch.FloatTensor = None,
417
+ timestep_ratio: torch.FloatTensor = None,
418
+ ):
419
+ # Get the timestep embedding
420
+ temb = self.time_text_embed(timestep_ratio, pooled_projections)
421
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
422
+ encoder_hidden_length = encoder_hidden_states.shape[1]
423
+
424
+ # Get the input sequence
425
+ hidden_states, hidden_length, temps, heights, widths, trainable_token_list, encoder_attention_mask, \
426
+ attention_mask, image_rotary_emb = self.merge_input(sample, encoder_hidden_length, encoder_attention_mask)
427
+
428
+ # split the long latents if necessary
429
+ if is_sequence_parallel_initialized():
430
+ sp_group = get_sequence_parallel_group()
431
+ sp_group_size = get_sequence_parallel_world_size()
432
+
433
+ # sync the input hidden states
434
+ batch_hidden_states = []
435
+ for i_p, hidden_states_ in enumerate(hidden_states):
436
+ assert hidden_states_.shape[1] % sp_group_size == 0, "The sequence length should be divided by sequence parallel size"
437
+ hidden_states_ = all_to_all(hidden_states_, sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
438
+ hidden_length[i_p] = hidden_length[i_p] // sp_group_size
439
+ batch_hidden_states.append(hidden_states_)
440
+
441
+ # sync the encoder hidden states
442
+ hidden_states = torch.cat(batch_hidden_states, dim=1)
443
+ encoder_hidden_states = all_to_all(encoder_hidden_states, sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
444
+ temb = all_to_all(temb.unsqueeze(1).repeat(1, sp_group_size, 1), sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
445
+ temb = temb.squeeze(1)
446
+ else:
447
+ hidden_states = torch.cat(hidden_states, dim=1)
448
+
449
+ # print(hidden_length)
450
+ for i_b, block in enumerate(self.transformer_blocks):
451
+ if self.training and self.gradient_checkpointing and (i_b >= 2):
452
+ def create_custom_forward(module):
453
+ def custom_forward(*inputs):
454
+ return module(*inputs)
455
+
456
+ return custom_forward
457
+
458
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
459
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
460
+ create_custom_forward(block),
461
+ hidden_states,
462
+ encoder_hidden_states,
463
+ encoder_attention_mask,
464
+ temb,
465
+ attention_mask,
466
+ hidden_length,
467
+ image_rotary_emb,
468
+ **ckpt_kwargs,
469
+ )
470
+
471
+ else:
472
+ encoder_hidden_states, hidden_states = block(
473
+ hidden_states=hidden_states,
474
+ encoder_hidden_states=encoder_hidden_states,
475
+ encoder_attention_mask=encoder_attention_mask,
476
+ temb=temb,
477
+ attention_mask=attention_mask,
478
+ hidden_length=hidden_length,
479
+ image_rotary_emb=image_rotary_emb,
480
+ )
481
+
482
+ hidden_states = self.norm_out(hidden_states, temb, hidden_length=hidden_length)
483
+ hidden_states = self.proj_out(hidden_states)
484
+
485
+ output = self.split_output(hidden_states, hidden_length, temps, heights, widths, trainable_token_list)
486
+
487
+ return output
pyramid_dit/modeling_text_encoder.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+
5
+ from transformers import (
6
+ CLIPTextModelWithProjection,
7
+ CLIPTokenizer,
8
+ T5EncoderModel,
9
+ T5TokenizerFast,
10
+ )
11
+
12
+ from typing import Any, Callable, Dict, List, Optional, Union
13
+
14
+
15
+ class SD3TextEncoderWithMask(nn.Module):
16
+ def __init__(self, model_path, torch_dtype):
17
+ super().__init__()
18
+ # CLIP-L
19
+ self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
20
+ self.tokenizer_max_length = self.tokenizer.model_max_length
21
+ self.text_encoder = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder'), torch_dtype=torch_dtype)
22
+
23
+ # CLIP-G
24
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
25
+ self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder_2'), torch_dtype=torch_dtype)
26
+
27
+ # T5
28
+ self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
29
+ self.text_encoder_3 = T5EncoderModel.from_pretrained(os.path.join(model_path, 'text_encoder_3'), torch_dtype=torch_dtype)
30
+
31
+ self._freeze()
32
+
33
+ def _freeze(self):
34
+ for param in self.parameters():
35
+ param.requires_grad = False
36
+
37
+ def _get_t5_prompt_embeds(
38
+ self,
39
+ prompt: Union[str, List[str]] = None,
40
+ num_images_per_prompt: int = 1,
41
+ device: Optional[torch.device] = None,
42
+ max_sequence_length: int = 128,
43
+ ):
44
+ prompt = [prompt] if isinstance(prompt, str) else prompt
45
+ batch_size = len(prompt)
46
+
47
+ text_inputs = self.tokenizer_3(
48
+ prompt,
49
+ padding="max_length",
50
+ max_length=max_sequence_length,
51
+ truncation=True,
52
+ add_special_tokens=True,
53
+ return_tensors="pt",
54
+ )
55
+ text_input_ids = text_inputs.input_ids
56
+ prompt_attention_mask = text_inputs.attention_mask
57
+ prompt_attention_mask = prompt_attention_mask.to(device)
58
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
59
+ dtype = self.text_encoder_3.dtype
60
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
61
+
62
+ _, seq_len, _ = prompt_embeds.shape
63
+
64
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
65
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
66
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
67
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
68
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
69
+
70
+ return prompt_embeds, prompt_attention_mask
71
+
72
+ def _get_clip_prompt_embeds(
73
+ self,
74
+ prompt: Union[str, List[str]],
75
+ num_images_per_prompt: int = 1,
76
+ device: Optional[torch.device] = None,
77
+ clip_skip: Optional[int] = None,
78
+ clip_model_index: int = 0,
79
+ ):
80
+
81
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
82
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
83
+
84
+ tokenizer = clip_tokenizers[clip_model_index]
85
+ text_encoder = clip_text_encoders[clip_model_index]
86
+
87
+ batch_size = len(prompt)
88
+
89
+ text_inputs = tokenizer(
90
+ prompt,
91
+ padding="max_length",
92
+ max_length=self.tokenizer_max_length,
93
+ truncation=True,
94
+ return_tensors="pt",
95
+ )
96
+
97
+ text_input_ids = text_inputs.input_ids
98
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
99
+ pooled_prompt_embeds = prompt_embeds[0]
100
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
101
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
102
+
103
+ return pooled_prompt_embeds
104
+
105
+ def encode_prompt(self,
106
+ prompt,
107
+ num_images_per_prompt=1,
108
+ clip_skip: Optional[int] = None,
109
+ device=None,
110
+ ):
111
+ prompt = [prompt] if isinstance(prompt, str) else prompt
112
+
113
+ pooled_prompt_embed = self._get_clip_prompt_embeds(
114
+ prompt=prompt,
115
+ device=device,
116
+ num_images_per_prompt=num_images_per_prompt,
117
+ clip_skip=clip_skip,
118
+ clip_model_index=0,
119
+ )
120
+ pooled_prompt_2_embed = self._get_clip_prompt_embeds(
121
+ prompt=prompt,
122
+ device=device,
123
+ num_images_per_prompt=num_images_per_prompt,
124
+ clip_skip=clip_skip,
125
+ clip_model_index=1,
126
+ )
127
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
128
+
129
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
130
+ prompt=prompt,
131
+ num_images_per_prompt=num_images_per_prompt,
132
+ device=device,
133
+ )
134
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
135
+
136
+ def forward(self, input_prompts, device):
137
+ with torch.no_grad():
138
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, 1, clip_skip=None, device=device)
139
+
140
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
pyramid_dit/pyramid_dit_for_video_gen_pipeline.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import sys
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from collections import OrderedDict
8
+ from einops import rearrange
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ import numpy as np
11
+ import math
12
+ import random
13
+ import PIL
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+ from torchvision import transforms
17
+ from copy import deepcopy
18
+ from typing import Any, Callable, Dict, List, Optional, Union
19
+ from accelerate import Accelerator
20
+ from diffusion_schedulers import PyramidFlowMatchEulerDiscreteScheduler
21
+ from video_vae.modeling_causal_vae import CausalVideoVAE
22
+
23
+ from trainer_misc import (
24
+ all_to_all,
25
+ is_sequence_parallel_initialized,
26
+ get_sequence_parallel_group,
27
+ get_sequence_parallel_group_rank,
28
+ get_sequence_parallel_rank,
29
+ get_sequence_parallel_world_size,
30
+ get_rank,
31
+ )
32
+
33
+ from .modeling_pyramid_mmdit import PyramidDiffusionMMDiT
34
+ from .modeling_text_encoder import SD3TextEncoderWithMask
35
+
36
+
37
+ def compute_density_for_timestep_sampling(
38
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
39
+ ):
40
+ if weighting_scheme == "logit_normal":
41
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
42
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
43
+ u = torch.nn.functional.sigmoid(u)
44
+ elif weighting_scheme == "mode":
45
+ u = torch.rand(size=(batch_size,), device="cpu")
46
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
47
+ else:
48
+ u = torch.rand(size=(batch_size,), device="cpu")
49
+ return u
50
+
51
+
52
+ class PyramidDiTForVideoGeneration:
53
+ """
54
+ The pyramid dit for both image and video generation, The running class wrapper
55
+ This class is mainly for fixed unit implementation: 1 + n + n + n
56
+ """
57
+ def __init__(self, model_path, model_dtype='bf16', use_gradient_checkpointing=False, return_log=True,
58
+ model_variant="diffusion_transformer_768p", timestep_shift=1.0, stage_range=[0, 1/3, 2/3, 1],
59
+ sample_ratios=[1, 1, 1], scheduler_gamma=1/3, use_mixed_training=False, use_flash_attn=False,
60
+ load_text_encoder=True, load_vae=True, max_temporal_length=31, frame_per_unit=1, use_temporal_causal=True,
61
+ corrupt_ratio=1/3, interp_condition_pos=True, stages=[1, 2, 4], **kwargs,
62
+ ):
63
+ super().__init__()
64
+
65
+ if model_dtype == 'bf16':
66
+ torch_dtype = torch.bfloat16
67
+ elif model_dtype == 'fp16':
68
+ torch_dtype = torch.float16
69
+ else:
70
+ torch_dtype = torch.float32
71
+
72
+ self.stages = stages
73
+ self.sample_ratios = sample_ratios
74
+ self.corrupt_ratio = corrupt_ratio
75
+
76
+ dit_path = os.path.join(model_path, model_variant)
77
+
78
+ # The dit
79
+ if use_mixed_training:
80
+ print("using mixed precision training, do not explicitly casting models")
81
+ self.dit = PyramidDiffusionMMDiT.from_pretrained(
82
+ dit_path, use_gradient_checkpointing=use_gradient_checkpointing,
83
+ use_flash_attn=use_flash_attn, use_t5_mask=True,
84
+ add_temp_pos_embed=True, temp_pos_embed_type='rope',
85
+ use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos,
86
+ )
87
+ else:
88
+ print("using half precision")
89
+ self.dit = PyramidDiffusionMMDiT.from_pretrained(
90
+ dit_path, torch_dtype=torch_dtype,
91
+ use_gradient_checkpointing=use_gradient_checkpointing,
92
+ use_flash_attn=use_flash_attn, use_t5_mask=True,
93
+ add_temp_pos_embed=True, temp_pos_embed_type='rope',
94
+ use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos,
95
+ )
96
+
97
+ # The text encoder
98
+ if load_text_encoder:
99
+ self.text_encoder = SD3TextEncoderWithMask(model_path, torch_dtype=torch_dtype)
100
+ else:
101
+ self.text_encoder = None
102
+
103
+ # The base video vae decoder
104
+ if load_vae:
105
+ self.vae = CausalVideoVAE.from_pretrained(os.path.join(model_path, 'causal_video_vae'), torch_dtype=torch_dtype, interpolate=False)
106
+ # Freeze vae
107
+ for parameter in self.vae.parameters():
108
+ parameter.requires_grad = False
109
+ else:
110
+ self.vae = None
111
+
112
+ # For the image latent
113
+ self.vae_shift_factor = 0.1490
114
+ self.vae_scale_factor = 1 / 1.8415
115
+
116
+ # For the video latent
117
+ self.vae_video_shift_factor = -0.2343
118
+ self.vae_video_scale_factor = 1 / 3.0986
119
+
120
+ self.downsample = 8
121
+
122
+ # Configure the video training hyper-parameters
123
+ # The video sequence: one frame + N * unit
124
+ self.frame_per_unit = frame_per_unit
125
+ self.max_temporal_length = max_temporal_length
126
+ assert (max_temporal_length - 1) % frame_per_unit == 0, "The frame number should be divided by the frame number per unit"
127
+ self.num_units_per_video = 1 + ((max_temporal_length - 1) // frame_per_unit) + int(sum(sample_ratios))
128
+
129
+ self.scheduler = PyramidFlowMatchEulerDiscreteScheduler(
130
+ shift=timestep_shift, stages=len(self.stages),
131
+ stage_range=stage_range, gamma=scheduler_gamma,
132
+ )
133
+ print(f"The start sigmas and end sigmas of each stage is Start: {self.scheduler.start_sigmas}, End: {self.scheduler.end_sigmas}, Ori_start: {self.scheduler.ori_start_sigmas}")
134
+
135
+ self.cfg_rate = 0.1
136
+ self.return_log = return_log
137
+ self.use_flash_attn = use_flash_attn
138
+
139
+ def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs):
140
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
141
+ dit_checkpoint = OrderedDict()
142
+ for key in checkpoint:
143
+ if key.startswith('vae') or key.startswith('text_encoder'):
144
+ continue
145
+ if key.startswith('dit'):
146
+ new_key = key.split('.')
147
+ new_key = '.'.join(new_key[1:])
148
+ dit_checkpoint[new_key] = checkpoint[key]
149
+ else:
150
+ dit_checkpoint[key] = checkpoint[key]
151
+
152
+ load_result = self.dit.load_state_dict(dit_checkpoint, strict=True)
153
+ print(f"Load checkpoint from {checkpoint_path}, load result: {load_result}")
154
+
155
+ def load_vae_checkpoint(self, vae_checkpoint_path, model_key='model'):
156
+ checkpoint = torch.load(vae_checkpoint_path, map_location='cpu')
157
+ checkpoint = checkpoint[model_key]
158
+ loaded_checkpoint = OrderedDict()
159
+
160
+ for key in checkpoint.keys():
161
+ if key.startswith('vae.'):
162
+ new_key = key.split('.')
163
+ new_key = '.'.join(new_key[1:])
164
+ loaded_checkpoint[new_key] = checkpoint[key]
165
+
166
+ load_result = self.vae.load_state_dict(loaded_checkpoint)
167
+ print(f"Load the VAE from {vae_checkpoint_path}, load result: {load_result}")
168
+
169
+ @torch.no_grad()
170
+ def get_pyramid_latent(self, x, stage_num):
171
+ # x is the origin vae latent
172
+ vae_latent_list = []
173
+ vae_latent_list.append(x)
174
+
175
+ temp, height, width = x.shape[-3], x.shape[-2], x.shape[-1]
176
+ for _ in range(stage_num):
177
+ height //= 2
178
+ width //= 2
179
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
180
+ x = torch.nn.functional.interpolate(x, size=(height, width), mode='bilinear')
181
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=temp)
182
+ vae_latent_list.append(x)
183
+
184
+ vae_latent_list = list(reversed(vae_latent_list))
185
+ return vae_latent_list
186
+
187
+ def prepare_latents(
188
+ self,
189
+ batch_size,
190
+ num_channels_latents,
191
+ temp,
192
+ height,
193
+ width,
194
+ dtype,
195
+ device,
196
+ generator,
197
+ ):
198
+ shape = (
199
+ batch_size,
200
+ num_channels_latents,
201
+ int(temp),
202
+ int(height) // self.downsample,
203
+ int(width) // self.downsample,
204
+ )
205
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
206
+ return latents
207
+
208
+ def sample_block_noise(self, bs, ch, temp, height, width):
209
+ gamma = self.scheduler.config.gamma
210
+ dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma)
211
+ block_number = bs * ch * temp * (height // 2) * (width // 2)
212
+ noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4]
213
+ noise = rearrange(noise, '(b c t h w) (p q) -> b c t (h p) (w q)',b=bs,c=ch,t=temp,h=height//2,w=width//2,p=2,q=2)
214
+ return noise
215
+
216
+ @torch.no_grad()
217
+ def generate_one_unit(
218
+ self,
219
+ latents,
220
+ past_conditions, # List of past conditions, contains the conditions of each stage
221
+ prompt_embeds,
222
+ prompt_attention_mask,
223
+ pooled_prompt_embeds,
224
+ num_inference_steps,
225
+ height,
226
+ width,
227
+ temp,
228
+ device,
229
+ dtype,
230
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
231
+ is_first_frame: bool = False,
232
+ ):
233
+ stages = self.stages
234
+ intermed_latents = []
235
+
236
+ for i_s in range(len(stages)):
237
+ self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device)
238
+ timesteps = self.scheduler.timesteps
239
+
240
+ if i_s > 0:
241
+ height *= 2; width *= 2
242
+ latents = rearrange(latents, 'b c t h w -> (b t) c h w')
243
+ latents = F.interpolate(latents, size=(height, width), mode='nearest')
244
+ latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
245
+ # Fix the stage
246
+ ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal
247
+ gamma = self.scheduler.config.gamma
248
+ alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)
249
+ beta = alpha * (1 - ori_sigma) / math.sqrt(gamma)
250
+
251
+ bs, ch, temp, height, width = latents.shape
252
+ noise = self.sample_block_noise(bs, ch, temp, height, width)
253
+ noise = noise.to(device=device, dtype=dtype)
254
+ latents = alpha * latents + beta * noise # To fix the block artifact
255
+
256
+ for idx, t in enumerate(timesteps):
257
+ # expand the latents if we are doing classifier free guidance
258
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
259
+
260
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
261
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
262
+
263
+ latent_model_input = past_conditions[i_s] + [latent_model_input]
264
+
265
+ noise_pred = self.dit(
266
+ sample=[latent_model_input],
267
+ timestep_ratio=timestep,
268
+ encoder_hidden_states=prompt_embeds,
269
+ encoder_attention_mask=prompt_attention_mask,
270
+ pooled_projections=pooled_prompt_embeds,
271
+ )
272
+
273
+ noise_pred = noise_pred[0]
274
+
275
+ # perform guidance
276
+ if self.do_classifier_free_guidance:
277
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
278
+ if is_first_frame:
279
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
280
+ else:
281
+ noise_pred = noise_pred_uncond + self.video_guidance_scale * (noise_pred_text - noise_pred_uncond)
282
+
283
+ # compute the previous noisy sample x_t -> x_t-1
284
+ latents = self.scheduler.step(
285
+ model_output=noise_pred,
286
+ timestep=timestep,
287
+ sample=latents,
288
+ generator=generator,
289
+ ).prev_sample
290
+
291
+ intermed_latents.append(latents)
292
+
293
+ return intermed_latents
294
+
295
+ @torch.no_grad()
296
+ def generate_i2v(
297
+ self,
298
+ prompt: Union[str, List[str]] = '',
299
+ input_image: PIL.Image = None,
300
+ temp: int = 1,
301
+ num_inference_steps: Optional[Union[int, List[int]]] = 28,
302
+ guidance_scale: float = 7.0,
303
+ video_guidance_scale: float = 4.0,
304
+ min_guidance_scale: float = 2.0,
305
+ use_linear_guidance: bool = False,
306
+ alpha: float = 0.5,
307
+ negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror",
308
+ num_images_per_prompt: Optional[int] = 1,
309
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
310
+ output_type: Optional[str] = "pil",
311
+ save_memory: bool = True,
312
+ ):
313
+ device = self.device
314
+ dtype = self.dtype
315
+
316
+ width = input_image.width
317
+ height = input_image.height
318
+
319
+ assert temp % self.frame_per_unit == 0, "The frames should be divided by frame_per unit"
320
+
321
+ if isinstance(prompt, str):
322
+ batch_size = 1
323
+ prompt = prompt + ", hyper quality, Ultra HD, 8K" # adding this prompt to improve aesthetics
324
+ else:
325
+ assert isinstance(prompt, list)
326
+ batch_size = len(prompt)
327
+ prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt]
328
+
329
+ if isinstance(num_inference_steps, int):
330
+ num_inference_steps = [num_inference_steps] * len(self.stages)
331
+
332
+ negative_prompt = negative_prompt or ""
333
+
334
+ # Get the text embeddings
335
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
336
+ negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
337
+
338
+ if use_linear_guidance:
339
+ max_guidance_scale = guidance_scale
340
+ guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp+1)]
341
+ print(guidance_scale_list)
342
+
343
+ self._guidance_scale = guidance_scale
344
+ self._video_guidance_scale = video_guidance_scale
345
+
346
+ if self.do_classifier_free_guidance:
347
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
348
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
349
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
350
+
351
+ # Create the initial random noise
352
+ num_channels_latents = self.dit.config.in_channels
353
+ latents = self.prepare_latents(
354
+ batch_size * num_images_per_prompt,
355
+ num_channels_latents,
356
+ temp,
357
+ height,
358
+ width,
359
+ prompt_embeds.dtype,
360
+ device,
361
+ generator,
362
+ )
363
+
364
+ temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
365
+
366
+ latents = rearrange(latents, 'b c t h w -> (b t) c h w')
367
+ # by defalut, we needs to start from the block noise
368
+ for _ in range(len(self.stages)-1):
369
+ height //= 2;width //= 2
370
+ latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2
371
+
372
+ latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
373
+
374
+ num_units = temp // self.frame_per_unit
375
+ stages = self.stages
376
+
377
+ # encode the image latents
378
+ image_transform = transforms.Compose([
379
+ transforms.ToTensor(),
380
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
381
+ ])
382
+ input_image_tensor = image_transform(input_image).unsqueeze(0).unsqueeze(2) # [b c 1 h w]
383
+ input_image_latent = (self.vae.encode(input_image_tensor.to(device)).latent_dist.sample() - self.vae_shift_factor) * self.vae_scale_factor # [b c 1 h w]
384
+
385
+ generated_latents_list = [input_image_latent] # The generated results
386
+ last_generated_latents = input_image_latent
387
+
388
+ for unit_index in tqdm(range(1, num_units + 1)):
389
+ if use_linear_guidance:
390
+ self._guidance_scale = guidance_scale_list[unit_index]
391
+ self._video_guidance_scale = guidance_scale_list[unit_index]
392
+
393
+ # prepare the condition latents
394
+ past_condition_latents = []
395
+ clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
396
+
397
+ for i_s in range(len(stages)):
398
+ last_cond_latent = clean_latents_list[i_s][:,:,-self.frame_per_unit:]
399
+
400
+ stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
401
+
402
+ # pad the past clean latents
403
+ cur_unit_num = unit_index
404
+ cur_stage = i_s
405
+ cur_unit_ptx = 1
406
+
407
+ while cur_unit_ptx < cur_unit_num:
408
+ cur_stage = max(cur_stage - 1, 0)
409
+ if cur_stage == 0:
410
+ break
411
+ cur_unit_ptx += 1
412
+ cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
413
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
414
+
415
+ if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
416
+ cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
417
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
418
+
419
+ stage_input = list(reversed(stage_input))
420
+ past_condition_latents.append(stage_input)
421
+
422
+ intermed_latents = self.generate_one_unit(
423
+ latents[:,:,(unit_index - 1) * self.frame_per_unit:unit_index * self.frame_per_unit],
424
+ past_condition_latents,
425
+ prompt_embeds,
426
+ prompt_attention_mask,
427
+ pooled_prompt_embeds,
428
+ num_inference_steps,
429
+ height,
430
+ width,
431
+ self.frame_per_unit,
432
+ device,
433
+ dtype,
434
+ generator,
435
+ is_first_frame=False,
436
+ )
437
+
438
+ generated_latents_list.append(intermed_latents[-1])
439
+ last_generated_latents = intermed_latents
440
+
441
+ generated_latents = torch.cat(generated_latents_list, dim=2)
442
+
443
+ if output_type == "latent":
444
+ image = generated_latents
445
+ else:
446
+ image = self.decode_latent(generated_latents, save_memory=save_memory)
447
+
448
+ return image
449
+
450
+ @torch.no_grad()
451
+ def generate(
452
+ self,
453
+ prompt: Union[str, List[str]] = None,
454
+ height: Optional[int] = None,
455
+ width: Optional[int] = None,
456
+ temp: int = 1,
457
+ num_inference_steps: Optional[Union[int, List[int]]] = 28,
458
+ video_num_inference_steps: Optional[Union[int, List[int]]] = 28,
459
+ guidance_scale: float = 7.0,
460
+ video_guidance_scale: float = 7.0,
461
+ min_guidance_scale: float = 2.0,
462
+ use_linear_guidance: bool = False,
463
+ alpha: float = 0.5,
464
+ negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror",
465
+ num_images_per_prompt: Optional[int] = 1,
466
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
467
+ output_type: Optional[str] = "pil",
468
+ save_memory: bool = True,
469
+ ):
470
+ device = self.device
471
+ dtype = self.dtype
472
+
473
+ assert (temp - 1) % self.frame_per_unit == 0, "The frames should be divided by frame_per unit"
474
+
475
+ if isinstance(prompt, str):
476
+ batch_size = 1
477
+ prompt = prompt + ", hyper quality, Ultra HD, 8K" # adding this prompt to improve aesthetics
478
+ else:
479
+ assert isinstance(prompt, list)
480
+ batch_size = len(prompt)
481
+ prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt]
482
+
483
+ if isinstance(num_inference_steps, int):
484
+ num_inference_steps = [num_inference_steps] * len(self.stages)
485
+
486
+ if isinstance(video_num_inference_steps, int):
487
+ video_num_inference_steps = [video_num_inference_steps] * len(self.stages)
488
+
489
+ negative_prompt = negative_prompt or ""
490
+
491
+ # Get the text embeddings
492
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
493
+ negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
494
+
495
+ if use_linear_guidance:
496
+ max_guidance_scale = guidance_scale
497
+ # guidance_scale_list = torch.linspace(max_guidance_scale, min_guidance_scale, temp).tolist()
498
+ guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp)]
499
+ print(guidance_scale_list)
500
+
501
+ self._guidance_scale = guidance_scale
502
+ self._video_guidance_scale = video_guidance_scale
503
+
504
+ if self.do_classifier_free_guidance:
505
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
506
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
507
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
508
+
509
+ # Create the initial random noise
510
+ num_channels_latents = self.dit.config.in_channels
511
+ latents = self.prepare_latents(
512
+ batch_size * num_images_per_prompt,
513
+ num_channels_latents,
514
+ temp,
515
+ height,
516
+ width,
517
+ prompt_embeds.dtype,
518
+ device,
519
+ generator,
520
+ )
521
+
522
+ temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
523
+
524
+ latents = rearrange(latents, 'b c t h w -> (b t) c h w')
525
+ # by defalut, we needs to start from the block noise
526
+ for _ in range(len(self.stages)-1):
527
+ height //= 2;width //= 2
528
+ latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2
529
+
530
+ latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
531
+
532
+ num_units = 1 + (temp - 1) // self.frame_per_unit
533
+ stages = self.stages
534
+
535
+ generated_latents_list = [] # The generated results
536
+ last_generated_latents = None
537
+
538
+ for unit_index in tqdm(range(num_units)):
539
+ if use_linear_guidance:
540
+ self._guidance_scale = guidance_scale_list[unit_index]
541
+ self._video_guidance_scale = guidance_scale_list[unit_index]
542
+
543
+ if unit_index == 0:
544
+ past_condition_latents = [[] for _ in range(len(stages))]
545
+ intermed_latents = self.generate_one_unit(
546
+ latents[:,:,:1],
547
+ past_condition_latents,
548
+ prompt_embeds,
549
+ prompt_attention_mask,
550
+ pooled_prompt_embeds,
551
+ num_inference_steps,
552
+ height,
553
+ width,
554
+ 1,
555
+ device,
556
+ dtype,
557
+ generator,
558
+ is_first_frame=True,
559
+ )
560
+ else:
561
+ # prepare the condition latents
562
+ past_condition_latents = []
563
+ clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
564
+
565
+ for i_s in range(len(stages)):
566
+ last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]
567
+
568
+ stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
569
+
570
+ # pad the past clean latents
571
+ cur_unit_num = unit_index
572
+ cur_stage = i_s
573
+ cur_unit_ptx = 1
574
+
575
+ while cur_unit_ptx < cur_unit_num:
576
+ cur_stage = max(cur_stage - 1, 0)
577
+ if cur_stage == 0:
578
+ break
579
+ cur_unit_ptx += 1
580
+ cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
581
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
582
+
583
+ if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
584
+ cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
585
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
586
+
587
+ stage_input = list(reversed(stage_input))
588
+ past_condition_latents.append(stage_input)
589
+
590
+ intermed_latents = self.generate_one_unit(
591
+ latents[:,:, 1 + (unit_index - 1) * self.frame_per_unit:1 + unit_index * self.frame_per_unit],
592
+ past_condition_latents,
593
+ prompt_embeds,
594
+ prompt_attention_mask,
595
+ pooled_prompt_embeds,
596
+ video_num_inference_steps,
597
+ height,
598
+ width,
599
+ self.frame_per_unit,
600
+ device,
601
+ dtype,
602
+ generator,
603
+ is_first_frame=False,
604
+ )
605
+
606
+ generated_latents_list.append(intermed_latents[-1])
607
+ last_generated_latents = intermed_latents
608
+
609
+ generated_latents = torch.cat(generated_latents_list, dim=2)
610
+
611
+ if output_type == "latent":
612
+ image = generated_latents
613
+ else:
614
+ image = self.decode_latent(generated_latents, save_memory=save_memory)
615
+
616
+ return image
617
+
618
+ def decode_latent(self, latents, save_memory=True):
619
+ if latents.shape[2] == 1:
620
+ latents = (latents / self.vae_scale_factor) + self.vae_shift_factor
621
+ else:
622
+ latents[:, :, :1] = (latents[:, :, :1] / self.vae_scale_factor) + self.vae_shift_factor
623
+ latents[:, :, 1:] = (latents[:, :, 1:] / self.vae_video_scale_factor) + self.vae_video_shift_factor
624
+
625
+ if save_memory:
626
+ # reducing the tile size and temporal chunk window size
627
+ image = self.vae.decode(latents, temporal_chunk=True, window_size=1, tile_sample_min_size=256).sample
628
+ else:
629
+ image = self.vae.decode(latents, temporal_chunk=True, window_size=2, tile_sample_min_size=512).sample
630
+
631
+ image = image.float()
632
+ image = (image / 2 + 0.5).clamp(0, 1)
633
+ image = rearrange(image, "B C T H W -> (B T) C H W")
634
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
635
+ image = self.numpy_to_pil(image)
636
+ return image
637
+
638
+ @staticmethod
639
+ def numpy_to_pil(images):
640
+ """
641
+ Convert a numpy image or a batch of images to a PIL image.
642
+ """
643
+ if images.ndim == 3:
644
+ images = images[None, ...]
645
+ images = (images * 255).round().astype("uint8")
646
+ if images.shape[-1] == 1:
647
+ # special case for grayscale (single channel) images
648
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
649
+ else:
650
+ pil_images = [Image.fromarray(image) for image in images]
651
+
652
+ return pil_images
653
+
654
+ @property
655
+ def device(self):
656
+ return next(self.dit.parameters()).device
657
+
658
+ @property
659
+ def dtype(self):
660
+ return next(self.dit.parameters()).dtype
661
+
662
+ @property
663
+ def guidance_scale(self):
664
+ return self._guidance_scale
665
+
666
+ @property
667
+ def video_guidance_scale(self):
668
+ return self._video_guidance_scale
669
+
670
+ @property
671
+ def do_classifier_free_guidance(self):
672
+ return self._guidance_scale > 0
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ contexttimer
2
+ decord
3
+ diffusers>=0.30.1
4
+ accelerate==0.30.0
5
+ torch==2.1.2
6
+ torchvision==0.16.2
7
+ numpy==1.24.4
8
+ einops
9
+ ftfy
10
+ ipython
11
+ opencv-python-headless==4.10.0.84
12
+ imageio==2.33.1
13
+ imageio-ffmpeg==0.5.1
14
+ packaging
15
+ pandas
16
+ plotly
17
+ pre-commit
18
+ pycocoevalcap
19
+ pycocotools
20
+ python-magic
21
+ scikit-image
22
+ sentencepiece
23
+ spacy
24
+ streamlit
25
+ timm==0.6.12
26
+ tqdm
27
+ transformers==4.39.3
28
+ wheel
29
+ torchmetrics
30
+ tiktoken
31
+ jsonlines
32
+ tensorboardX
trainer_misc/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import (
2
+ create_optimizer,
3
+ get_rank,
4
+ get_world_size,
5
+ is_main_process,
6
+ is_dist_avail_and_initialized,
7
+ init_distributed_mode,
8
+ setup_for_distributed,
9
+ cosine_scheduler,
10
+ constant_scheduler,
11
+ )
12
+
13
+ from .sp_utils import (
14
+ is_sequence_parallel_initialized,
15
+ init_sequence_parallel_group,
16
+ get_sequence_parallel_group,
17
+ get_sequence_parallel_world_size,
18
+ get_sequence_parallel_rank,
19
+ get_sequence_parallel_group_rank,
20
+ get_sequence_parallel_proc_num,
21
+ init_sync_input_group,
22
+ get_sync_input_group,
23
+ )
24
+
25
+ from .communicate import all_to_all
trainer_misc/communicate.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import torch.distributed as dist
5
+
6
+
7
+ def _all_to_all(
8
+ input_: torch.Tensor,
9
+ world_size: int,
10
+ group: dist.ProcessGroup,
11
+ scatter_dim: int,
12
+ gather_dim: int,
13
+ ):
14
+ if world_size == 1:
15
+ return input_
16
+ input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
17
+ output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
18
+ dist.all_to_all(output_list, input_list, group=group)
19
+ return torch.cat(output_list, dim=gather_dim).contiguous()
20
+
21
+
22
+ class _AllToAll(torch.autograd.Function):
23
+
24
+ @staticmethod
25
+ def forward(ctx, input_, process_group, world_size, scatter_dim, gather_dim):
26
+ ctx.process_group = process_group
27
+ ctx.scatter_dim = scatter_dim
28
+ ctx.gather_dim = gather_dim
29
+ ctx.world_size = world_size
30
+ output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim)
31
+ return output
32
+
33
+ @staticmethod
34
+ def backward(ctx, grad_output):
35
+ grad_output = _all_to_all(
36
+ grad_output,
37
+ ctx.world_size,
38
+ ctx.process_group,
39
+ ctx.gather_dim,
40
+ ctx.scatter_dim,
41
+ )
42
+ return (
43
+ grad_output,
44
+ None,
45
+ None,
46
+ None,
47
+ None,
48
+ )
49
+
50
+
51
+ def all_to_all(
52
+ input_: torch.Tensor,
53
+ process_group: dist.ProcessGroup,
54
+ world_size: int = 1,
55
+ scatter_dim: int = 2,
56
+ gather_dim: int = 1,
57
+ ):
58
+ return _AllToAll.apply(input_, process_group, world_size, scatter_dim, gather_dim)
trainer_misc/sp_utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.distributed as dist
4
+ from .utils import is_dist_avail_and_initialized, get_rank
5
+
6
+
7
+ SEQ_PARALLEL_GROUP = None
8
+ SEQ_PARALLEL_SIZE = None
9
+ SEQ_PARALLEL_PROC_NUM = None # using how many process for sequence parallel
10
+
11
+ SYNC_INPUT_GROUP = None
12
+ SYNC_INPUT_SIZE = None
13
+
14
+ def is_sequence_parallel_initialized():
15
+ if SEQ_PARALLEL_GROUP is None:
16
+ return False
17
+ else:
18
+ return True
19
+
20
+
21
+ def init_sequence_parallel_group(args):
22
+ global SEQ_PARALLEL_GROUP
23
+ global SEQ_PARALLEL_SIZE
24
+ global SEQ_PARALLEL_PROC_NUM
25
+
26
+ assert SEQ_PARALLEL_GROUP is None, "sequence parallel group is already initialized"
27
+ assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized"
28
+ SEQ_PARALLEL_SIZE = args.sp_group_size
29
+
30
+ print(f"Setting the Sequence Parallel Size {SEQ_PARALLEL_SIZE}")
31
+
32
+ rank = torch.distributed.get_rank()
33
+ world_size = torch.distributed.get_world_size()
34
+
35
+ if args.sp_proc_num == -1:
36
+ SEQ_PARALLEL_PROC_NUM = world_size
37
+ else:
38
+ SEQ_PARALLEL_PROC_NUM = args.sp_proc_num
39
+
40
+ assert SEQ_PARALLEL_PROC_NUM % SEQ_PARALLEL_SIZE == 0, "The process needs to be evenly divided"
41
+
42
+ for i in range(0, SEQ_PARALLEL_PROC_NUM, SEQ_PARALLEL_SIZE):
43
+ ranks = list(range(i, i + SEQ_PARALLEL_SIZE))
44
+ group = torch.distributed.new_group(ranks)
45
+ if rank in ranks:
46
+ SEQ_PARALLEL_GROUP = group
47
+ break
48
+
49
+
50
+ def init_sync_input_group(args):
51
+ global SYNC_INPUT_GROUP
52
+ global SYNC_INPUT_SIZE
53
+
54
+ assert SYNC_INPUT_GROUP is None, "parallel group is already initialized"
55
+ assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized"
56
+ SYNC_INPUT_SIZE = args.max_frames
57
+
58
+ rank = torch.distributed.get_rank()
59
+ world_size = torch.distributed.get_world_size()
60
+
61
+ for i in range(0, world_size, SYNC_INPUT_SIZE):
62
+ ranks = list(range(i, i + SYNC_INPUT_SIZE))
63
+ group = torch.distributed.new_group(ranks)
64
+ if rank in ranks:
65
+ SYNC_INPUT_GROUP = group
66
+ break
67
+
68
+
69
+ def get_sequence_parallel_group():
70
+ assert SEQ_PARALLEL_GROUP is not None, "sequence parallel group is not initialized"
71
+ return SEQ_PARALLEL_GROUP
72
+
73
+
74
+ def get_sync_input_group():
75
+ return SYNC_INPUT_GROUP
76
+
77
+
78
+ def get_sequence_parallel_world_size():
79
+ assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
80
+ return SEQ_PARALLEL_SIZE
81
+
82
+
83
+ def get_sequence_parallel_rank():
84
+ assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
85
+ rank = get_rank()
86
+ cp_rank = rank % SEQ_PARALLEL_SIZE
87
+ return cp_rank
88
+
89
+
90
+ def get_sequence_parallel_group_rank():
91
+ assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
92
+ rank = get_rank()
93
+ cp_group_rank = rank // SEQ_PARALLEL_SIZE
94
+ return cp_group_rank
95
+
96
+
97
+ def get_sequence_parallel_proc_num():
98
+ return SEQ_PARALLEL_PROC_NUM
trainer_misc/utils.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import math
4
+ import time
5
+ import json
6
+ import glob
7
+ from collections import defaultdict, deque, OrderedDict
8
+ import datetime
9
+ import numpy as np
10
+
11
+
12
+ from pathlib import Path
13
+ import argparse
14
+
15
+ import torch
16
+ from torch import optim as optim
17
+ import torch.distributed as dist
18
+ from tensorboardX import SummaryWriter
19
+
20
+
21
+ def is_dist_avail_and_initialized():
22
+ if not dist.is_available():
23
+ return False
24
+ if not dist.is_initialized():
25
+ return False
26
+ return True
27
+
28
+
29
+ def get_world_size():
30
+ if not is_dist_avail_and_initialized():
31
+ return 1
32
+ return dist.get_world_size()
33
+
34
+
35
+ def get_rank():
36
+ if not is_dist_avail_and_initialized():
37
+ return 0
38
+ return dist.get_rank()
39
+
40
+
41
+ def is_main_process():
42
+ return get_rank() == 0
43
+
44
+
45
+ def save_on_master(*args, **kwargs):
46
+ if is_main_process():
47
+ torch.save(*args, **kwargs)
48
+
49
+
50
+ def setup_for_distributed(is_master):
51
+ """
52
+ This function disables printing when not in master process
53
+ """
54
+ import builtins as __builtin__
55
+ builtin_print = __builtin__.print
56
+
57
+ def print(*args, **kwargs):
58
+ force = kwargs.pop('force', False)
59
+ if is_master or force:
60
+ builtin_print(*args, **kwargs)
61
+
62
+ __builtin__.print = print
63
+
64
+
65
+ def init_distributed_mode(args):
66
+ if int(os.getenv('OMPI_COMM_WORLD_SIZE', '0')) > 0:
67
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
68
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
69
+ world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
70
+
71
+ os.environ["LOCAL_RANK"] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
72
+ os.environ["RANK"] = os.environ['OMPI_COMM_WORLD_RANK']
73
+ os.environ["WORLD_SIZE"] = os.environ['OMPI_COMM_WORLD_SIZE']
74
+
75
+ args.rank = int(os.environ["RANK"])
76
+ args.world_size = int(os.environ["WORLD_SIZE"])
77
+ args.gpu = int(os.environ["LOCAL_RANK"])
78
+
79
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
80
+ args.rank = int(os.environ["RANK"])
81
+ args.world_size = int(os.environ['WORLD_SIZE'])
82
+ args.gpu = int(os.environ['LOCAL_RANK'])
83
+
84
+ else:
85
+ print('Not using distributed mode')
86
+ args.distributed = False
87
+ return
88
+
89
+ args.distributed = True
90
+ args.dist_backend = 'nccl'
91
+ args.dist_url = "env://"
92
+ print('| distributed init (rank {}): {}, gpu {}'.format(
93
+ args.rank, args.dist_url, args.gpu), flush=True)
94
+
95
+
96
+ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
97
+ start_warmup_value=0, warmup_steps=-1):
98
+ warmup_schedule = np.array([])
99
+ warmup_iters = warmup_epochs * niter_per_ep
100
+ if warmup_steps > 0:
101
+ warmup_iters = warmup_steps
102
+ print("Set warmup steps = %d" % warmup_iters)
103
+ if warmup_epochs > 0:
104
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
105
+
106
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
107
+ schedule = np.array(
108
+ [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
109
+
110
+ schedule = np.concatenate((warmup_schedule, schedule))
111
+
112
+ assert len(schedule) == epochs * niter_per_ep
113
+ return schedule
114
+
115
+
116
+ def constant_scheduler(base_value, epochs, niter_per_ep, warmup_epochs=0,
117
+ start_warmup_value=1e-6, warmup_steps=-1):
118
+ warmup_schedule = np.array([])
119
+ warmup_iters = warmup_epochs * niter_per_ep
120
+ if warmup_steps > 0:
121
+ warmup_iters = warmup_steps
122
+ print("Set warmup steps = %d" % warmup_iters)
123
+ if warmup_iters > 0:
124
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
125
+
126
+ iters = epochs * niter_per_ep - warmup_iters
127
+ schedule = np.array([base_value] * iters)
128
+
129
+ schedule = np.concatenate((warmup_schedule, schedule))
130
+
131
+ assert len(schedule) == epochs * niter_per_ep
132
+ return schedule
133
+
134
+
135
+ def get_parameter_groups(model, weight_decay=1e-5, base_lr=1e-4, skip_list=(), get_num_layer=None, get_layer_scale=None, **kwargs):
136
+ parameter_group_names = {}
137
+ parameter_group_vars = {}
138
+
139
+ for name, param in model.named_parameters():
140
+ if not param.requires_grad:
141
+ continue # frozen weights
142
+ if len(kwargs.get('filter_name', [])) > 0:
143
+ flag = False
144
+ for filter_n in kwargs.get('filter_name', []):
145
+ if filter_n in name:
146
+ print(f"filter {name} because of the pattern {filter_n}")
147
+ flag = True
148
+ if flag:
149
+ continue
150
+
151
+ default_scale=1.
152
+
153
+ if param.ndim <= 1 or name.endswith(".bias") or name in skip_list: # param.ndim <= 1 len(param.shape) == 1
154
+ group_name = "no_decay"
155
+ this_weight_decay = 0.
156
+ else:
157
+ group_name = "decay"
158
+ this_weight_decay = weight_decay
159
+
160
+ if get_num_layer is not None:
161
+ layer_id = get_num_layer(name)
162
+ group_name = "layer_%d_%s" % (layer_id, group_name)
163
+ else:
164
+ layer_id = None
165
+
166
+ if group_name not in parameter_group_names:
167
+ if get_layer_scale is not None:
168
+ scale = get_layer_scale(layer_id)
169
+ else:
170
+ scale = default_scale
171
+
172
+ parameter_group_names[group_name] = {
173
+ "weight_decay": this_weight_decay,
174
+ "params": [],
175
+ "lr": base_lr,
176
+ "lr_scale": scale,
177
+ }
178
+
179
+ parameter_group_vars[group_name] = {
180
+ "weight_decay": this_weight_decay,
181
+ "params": [],
182
+ "lr": base_lr,
183
+ "lr_scale": scale,
184
+ }
185
+
186
+ parameter_group_vars[group_name]["params"].append(param)
187
+ parameter_group_names[group_name]["params"].append(name)
188
+
189
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
190
+ return list(parameter_group_vars.values())
191
+
192
+
193
+ def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None, **kwargs):
194
+ opt_lower = args.opt.lower()
195
+ weight_decay = args.weight_decay
196
+
197
+ skip = {}
198
+ if skip_list is not None:
199
+ skip = skip_list
200
+ elif hasattr(model, 'no_weight_decay'):
201
+ skip = model.no_weight_decay()
202
+ print(f"Skip weight decay name marked in model: {skip}")
203
+ parameters = get_parameter_groups(model, weight_decay, args.lr, skip, get_num_layer, get_layer_scale, **kwargs)
204
+ weight_decay = 0.
205
+
206
+ if 'fused' in opt_lower:
207
+ assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
208
+
209
+ opt_args = dict(lr=args.lr, weight_decay=weight_decay)
210
+ if hasattr(args, 'opt_eps') and args.opt_eps is not None:
211
+ opt_args['eps'] = args.opt_eps
212
+ if hasattr(args, 'opt_beta1') and args.opt_beta1 is not None:
213
+ opt_args['betas'] = (args.opt_beta1, args.opt_beta2)
214
+
215
+ print('Optimizer config:', opt_args)
216
+ opt_split = opt_lower.split('_')
217
+ opt_lower = opt_split[-1]
218
+ if opt_lower == 'sgd' or opt_lower == 'nesterov':
219
+ opt_args.pop('eps', None)
220
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
221
+ elif opt_lower == 'momentum':
222
+ opt_args.pop('eps', None)
223
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
224
+ elif opt_lower == 'adam':
225
+ optimizer = optim.Adam(parameters, **opt_args)
226
+ elif opt_lower == 'adamw':
227
+ optimizer = optim.AdamW(parameters, **opt_args)
228
+ elif opt_lower == 'adadelta':
229
+ optimizer = optim.Adadelta(parameters, **opt_args)
230
+ elif opt_lower == 'rmsprop':
231
+ optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
232
+ else:
233
+ assert False and "Invalid optimizer"
234
+ raise ValueError
235
+
236
+ return optimizer
237
+
238
+
239
+ class SmoothedValue(object):
240
+ """Track a series of values and provide access to smoothed values over a
241
+ window or the global series average.
242
+ """
243
+
244
+ def __init__(self, window_size=20, fmt=None):
245
+ if fmt is None:
246
+ fmt = "{median:.4f} ({global_avg:.4f})"
247
+ self.deque = deque(maxlen=window_size)
248
+ self.total = 0.0
249
+ self.count = 0
250
+ self.fmt = fmt
251
+
252
+ def update(self, value, n=1):
253
+ self.deque.append(value)
254
+ self.count += n
255
+ self.total += value * n
256
+
257
+ def synchronize_between_processes(self):
258
+ """
259
+ Warning: does not synchronize the deque!
260
+ """
261
+ if not is_dist_avail_and_initialized():
262
+ return
263
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
264
+ dist.barrier()
265
+ dist.all_reduce(t)
266
+ t = t.tolist()
267
+ self.count = int(t[0])
268
+ self.total = t[1]
269
+
270
+ @property
271
+ def median(self):
272
+ d = torch.tensor(list(self.deque))
273
+ return d.median().item()
274
+
275
+ @property
276
+ def avg(self):
277
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
278
+ return d.mean().item()
279
+
280
+ @property
281
+ def global_avg(self):
282
+ return self.total / self.count
283
+
284
+ @property
285
+ def max(self):
286
+ return max(self.deque)
287
+
288
+ @property
289
+ def value(self):
290
+ return self.deque[-1]
291
+
292
+ def __str__(self):
293
+ return self.fmt.format(
294
+ median=self.median,
295
+ avg=self.avg,
296
+ global_avg=self.global_avg,
297
+ max=self.max,
298
+ value=self.value)
299
+
300
+
301
+ class MetricLogger(object):
302
+ def __init__(self, delimiter="\t"):
303
+ self.meters = defaultdict(SmoothedValue)
304
+ self.delimiter = delimiter
305
+
306
+ def update(self, **kwargs):
307
+ for k, v in kwargs.items():
308
+ if v is None:
309
+ continue
310
+ if isinstance(v, torch.Tensor):
311
+ v = v.item()
312
+ assert isinstance(v, (float, int))
313
+ self.meters[k].update(v)
314
+
315
+ def __getattr__(self, attr):
316
+ if attr in self.meters:
317
+ return self.meters[attr]
318
+ if attr in self.__dict__:
319
+ return self.__dict__[attr]
320
+ raise AttributeError("'{}' object has no attribute '{}'".format(
321
+ type(self).__name__, attr))
322
+
323
+ def __str__(self):
324
+ loss_str = []
325
+ for name, meter in self.meters.items():
326
+ loss_str.append(
327
+ "{}: {}".format(name, str(meter))
328
+ )
329
+ return self.delimiter.join(loss_str)
330
+
331
+ def synchronize_between_processes(self):
332
+ for meter in self.meters.values():
333
+ meter.synchronize_between_processes()
334
+
335
+ def add_meter(self, name, meter):
336
+ self.meters[name] = meter
337
+
338
+ def log_every(self, iterable, print_freq, header=None):
339
+ i = 0
340
+ if not header:
341
+ header = ''
342
+ start_time = time.time()
343
+ end = time.time()
344
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
345
+ data_time = SmoothedValue(fmt='{avg:.4f}')
346
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
347
+ log_msg = [
348
+ header,
349
+ '[{0' + space_fmt + '}/{1}]',
350
+ 'eta: {eta}',
351
+ '{meters}',
352
+ 'time: {time}',
353
+ 'data: {data}'
354
+ ]
355
+ if torch.cuda.is_available():
356
+ log_msg.append('max mem: {memory:.0f}')
357
+ log_msg = self.delimiter.join(log_msg)
358
+ MB = 1024.0 * 1024.0
359
+ for obj in iterable:
360
+ data_time.update(time.time() - end)
361
+ yield obj
362
+ iter_time.update(time.time() - end)
363
+ if i % print_freq == 0 or i == len(iterable) - 1:
364
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
365
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
366
+ if torch.cuda.is_available():
367
+ print(log_msg.format(
368
+ i, len(iterable), eta=eta_string,
369
+ meters=str(self),
370
+ time=str(iter_time), data=str(data_time),
371
+ memory=torch.cuda.max_memory_allocated() / MB))
372
+ else:
373
+ print(log_msg.format(
374
+ i, len(iterable), eta=eta_string,
375
+ meters=str(self),
376
+ time=str(iter_time), data=str(data_time)))
377
+ i += 1
378
+ end = time.time()
379
+ total_time = time.time() - start_time
380
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
381
+ print('{} Total time: {} ({:.4f} s / it)'.format(
382
+ header, total_time_str, total_time / len(iterable)))
utils.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import PIL.Image
4
+ import numpy as np
5
+ from torch import nn
6
+ import torch.distributed as dist
7
+ import timm.models.hub as timm_hub
8
+
9
+ """Modified from https://github.com/CompVis/taming-transformers.git"""
10
+
11
+ import hashlib
12
+ import requests
13
+ from tqdm import tqdm
14
+ try:
15
+ import piq
16
+ except:
17
+ pass
18
+
19
+ _CONTEXT_PARALLEL_GROUP = None
20
+ _CONTEXT_PARALLEL_SIZE = None
21
+
22
+
23
+ def is_dist_avail_and_initialized():
24
+ if not dist.is_available():
25
+ return False
26
+ if not dist.is_initialized():
27
+ return False
28
+ return True
29
+
30
+
31
+ def get_world_size():
32
+ if not is_dist_avail_and_initialized():
33
+ return 1
34
+ return dist.get_world_size()
35
+
36
+
37
+ def get_rank():
38
+ if not is_dist_avail_and_initialized():
39
+ return 0
40
+ return dist.get_rank()
41
+
42
+
43
+ def is_main_process():
44
+ return get_rank() == 0
45
+
46
+
47
+ def is_context_parallel_initialized():
48
+ if _CONTEXT_PARALLEL_GROUP is None:
49
+ return False
50
+ else:
51
+ return True
52
+
53
+
54
+ def set_context_parallel_group(size, group):
55
+ global _CONTEXT_PARALLEL_GROUP
56
+ global _CONTEXT_PARALLEL_SIZE
57
+ _CONTEXT_PARALLEL_GROUP = group
58
+ _CONTEXT_PARALLEL_SIZE = size
59
+
60
+
61
+ def initialize_context_parallel(context_parallel_size):
62
+ global _CONTEXT_PARALLEL_GROUP
63
+ global _CONTEXT_PARALLEL_SIZE
64
+
65
+ assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
66
+ _CONTEXT_PARALLEL_SIZE = context_parallel_size
67
+
68
+ rank = torch.distributed.get_rank()
69
+ world_size = torch.distributed.get_world_size()
70
+
71
+ for i in range(0, world_size, context_parallel_size):
72
+ ranks = range(i, i + context_parallel_size)
73
+ group = torch.distributed.new_group(ranks)
74
+ if rank in ranks:
75
+ _CONTEXT_PARALLEL_GROUP = group
76
+ break
77
+
78
+
79
+ def get_context_parallel_group():
80
+ assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
81
+
82
+ return _CONTEXT_PARALLEL_GROUP
83
+
84
+
85
+ def get_context_parallel_world_size():
86
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
87
+
88
+ return _CONTEXT_PARALLEL_SIZE
89
+
90
+
91
+ def get_context_parallel_rank():
92
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
93
+
94
+ rank = get_rank()
95
+ cp_rank = rank % _CONTEXT_PARALLEL_SIZE
96
+ return cp_rank
97
+
98
+
99
+ def get_context_parallel_group_rank():
100
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
101
+
102
+ rank = get_rank()
103
+ cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
104
+
105
+ return cp_group_rank
106
+
107
+
108
+ def download_cached_file(url, check_hash=True, progress=False):
109
+ """
110
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
111
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
112
+ """
113
+
114
+ def get_cached_file_path():
115
+ # a hack to sync the file path across processes
116
+ parts = torch.hub.urlparse(url)
117
+ filename = os.path.basename(parts.path)
118
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
119
+
120
+ return cached_file
121
+
122
+ if is_main_process():
123
+ timm_hub.download_cached_file(url, check_hash, progress)
124
+
125
+ if is_dist_avail_and_initialized():
126
+ dist.barrier()
127
+
128
+ return get_cached_file_path()
129
+
130
+
131
+ def convert_weights_to_fp16(model: nn.Module):
132
+ """Convert applicable model parameters to fp16"""
133
+
134
+ def _convert_weights_to_fp16(l):
135
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
136
+ l.weight.data = l.weight.data.to(torch.float16)
137
+ if l.bias is not None:
138
+ l.bias.data = l.bias.data.to(torch.float16)
139
+
140
+ model.apply(_convert_weights_to_fp16)
141
+
142
+
143
+ def convert_weights_to_bf16(model: nn.Module):
144
+ """Convert applicable model parameters to fp16"""
145
+
146
+ def _convert_weights_to_bf16(l):
147
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
148
+ l.weight.data = l.weight.data.to(torch.bfloat16)
149
+ if l.bias is not None:
150
+ l.bias.data = l.bias.data.to(torch.bfloat16)
151
+
152
+ model.apply(_convert_weights_to_bf16)
153
+
154
+
155
+ def save_result(result, result_dir, filename, remove_duplicate="", save_format='json'):
156
+ import json
157
+ import jsonlines
158
+ print("Dump result")
159
+
160
+ # Make the temp dir for saving results
161
+ if not os.path.exists(result_dir):
162
+ if is_main_process():
163
+ os.makedirs(result_dir)
164
+ if is_dist_avail_and_initialized():
165
+ torch.distributed.barrier()
166
+
167
+ result_file = os.path.join(
168
+ result_dir, "%s_rank%d.json" % (filename, get_rank())
169
+ )
170
+
171
+ final_result_file = os.path.join(result_dir, f"{filename}.{save_format}")
172
+
173
+ json.dump(result, open(result_file, "w"))
174
+
175
+ if is_dist_avail_and_initialized():
176
+ torch.distributed.barrier()
177
+
178
+ if is_main_process():
179
+ # print("rank %d starts merging results." % get_rank())
180
+ # combine results from all processes
181
+ result = []
182
+
183
+ for rank in range(get_world_size()):
184
+ result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank))
185
+ res = json.load(open(result_file, "r"))
186
+ result += res
187
+
188
+ # print("Remove duplicate")
189
+ if remove_duplicate:
190
+ result_new = []
191
+ id_set = set()
192
+ for res in result:
193
+ if res[remove_duplicate] not in id_set:
194
+ id_set.add(res[remove_duplicate])
195
+ result_new.append(res)
196
+ result = result_new
197
+
198
+ if save_format == 'json':
199
+ json.dump(result, open(final_result_file, "w"))
200
+ else:
201
+ assert save_format == 'jsonl', "Only support json adn jsonl format"
202
+ with jsonlines.open(final_result_file, "w") as writer:
203
+ writer.write_all(result)
204
+
205
+ # print("result file saved to %s" % final_result_file)
206
+
207
+ return final_result_file
208
+
209
+
210
+ # resizing utils
211
+ # TODO: clean up later
212
+ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
213
+ h, w = input.shape[-2:]
214
+ factors = (h / size[0], w / size[1])
215
+
216
+ # First, we have to determine sigma
217
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
218
+ sigmas = (
219
+ max((factors[0] - 1.0) / 2.0, 0.001),
220
+ max((factors[1] - 1.0) / 2.0, 0.001),
221
+ )
222
+
223
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
224
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
225
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
226
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
227
+
228
+ # Make sure it is odd
229
+ if (ks[0] % 2) == 0:
230
+ ks = ks[0] + 1, ks[1]
231
+
232
+ if (ks[1] % 2) == 0:
233
+ ks = ks[0], ks[1] + 1
234
+
235
+ input = _gaussian_blur2d(input, ks, sigmas)
236
+
237
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
238
+ return output
239
+
240
+
241
+ def _compute_padding(kernel_size):
242
+ """Compute padding tuple."""
243
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
244
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
245
+ if len(kernel_size) < 2:
246
+ raise AssertionError(kernel_size)
247
+ computed = [k - 1 for k in kernel_size]
248
+
249
+ # for even kernels we need to do asymmetric padding :(
250
+ out_padding = 2 * len(kernel_size) * [0]
251
+
252
+ for i in range(len(kernel_size)):
253
+ computed_tmp = computed[-(i + 1)]
254
+
255
+ pad_front = computed_tmp // 2
256
+ pad_rear = computed_tmp - pad_front
257
+
258
+ out_padding[2 * i + 0] = pad_front
259
+ out_padding[2 * i + 1] = pad_rear
260
+
261
+ return out_padding
262
+
263
+
264
+ def _filter2d(input, kernel):
265
+ # prepare kernel
266
+ b, c, h, w = input.shape
267
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
268
+
269
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
270
+
271
+ height, width = tmp_kernel.shape[-2:]
272
+
273
+ padding_shape: list[int] = _compute_padding([height, width])
274
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
275
+
276
+ # kernel and input tensor reshape to align element-wise or batch-wise params
277
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
278
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
279
+
280
+ # convolve the tensor with the kernel.
281
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
282
+
283
+ out = output.view(b, c, h, w)
284
+ return out
285
+
286
+
287
+ def _gaussian(window_size: int, sigma):
288
+ if isinstance(sigma, float):
289
+ sigma = torch.tensor([[sigma]])
290
+
291
+ batch_size = sigma.shape[0]
292
+
293
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
294
+
295
+ if window_size % 2 == 0:
296
+ x = x + 0.5
297
+
298
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
299
+
300
+ return gauss / gauss.sum(-1, keepdim=True)
301
+
302
+
303
+ def _gaussian_blur2d(input, kernel_size, sigma):
304
+ if isinstance(sigma, tuple):
305
+ sigma = torch.tensor([sigma], dtype=input.dtype)
306
+ else:
307
+ sigma = sigma.to(dtype=input.dtype)
308
+
309
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
310
+ bs = sigma.shape[0]
311
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
312
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
313
+ out_x = _filter2d(input, kernel_x[..., None, :])
314
+ out = _filter2d(out_x, kernel_y[..., None])
315
+
316
+ return out
317
+
318
+
319
+ URL_MAP = {
320
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
321
+ }
322
+
323
+ CKPT_MAP = {
324
+ "vgg_lpips": "vgg.pth"
325
+ }
326
+
327
+ MD5_MAP = {
328
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
329
+ }
330
+
331
+
332
+ def download(url, local_path, chunk_size=1024):
333
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
334
+ with requests.get(url, stream=True) as r:
335
+ total_size = int(r.headers.get("content-length", 0))
336
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
337
+ with open(local_path, "wb") as f:
338
+ for data in r.iter_content(chunk_size=chunk_size):
339
+ if data:
340
+ f.write(data)
341
+ pbar.update(chunk_size)
342
+
343
+
344
+ def md5_hash(path):
345
+ with open(path, "rb") as f:
346
+ content = f.read()
347
+ return hashlib.md5(content).hexdigest()
348
+
349
+
350
+ def get_ckpt_path(name, root, check=False):
351
+ assert name in URL_MAP
352
+ path = os.path.join(root, CKPT_MAP[name])
353
+ print(md5_hash(path))
354
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
355
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
356
+ download(URL_MAP[name], path)
357
+ md5 = md5_hash(path)
358
+ assert md5 == MD5_MAP[name], md5
359
+ return path
360
+
361
+
362
+ class KeyNotFoundError(Exception):
363
+ def __init__(self, cause, keys=None, visited=None):
364
+ self.cause = cause
365
+ self.keys = keys
366
+ self.visited = visited
367
+ messages = list()
368
+ if keys is not None:
369
+ messages.append("Key not found: {}".format(keys))
370
+ if visited is not None:
371
+ messages.append("Visited: {}".format(visited))
372
+ messages.append("Cause:\n{}".format(cause))
373
+ message = "\n".join(messages)
374
+ super().__init__(message)
375
+
376
+
377
+ def retrieve(
378
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
379
+ ):
380
+ """Given a nested list or dict return the desired value at key expanding
381
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
382
+ is done in-place.
383
+
384
+ Parameters
385
+ ----------
386
+ list_or_dict : list or dict
387
+ Possibly nested list or dictionary.
388
+ key : str
389
+ key/to/value, path like string describing all keys necessary to
390
+ consider to get to the desired value. List indices can also be
391
+ passed here.
392
+ splitval : str
393
+ String that defines the delimiter between keys of the
394
+ different depth levels in `key`.
395
+ default : obj
396
+ Value returned if :attr:`key` is not found.
397
+ expand : bool
398
+ Whether to expand callable nodes on the path or not.
399
+
400
+ Returns
401
+ -------
402
+ The desired value or if :attr:`default` is not ``None`` and the
403
+ :attr:`key` is not found returns ``default``.
404
+
405
+ Raises
406
+ ------
407
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
408
+ ``None``.
409
+ """
410
+
411
+ keys = key.split(splitval)
412
+
413
+ success = True
414
+ try:
415
+ visited = []
416
+ parent = None
417
+ last_key = None
418
+ for key in keys:
419
+ if callable(list_or_dict):
420
+ if not expand:
421
+ raise KeyNotFoundError(
422
+ ValueError(
423
+ "Trying to get past callable node with expand=False."
424
+ ),
425
+ keys=keys,
426
+ visited=visited,
427
+ )
428
+ list_or_dict = list_or_dict()
429
+ parent[last_key] = list_or_dict
430
+
431
+ last_key = key
432
+ parent = list_or_dict
433
+
434
+ try:
435
+ if isinstance(list_or_dict, dict):
436
+ list_or_dict = list_or_dict[key]
437
+ else:
438
+ list_or_dict = list_or_dict[int(key)]
439
+ except (KeyError, IndexError, ValueError) as e:
440
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
441
+
442
+ visited += [key]
443
+ # final expansion of retrieved value
444
+ if expand and callable(list_or_dict):
445
+ list_or_dict = list_or_dict()
446
+ parent[last_key] = list_or_dict
447
+ except KeyNotFoundError as e:
448
+ if default is None:
449
+ raise e
450
+ else:
451
+ list_or_dict = default
452
+ success = False
453
+
454
+ if not pass_success:
455
+ return list_or_dict
456
+ else:
457
+ return list_or_dict, success
video_generation_demo.ipynb ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import json\n",
11
+ "import torch\n",
12
+ "import numpy as np\n",
13
+ "import PIL\n",
14
+ "from PIL import Image\n",
15
+ "from IPython.display import HTML\n",
16
+ "from pyramid_dit import PyramidDiTForVideoGeneration\n",
17
+ "from IPython.display import Image as ipython_image\n",
18
+ "from diffusers.utils import load_image, export_to_video, export_to_gif"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "variant='diffusion_transformer_768p' # For high resolution\n",
28
+ "# variant='diffusion_transformer_384p' # For low resolution\n",
29
+ "\n",
30
+ "model_path = \"/home/jinyang06/models/pyramid-flow\" # The downloaded checkpoint dir\n",
31
+ "model_dtype = 'bf16'\n",
32
+ "\n",
33
+ "device_id = 0\n",
34
+ "torch.cuda.set_device(device_id)\n",
35
+ "\n",
36
+ "model = PyramidDiTForVideoGeneration(\n",
37
+ " model_path,\n",
38
+ " model_dtype,\n",
39
+ " model_variant=variant,\n",
40
+ ")\n",
41
+ "\n",
42
+ "model.vae.to(\"cuda\")\n",
43
+ "model.dit.to(\"cuda\")\n",
44
+ "model.text_encoder.to(\"cuda\")\n",
45
+ "\n",
46
+ "if model_dtype == \"bf16\":\n",
47
+ " torch_dtype = torch.bfloat16 \n",
48
+ "elif model_dtype == \"fp16\":\n",
49
+ " torch_dtype = torch.float16\n",
50
+ "else:\n",
51
+ " torch_dtype = torch.float32\n",
52
+ "\n",
53
+ "\n",
54
+ "def show_video(ori_path, rec_path, width=\"100%\"):\n",
55
+ " html = ''\n",
56
+ " if ori_path is not None:\n",
57
+ " html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
58
+ " <source src=\"{ori_path}\" type=\"video/mp4\">\n",
59
+ " </video>\n",
60
+ " \"\"\"\n",
61
+ " \n",
62
+ " html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
63
+ " <source src=\"{rec_path}\" type=\"video/mp4\">\n",
64
+ " </video>\n",
65
+ " \"\"\"\n",
66
+ " return HTML(html)"
67
+ ]
68
+ },
69
+ {
70
+ "attachments": {},
71
+ "cell_type": "markdown",
72
+ "metadata": {},
73
+ "source": [
74
+ "#### Text-to-Video"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "prompt = \"A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors\"\n",
84
+ "\n",
85
+ "# used for 384p model variant\n",
86
+ "# width = 640\n",
87
+ "# height = 384\n",
88
+ "\n",
89
+ "# used for 768p model variant\n",
90
+ "width = 1280\n",
91
+ "height = 768\n",
92
+ "\n",
93
+ "temp = 16 # temp in [1, 31] <=> frame in [1, 241] <=> duration in [0, 10s]\n",
94
+ "\n",
95
+ "model.vae.enable_tiling()\n",
96
+ "\n",
97
+ "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
98
+ " frames = model.generate(\n",
99
+ " prompt=prompt,\n",
100
+ " num_inference_steps=[20, 20, 20],\n",
101
+ " video_num_inference_steps=[10, 10, 10],\n",
102
+ " height=height,\n",
103
+ " width=width,\n",
104
+ " temp=temp,\n",
105
+ " guidance_scale=9.0, # The guidance for the first frame\n",
106
+ " video_guidance_scale=5.0, # The guidance for the other video latent\n",
107
+ " output_type=\"pil\",\n",
108
+ " save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed\n",
109
+ " )\n",
110
+ "\n",
111
+ "export_to_video(frames, \"./text_to_video_sample.mp4\", fps=24)\n",
112
+ "show_video(None, \"./text_to_video_sample.mp4\", \"70%\")"
113
+ ]
114
+ },
115
+ {
116
+ "attachments": {},
117
+ "cell_type": "markdown",
118
+ "metadata": {},
119
+ "source": [
120
+ "#### Image-to-Video"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "image_path = 'assets/the_great_wall.jpg'\n",
130
+ "image = Image.open(image_path).convert(\"RGB\")\n",
131
+ "\n",
132
+ "width = 1280\n",
133
+ "height = 768\n",
134
+ "temp = 16\n",
135
+ "\n",
136
+ "image = image.resize((width, height))\n",
137
+ "\n",
138
+ "display(image)\n",
139
+ "\n",
140
+ "prompt = \"FPV flying over the Great Wall\"\n",
141
+ "\n",
142
+ "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
143
+ " frames = model.generate_i2v(\n",
144
+ " prompt=prompt,\n",
145
+ " input_image=image,\n",
146
+ " num_inference_steps=[10, 10, 10],\n",
147
+ " temp=temp,\n",
148
+ " guidance_scale=7.0,\n",
149
+ " video_guidance_scale=4.0,\n",
150
+ " output_type=\"pil\",\n",
151
+ " save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed\n",
152
+ " )\n",
153
+ "\n",
154
+ "export_to_video(frames, \"./image_to_video_sample.mp4\", fps=24)\n",
155
+ "show_video(None, \"./image_to_video_sample.mp4\", \"70%\")"
156
+ ]
157
+ }
158
+ ],
159
+ "metadata": {
160
+ "kernelspec": {
161
+ "display_name": "Python 3",
162
+ "language": "python",
163
+ "name": "python3"
164
+ },
165
+ "language_info": {
166
+ "codemirror_mode": {
167
+ "name": "ipython",
168
+ "version": 3
169
+ },
170
+ "file_extension": ".py",
171
+ "mimetype": "text/x-python",
172
+ "name": "python",
173
+ "nbconvert_exporter": "python",
174
+ "pygments_lexer": "ipython3",
175
+ "version": "3.8.10"
176
+ },
177
+ "orig_nbformat": 4
178
+ },
179
+ "nbformat": 4,
180
+ "nbformat_minor": 2
181
+ }
video_vae/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .modeling_loss import LPIPSWithDiscriminator
2
+ from .modeling_causal_vae import CausalVideoVAE
video_vae/context_parallel_ops.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from cogvideoX
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+
6
+ from utils import (
7
+ get_context_parallel_group,
8
+ get_context_parallel_rank,
9
+ get_context_parallel_world_size,
10
+ get_context_parallel_group_rank,
11
+ )
12
+
13
+
14
+ def _conv_split(input_, dim=2, kernel_size=1):
15
+ cp_world_size = get_context_parallel_world_size()
16
+
17
+ # Bypass the function if context parallel is 1
18
+ if cp_world_size == 1:
19
+ return input_
20
+
21
+ # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
22
+
23
+ cp_rank = get_context_parallel_rank()
24
+
25
+ dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
26
+
27
+ if cp_rank == 0:
28
+ output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
29
+ else:
30
+ # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
31
+ output = input_.transpose(dim, 0)[
32
+ cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
33
+ ].transpose(dim, 0)
34
+ output = output.contiguous()
35
+
36
+ # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
37
+
38
+ return output
39
+
40
+
41
+ def _conv_gather(input_, dim=2, kernel_size=1):
42
+ cp_world_size = get_context_parallel_world_size()
43
+
44
+ # Bypass the function if context parallel is 1
45
+ if cp_world_size == 1:
46
+ return input_
47
+
48
+ group = get_context_parallel_group()
49
+ cp_rank = get_context_parallel_rank()
50
+
51
+ # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
52
+
53
+ input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
54
+ if cp_rank == 0:
55
+ input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
56
+ else:
57
+ input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous()
58
+
59
+ tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
60
+ torch.empty_like(input_) for _ in range(cp_world_size - 1)
61
+ ]
62
+ if cp_rank == 0:
63
+ input_ = torch.cat([input_first_kernel_, input_], dim=dim)
64
+
65
+ tensor_list[cp_rank] = input_
66
+ torch.distributed.all_gather(tensor_list, input_, group=group)
67
+
68
+ # Note: torch.cat already creates a contiguous tensor.
69
+ output = torch.cat(tensor_list, dim=dim).contiguous()
70
+
71
+ # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
72
+
73
+ return output
74
+
75
+
76
+ def _cp_pass_from_previous_rank(input_, dim, kernel_size):
77
+ # Bypass the function if kernel size is 1
78
+ if kernel_size == 1:
79
+ return input_
80
+
81
+ group = get_context_parallel_group()
82
+ cp_rank = get_context_parallel_rank()
83
+ cp_group_rank = get_context_parallel_group_rank()
84
+ cp_world_size = get_context_parallel_world_size()
85
+
86
+ # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
87
+
88
+ global_rank = torch.distributed.get_rank()
89
+ global_world_size = torch.distributed.get_world_size()
90
+
91
+ input_ = input_.transpose(0, dim)
92
+
93
+ # pass from last rank
94
+ send_rank = global_rank + 1
95
+ recv_rank = global_rank - 1
96
+ if send_rank % cp_world_size == 0:
97
+ send_rank -= cp_world_size
98
+ if recv_rank % cp_world_size == cp_world_size - 1:
99
+ recv_rank += cp_world_size
100
+
101
+ recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
102
+ if cp_rank < cp_world_size - 1:
103
+ req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
104
+ if cp_rank > 0:
105
+ req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
106
+
107
+ if cp_rank == 0:
108
+ input_ = torch.cat([torch.zeros_like(input_[:1])] * (kernel_size - 1) + [input_], dim=0)
109
+ else:
110
+ req_recv.wait()
111
+ input_ = torch.cat([recv_buffer, input_], dim=0)
112
+
113
+ input_ = input_.transpose(0, dim).contiguous()
114
+ return input_
115
+
116
+
117
+ def _drop_from_previous_rank(input_, dim, kernel_size):
118
+ input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
119
+ return input_
120
+
121
+
122
+ class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
123
+ @staticmethod
124
+ def forward(ctx, input_, dim, kernel_size):
125
+ ctx.dim = dim
126
+ ctx.kernel_size = kernel_size
127
+ return _conv_split(input_, dim, kernel_size)
128
+
129
+ @staticmethod
130
+ def backward(ctx, grad_output):
131
+ return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
132
+
133
+
134
+ class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
135
+ @staticmethod
136
+ def forward(ctx, input_, dim, kernel_size):
137
+ ctx.dim = dim
138
+ ctx.kernel_size = kernel_size
139
+ return _conv_gather(input_, dim, kernel_size)
140
+
141
+ @staticmethod
142
+ def backward(ctx, grad_output):
143
+ return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
144
+
145
+
146
+ class _CPConvolutionPassFromPreviousRank(torch.autograd.Function):
147
+ @staticmethod
148
+ def forward(ctx, input_, dim, kernel_size):
149
+ ctx.dim = dim
150
+ ctx.kernel_size = kernel_size
151
+ return _cp_pass_from_previous_rank(input_, dim, kernel_size)
152
+
153
+ @staticmethod
154
+ def backward(ctx, grad_output):
155
+ return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
156
+
157
+
158
+ def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
159
+ return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
160
+
161
+
162
+ def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
163
+ return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
164
+
165
+
166
+ def cp_pass_from_previous_rank(input_, dim, kernel_size):
167
+ return _CPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
168
+
169
+
170
+
171
+
172
+
video_vae/modeling_block.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple, Union
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+ from einops import rearrange
21
+
22
+ from diffusers.utils import logging
23
+ from diffusers.models.attention_processor import Attention
24
+ from .modeling_resnet import (
25
+ Downsample2D, ResnetBlock2D, CausalResnetBlock3D, Upsample2D,
26
+ TemporalDownsample2x, TemporalUpsample2x,
27
+ CausalDownsample2x, CausalTemporalDownsample2x,
28
+ CausalUpsample2x, CausalTemporalUpsample2x,
29
+ )
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ def get_input_layer(
35
+ in_channels: int,
36
+ out_channels: int,
37
+ norm_num_groups: int,
38
+ layer_type: str,
39
+ norm_type: str = 'group',
40
+ affine: bool = True,
41
+ ):
42
+ if layer_type == 'conv':
43
+ input_layer = nn.Conv3d(
44
+ in_channels,
45
+ out_channels,
46
+ kernel_size=3,
47
+ stride=1,
48
+ padding=1,
49
+ )
50
+
51
+ elif layer_type == 'pixel_shuffle':
52
+ input_layer = nn.Sequential(
53
+ nn.PixelUnshuffle(2),
54
+ nn.Conv2d(in_channels * 4, out_channels, kernel_size=1),
55
+ )
56
+ else:
57
+ raise NotImplementedError(f"Not support input layer {layer_type}")
58
+
59
+ return input_layer
60
+
61
+
62
+ def get_output_layer(
63
+ in_channels: int,
64
+ out_channels: int,
65
+ norm_num_groups: int,
66
+ layer_type: str,
67
+ norm_type: str = 'group',
68
+ affine: bool = True,
69
+ ):
70
+ if layer_type == 'norm_act_conv':
71
+ output_layer = nn.Sequential(
72
+ nn.GroupNorm(num_channels=in_channels, num_groups=norm_num_groups, eps=1e-6, affine=affine),
73
+ nn.SiLU(),
74
+ nn.Conv3d(in_channels, out_channels, 3, stride=1, padding=1),
75
+ )
76
+
77
+ elif layer_type == 'pixel_shuffle':
78
+ output_layer = nn.Sequential(
79
+ nn.Conv2d(in_channels, out_channels * 4, kernel_size=1),
80
+ nn.PixelShuffle(2),
81
+ )
82
+
83
+ else:
84
+ raise NotImplementedError(f"Not support output layer {layer_type}")
85
+
86
+ return output_layer
87
+
88
+
89
+ def get_down_block(
90
+ down_block_type: str,
91
+ num_layers: int,
92
+ in_channels: int,
93
+ out_channels: int = None,
94
+ temb_channels: int = None,
95
+ add_spatial_downsample: bool = None,
96
+ add_temporal_downsample: bool = None,
97
+ resnet_eps: float = 1e-6,
98
+ resnet_act_fn: str = 'silu',
99
+ resnet_groups: Optional[int] = None,
100
+ downsample_padding: Optional[int] = None,
101
+ resnet_time_scale_shift: str = "default",
102
+ attention_head_dim: Optional[int] = None,
103
+ dropout: float = 0.0,
104
+ norm_affline: bool = True,
105
+ norm_layer: str = 'layer',
106
+ ):
107
+
108
+ if down_block_type == "DownEncoderBlock2D":
109
+ return DownEncoderBlock2D(
110
+ num_layers=num_layers,
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ dropout=dropout,
114
+ add_spatial_downsample=add_spatial_downsample,
115
+ add_temporal_downsample=add_temporal_downsample,
116
+ resnet_eps=resnet_eps,
117
+ resnet_act_fn=resnet_act_fn,
118
+ resnet_groups=resnet_groups,
119
+ downsample_padding=downsample_padding,
120
+ resnet_time_scale_shift=resnet_time_scale_shift,
121
+ )
122
+
123
+ elif down_block_type == "DownEncoderBlockCausal3D":
124
+ return DownEncoderBlockCausal3D(
125
+ num_layers=num_layers,
126
+ in_channels=in_channels,
127
+ out_channels=out_channels,
128
+ dropout=dropout,
129
+ add_spatial_downsample=add_spatial_downsample,
130
+ add_temporal_downsample=add_temporal_downsample,
131
+ resnet_eps=resnet_eps,
132
+ resnet_act_fn=resnet_act_fn,
133
+ resnet_groups=resnet_groups,
134
+ downsample_padding=downsample_padding,
135
+ resnet_time_scale_shift=resnet_time_scale_shift,
136
+ )
137
+
138
+ raise ValueError(f"{down_block_type} does not exist.")
139
+
140
+
141
+ def get_up_block(
142
+ up_block_type: str,
143
+ num_layers: int,
144
+ in_channels: int,
145
+ out_channels: int,
146
+ prev_output_channel: int = None,
147
+ temb_channels: int = None,
148
+ add_spatial_upsample: bool = None,
149
+ add_temporal_upsample: bool = None,
150
+ resnet_eps: float = 1e-6,
151
+ resnet_act_fn: str = 'silu',
152
+ resolution_idx: Optional[int] = None,
153
+ resnet_groups: Optional[int] = None,
154
+ resnet_time_scale_shift: str = "default",
155
+ attention_head_dim: Optional[int] = None,
156
+ dropout: float = 0.0,
157
+ interpolate: bool = True,
158
+ norm_affline: bool = True,
159
+ norm_layer: str = 'layer',
160
+ ) -> nn.Module:
161
+
162
+ if up_block_type == "UpDecoderBlock2D":
163
+ return UpDecoderBlock2D(
164
+ num_layers=num_layers,
165
+ in_channels=in_channels,
166
+ out_channels=out_channels,
167
+ resolution_idx=resolution_idx,
168
+ dropout=dropout,
169
+ add_spatial_upsample=add_spatial_upsample,
170
+ add_temporal_upsample=add_temporal_upsample,
171
+ resnet_eps=resnet_eps,
172
+ resnet_act_fn=resnet_act_fn,
173
+ resnet_groups=resnet_groups,
174
+ resnet_time_scale_shift=resnet_time_scale_shift,
175
+ temb_channels=temb_channels,
176
+ interpolate=interpolate,
177
+ )
178
+
179
+ elif up_block_type == "UpDecoderBlockCausal3D":
180
+ return UpDecoderBlockCausal3D(
181
+ num_layers=num_layers,
182
+ in_channels=in_channels,
183
+ out_channels=out_channels,
184
+ resolution_idx=resolution_idx,
185
+ dropout=dropout,
186
+ add_spatial_upsample=add_spatial_upsample,
187
+ add_temporal_upsample=add_temporal_upsample,
188
+ resnet_eps=resnet_eps,
189
+ resnet_act_fn=resnet_act_fn,
190
+ resnet_groups=resnet_groups,
191
+ resnet_time_scale_shift=resnet_time_scale_shift,
192
+ temb_channels=temb_channels,
193
+ interpolate=interpolate,
194
+ )
195
+
196
+ raise ValueError(f"{up_block_type} does not exist.")
197
+
198
+
199
+
200
+ class UNetMidBlock2D(nn.Module):
201
+ """
202
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
203
+
204
+ Args:
205
+ in_channels (`int`): The number of input channels.
206
+ temb_channels (`int`): The number of temporal embedding channels.
207
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
208
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
209
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
210
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
211
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
212
+ model on tasks with long-range temporal dependencies.
213
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
214
+ resnet_groups (`int`, *optional*, defaults to 32):
215
+ The number of groups to use in the group normalization layers of the resnet blocks.
216
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
217
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
218
+ Whether to use pre-normalization for the resnet blocks.
219
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
220
+ attention_head_dim (`int`, *optional*, defaults to 1):
221
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
222
+ the number of input channels.
223
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
224
+
225
+ Returns:
226
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
227
+ in_channels, height, width)`.
228
+
229
+ """
230
+
231
+ def __init__(
232
+ self,
233
+ in_channels: int,
234
+ temb_channels: int,
235
+ dropout: float = 0.0,
236
+ num_layers: int = 1,
237
+ resnet_eps: float = 1e-6,
238
+ resnet_time_scale_shift: str = "default", # default, spatial
239
+ resnet_act_fn: str = "swish",
240
+ resnet_groups: int = 32,
241
+ attn_groups: Optional[int] = None,
242
+ resnet_pre_norm: bool = True,
243
+ add_attention: bool = True,
244
+ attention_head_dim: int = 1,
245
+ output_scale_factor: float = 1.0,
246
+ ):
247
+ super().__init__()
248
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
249
+ self.add_attention = add_attention
250
+
251
+ if attn_groups is None:
252
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
253
+
254
+ # there is always at least one resnet
255
+ resnets = [
256
+ ResnetBlock2D(
257
+ in_channels=in_channels,
258
+ out_channels=in_channels,
259
+ temb_channels=temb_channels,
260
+ eps=resnet_eps,
261
+ groups=resnet_groups,
262
+ dropout=dropout,
263
+ time_embedding_norm=resnet_time_scale_shift,
264
+ non_linearity=resnet_act_fn,
265
+ output_scale_factor=output_scale_factor,
266
+ pre_norm=resnet_pre_norm,
267
+ )
268
+ ]
269
+ attentions = []
270
+
271
+ if attention_head_dim is None:
272
+ logger.warn(
273
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
274
+ )
275
+ attention_head_dim = in_channels
276
+
277
+ for _ in range(num_layers):
278
+ if self.add_attention:
279
+ # Spatial attention
280
+ attentions.append(
281
+ Attention(
282
+ in_channels,
283
+ heads=in_channels // attention_head_dim,
284
+ dim_head=attention_head_dim,
285
+ rescale_output_factor=output_scale_factor,
286
+ eps=resnet_eps,
287
+ norm_num_groups=attn_groups,
288
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
289
+ residual_connection=True,
290
+ bias=True,
291
+ upcast_softmax=True,
292
+ _from_deprecated_attn_block=True,
293
+ )
294
+ )
295
+ else:
296
+ attentions.append(None)
297
+
298
+ resnets.append(
299
+ ResnetBlock2D(
300
+ in_channels=in_channels,
301
+ out_channels=in_channels,
302
+ temb_channels=temb_channels,
303
+ eps=resnet_eps,
304
+ groups=resnet_groups,
305
+ dropout=dropout,
306
+ time_embedding_norm=resnet_time_scale_shift,
307
+ non_linearity=resnet_act_fn,
308
+ output_scale_factor=output_scale_factor,
309
+ pre_norm=resnet_pre_norm,
310
+ )
311
+ )
312
+
313
+ self.attentions = nn.ModuleList(attentions)
314
+ self.resnets = nn.ModuleList(resnets)
315
+
316
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
317
+ hidden_states = self.resnets[0](hidden_states, temb)
318
+ t = hidden_states.shape[2]
319
+
320
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
321
+ if attn is not None:
322
+ hidden_states = rearrange(hidden_states, 'b c t h w -> b t c h w')
323
+ hidden_states = rearrange(hidden_states, 'b t c h w -> (b t) c h w')
324
+ hidden_states = attn(hidden_states, temb=temb)
325
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b t c h w', t=t)
326
+ hidden_states = rearrange(hidden_states, 'b t c h w -> b c t h w')
327
+
328
+ hidden_states = resnet(hidden_states, temb)
329
+
330
+ return hidden_states
331
+
332
+
333
+ class CausalUNetMidBlock2D(nn.Module):
334
+ """
335
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
336
+
337
+ Args:
338
+ in_channels (`int`): The number of input channels.
339
+ temb_channels (`int`): The number of temporal embedding channels.
340
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
341
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
342
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
343
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
344
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
345
+ model on tasks with long-range temporal dependencies.
346
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
347
+ resnet_groups (`int`, *optional*, defaults to 32):
348
+ The number of groups to use in the group normalization layers of the resnet blocks.
349
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
350
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
351
+ Whether to use pre-normalization for the resnet blocks.
352
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
353
+ attention_head_dim (`int`, *optional*, defaults to 1):
354
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
355
+ the number of input channels.
356
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
357
+
358
+ Returns:
359
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
360
+ in_channels, height, width)`.
361
+
362
+ """
363
+
364
+ def __init__(
365
+ self,
366
+ in_channels: int,
367
+ temb_channels: int,
368
+ dropout: float = 0.0,
369
+ num_layers: int = 1,
370
+ resnet_eps: float = 1e-6,
371
+ resnet_time_scale_shift: str = "default", # default, spatial
372
+ resnet_act_fn: str = "swish",
373
+ resnet_groups: int = 32,
374
+ attn_groups: Optional[int] = None,
375
+ resnet_pre_norm: bool = True,
376
+ add_attention: bool = True,
377
+ attention_head_dim: int = 1,
378
+ output_scale_factor: float = 1.0,
379
+ ):
380
+ super().__init__()
381
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
382
+ self.add_attention = add_attention
383
+
384
+ if attn_groups is None:
385
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
386
+
387
+ # there is always at least one resnet
388
+ resnets = [
389
+ CausalResnetBlock3D(
390
+ in_channels=in_channels,
391
+ out_channels=in_channels,
392
+ temb_channels=temb_channels,
393
+ eps=resnet_eps,
394
+ groups=resnet_groups,
395
+ dropout=dropout,
396
+ time_embedding_norm=resnet_time_scale_shift,
397
+ non_linearity=resnet_act_fn,
398
+ output_scale_factor=output_scale_factor,
399
+ pre_norm=resnet_pre_norm,
400
+ )
401
+ ]
402
+ attentions = []
403
+
404
+ if attention_head_dim is None:
405
+ logger.warn(
406
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
407
+ )
408
+ attention_head_dim = in_channels
409
+
410
+ for _ in range(num_layers):
411
+ if self.add_attention:
412
+ # Spatial attention
413
+ attentions.append(
414
+ Attention(
415
+ in_channels,
416
+ heads=in_channels // attention_head_dim,
417
+ dim_head=attention_head_dim,
418
+ rescale_output_factor=output_scale_factor,
419
+ eps=resnet_eps,
420
+ norm_num_groups=attn_groups,
421
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
422
+ residual_connection=True,
423
+ bias=True,
424
+ upcast_softmax=True,
425
+ _from_deprecated_attn_block=True,
426
+ )
427
+ )
428
+ else:
429
+ attentions.append(None)
430
+
431
+ resnets.append(
432
+ CausalResnetBlock3D(
433
+ in_channels=in_channels,
434
+ out_channels=in_channels,
435
+ temb_channels=temb_channels,
436
+ eps=resnet_eps,
437
+ groups=resnet_groups,
438
+ dropout=dropout,
439
+ time_embedding_norm=resnet_time_scale_shift,
440
+ non_linearity=resnet_act_fn,
441
+ output_scale_factor=output_scale_factor,
442
+ pre_norm=resnet_pre_norm,
443
+ )
444
+ )
445
+
446
+ self.attentions = nn.ModuleList(attentions)
447
+ self.resnets = nn.ModuleList(resnets)
448
+
449
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None,
450
+ is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
451
+ hidden_states = self.resnets[0](hidden_states, temb, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
452
+ t = hidden_states.shape[2]
453
+
454
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
455
+ if attn is not None:
456
+ hidden_states = rearrange(hidden_states, 'b c t h w -> b t c h w')
457
+ hidden_states = rearrange(hidden_states, 'b t c h w -> (b t) c h w')
458
+ hidden_states = attn(hidden_states, temb=temb)
459
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b t c h w', t=t)
460
+ hidden_states = rearrange(hidden_states, 'b t c h w -> b c t h w')
461
+
462
+ hidden_states = resnet(hidden_states, temb, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
463
+
464
+ return hidden_states
465
+
466
+
467
+ class DownEncoderBlockCausal3D(nn.Module):
468
+ def __init__(
469
+ self,
470
+ in_channels: int,
471
+ out_channels: int,
472
+ dropout: float = 0.0,
473
+ num_layers: int = 1,
474
+ resnet_eps: float = 1e-6,
475
+ resnet_time_scale_shift: str = "default",
476
+ resnet_act_fn: str = "swish",
477
+ resnet_groups: int = 32,
478
+ resnet_pre_norm: bool = True,
479
+ output_scale_factor: float = 1.0,
480
+ add_spatial_downsample: bool = True,
481
+ add_temporal_downsample: bool = False,
482
+ downsample_padding: int = 1,
483
+ ):
484
+ super().__init__()
485
+ resnets = []
486
+
487
+ for i in range(num_layers):
488
+ in_channels = in_channels if i == 0 else out_channels
489
+ resnets.append(
490
+ CausalResnetBlock3D(
491
+ in_channels=in_channels,
492
+ out_channels=out_channels,
493
+ temb_channels=None,
494
+ eps=resnet_eps,
495
+ groups=resnet_groups,
496
+ dropout=dropout,
497
+ time_embedding_norm=resnet_time_scale_shift,
498
+ non_linearity=resnet_act_fn,
499
+ output_scale_factor=output_scale_factor,
500
+ pre_norm=resnet_pre_norm,
501
+ )
502
+ )
503
+
504
+ self.resnets = nn.ModuleList(resnets)
505
+
506
+ if add_spatial_downsample:
507
+ self.downsamplers = nn.ModuleList(
508
+ [
509
+ CausalDownsample2x(
510
+ out_channels, use_conv=True, out_channels=out_channels,
511
+ )
512
+ ]
513
+ )
514
+ else:
515
+ self.downsamplers = None
516
+
517
+ if add_temporal_downsample:
518
+ self.temporal_downsamplers = nn.ModuleList(
519
+ [
520
+ CausalTemporalDownsample2x(
521
+ out_channels, use_conv=True, out_channels=out_channels,
522
+ )
523
+ ]
524
+ )
525
+ else:
526
+ self.temporal_downsamplers = None
527
+
528
+ def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
529
+ for resnet in self.resnets:
530
+ hidden_states = resnet(hidden_states, temb=None, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
531
+
532
+ if self.downsamplers is not None:
533
+ for downsampler in self.downsamplers:
534
+ hidden_states = downsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
535
+
536
+ if self.temporal_downsamplers is not None:
537
+ for temporal_downsampler in self.temporal_downsamplers:
538
+ hidden_states = temporal_downsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
539
+
540
+ return hidden_states
541
+
542
+
543
+ class DownEncoderBlock2D(nn.Module):
544
+ def __init__(
545
+ self,
546
+ in_channels: int,
547
+ out_channels: int,
548
+ dropout: float = 0.0,
549
+ num_layers: int = 1,
550
+ resnet_eps: float = 1e-6,
551
+ resnet_time_scale_shift: str = "default",
552
+ resnet_act_fn: str = "swish",
553
+ resnet_groups: int = 32,
554
+ resnet_pre_norm: bool = True,
555
+ output_scale_factor: float = 1.0,
556
+ add_spatial_downsample: bool = True,
557
+ add_temporal_downsample: bool = False,
558
+ downsample_padding: int = 1,
559
+ ):
560
+ super().__init__()
561
+ resnets = []
562
+
563
+ for i in range(num_layers):
564
+ in_channels = in_channels if i == 0 else out_channels
565
+ resnets.append(
566
+ ResnetBlock2D(
567
+ in_channels=in_channels,
568
+ out_channels=out_channels,
569
+ temb_channels=None,
570
+ eps=resnet_eps,
571
+ groups=resnet_groups,
572
+ dropout=dropout,
573
+ time_embedding_norm=resnet_time_scale_shift,
574
+ non_linearity=resnet_act_fn,
575
+ output_scale_factor=output_scale_factor,
576
+ pre_norm=resnet_pre_norm,
577
+ )
578
+ )
579
+
580
+ self.resnets = nn.ModuleList(resnets)
581
+
582
+ if add_spatial_downsample:
583
+ self.downsamplers = nn.ModuleList(
584
+ [
585
+ Downsample2D(
586
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
587
+ )
588
+ ]
589
+ )
590
+ else:
591
+ self.downsamplers = None
592
+
593
+ if add_temporal_downsample:
594
+ self.temporal_downsamplers = nn.ModuleList(
595
+ [
596
+ TemporalDownsample2x(
597
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding,
598
+ )
599
+ ]
600
+ )
601
+ else:
602
+ self.temporal_downsamplers = None
603
+
604
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
605
+ for resnet in self.resnets:
606
+ hidden_states = resnet(hidden_states, temb=None)
607
+
608
+ if self.downsamplers is not None:
609
+ for downsampler in self.downsamplers:
610
+ hidden_states = downsampler(hidden_states)
611
+
612
+ if self.temporal_downsamplers is not None:
613
+ for temporal_downsampler in self.temporal_downsamplers:
614
+ hidden_states = temporal_downsampler(hidden_states)
615
+
616
+ return hidden_states
617
+
618
+
619
+ class UpDecoderBlock2D(nn.Module):
620
+ def __init__(
621
+ self,
622
+ in_channels: int,
623
+ out_channels: int,
624
+ resolution_idx: Optional[int] = None,
625
+ dropout: float = 0.0,
626
+ num_layers: int = 1,
627
+ resnet_eps: float = 1e-6,
628
+ resnet_time_scale_shift: str = "default", # default, spatial
629
+ resnet_act_fn: str = "swish",
630
+ resnet_groups: int = 32,
631
+ resnet_pre_norm: bool = True,
632
+ output_scale_factor: float = 1.0,
633
+ add_spatial_upsample: bool = True,
634
+ add_temporal_upsample: bool = False,
635
+ temb_channels: Optional[int] = None,
636
+ interpolate: bool = True,
637
+ ):
638
+ super().__init__()
639
+ resnets = []
640
+
641
+ for i in range(num_layers):
642
+ input_channels = in_channels if i == 0 else out_channels
643
+
644
+ resnets.append(
645
+ ResnetBlock2D(
646
+ in_channels=input_channels,
647
+ out_channels=out_channels,
648
+ temb_channels=temb_channels,
649
+ eps=resnet_eps,
650
+ groups=resnet_groups,
651
+ dropout=dropout,
652
+ time_embedding_norm=resnet_time_scale_shift,
653
+ non_linearity=resnet_act_fn,
654
+ output_scale_factor=output_scale_factor,
655
+ pre_norm=resnet_pre_norm,
656
+ )
657
+ )
658
+
659
+ self.resnets = nn.ModuleList(resnets)
660
+
661
+ if add_spatial_upsample:
662
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
663
+ else:
664
+ self.upsamplers = None
665
+
666
+ if add_temporal_upsample:
667
+ self.temporal_upsamplers = nn.ModuleList([TemporalUpsample2x(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
668
+ else:
669
+ self.temporal_upsamplers = None
670
+
671
+ self.resolution_idx = resolution_idx
672
+
673
+ def forward(
674
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, is_image: bool = False,
675
+ ) -> torch.FloatTensor:
676
+ for resnet in self.resnets:
677
+ hidden_states = resnet(hidden_states, temb=temb, scale=scale)
678
+
679
+ if self.upsamplers is not None:
680
+ for upsampler in self.upsamplers:
681
+ hidden_states = upsampler(hidden_states)
682
+
683
+ if self.temporal_upsamplers is not None:
684
+ for temporal_upsampler in self.temporal_upsamplers:
685
+ hidden_states = temporal_upsampler(hidden_states, is_image=is_image)
686
+
687
+ return hidden_states
688
+
689
+
690
+ class UpDecoderBlockCausal3D(nn.Module):
691
+ def __init__(
692
+ self,
693
+ in_channels: int,
694
+ out_channels: int,
695
+ resolution_idx: Optional[int] = None,
696
+ dropout: float = 0.0,
697
+ num_layers: int = 1,
698
+ resnet_eps: float = 1e-6,
699
+ resnet_time_scale_shift: str = "default", # default, spatial
700
+ resnet_act_fn: str = "swish",
701
+ resnet_groups: int = 32,
702
+ resnet_pre_norm: bool = True,
703
+ output_scale_factor: float = 1.0,
704
+ add_spatial_upsample: bool = True,
705
+ add_temporal_upsample: bool = False,
706
+ temb_channels: Optional[int] = None,
707
+ interpolate: bool = True,
708
+ ):
709
+ super().__init__()
710
+ resnets = []
711
+
712
+ for i in range(num_layers):
713
+ input_channels = in_channels if i == 0 else out_channels
714
+
715
+ resnets.append(
716
+ CausalResnetBlock3D(
717
+ in_channels=input_channels,
718
+ out_channels=out_channels,
719
+ temb_channels=temb_channels,
720
+ eps=resnet_eps,
721
+ groups=resnet_groups,
722
+ dropout=dropout,
723
+ time_embedding_norm=resnet_time_scale_shift,
724
+ non_linearity=resnet_act_fn,
725
+ output_scale_factor=output_scale_factor,
726
+ pre_norm=resnet_pre_norm,
727
+ )
728
+ )
729
+
730
+ self.resnets = nn.ModuleList(resnets)
731
+
732
+ if add_spatial_upsample:
733
+ self.upsamplers = nn.ModuleList([CausalUpsample2x(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
734
+ else:
735
+ self.upsamplers = None
736
+
737
+ if add_temporal_upsample:
738
+ self.temporal_upsamplers = nn.ModuleList([CausalTemporalUpsample2x(out_channels, use_conv=True, out_channels=out_channels, interpolate=interpolate)])
739
+ else:
740
+ self.temporal_upsamplers = None
741
+
742
+ self.resolution_idx = resolution_idx
743
+
744
+ def forward(
745
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None,
746
+ is_init_image=True, temporal_chunk=False,
747
+ ) -> torch.FloatTensor:
748
+ for resnet in self.resnets:
749
+ hidden_states = resnet(hidden_states, temb=temb, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
750
+
751
+ if self.upsamplers is not None:
752
+ for upsampler in self.upsamplers:
753
+ hidden_states = upsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
754
+
755
+ if self.temporal_upsamplers is not None:
756
+ for temporal_upsampler in self.temporal_upsamplers:
757
+ hidden_states = temporal_upsampler(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
758
+
759
+ return hidden_states
760
+
video_vae/modeling_causal_conv.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.checkpoint import checkpoint
5
+ import torch.nn.functional as F
6
+ from collections import deque
7
+ from einops import rearrange
8
+ from timm.models.layers import trunc_normal_
9
+ from IPython import embed
10
+ from torch import Tensor
11
+
12
+ from utils import (
13
+ is_context_parallel_initialized,
14
+ get_context_parallel_group,
15
+ get_context_parallel_world_size,
16
+ get_context_parallel_rank,
17
+ get_context_parallel_group_rank,
18
+ )
19
+
20
+ from .context_parallel_ops import (
21
+ conv_scatter_to_context_parallel_region,
22
+ conv_gather_from_context_parallel_region,
23
+ cp_pass_from_previous_rank,
24
+ )
25
+
26
+
27
+ def divisible_by(num, den):
28
+ return (num % den) == 0
29
+
30
+ def cast_tuple(t, length = 1):
31
+ return t if isinstance(t, tuple) else ((t,) * length)
32
+
33
+ def is_odd(n):
34
+ return not divisible_by(n, 2)
35
+
36
+
37
+ class CausalGroupNorm(nn.GroupNorm):
38
+
39
+ def forward(self, x: Tensor) -> Tensor:
40
+ t = x.shape[2]
41
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
42
+ x = super().forward(x)
43
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
44
+ return x
45
+
46
+
47
+ class CausalConv3d(nn.Module):
48
+
49
+ def __init__(
50
+ self,
51
+ in_channels,
52
+ out_channels,
53
+ kernel_size: Union[int, Tuple[int, int, int]],
54
+ stride: Union[int, Tuple[int, int, int]] = 1,
55
+ pad_mode: str ='constant',
56
+ **kwargs
57
+ ):
58
+ super().__init__()
59
+ if isinstance(kernel_size, int):
60
+ kernel_size = cast_tuple(kernel_size, 3)
61
+
62
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
63
+ self.time_kernel_size = time_kernel_size
64
+ assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
65
+ dilation = kwargs.pop('dilation', 1)
66
+ self.pad_mode = pad_mode
67
+
68
+ if isinstance(stride, int):
69
+ stride = (stride, 1, 1)
70
+
71
+ time_pad = dilation * (time_kernel_size - 1)
72
+ height_pad = height_kernel_size // 2
73
+ width_pad = width_kernel_size // 2
74
+
75
+ self.temporal_stride = stride[0]
76
+ self.time_pad = time_pad
77
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
78
+ self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0)
79
+
80
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs)
81
+ self.cache_front_feat = deque()
82
+
83
+ def _clear_context_parallel_cache(self):
84
+ del self.cache_front_feat
85
+ self.cache_front_feat = deque()
86
+
87
+ def _init_weights(self, m):
88
+ if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
89
+ trunc_normal_(m.weight, std=.02)
90
+ if m.bias is not None:
91
+ nn.init.constant_(m.bias, 0)
92
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
93
+ nn.init.constant_(m.bias, 0)
94
+ nn.init.constant_(m.weight, 1.0)
95
+
96
+ def context_parallel_forward(self, x):
97
+ x = cp_pass_from_previous_rank(x, dim=2, kernel_size=self.time_kernel_size)
98
+
99
+ x = F.pad(x, self.time_uncausal_padding, mode='constant')
100
+
101
+ cp_rank = get_context_parallel_rank()
102
+ if cp_rank != 0:
103
+ if self.temporal_stride == 2 and self.time_kernel_size == 3:
104
+ x = x[:,:,1:]
105
+
106
+ x = self.conv(x)
107
+ return x
108
+
109
+ def forward(self, x, is_init_image=True, temporal_chunk=False):
110
+ # temporal_chunk: whether to use the temporal chunk
111
+
112
+ if is_context_parallel_initialized():
113
+ return self.context_parallel_forward(x)
114
+
115
+ pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
116
+
117
+ if not temporal_chunk:
118
+ x = F.pad(x, self.time_causal_padding, mode=pad_mode)
119
+ else:
120
+ assert not self.training, "The feature cache should not be used in training"
121
+ if is_init_image:
122
+ # Encode the first chunk
123
+ x = F.pad(x, self.time_causal_padding, mode=pad_mode)
124
+ self._clear_context_parallel_cache()
125
+ self.cache_front_feat.append(x[:, :, -2:].clone().detach())
126
+ else:
127
+ x = F.pad(x, self.time_uncausal_padding, mode=pad_mode)
128
+ video_front_context = self.cache_front_feat.pop()
129
+ self._clear_context_parallel_cache()
130
+
131
+ if self.temporal_stride == 1 and self.time_kernel_size == 3:
132
+ x = torch.cat([video_front_context, x], dim=2)
133
+ elif self.temporal_stride == 2 and self.time_kernel_size == 3:
134
+ x = torch.cat([video_front_context[:,:,-1:], x], dim=2)
135
+
136
+ self.cache_front_feat.append(x[:, :, -2:].clone().detach())
137
+
138
+ x = self.conv(x)
139
+ return x
video_vae/modeling_causal_vae.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, Union
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models.attention_processor import (
7
+ ADDED_KV_ATTENTION_PROCESSORS,
8
+ CROSS_ATTENTION_PROCESSORS,
9
+ Attention,
10
+ AttentionProcessor,
11
+ AttnAddedKVProcessor,
12
+ AttnProcessor,
13
+ )
14
+
15
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+
18
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
19
+ from .modeling_enc_dec import (
20
+ DecoderOutput, DiagonalGaussianDistribution,
21
+ CausalVaeDecoder, CausalVaeEncoder,
22
+ )
23
+ from .modeling_causal_conv import CausalConv3d
24
+ from IPython import embed
25
+
26
+ from utils import (
27
+ is_context_parallel_initialized,
28
+ get_context_parallel_group,
29
+ get_context_parallel_world_size,
30
+ get_context_parallel_rank,
31
+ get_context_parallel_group_rank,
32
+ )
33
+
34
+ from .context_parallel_ops import (
35
+ conv_scatter_to_context_parallel_region,
36
+ conv_gather_from_context_parallel_region,
37
+ )
38
+
39
+
40
+ class CausalVideoVAE(ModelMixin, ConfigMixin):
41
+ r"""
42
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
43
+
44
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
45
+ for all models (such as downloading or saving).
46
+
47
+ Parameters:
48
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
49
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
50
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
51
+ Tuple of downsample block types.
52
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
53
+ Tuple of upsample block types.
54
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
55
+ Tuple of block output channels.
56
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
57
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
58
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
59
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
60
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
61
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
62
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
63
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
64
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
65
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
66
+ force_upcast (`bool`, *optional*, default to `True`):
67
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
68
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
69
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
70
+ """
71
+
72
+ _supports_gradient_checkpointing = True
73
+
74
+ @register_to_config
75
+ def __init__(
76
+ self,
77
+ # encoder related parameters
78
+ encoder_in_channels: int = 3,
79
+ encoder_out_channels: int = 4,
80
+ encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 2),
81
+ encoder_down_block_types: Tuple[str, ...] = (
82
+ "DownEncoderBlockCausal3D",
83
+ "DownEncoderBlockCausal3D",
84
+ "DownEncoderBlockCausal3D",
85
+ "DownEncoderBlockCausal3D",
86
+ ),
87
+ encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
88
+ encoder_spatial_down_sample: Tuple[bool, ...] = (True, True, True, False),
89
+ encoder_temporal_down_sample: Tuple[bool, ...] = (True, True, True, False),
90
+ encoder_block_dropout: Tuple[int, ...] = (0.0, 0.0, 0.0, 0.0),
91
+ encoder_act_fn: str = "silu",
92
+ encoder_norm_num_groups: int = 32,
93
+ encoder_double_z: bool = True,
94
+ encoder_type: str = 'causal_vae_conv',
95
+ # decoder related
96
+ decoder_in_channels: int = 4,
97
+ decoder_out_channels: int = 3,
98
+ decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3),
99
+ decoder_up_block_types: Tuple[str, ...] = (
100
+ "UpDecoderBlockCausal3D",
101
+ "UpDecoderBlockCausal3D",
102
+ "UpDecoderBlockCausal3D",
103
+ "UpDecoderBlockCausal3D",
104
+ ),
105
+ decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
106
+ decoder_spatial_up_sample: Tuple[bool, ...] = (True, True, True, False),
107
+ decoder_temporal_up_sample: Tuple[bool, ...] = (True, True, True, False),
108
+ decoder_block_dropout: Tuple[int, ...] = (0.0, 0.0, 0.0, 0.0),
109
+ decoder_act_fn: str = "silu",
110
+ decoder_norm_num_groups: int = 32,
111
+ decoder_type: str = 'causal_vae_conv',
112
+ sample_size: int = 256,
113
+ scaling_factor: float = 0.18215,
114
+ add_post_quant_conv: bool = True,
115
+ interpolate: bool = False,
116
+ downsample_scale: int = 8,
117
+ ):
118
+ super().__init__()
119
+
120
+ print(f"The latent dimmension channes is {encoder_out_channels}")
121
+ # pass init params to Encoder
122
+
123
+ self.encoder = CausalVaeEncoder(
124
+ in_channels=encoder_in_channels,
125
+ out_channels=encoder_out_channels,
126
+ down_block_types=encoder_down_block_types,
127
+ spatial_down_sample=encoder_spatial_down_sample,
128
+ temporal_down_sample=encoder_temporal_down_sample,
129
+ block_out_channels=encoder_block_out_channels,
130
+ layers_per_block=encoder_layers_per_block,
131
+ act_fn=encoder_act_fn,
132
+ norm_num_groups=encoder_norm_num_groups,
133
+ double_z=True,
134
+ block_dropout=encoder_block_dropout,
135
+ )
136
+
137
+ # pass init params to Decoder
138
+ self.decoder = CausalVaeDecoder(
139
+ in_channels=decoder_in_channels,
140
+ out_channels=decoder_out_channels,
141
+ up_block_types=decoder_up_block_types,
142
+ spatial_up_sample=decoder_spatial_up_sample,
143
+ temporal_up_sample=decoder_temporal_up_sample,
144
+ block_out_channels=decoder_block_out_channels,
145
+ layers_per_block=decoder_layers_per_block,
146
+ norm_num_groups=decoder_norm_num_groups,
147
+ act_fn=decoder_act_fn,
148
+ interpolate=interpolate,
149
+ block_dropout=decoder_block_dropout,
150
+ )
151
+
152
+ self.quant_conv = CausalConv3d(2 * encoder_out_channels, 2 * encoder_out_channels, kernel_size=1, stride=1)
153
+ self.post_quant_conv = CausalConv3d(encoder_out_channels, encoder_out_channels, kernel_size=1, stride=1)
154
+ self.use_tiling = False
155
+
156
+ # only relevant if vae tiling is enabled
157
+ self.tile_sample_min_size = self.config.sample_size
158
+
159
+ sample_size = (
160
+ self.config.sample_size[0]
161
+ if isinstance(self.config.sample_size, (list, tuple))
162
+ else self.config.sample_size
163
+ )
164
+ self.tile_latent_min_size = int(sample_size / downsample_scale)
165
+ self.encode_tile_overlap_factor = 1 / 8
166
+ self.decode_tile_overlap_factor = 1 / 8
167
+ self.downsample_scale = downsample_scale
168
+
169
+ self.apply(self._init_weights)
170
+
171
+ def _init_weights(self, m):
172
+ if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
173
+ trunc_normal_(m.weight, std=.02)
174
+ if m.bias is not None:
175
+ nn.init.constant_(m.bias, 0)
176
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
177
+ nn.init.constant_(m.bias, 0)
178
+ nn.init.constant_(m.weight, 1.0)
179
+
180
+ def _set_gradient_checkpointing(self, module, value=False):
181
+ if isinstance(module, (Encoder, Decoder)):
182
+ module.gradient_checkpointing = value
183
+
184
+ def enable_tiling(self, use_tiling: bool = True):
185
+ r"""
186
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
187
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
188
+ processing larger images.
189
+ """
190
+ self.use_tiling = use_tiling
191
+
192
+ def disable_tiling(self):
193
+ r"""
194
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
195
+ decoding in one step.
196
+ """
197
+ self.enable_tiling(False)
198
+
199
+ @property
200
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
201
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
202
+ r"""
203
+ Returns:
204
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
205
+ indexed by its weight name.
206
+ """
207
+ # set recursively
208
+ processors = {}
209
+
210
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
211
+ if hasattr(module, "get_processor"):
212
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
213
+
214
+ for sub_name, child in module.named_children():
215
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
216
+
217
+ return processors
218
+
219
+ for name, module in self.named_children():
220
+ fn_recursive_add_processors(name, module, processors)
221
+
222
+ return processors
223
+
224
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
225
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
226
+ r"""
227
+ Sets the attention processor to use to compute attention.
228
+
229
+ Parameters:
230
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
231
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
232
+ for **all** `Attention` layers.
233
+
234
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
235
+ processor. This is strongly recommended when setting trainable attention processors.
236
+
237
+ """
238
+ count = len(self.attn_processors.keys())
239
+
240
+ if isinstance(processor, dict) and len(processor) != count:
241
+ raise ValueError(
242
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
243
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
244
+ )
245
+
246
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
247
+ if hasattr(module, "set_processor"):
248
+ if not isinstance(processor, dict):
249
+ module.set_processor(processor)
250
+ else:
251
+ module.set_processor(processor.pop(f"{name}.processor"))
252
+
253
+ for sub_name, child in module.named_children():
254
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
255
+
256
+ for name, module in self.named_children():
257
+ fn_recursive_attn_processor(name, module, processor)
258
+
259
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
260
+ def set_default_attn_processor(self):
261
+ """
262
+ Disables custom attention processors and sets the default attention implementation.
263
+ """
264
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
265
+ processor = AttnAddedKVProcessor()
266
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
267
+ processor = AttnProcessor()
268
+ else:
269
+ raise ValueError(
270
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
271
+ )
272
+
273
+ self.set_attn_processor(processor)
274
+
275
+ def encode(
276
+ self, x: torch.FloatTensor, return_dict: bool = True,
277
+ is_init_image=True, temporal_chunk=False, window_size=16, tile_sample_min_size=256,
278
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
279
+ """
280
+ Encode a batch of images into latents.
281
+
282
+ Args:
283
+ x (`torch.FloatTensor`): Input batch of images.
284
+ return_dict (`bool`, *optional*, defaults to `True`):
285
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
286
+
287
+ Returns:
288
+ The latent representations of the encoded images. If `return_dict` is True, a
289
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
290
+ """
291
+ self.tile_sample_min_size = tile_sample_min_size
292
+ self.tile_latent_min_size = int(tile_sample_min_size / self.downsample_scale)
293
+
294
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
295
+ return self.tiled_encode(x, return_dict=return_dict, is_init_image=is_init_image,
296
+ temporal_chunk=temporal_chunk, window_size=window_size)
297
+
298
+ if temporal_chunk:
299
+ moments = self.chunk_encode(x, window_size=window_size)
300
+ else:
301
+ h = self.encoder(x, is_init_image=is_init_image, temporal_chunk=False)
302
+ moments = self.quant_conv(h, is_init_image=is_init_image, temporal_chunk=False)
303
+
304
+ posterior = DiagonalGaussianDistribution(moments)
305
+
306
+ if not return_dict:
307
+ return (posterior,)
308
+
309
+ return AutoencoderKLOutput(latent_dist=posterior)
310
+
311
+ @torch.no_grad()
312
+ def chunk_encode(self, x: torch.FloatTensor, window_size=16):
313
+ # Only used during inference
314
+ # Encode a long video clips through sliding window
315
+ num_frames = x.shape[2]
316
+ assert (num_frames - 1) % self.downsample_scale == 0
317
+ init_window_size = window_size + 1
318
+ frame_list = [x[:,:,:init_window_size]]
319
+
320
+ # To chunk the long video
321
+ full_chunk_size = (num_frames - init_window_size) // window_size
322
+ fid = init_window_size
323
+ for idx in range(full_chunk_size):
324
+ frame_list.append(x[:, :, fid:fid+window_size])
325
+ fid += window_size
326
+
327
+ if fid < num_frames:
328
+ frame_list.append(x[:, :, fid:])
329
+
330
+ latent_list = []
331
+ for idx, frames in enumerate(frame_list):
332
+ if idx == 0:
333
+ h = self.encoder(frames, is_init_image=True, temporal_chunk=True)
334
+ moments = self.quant_conv(h, is_init_image=True, temporal_chunk=True)
335
+ else:
336
+ h = self.encoder(frames, is_init_image=False, temporal_chunk=True)
337
+ moments = self.quant_conv(h, is_init_image=False, temporal_chunk=True)
338
+
339
+ latent_list.append(moments)
340
+
341
+ latent = torch.cat(latent_list, dim=2)
342
+ return latent
343
+
344
+ def get_last_layer(self):
345
+ return self.decoder.conv_out.conv.weight
346
+
347
+ @torch.no_grad()
348
+ def chunk_decode(self, z: torch.FloatTensor, window_size=2):
349
+ num_frames = z.shape[2]
350
+ init_window_size = window_size + 1
351
+ frame_list = [z[:,:,:init_window_size]]
352
+
353
+ # To chunk the long video
354
+ full_chunk_size = (num_frames - init_window_size) // window_size
355
+ fid = init_window_size
356
+ for idx in range(full_chunk_size):
357
+ frame_list.append(z[:, :, fid:fid+window_size])
358
+ fid += window_size
359
+
360
+ if fid < num_frames:
361
+ frame_list.append(z[:, :, fid:])
362
+
363
+ dec_list = []
364
+ for idx, frames in enumerate(frame_list):
365
+ if idx == 0:
366
+ z_h = self.post_quant_conv(frames, is_init_image=True, temporal_chunk=True)
367
+ dec = self.decoder(z_h, is_init_image=True, temporal_chunk=True)
368
+ else:
369
+ z_h = self.post_quant_conv(frames, is_init_image=False, temporal_chunk=True)
370
+ dec = self.decoder(z_h, is_init_image=False, temporal_chunk=True)
371
+
372
+ dec_list.append(dec)
373
+
374
+ dec = torch.cat(dec_list, dim=2)
375
+ return dec
376
+
377
+ def decode(self, z: torch.FloatTensor, is_init_image=True, temporal_chunk=False,
378
+ return_dict: bool = True, window_size: int = 2, tile_sample_min_size: int = 256,) -> Union[DecoderOutput, torch.FloatTensor]:
379
+
380
+ self.tile_sample_min_size = tile_sample_min_size
381
+ self.tile_latent_min_size = int(tile_sample_min_size / self.downsample_scale)
382
+
383
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
384
+ return self.tiled_decode(z, is_init_image=is_init_image,
385
+ temporal_chunk=temporal_chunk, window_size=window_size, return_dict=return_dict)
386
+
387
+ if temporal_chunk:
388
+ dec = self.chunk_decode(z, window_size=window_size)
389
+ else:
390
+ z = self.post_quant_conv(z, is_init_image=is_init_image, temporal_chunk=False)
391
+ dec = self.decoder(z, is_init_image=is_init_image, temporal_chunk=False)
392
+
393
+ if not return_dict:
394
+ return (dec,)
395
+
396
+ return DecoderOutput(sample=dec)
397
+
398
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
399
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
400
+ for y in range(blend_extent):
401
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
402
+ return b
403
+
404
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
405
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
406
+ for x in range(blend_extent):
407
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
408
+ return b
409
+
410
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True,
411
+ is_init_image=True, temporal_chunk=False, window_size=16,) -> AutoencoderKLOutput:
412
+ r"""Encode a batch of images using a tiled encoder.
413
+
414
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
415
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
416
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
417
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
418
+ output, but they should be much less noticeable.
419
+
420
+ Args:
421
+ x (`torch.FloatTensor`): Input batch of images.
422
+ return_dict (`bool`, *optional*, defaults to `True`):
423
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
424
+
425
+ Returns:
426
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
427
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
428
+ `tuple` is returned.
429
+ """
430
+ overlap_size = int(self.tile_sample_min_size * (1 - self.encode_tile_overlap_factor))
431
+ blend_extent = int(self.tile_latent_min_size * self.encode_tile_overlap_factor)
432
+ row_limit = self.tile_latent_min_size - blend_extent
433
+
434
+ # Split the image into 512x512 tiles and encode them separately.
435
+ rows = []
436
+ for i in range(0, x.shape[3], overlap_size):
437
+ row = []
438
+ for j in range(0, x.shape[4], overlap_size):
439
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
440
+ if temporal_chunk:
441
+ tile = self.chunk_encode(tile, window_size=window_size)
442
+ else:
443
+ tile = self.encoder(tile, is_init_image=True, temporal_chunk=False)
444
+ tile = self.quant_conv(tile, is_init_image=True, temporal_chunk=False)
445
+ row.append(tile)
446
+ rows.append(row)
447
+ result_rows = []
448
+ for i, row in enumerate(rows):
449
+ result_row = []
450
+ for j, tile in enumerate(row):
451
+ # blend the above tile and the left tile
452
+ # to the current tile and add the current tile to the result row
453
+ if i > 0:
454
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
455
+ if j > 0:
456
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
457
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
458
+ result_rows.append(torch.cat(result_row, dim=4))
459
+
460
+ moments = torch.cat(result_rows, dim=3)
461
+
462
+ posterior = DiagonalGaussianDistribution(moments)
463
+
464
+ if not return_dict:
465
+ return (posterior,)
466
+
467
+ return AutoencoderKLOutput(latent_dist=posterior)
468
+
469
+ def tiled_decode(self, z: torch.FloatTensor, is_init_image=True,
470
+ temporal_chunk=False, window_size=2, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
471
+ r"""
472
+ Decode a batch of images using a tiled decoder.
473
+
474
+ Args:
475
+ z (`torch.FloatTensor`): Input batch of latent vectors.
476
+ return_dict (`bool`, *optional*, defaults to `True`):
477
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
478
+
479
+ Returns:
480
+ [`~models.vae.DecoderOutput`] or `tuple`:
481
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
482
+ returned.
483
+ """
484
+ overlap_size = int(self.tile_latent_min_size * (1 - self.decode_tile_overlap_factor))
485
+ blend_extent = int(self.tile_sample_min_size * self.decode_tile_overlap_factor)
486
+ row_limit = self.tile_sample_min_size - blend_extent
487
+
488
+ # Split z into overlapping 64x64 tiles and decode them separately.
489
+ # The tiles have an overlap to avoid seams between tiles.
490
+ rows = []
491
+ for i in range(0, z.shape[3], overlap_size):
492
+ row = []
493
+ for j in range(0, z.shape[4], overlap_size):
494
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
495
+ if temporal_chunk:
496
+ decoded = self.chunk_decode(tile, window_size=window_size)
497
+ else:
498
+ tile = self.post_quant_conv(tile, is_init_image=True, temporal_chunk=False)
499
+ decoded = self.decoder(tile, is_init_image=True, temporal_chunk=False)
500
+ row.append(decoded)
501
+ rows.append(row)
502
+ result_rows = []
503
+
504
+ for i, row in enumerate(rows):
505
+ result_row = []
506
+ for j, tile in enumerate(row):
507
+ # blend the above tile and the left tile
508
+ # to the current tile and add the current tile to the result row
509
+ if i > 0:
510
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
511
+ if j > 0:
512
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
513
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
514
+ result_rows.append(torch.cat(result_row, dim=4))
515
+
516
+ dec = torch.cat(result_rows, dim=3)
517
+ if not return_dict:
518
+ return (dec,)
519
+
520
+ return DecoderOutput(sample=dec)
521
+
522
+ def forward(
523
+ self,
524
+ sample: torch.FloatTensor,
525
+ sample_posterior: bool = True,
526
+ generator: Optional[torch.Generator] = None,
527
+ freeze_encoder: bool = False,
528
+ is_init_image=True,
529
+ temporal_chunk=False,
530
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
531
+ r"""
532
+ Args:
533
+ sample (`torch.FloatTensor`): Input sample.
534
+ sample_posterior (`bool`, *optional*, defaults to `False`):
535
+ Whether to sample from the posterior.
536
+ return_dict (`bool`, *optional*, defaults to `True`):
537
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
538
+ """
539
+ x = sample
540
+
541
+ if is_context_parallel_initialized():
542
+ assert self.training, "Only supports during training now"
543
+
544
+ if freeze_encoder:
545
+ with torch.no_grad():
546
+ h = self.encoder(x, is_init_image=True, temporal_chunk=False)
547
+ moments = self.quant_conv(h, is_init_image=True, temporal_chunk=False)
548
+ posterior = DiagonalGaussianDistribution(moments)
549
+ global_posterior = posterior
550
+ else:
551
+ h = self.encoder(x, is_init_image=True, temporal_chunk=False)
552
+ moments = self.quant_conv(h, is_init_image=True, temporal_chunk=False)
553
+ posterior = DiagonalGaussianDistribution(moments)
554
+ global_moments = conv_gather_from_context_parallel_region(moments, dim=2, kernel_size=1)
555
+ global_posterior = DiagonalGaussianDistribution(global_moments)
556
+
557
+ if sample_posterior:
558
+ z = posterior.sample(generator=generator)
559
+ else:
560
+ z = posterior.mode()
561
+
562
+ if get_context_parallel_rank() == 0:
563
+ dec = self.decode(z, is_init_image=True).sample
564
+ else:
565
+ # Do not drop the first upsampled frame
566
+ dec = self.decode(z, is_init_image=False).sample
567
+
568
+ return global_posterior, dec
569
+
570
+ else:
571
+ # The normal training
572
+ if freeze_encoder:
573
+ with torch.no_grad():
574
+ posterior = self.encode(x, is_init_image=is_init_image,
575
+ temporal_chunk=temporal_chunk).latent_dist
576
+ else:
577
+ posterior = self.encode(x, is_init_image=is_init_image,
578
+ temporal_chunk=temporal_chunk).latent_dist
579
+
580
+ if sample_posterior:
581
+ z = posterior.sample(generator=generator)
582
+ else:
583
+ z = posterior.mode()
584
+
585
+ dec = self.decode(z, is_init_image=is_init_image, temporal_chunk=temporal_chunk).sample
586
+
587
+ return posterior, dec
588
+
589
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
590
+ def fuse_qkv_projections(self):
591
+ """
592
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
593
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
594
+
595
+ <Tip warning={true}>
596
+
597
+ This API is 🧪 experimental.
598
+
599
+ </Tip>
600
+ """
601
+ self.original_attn_processors = None
602
+
603
+ for _, attn_processor in self.attn_processors.items():
604
+ if "Added" in str(attn_processor.__class__.__name__):
605
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
606
+
607
+ self.original_attn_processors = self.attn_processors
608
+
609
+ for module in self.modules():
610
+ if isinstance(module, Attention):
611
+ module.fuse_projections(fuse=True)
612
+
613
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
614
+ def unfuse_qkv_projections(self):
615
+ """Disables the fused QKV projection if enabled.
616
+
617
+ <Tip warning={true}>
618
+
619
+ This API is 🧪 experimental.
620
+
621
+ </Tip>
622
+
623
+ """
624
+ if self.original_attn_processors is not None:
625
+ self.set_attn_processor(self.original_attn_processors)
video_vae/modeling_discriminator.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ import torch
5
+
6
+
7
+ def weights_init(m):
8
+ classname = m.__class__.__name__
9
+ if classname.find('Conv') != -1:
10
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
11
+ nn.init.constant_(m.bias.data, 0)
12
+ elif classname.find('BatchNorm') != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+ def __init__(self, input_nc=3, ndf=64, n_layers=4):
22
+ """Construct a PatchGAN discriminator
23
+ Parameters:
24
+ input_nc (int) -- the number of channels in input images
25
+ ndf (int) -- the number of filters in the last conv layer
26
+ n_layers (int) -- the number of conv layers in the discriminator
27
+ norm_layer -- normalization layer
28
+ """
29
+ super(NLayerDiscriminator, self).__init__()
30
+
31
+ # norm_layer = nn.BatchNorm2d
32
+ norm_layer = nn.InstanceNorm2d
33
+
34
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35
+ use_bias = norm_layer.func != nn.BatchNorm2d
36
+ else:
37
+ use_bias = norm_layer != nn.BatchNorm2d
38
+
39
+ kw = 4
40
+ padw = 1
41
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42
+ nf_mult = 1
43
+ nf_mult_prev = 1
44
+ for n in range(1, n_layers): # gradually increase the number of filters
45
+ nf_mult_prev = nf_mult
46
+ nf_mult = min(2 ** n, 8)
47
+ sequence += [
48
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49
+ norm_layer(ndf * nf_mult),
50
+ nn.LeakyReLU(0.2, True)
51
+ ]
52
+
53
+ nf_mult_prev = nf_mult
54
+ nf_mult = min(2 ** n_layers, 8)
55
+ sequence += [
56
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57
+ norm_layer(ndf * nf_mult),
58
+ nn.LeakyReLU(0.2, True)
59
+ ]
60
+
61
+ sequence += [
62
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63
+ self.main = nn.Sequential(*sequence)
64
+
65
+ def forward(self, input):
66
+ """Standard forward."""
67
+ return self.main(input)
68
+
69
+
70
+ class NLayerDiscriminator3D(nn.Module):
71
+ """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
72
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
73
+ """
74
+ Construct a 3D PatchGAN discriminator
75
+
76
+ Parameters:
77
+ input_nc (int) -- the number of channels in input volumes
78
+ ndf (int) -- the number of filters in the last conv layer
79
+ n_layers (int) -- the number of conv layers in the discriminator
80
+ use_actnorm (bool) -- flag to use actnorm instead of batchnorm
81
+ """
82
+ super(NLayerDiscriminator3D, self).__init__()
83
+ # if not use_actnorm:
84
+ # norm_layer = nn.BatchNorm3d
85
+ # else:
86
+ # raise NotImplementedError("Not implemented.")
87
+
88
+ norm_layer = nn.InstanceNorm3d
89
+
90
+ if type(norm_layer) == functools.partial:
91
+ use_bias = norm_layer.func != nn.BatchNorm3d
92
+ else:
93
+ use_bias = norm_layer != nn.BatchNorm3d
94
+
95
+ kw = 4
96
+ padw = 1
97
+ sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
98
+ nf_mult = 1
99
+ nf_mult_prev = 1
100
+ for n in range(1, n_layers): # gradually increase the number of filters
101
+ nf_mult_prev = nf_mult
102
+ nf_mult = min(2 ** n, 8)
103
+ sequence += [
104
+ nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(1,2,2), padding=padw, bias=use_bias),
105
+ norm_layer(ndf * nf_mult),
106
+ nn.LeakyReLU(0.2, True)
107
+ ]
108
+
109
+ nf_mult_prev = nf_mult
110
+ nf_mult = min(2 ** n_layers, 8)
111
+ sequence += [
112
+ nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias),
113
+ norm_layer(ndf * nf_mult),
114
+ nn.LeakyReLU(0.2, True)
115
+ ]
116
+
117
+ sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
118
+ self.main = nn.Sequential(*sequence)
119
+
120
+ def forward(self, input):
121
+ """Standard forward."""
122
+ return self.main(input)
video_vae/modeling_enc_dec.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ from einops import rearrange
21
+
22
+ from diffusers.utils import BaseOutput, is_torch_version
23
+ from diffusers.utils.torch_utils import randn_tensor
24
+ from diffusers.models.attention_processor import SpatialNorm
25
+ from .modeling_block import (
26
+ UNetMidBlock2D,
27
+ CausalUNetMidBlock2D,
28
+ get_down_block,
29
+ get_up_block,
30
+ get_input_layer,
31
+ get_output_layer,
32
+ )
33
+ from .modeling_resnet import (
34
+ Downsample2D,
35
+ Upsample2D,
36
+ TemporalDownsample2x,
37
+ TemporalUpsample2x,
38
+ )
39
+ from .modeling_causal_conv import CausalConv3d, CausalGroupNorm
40
+
41
+
42
+ @dataclass
43
+ class DecoderOutput(BaseOutput):
44
+ r"""
45
+ Output of decoding method.
46
+
47
+ Args:
48
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
49
+ The decoded output sample from the last layer of the model.
50
+ """
51
+
52
+ sample: torch.FloatTensor
53
+
54
+
55
+ class CausalVaeEncoder(nn.Module):
56
+ r"""
57
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
58
+
59
+ Args:
60
+ in_channels (`int`, *optional*, defaults to 3):
61
+ The number of input channels.
62
+ out_channels (`int`, *optional*, defaults to 3):
63
+ The number of output channels.
64
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
65
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
66
+ options.
67
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
68
+ The number of output channels for each block.
69
+ layers_per_block (`int`, *optional*, defaults to 2):
70
+ The number of layers per block.
71
+ norm_num_groups (`int`, *optional*, defaults to 32):
72
+ The number of groups for normalization.
73
+ act_fn (`str`, *optional*, defaults to `"silu"`):
74
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
75
+ double_z (`bool`, *optional*, defaults to `True`):
76
+ Whether to double the number of output channels for the last block.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ in_channels: int = 3,
82
+ out_channels: int = 3,
83
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
84
+ spatial_down_sample: Tuple[bool, ...] = (True,),
85
+ temporal_down_sample: Tuple[bool, ...] = (False,),
86
+ block_out_channels: Tuple[int, ...] = (64,),
87
+ layers_per_block: Tuple[int, ...] = (2,),
88
+ norm_num_groups: int = 32,
89
+ act_fn: str = "silu",
90
+ double_z: bool = True,
91
+ block_dropout: Tuple[int, ...] = (0.0,),
92
+ mid_block_add_attention=True,
93
+ ):
94
+ super().__init__()
95
+ self.layers_per_block = layers_per_block
96
+
97
+ self.conv_in = CausalConv3d(
98
+ in_channels,
99
+ block_out_channels[0],
100
+ kernel_size=3,
101
+ stride=1,
102
+ )
103
+
104
+ self.mid_block = None
105
+ self.down_blocks = nn.ModuleList([])
106
+
107
+ # down
108
+ output_channel = block_out_channels[0]
109
+ for i, down_block_type in enumerate(down_block_types):
110
+ input_channel = output_channel
111
+ output_channel = block_out_channels[i]
112
+
113
+ down_block = get_down_block(
114
+ down_block_type,
115
+ num_layers=self.layers_per_block[i],
116
+ in_channels=input_channel,
117
+ out_channels=output_channel,
118
+ add_spatial_downsample=spatial_down_sample[i],
119
+ add_temporal_downsample=temporal_down_sample[i],
120
+ resnet_eps=1e-6,
121
+ downsample_padding=0,
122
+ resnet_act_fn=act_fn,
123
+ resnet_groups=norm_num_groups,
124
+ attention_head_dim=output_channel,
125
+ temb_channels=None,
126
+ dropout=block_dropout[i],
127
+ )
128
+ self.down_blocks.append(down_block)
129
+
130
+ # mid
131
+ self.mid_block = CausalUNetMidBlock2D(
132
+ in_channels=block_out_channels[-1],
133
+ resnet_eps=1e-6,
134
+ resnet_act_fn=act_fn,
135
+ output_scale_factor=1,
136
+ resnet_time_scale_shift="default",
137
+ attention_head_dim=block_out_channels[-1],
138
+ resnet_groups=norm_num_groups,
139
+ temb_channels=None,
140
+ add_attention=mid_block_add_attention,
141
+ dropout=block_dropout[-1],
142
+ )
143
+
144
+ # out
145
+
146
+ self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
147
+ self.conv_act = nn.SiLU()
148
+
149
+ conv_out_channels = 2 * out_channels if double_z else out_channels
150
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3, stride=1)
151
+
152
+ self.gradient_checkpointing = False
153
+
154
+ def forward(self, sample: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
155
+ r"""The forward method of the `Encoder` class."""
156
+
157
+ sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
158
+
159
+ if self.training and self.gradient_checkpointing:
160
+
161
+ def create_custom_forward(module):
162
+ def custom_forward(*inputs):
163
+ return module(*inputs)
164
+
165
+ return custom_forward
166
+
167
+ # down
168
+ if is_torch_version(">=", "1.11.0"):
169
+ for down_block in self.down_blocks:
170
+ sample = torch.utils.checkpoint.checkpoint(
171
+ create_custom_forward(down_block), sample, is_init_image,
172
+ temporal_chunk, use_reentrant=False
173
+ )
174
+ # middle
175
+ sample = torch.utils.checkpoint.checkpoint(
176
+ create_custom_forward(self.mid_block), sample, is_init_image,
177
+ temporal_chunk, use_reentrant=False
178
+ )
179
+ else:
180
+ for down_block in self.down_blocks:
181
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample, is_init_image, temporal_chunk)
182
+ # middle
183
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, is_init_image, temporal_chunk)
184
+
185
+ else:
186
+ # down
187
+ for down_block in self.down_blocks:
188
+ sample = down_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
189
+
190
+ # middle
191
+ sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
192
+
193
+ # post-process
194
+ sample = self.conv_norm_out(sample)
195
+ sample = self.conv_act(sample)
196
+ sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
197
+
198
+ return sample
199
+
200
+
201
+ class CausalVaeDecoder(nn.Module):
202
+ r"""
203
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
204
+
205
+ Args:
206
+ in_channels (`int`, *optional*, defaults to 3):
207
+ The number of input channels.
208
+ out_channels (`int`, *optional*, defaults to 3):
209
+ The number of output channels.
210
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
211
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
212
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
213
+ The number of output channels for each block.
214
+ layers_per_block (`int`, *optional*, defaults to 2):
215
+ The number of layers per block.
216
+ norm_num_groups (`int`, *optional*, defaults to 32):
217
+ The number of groups for normalization.
218
+ act_fn (`str`, *optional*, defaults to `"silu"`):
219
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
220
+ norm_type (`str`, *optional*, defaults to `"group"`):
221
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ in_channels: int = 3,
227
+ out_channels: int = 3,
228
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
229
+ spatial_up_sample: Tuple[bool, ...] = (True,),
230
+ temporal_up_sample: Tuple[bool, ...] = (False,),
231
+ block_out_channels: Tuple[int, ...] = (64,),
232
+ layers_per_block: Tuple[int, ...] = (2,),
233
+ norm_num_groups: int = 32,
234
+ act_fn: str = "silu",
235
+ mid_block_add_attention=True,
236
+ interpolate: bool = True,
237
+ block_dropout: Tuple[int, ...] = (0.0,),
238
+ ):
239
+ super().__init__()
240
+ self.layers_per_block = layers_per_block
241
+
242
+ self.conv_in = CausalConv3d(
243
+ in_channels,
244
+ block_out_channels[-1],
245
+ kernel_size=3,
246
+ stride=1,
247
+ )
248
+
249
+ self.mid_block = None
250
+ self.up_blocks = nn.ModuleList([])
251
+
252
+ # mid
253
+ self.mid_block = CausalUNetMidBlock2D(
254
+ in_channels=block_out_channels[-1],
255
+ resnet_eps=1e-6,
256
+ resnet_act_fn=act_fn,
257
+ output_scale_factor=1,
258
+ resnet_time_scale_shift="default",
259
+ attention_head_dim=block_out_channels[-1],
260
+ resnet_groups=norm_num_groups,
261
+ temb_channels=None,
262
+ add_attention=mid_block_add_attention,
263
+ dropout=block_dropout[-1],
264
+ )
265
+
266
+ # up
267
+ reversed_block_out_channels = list(reversed(block_out_channels))
268
+ output_channel = reversed_block_out_channels[0]
269
+ for i, up_block_type in enumerate(up_block_types):
270
+ prev_output_channel = output_channel
271
+ output_channel = reversed_block_out_channels[i]
272
+
273
+ is_final_block = i == len(block_out_channels) - 1
274
+
275
+ up_block = get_up_block(
276
+ up_block_type,
277
+ num_layers=self.layers_per_block[i],
278
+ in_channels=prev_output_channel,
279
+ out_channels=output_channel,
280
+ prev_output_channel=None,
281
+ add_spatial_upsample=spatial_up_sample[i],
282
+ add_temporal_upsample=temporal_up_sample[i],
283
+ resnet_eps=1e-6,
284
+ resnet_act_fn=act_fn,
285
+ resnet_groups=norm_num_groups,
286
+ attention_head_dim=output_channel,
287
+ temb_channels=None,
288
+ resnet_time_scale_shift='default',
289
+ interpolate=interpolate,
290
+ dropout=block_dropout[i],
291
+ )
292
+ self.up_blocks.append(up_block)
293
+ prev_output_channel = output_channel
294
+
295
+ # out
296
+ self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
297
+ self.conv_act = nn.SiLU()
298
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3, stride=1)
299
+
300
+ self.gradient_checkpointing = False
301
+
302
+ def forward(
303
+ self,
304
+ sample: torch.FloatTensor,
305
+ is_init_image=True,
306
+ temporal_chunk=False,
307
+ ) -> torch.FloatTensor:
308
+ r"""The forward method of the `Decoder` class."""
309
+
310
+ sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
311
+
312
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
313
+ if self.training and self.gradient_checkpointing:
314
+
315
+ def create_custom_forward(module):
316
+ def custom_forward(*inputs):
317
+ return module(*inputs)
318
+
319
+ return custom_forward
320
+
321
+ if is_torch_version(">=", "1.11.0"):
322
+ # middle
323
+ sample = torch.utils.checkpoint.checkpoint(
324
+ create_custom_forward(self.mid_block),
325
+ sample,
326
+ is_init_image=is_init_image,
327
+ temporal_chunk=temporal_chunk,
328
+ use_reentrant=False,
329
+ )
330
+ sample = sample.to(upscale_dtype)
331
+
332
+ # up
333
+ for up_block in self.up_blocks:
334
+ sample = torch.utils.checkpoint.checkpoint(
335
+ create_custom_forward(up_block),
336
+ sample,
337
+ is_init_image=is_init_image,
338
+ temporal_chunk=temporal_chunk,
339
+ use_reentrant=False,
340
+ )
341
+ else:
342
+ # middle
343
+ sample = torch.utils.checkpoint.checkpoint(
344
+ create_custom_forward(self.mid_block), sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,
345
+ )
346
+ sample = sample.to(upscale_dtype)
347
+
348
+ # up
349
+ for up_block in self.up_blocks:
350
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample,
351
+ is_init_image=is_init_image, temporal_chunk=temporal_chunk,)
352
+ else:
353
+ # middle
354
+ sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
355
+ sample = sample.to(upscale_dtype)
356
+
357
+ # up
358
+ for up_block in self.up_blocks:
359
+ sample = up_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,)
360
+
361
+ # post-process
362
+ sample = self.conv_norm_out(sample)
363
+ sample = self.conv_act(sample)
364
+ sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
365
+
366
+ return sample
367
+
368
+
369
+ class DiagonalGaussianDistribution(object):
370
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
371
+ self.parameters = parameters
372
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
373
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
374
+ self.deterministic = deterministic
375
+ self.std = torch.exp(0.5 * self.logvar)
376
+ self.var = torch.exp(self.logvar)
377
+ if self.deterministic:
378
+ self.var = self.std = torch.zeros_like(
379
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
380
+ )
381
+
382
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
383
+ # make sure sample is on the same device as the parameters and has same dtype
384
+ sample = randn_tensor(
385
+ self.mean.shape,
386
+ generator=generator,
387
+ device=self.parameters.device,
388
+ dtype=self.parameters.dtype,
389
+ )
390
+ x = self.mean + self.std * sample
391
+ return x
392
+
393
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
394
+ if self.deterministic:
395
+ return torch.Tensor([0.0])
396
+ else:
397
+ if other is None:
398
+ return 0.5 * torch.sum(
399
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
400
+ dim=[2, 3, 4],
401
+ )
402
+ else:
403
+ return 0.5 * torch.sum(
404
+ torch.pow(self.mean - other.mean, 2) / other.var
405
+ + self.var / other.var
406
+ - 1.0
407
+ - self.logvar
408
+ + other.logvar,
409
+ dim=[2, 3, 4],
410
+ )
411
+
412
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
413
+ if self.deterministic:
414
+ return torch.Tensor([0.0])
415
+ logtwopi = np.log(2.0 * np.pi)
416
+ return 0.5 * torch.sum(
417
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
418
+ dim=dims,
419
+ )
420
+
421
+ def mode(self) -> torch.Tensor:
422
+ return self.mean
video_vae/modeling_loss.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from .modeling_lpips import LPIPS
7
+ from .modeling_discriminator import NLayerDiscriminator, NLayerDiscriminator3D, weights_init
8
+ from IPython import embed
9
+
10
+
11
+ class AdaptiveLossWeight:
12
+ def __init__(self, timestep_range=[0, 1], buckets=300, weight_range=[1e-7, 1e7]):
13
+ self.bucket_ranges = torch.linspace(timestep_range[0], timestep_range[1], buckets-1)
14
+ self.bucket_losses = torch.ones(buckets)
15
+ self.weight_range = weight_range
16
+
17
+ def weight(self, timestep):
18
+ indices = torch.searchsorted(self.bucket_ranges.to(timestep.device), timestep)
19
+ return (1/self.bucket_losses.to(timestep.device)[indices]).clamp(*self.weight_range)
20
+
21
+ def update_buckets(self, timestep, loss, beta=0.99):
22
+ indices = torch.searchsorted(self.bucket_ranges.to(timestep.device), timestep).cpu()
23
+ self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta)
24
+
25
+
26
+ def hinge_d_loss(logits_real, logits_fake):
27
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
28
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
29
+ d_loss = 0.5 * (loss_real + loss_fake)
30
+ return d_loss
31
+
32
+
33
+ def vanilla_d_loss(logits_real, logits_fake):
34
+ d_loss = 0.5 * (
35
+ torch.mean(torch.nn.functional.softplus(-logits_real))
36
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
37
+ )
38
+ return d_loss
39
+
40
+
41
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
42
+ if global_step < threshold:
43
+ weight = value
44
+ return weight
45
+
46
+
47
+ class LPIPSWithDiscriminator(nn.Module):
48
+ def __init__(
49
+ self,
50
+ disc_start,
51
+ logvar_init=0.0,
52
+ kl_weight=1.0,
53
+ pixelloss_weight=1.0,
54
+ perceptual_weight=1.0,
55
+ # --- Discriminator Loss ---
56
+ disc_num_layers=4,
57
+ disc_in_channels=3,
58
+ disc_factor=1.0,
59
+ disc_weight=0.5,
60
+ disc_loss="hinge",
61
+ add_discriminator=True,
62
+ using_3d_discriminator=False,
63
+ ):
64
+
65
+ super().__init__()
66
+ assert disc_loss in ["hinge", "vanilla"]
67
+ self.kl_weight = kl_weight
68
+ self.pixel_weight = pixelloss_weight
69
+ self.perceptual_loss = LPIPS().eval()
70
+ self.perceptual_weight = perceptual_weight
71
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
72
+
73
+ if add_discriminator:
74
+ disc_cls = NLayerDiscriminator3D if using_3d_discriminator else NLayerDiscriminator
75
+ self.discriminator = disc_cls(
76
+ input_nc=disc_in_channels, n_layers=disc_num_layers,
77
+ ).apply(weights_init)
78
+ else:
79
+ self.discriminator = None
80
+
81
+ self.discriminator_iter_start = disc_start
82
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
83
+ self.disc_factor = disc_factor
84
+ self.discriminator_weight = disc_weight
85
+ self.using_3d_discriminator = using_3d_discriminator
86
+
87
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
88
+ if last_layer is not None:
89
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
90
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
91
+ else:
92
+ nll_grads = torch.autograd.grad(
93
+ nll_loss, self.last_layer[0], retain_graph=True
94
+ )[0]
95
+ g_grads = torch.autograd.grad(
96
+ g_loss, self.last_layer[0], retain_graph=True
97
+ )[0]
98
+
99
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
100
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
101
+ d_weight = d_weight * self.discriminator_weight
102
+ return d_weight
103
+
104
+ def forward(
105
+ self,
106
+ inputs,
107
+ reconstructions,
108
+ posteriors,
109
+ optimizer_idx,
110
+ global_step,
111
+ split="train",
112
+ last_layer=None,
113
+ ):
114
+ t = reconstructions.shape[2]
115
+ inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous()
116
+ reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w").contiguous()
117
+
118
+ if optimizer_idx == 0:
119
+ # rec_loss = torch.mean(torch.abs(inputs - reconstructions), dim=(1,2,3), keepdim=True)
120
+ rec_loss = torch.mean(F.mse_loss(inputs, reconstructions, reduction='none'), dim=(1,2,3), keepdim=True)
121
+
122
+ if self.perceptual_weight > 0:
123
+ p_loss = self.perceptual_loss(inputs, reconstructions)
124
+ nll_loss = self.pixel_weight * rec_loss + self.perceptual_weight * p_loss
125
+
126
+ nll_loss = nll_loss / torch.exp(self.logvar) + self.logvar
127
+ weighted_nll_loss = nll_loss
128
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
129
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
130
+
131
+ kl_loss = posteriors.kl()
132
+ kl_loss = torch.mean(kl_loss)
133
+
134
+ disc_factor = adopt_weight(
135
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
136
+ )
137
+
138
+ if disc_factor > 0.0:
139
+ if self.using_3d_discriminator:
140
+ reconstructions = rearrange(reconstructions, '(b t) c h w -> b c t h w', t=t)
141
+
142
+ logits_fake = self.discriminator(reconstructions.contiguous())
143
+ g_loss = -torch.mean(logits_fake)
144
+ try:
145
+ d_weight = self.calculate_adaptive_weight(
146
+ nll_loss, g_loss, last_layer=last_layer
147
+ )
148
+ except RuntimeError:
149
+ assert not self.training
150
+ d_weight = torch.tensor(0.0)
151
+ else:
152
+ d_weight = torch.tensor(0.0)
153
+ g_loss = torch.tensor(0.0)
154
+
155
+
156
+ loss = (
157
+ weighted_nll_loss
158
+ + self.kl_weight * kl_loss
159
+ + d_weight * disc_factor * g_loss
160
+ )
161
+ log = {
162
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
163
+ "{}/logvar".format(split): self.logvar.detach(),
164
+ "{}/kl_loss".format(split): kl_loss.detach().mean(),
165
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
166
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
167
+ "{}/perception_loss".format(split): p_loss.detach().mean(),
168
+ "{}/d_weight".format(split): d_weight.detach(),
169
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
170
+ "{}/g_loss".format(split): g_loss.detach().mean(),
171
+ }
172
+ return loss, log
173
+
174
+ if optimizer_idx == 1:
175
+ if self.using_3d_discriminator:
176
+ inputs = rearrange(inputs, '(b t) c h w -> b c t h w', t=t)
177
+ reconstructions = rearrange(reconstructions, '(b t) c h w -> b c t h w', t=t)
178
+
179
+ logits_real = self.discriminator(inputs.contiguous().detach())
180
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
181
+
182
+ disc_factor = adopt_weight(
183
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
184
+ )
185
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
186
+
187
+ log = {
188
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
189
+ "{}/logits_real".format(split): logits_real.detach().mean(),
190
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
191
+ }
192
+ return d_loss, log
video_vae/modeling_lpips.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import models
6
+ from collections import namedtuple
7
+
8
+
9
+ class LPIPS(nn.Module):
10
+ # Learned perceptual metric
11
+ def __init__(self, use_dropout=True):
12
+ super().__init__()
13
+ self.scaling_layer = ScalingLayer()
14
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
15
+ self.net = vgg16(pretrained=False, requires_grad=False)
16
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
17
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
18
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
19
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
20
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
21
+ self.load_from_pretrained()
22
+ for param in self.parameters():
23
+ param.requires_grad = False
24
+
25
+ def load_from_pretrained(self):
26
+ ckpt = "/home/jinyang/models/vae/video_vae_baseline/vgg_lpips.pth" # replace with your lpips
27
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=True)
28
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
29
+
30
+ def forward(self, input, target):
31
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
32
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
33
+ feats0, feats1, diffs = {}, {}, {}
34
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
35
+ for kk in range(len(self.chns)):
36
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
37
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
38
+
39
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
40
+ val = res[0]
41
+ for l in range(1, len(self.chns)):
42
+ val += res[l]
43
+ return val
44
+
45
+
46
+ class ScalingLayer(nn.Module):
47
+ def __init__(self):
48
+ super(ScalingLayer, self).__init__()
49
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
50
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
51
+
52
+ def forward(self, inp):
53
+ return (inp - self.shift) / self.scale
54
+
55
+
56
+ class NetLinLayer(nn.Module):
57
+ """ A single linear layer which does a 1x1 conv """
58
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
59
+ super(NetLinLayer, self).__init__()
60
+ layers = [nn.Dropout(), ] if (use_dropout) else []
61
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
62
+ self.model = nn.Sequential(*layers)
63
+
64
+
65
+ class vgg16(torch.nn.Module):
66
+ def __init__(self, requires_grad=False, pretrained=True):
67
+ super(vgg16, self).__init__()
68
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
69
+ self.slice1 = torch.nn.Sequential()
70
+ self.slice2 = torch.nn.Sequential()
71
+ self.slice3 = torch.nn.Sequential()
72
+ self.slice4 = torch.nn.Sequential()
73
+ self.slice5 = torch.nn.Sequential()
74
+ self.N_slices = 5
75
+ for x in range(4):
76
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
77
+ for x in range(4, 9):
78
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
79
+ for x in range(9, 16):
80
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
81
+ for x in range(16, 23):
82
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
83
+ for x in range(23, 30):
84
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
85
+ if not requires_grad:
86
+ for param in self.parameters():
87
+ param.requires_grad = False
88
+
89
+ def forward(self, X):
90
+ h = self.slice1(X)
91
+ h_relu1_2 = h
92
+ h = self.slice2(h)
93
+ h_relu2_2 = h
94
+ h = self.slice3(h)
95
+ h_relu3_3 = h
96
+ h = self.slice4(h)
97
+ h_relu4_3 = h
98
+ h = self.slice5(h)
99
+ h_relu5_3 = h
100
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
101
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
102
+ return out
103
+
104
+
105
+ def normalize_tensor(x,eps=1e-10):
106
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
107
+ return x/(norm_factor+eps)
108
+
109
+
110
+ def spatial_average(x, keepdim=True):
111
+ return x.mean([2,3],keepdim=keepdim)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ model = LPIPS().eval()
116
+ _ = torch.manual_seed(123)
117
+ img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
118
+ img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
119
+ print(model(img1, img2).shape)
120
+ # embed()
video_vae/modeling_resnet.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from diffusers.models.activations import get_activation
9
+ from diffusers.models.attention_processor import SpatialNorm
10
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
11
+ from diffusers.models.normalization import AdaGroupNorm
12
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
13
+ from .modeling_causal_conv import CausalConv3d, CausalGroupNorm
14
+
15
+
16
+ class CausalResnetBlock3D(nn.Module):
17
+ r"""
18
+ A Resnet block.
19
+
20
+ Parameters:
21
+ in_channels (`int`): The number of channels in the input.
22
+ out_channels (`int`, *optional*, default to be `None`):
23
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
24
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
25
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
26
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
27
+ groups_out (`int`, *optional*, default to None):
28
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
29
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
30
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
31
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
32
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
33
+ "ada_group" for a stronger conditioning with scale and shift.
34
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
35
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
36
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
37
+ use_in_shortcut (`bool`, *optional*, default to `True`):
38
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
39
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
40
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
41
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
42
+ `conv_shortcut` output.
43
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
44
+ If None, same as `out_channels`.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ *,
50
+ in_channels: int,
51
+ out_channels: Optional[int] = None,
52
+ conv_shortcut: bool = False,
53
+ dropout: float = 0.0,
54
+ temb_channels: int = 512,
55
+ groups: int = 32,
56
+ groups_out: Optional[int] = None,
57
+ pre_norm: bool = True,
58
+ eps: float = 1e-6,
59
+ non_linearity: str = "swish",
60
+ time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
61
+ output_scale_factor: float = 1.0,
62
+ use_in_shortcut: Optional[bool] = None,
63
+ conv_shortcut_bias: bool = True,
64
+ conv_2d_out_channels: Optional[int] = None,
65
+ ):
66
+ super().__init__()
67
+ self.pre_norm = pre_norm
68
+ self.pre_norm = True
69
+ self.in_channels = in_channels
70
+ out_channels = in_channels if out_channels is None else out_channels
71
+ self.out_channels = out_channels
72
+ self.use_conv_shortcut = conv_shortcut
73
+ self.output_scale_factor = output_scale_factor
74
+ self.time_embedding_norm = time_embedding_norm
75
+
76
+ linear_cls = nn.Linear
77
+
78
+ if groups_out is None:
79
+ groups_out = groups
80
+
81
+ if self.time_embedding_norm == "ada_group":
82
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
83
+ elif self.time_embedding_norm == "spatial":
84
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
85
+ else:
86
+ self.norm1 = CausalGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
87
+
88
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
89
+
90
+ if self.time_embedding_norm == "ada_group":
91
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
92
+ elif self.time_embedding_norm == "spatial":
93
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
94
+ else:
95
+ self.norm2 = CausalGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
96
+
97
+ self.dropout = torch.nn.Dropout(dropout)
98
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
99
+ self.conv2 = CausalConv3d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1)
100
+
101
+ self.nonlinearity = get_activation(non_linearity)
102
+ self.upsample = self.downsample = None
103
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
104
+
105
+ self.conv_shortcut = None
106
+ if self.use_in_shortcut:
107
+ self.conv_shortcut = CausalConv3d(
108
+ in_channels,
109
+ conv_2d_out_channels,
110
+ kernel_size=1,
111
+ stride=1,
112
+ bias=conv_shortcut_bias,
113
+ )
114
+
115
+ def forward(
116
+ self,
117
+ input_tensor: torch.FloatTensor,
118
+ temb: torch.FloatTensor = None,
119
+ is_init_image=True,
120
+ temporal_chunk=False,
121
+ ) -> torch.FloatTensor:
122
+ hidden_states = input_tensor
123
+
124
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
125
+ hidden_states = self.norm1(hidden_states, temb)
126
+ else:
127
+ hidden_states = self.norm1(hidden_states)
128
+
129
+ hidden_states = self.nonlinearity(hidden_states)
130
+
131
+ hidden_states = self.conv1(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
132
+
133
+ if temb is not None and self.time_embedding_norm == "default":
134
+ hidden_states = hidden_states + temb
135
+
136
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
137
+ hidden_states = self.norm2(hidden_states, temb)
138
+ else:
139
+ hidden_states = self.norm2(hidden_states)
140
+
141
+ hidden_states = self.nonlinearity(hidden_states)
142
+ hidden_states = self.dropout(hidden_states)
143
+ hidden_states = self.conv2(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
144
+
145
+ if self.conv_shortcut is not None:
146
+ input_tensor = self.conv_shortcut(input_tensor, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
147
+
148
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
149
+
150
+ return output_tensor
151
+
152
+
153
+ class ResnetBlock2D(nn.Module):
154
+ r"""
155
+ A Resnet block.
156
+
157
+ Parameters:
158
+ in_channels (`int`): The number of channels in the input.
159
+ out_channels (`int`, *optional*, default to be `None`):
160
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
161
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
162
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
163
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
164
+ groups_out (`int`, *optional*, default to None):
165
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
166
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
167
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
168
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
169
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
170
+ "ada_group" for a stronger conditioning with scale and shift.
171
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
172
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
173
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
174
+ use_in_shortcut (`bool`, *optional*, default to `True`):
175
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
176
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
177
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
178
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
179
+ `conv_shortcut` output.
180
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
181
+ If None, same as `out_channels`.
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ *,
187
+ in_channels: int,
188
+ out_channels: Optional[int] = None,
189
+ conv_shortcut: bool = False,
190
+ dropout: float = 0.0,
191
+ temb_channels: int = 512,
192
+ groups: int = 32,
193
+ groups_out: Optional[int] = None,
194
+ pre_norm: bool = True,
195
+ eps: float = 1e-6,
196
+ non_linearity: str = "swish",
197
+ time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
198
+ output_scale_factor: float = 1.0,
199
+ use_in_shortcut: Optional[bool] = None,
200
+ conv_shortcut_bias: bool = True,
201
+ conv_2d_out_channels: Optional[int] = None,
202
+ ):
203
+ super().__init__()
204
+ self.pre_norm = pre_norm
205
+ self.pre_norm = True
206
+ self.in_channels = in_channels
207
+ out_channels = in_channels if out_channels is None else out_channels
208
+ self.out_channels = out_channels
209
+ self.use_conv_shortcut = conv_shortcut
210
+ self.output_scale_factor = output_scale_factor
211
+ self.time_embedding_norm = time_embedding_norm
212
+
213
+ linear_cls = nn.Linear
214
+ conv_cls = nn.Conv3d
215
+
216
+ if groups_out is None:
217
+ groups_out = groups
218
+
219
+ if self.time_embedding_norm == "ada_group":
220
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
221
+ elif self.time_embedding_norm == "spatial":
222
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
223
+ else:
224
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
225
+
226
+ self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
227
+
228
+ if self.time_embedding_norm == "ada_group":
229
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
230
+ elif self.time_embedding_norm == "spatial":
231
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
232
+ else:
233
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
234
+
235
+ self.dropout = torch.nn.Dropout(dropout)
236
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
237
+ self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
238
+
239
+ self.nonlinearity = get_activation(non_linearity)
240
+ self.upsample = self.downsample = None
241
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
242
+
243
+ self.conv_shortcut = None
244
+ if self.use_in_shortcut:
245
+ self.conv_shortcut = conv_cls(
246
+ in_channels,
247
+ conv_2d_out_channels,
248
+ kernel_size=1,
249
+ stride=1,
250
+ padding=0,
251
+ bias=conv_shortcut_bias,
252
+ )
253
+
254
+ def forward(
255
+ self,
256
+ input_tensor: torch.FloatTensor,
257
+ temb: torch.FloatTensor = None,
258
+ scale: float = 1.0,
259
+ ) -> torch.FloatTensor:
260
+ hidden_states = input_tensor
261
+
262
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
263
+ hidden_states = self.norm1(hidden_states, temb)
264
+ else:
265
+ hidden_states = self.norm1(hidden_states)
266
+
267
+ hidden_states = self.nonlinearity(hidden_states)
268
+
269
+ hidden_states = self.conv1(hidden_states)
270
+
271
+ if temb is not None and self.time_embedding_norm == "default":
272
+ hidden_states = hidden_states + temb
273
+
274
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
275
+ hidden_states = self.norm2(hidden_states, temb)
276
+ else:
277
+ hidden_states = self.norm2(hidden_states)
278
+
279
+ hidden_states = self.nonlinearity(hidden_states)
280
+ hidden_states = self.dropout(hidden_states)
281
+ hidden_states = self.conv2(hidden_states)
282
+
283
+ if self.conv_shortcut is not None:
284
+ input_tensor = self.conv_shortcut(input_tensor)
285
+
286
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
287
+
288
+ return output_tensor
289
+
290
+
291
+ class CausalDownsample2x(nn.Module):
292
+ """A 2D downsampling layer with an optional convolution.
293
+
294
+ Parameters:
295
+ channels (`int`):
296
+ number of channels in the inputs and outputs.
297
+ use_conv (`bool`, default `False`):
298
+ option to use a convolution.
299
+ out_channels (`int`, optional):
300
+ number of output channels. Defaults to `channels`.
301
+ padding (`int`, default `1`):
302
+ padding for the convolution.
303
+ name (`str`, default `conv`):
304
+ name of the downsampling 2D layer.
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ channels: int,
310
+ use_conv: bool = True,
311
+ out_channels: Optional[int] = None,
312
+ name: str = "conv",
313
+ kernel_size=3,
314
+ bias=True,
315
+ ):
316
+ super().__init__()
317
+ self.channels = channels
318
+ self.out_channels = out_channels or channels
319
+ self.use_conv = use_conv
320
+ stride = (1, 2, 2)
321
+ self.name = name
322
+
323
+ if use_conv:
324
+ conv = CausalConv3d(
325
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
326
+ )
327
+ else:
328
+ assert self.channels == self.out_channels
329
+ conv = nn.AvgPool3d(kernel_size=stride, stride=stride)
330
+
331
+ self.conv = conv
332
+
333
+ def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
334
+ assert hidden_states.shape[1] == self.channels
335
+ hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
336
+ return hidden_states
337
+
338
+
339
+ class Downsample2D(nn.Module):
340
+ """A 2D downsampling layer with an optional convolution.
341
+
342
+ Parameters:
343
+ channels (`int`):
344
+ number of channels in the inputs and outputs.
345
+ use_conv (`bool`, default `False`):
346
+ option to use a convolution.
347
+ out_channels (`int`, optional):
348
+ number of output channels. Defaults to `channels`.
349
+ padding (`int`, default `1`):
350
+ padding for the convolution.
351
+ name (`str`, default `conv`):
352
+ name of the downsampling 2D layer.
353
+ """
354
+
355
+ def __init__(
356
+ self,
357
+ channels: int,
358
+ use_conv: bool = True,
359
+ out_channels: Optional[int] = None,
360
+ padding: int = 0,
361
+ name: str = "conv",
362
+ kernel_size=3,
363
+ bias=True,
364
+ ):
365
+ super().__init__()
366
+ self.channels = channels
367
+ self.out_channels = out_channels or channels
368
+ self.use_conv = use_conv
369
+ self.padding = padding
370
+ stride = (1, 2, 2)
371
+ self.name = name
372
+ conv_cls = nn.Conv3d
373
+
374
+ if use_conv:
375
+ conv = conv_cls(
376
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
377
+ )
378
+ else:
379
+ assert self.channels == self.out_channels
380
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
381
+
382
+ self.conv = conv
383
+
384
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
385
+ assert hidden_states.shape[1] == self.channels
386
+
387
+ if self.use_conv and self.padding == 0:
388
+ pad = (0, 1, 0, 1, 1, 1)
389
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
390
+
391
+ assert hidden_states.shape[1] == self.channels
392
+
393
+ hidden_states = self.conv(hidden_states)
394
+
395
+ return hidden_states
396
+
397
+
398
+ class TemporalDownsample2x(nn.Module):
399
+ """A Temporal downsampling layer with an optional convolution.
400
+
401
+ Parameters:
402
+ channels (`int`):
403
+ number of channels in the inputs and outputs.
404
+ use_conv (`bool`, default `False`):
405
+ option to use a convolution.
406
+ out_channels (`int`, optional):
407
+ number of output channels. Defaults to `channels`.
408
+ padding (`int`, default `1`):
409
+ padding for the convolution.
410
+ name (`str`, default `conv`):
411
+ name of the downsampling 2D layer.
412
+ """
413
+
414
+ def __init__(
415
+ self,
416
+ channels: int,
417
+ use_conv: bool = False,
418
+ out_channels: Optional[int] = None,
419
+ padding: int = 0,
420
+ kernel_size=3,
421
+ bias=True,
422
+ ):
423
+ super().__init__()
424
+ self.channels = channels
425
+ self.out_channels = out_channels or channels
426
+ self.use_conv = use_conv
427
+ self.padding = padding
428
+ stride = (2, 1, 1)
429
+
430
+ conv_cls = nn.Conv3d
431
+
432
+ if use_conv:
433
+ conv = conv_cls(
434
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
435
+ )
436
+ else:
437
+ raise NotImplementedError("Not implemented for temporal downsample without")
438
+
439
+ self.conv = conv
440
+
441
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
442
+ assert hidden_states.shape[1] == self.channels
443
+
444
+ if self.use_conv and self.padding == 0:
445
+ if hidden_states.shape[2] == 1:
446
+ # image
447
+ pad = (1, 1, 1, 1, 1, 1)
448
+ else:
449
+ # video
450
+ pad = (1, 1, 1, 1, 0, 1)
451
+
452
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
453
+
454
+ hidden_states = self.conv(hidden_states)
455
+ return hidden_states
456
+
457
+
458
+ class CausalTemporalDownsample2x(nn.Module):
459
+ """A Temporal downsampling layer with an optional convolution.
460
+
461
+ Parameters:
462
+ channels (`int`):
463
+ number of channels in the inputs and outputs.
464
+ use_conv (`bool`, default `False`):
465
+ option to use a convolution.
466
+ out_channels (`int`, optional):
467
+ number of output channels. Defaults to `channels`.
468
+ padding (`int`, default `1`):
469
+ padding for the convolution.
470
+ name (`str`, default `conv`):
471
+ name of the downsampling 2D layer.
472
+ """
473
+
474
+ def __init__(
475
+ self,
476
+ channels: int,
477
+ use_conv: bool = False,
478
+ out_channels: Optional[int] = None,
479
+ kernel_size=3,
480
+ bias=True,
481
+ ):
482
+ super().__init__()
483
+ self.channels = channels
484
+ self.out_channels = out_channels or channels
485
+ self.use_conv = use_conv
486
+ stride = (2, 1, 1)
487
+
488
+ conv_cls = nn.Conv3d
489
+
490
+ if use_conv:
491
+ conv = CausalConv3d(
492
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
493
+ )
494
+ else:
495
+ raise NotImplementedError("Not implemented for temporal downsample without")
496
+
497
+ self.conv = conv
498
+
499
+ def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
500
+ assert hidden_states.shape[1] == self.channels
501
+ hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
502
+ return hidden_states
503
+
504
+
505
+ class Upsample2D(nn.Module):
506
+ """A 2D upsampling layer with an optional convolution.
507
+
508
+ Parameters:
509
+ channels (`int`):
510
+ number of channels in the inputs and outputs.
511
+ use_conv (`bool`, default `False`):
512
+ option to use a convolution.
513
+ out_channels (`int`, optional):
514
+ number of output channels. Defaults to `channels`.
515
+ name (`str`, default `conv`):
516
+ name of the upsampling 2D layer.
517
+ """
518
+
519
+ def __init__(
520
+ self,
521
+ channels: int,
522
+ use_conv: bool = False,
523
+ out_channels: Optional[int] = None,
524
+ name: str = "conv",
525
+ kernel_size: Optional[int] = None,
526
+ padding=1,
527
+ bias=True,
528
+ interpolate=False,
529
+ ):
530
+ super().__init__()
531
+ self.channels = channels
532
+ self.out_channels = out_channels or channels
533
+ self.use_conv = use_conv
534
+ self.name = name
535
+ self.interpolate = interpolate
536
+ conv_cls = nn.Conv3d
537
+ conv = None
538
+
539
+ if interpolate:
540
+ raise NotImplementedError("Not implemented for spatial upsample with interpolate")
541
+ else:
542
+ if kernel_size is None:
543
+ kernel_size = 3
544
+ conv = conv_cls(self.channels, self.out_channels * 4, kernel_size=kernel_size, padding=padding, bias=bias)
545
+
546
+ self.conv = conv
547
+ self.conv.apply(self._init_weights)
548
+
549
+ def _init_weights(self, m):
550
+ if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
551
+ trunc_normal_(m.weight, std=.02)
552
+ if m.bias is not None:
553
+ nn.init.constant_(m.bias, 0)
554
+ elif isinstance(m, nn.LayerNorm):
555
+ nn.init.constant_(m.bias, 0)
556
+ nn.init.constant_(m.weight, 1.0)
557
+
558
+ def forward(
559
+ self,
560
+ hidden_states: torch.FloatTensor,
561
+ ) -> torch.FloatTensor:
562
+ assert hidden_states.shape[1] == self.channels
563
+
564
+ hidden_states = self.conv(hidden_states)
565
+ hidden_states = rearrange(hidden_states, 'b (c p1 p2) t h w -> b c t (h p1) (w p2)', p1=2, p2=2)
566
+
567
+ return hidden_states
568
+
569
+
570
+ class CausalUpsample2x(nn.Module):
571
+ """A 2D upsampling layer with an optional convolution.
572
+
573
+ Parameters:
574
+ channels (`int`):
575
+ number of channels in the inputs and outputs.
576
+ use_conv (`bool`, default `False`):
577
+ option to use a convolution.
578
+ out_channels (`int`, optional):
579
+ number of output channels. Defaults to `channels`.
580
+ name (`str`, default `conv`):
581
+ name of the upsampling 2D layer.
582
+ """
583
+
584
+ def __init__(
585
+ self,
586
+ channels: int,
587
+ use_conv: bool = False,
588
+ out_channels: Optional[int] = None,
589
+ name: str = "conv",
590
+ kernel_size: Optional[int] = 3,
591
+ bias=True,
592
+ interpolate=False,
593
+ ):
594
+ super().__init__()
595
+ self.channels = channels
596
+ self.out_channels = out_channels or channels
597
+ self.use_conv = use_conv
598
+ self.name = name
599
+ self.interpolate = interpolate
600
+ conv = None
601
+
602
+ if interpolate:
603
+ raise NotImplementedError("Not implemented for spatial upsample with interpolate")
604
+ else:
605
+ conv = CausalConv3d(self.channels, self.out_channels * 4, kernel_size=kernel_size, stride=1, bias=bias)
606
+
607
+ self.conv = conv
608
+
609
+ def forward(
610
+ self,
611
+ hidden_states: torch.FloatTensor,
612
+ is_init_image=True, temporal_chunk=False,
613
+ ) -> torch.FloatTensor:
614
+ assert hidden_states.shape[1] == self.channels
615
+ hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
616
+ hidden_states = rearrange(hidden_states, 'b (c p1 p2) t h w -> b c t (h p1) (w p2)', p1=2, p2=2)
617
+ return hidden_states
618
+
619
+
620
+ class TemporalUpsample2x(nn.Module):
621
+ """A 2D upsampling layer with an optional convolution.
622
+
623
+ Parameters:
624
+ channels (`int`):
625
+ number of channels in the inputs and outputs.
626
+ use_conv (`bool`, default `False`):
627
+ option to use a convolution.
628
+ out_channels (`int`, optional):
629
+ number of output channels. Defaults to `channels`.
630
+ name (`str`, default `conv`):
631
+ name of the upsampling 2D layer.
632
+ """
633
+
634
+ def __init__(
635
+ self,
636
+ channels: int,
637
+ use_conv: bool = True,
638
+ out_channels: Optional[int] = None,
639
+ kernel_size: Optional[int] = None,
640
+ padding=1,
641
+ bias=True,
642
+ interpolate=False,
643
+ ):
644
+ super().__init__()
645
+ self.channels = channels
646
+ self.out_channels = out_channels or channels
647
+ self.use_conv = use_conv
648
+ self.interpolate = interpolate
649
+ conv_cls = nn.Conv3d
650
+
651
+ conv = None
652
+ if interpolate:
653
+ raise NotImplementedError("Not implemented for spatial upsample with interpolate")
654
+ else:
655
+ # depth to space operator
656
+ if kernel_size is None:
657
+ kernel_size = 3
658
+ conv = conv_cls(self.channels, self.out_channels * 2, kernel_size=kernel_size, padding=padding, bias=bias)
659
+
660
+ self.conv = conv
661
+
662
+ def forward(
663
+ self,
664
+ hidden_states: torch.FloatTensor,
665
+ is_image: bool = False,
666
+ ) -> torch.FloatTensor:
667
+ assert hidden_states.shape[1] == self.channels
668
+ t = hidden_states.shape[2]
669
+ hidden_states = self.conv(hidden_states)
670
+ hidden_states = rearrange(hidden_states, 'b (c p) t h w -> b c (p t) h w', p=2)
671
+
672
+ if t == 1 and is_image:
673
+ hidden_states = hidden_states[:, :, 1:]
674
+
675
+ return hidden_states
676
+
677
+
678
+ class CausalTemporalUpsample2x(nn.Module):
679
+ """A 2D upsampling layer with an optional convolution.
680
+
681
+ Parameters:
682
+ channels (`int`):
683
+ number of channels in the inputs and outputs.
684
+ use_conv (`bool`, default `False`):
685
+ option to use a convolution.
686
+ out_channels (`int`, optional):
687
+ number of output channels. Defaults to `channels`.
688
+ name (`str`, default `conv`):
689
+ name of the upsampling 2D layer.
690
+ """
691
+
692
+ def __init__(
693
+ self,
694
+ channels: int,
695
+ use_conv: bool = True,
696
+ out_channels: Optional[int] = None,
697
+ kernel_size: Optional[int] = 3,
698
+ bias=True,
699
+ interpolate=False,
700
+ ):
701
+ super().__init__()
702
+ self.channels = channels
703
+ self.out_channels = out_channels or channels
704
+ self.use_conv = use_conv
705
+ self.interpolate = interpolate
706
+
707
+ conv = None
708
+ if interpolate:
709
+ raise NotImplementedError("Not implemented for spatial upsample with interpolate")
710
+ else:
711
+ # depth to space operator
712
+ conv = CausalConv3d(self.channels, self.out_channels * 2, kernel_size=kernel_size, stride=1, bias=bias)
713
+
714
+ self.conv = conv
715
+
716
+ def forward(
717
+ self,
718
+ hidden_states: torch.FloatTensor,
719
+ is_init_image=True, temporal_chunk=False,
720
+ ) -> torch.FloatTensor:
721
+ assert hidden_states.shape[1] == self.channels
722
+ t = hidden_states.shape[2]
723
+ hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
724
+ hidden_states = rearrange(hidden_states, 'b (c p) t h w -> b c (t p) h w', p=2)
725
+
726
+ if is_init_image:
727
+ hidden_states = hidden_states[:, :, 1:]
728
+
729
+ return hidden_states