fffiloni commited on
Commit
1a9b87d
·
verified ·
1 Parent(s): 25fd8c3

Migrated from GitHub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/examples/speech.wav filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
ORIGINAL_README.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MEMO
2
+
3
+ **MEMO: Memory-Guided Diffusion for Expressive Talking Video Generation**
4
+ <br>
5
+ [Longtao Zheng](https://ltzheng.github.io)\*,
6
+ [Yifan Zhang](https://scholar.google.com/citations?user=zuYIUJEAAAAJ)\*,
7
+ [Hanzhong Guo](https://scholar.google.com/citations?user=q3x6KsgAAAAJ)\,
8
+ [Jiachun Pan](https://scholar.google.com/citations?user=nrOvfb4AAAAJ),
9
+ [Zhenxiong Tan](https://scholar.google.com/citations?user=HP9Be6UAAAAJ),
10
+ [Jiahao Lu](https://scholar.google.com/citations?user=h7rbA-sAAAAJ),
11
+ [Chuanxin Tang](https://scholar.google.com/citations?user=3ZC8B7MAAAAJ),
12
+ [Bo An](https://personal.ntu.edu.sg/boan/index.html),
13
+ [Shuicheng Yan](https://scholar.google.com/citations?user=DNuiPHwAAAAJ)
14
+ <br>
15
+ _[Project Page](https://memoavatar.github.io) | [arXiv](https://arxiv.org/abs/2412.04448) | [Model](https://huggingface.co/memoavatar/memo)_
16
+
17
+ This repository contains the example inference script for the MEMO-preview model. The gif demo below is compressed. See our [project page](https://memoavatar.github.io) for full videos.
18
+
19
+ <div style="width: 100%; text-align: center;">
20
+ <img src="assets/demo.gif" alt="Demo GIF" style="width: 100%; height: auto;">
21
+ </div>
22
+
23
+ ## Installation
24
+
25
+ ```bash
26
+ conda create -n memo python=3.10 -y
27
+ conda activate memo
28
+ conda install -c conda-forge ffmpeg -y
29
+ pip install -e .
30
+ ```
31
+
32
+ > Our code will download the checkpoint from Hugging Face automatically, and the models for face analysis and vocal separation will be downloaded to `misc_model_dir` of `configs/inference.yaml`. If you want to download the models manually, please download the checkpoint from [here](https://huggingface.co/memoavatar/memo) and specify the path in `model_name_or_path` of `configs/inference.yaml`.
33
+
34
+ ## Inference
35
+
36
+ ```bash
37
+ python inference.py --config configs/inference.yaml --input_image <IMAGE_PATH> --input_audio <AUDIO_PATH> --output_dir <SAVE_PATH>
38
+ ```
39
+
40
+ For example:
41
+
42
+ ```bash
43
+ python inference.py --config configs/inference.yaml --input_image assets/examples/dicaprio.jpg --input_audio assets/examples/speech.wav --output_dir outputs
44
+ ```
45
+
46
+ > We tested the code on H100 and RTX 4090 GPUs using CUDA 12. Under the default settings (fps=30, inference_steps=20), the inference time is around 1 second per frame on H100 and 2 seconds per frame on RTX 4090. We welcome community contributions to improve the inference speed or interfaces like ComfyUI.
47
+
48
+ ## Acknowledgement
49
+
50
+ Our work is made possible thanks to high-quality open-source talking video datasets (including [HDTF](https://github.com/MRzzm/HDTF), [VFHQ](https://liangbinxie.github.io/projects/vfhq), [CelebV-HQ](https://celebv-hq.github.io), [MultiTalk](https://multi-talk.github.io), and [MEAD](https://wywu.github.io/projects/MEAD/MEAD.html)) and some pioneering works (such as [EMO](https://humanaigc.github.io/emote-portrait-alive) and [Hallo](https://github.com/fudan-generative-vision/hallo)).
51
+
52
+ ## Ethics Statement
53
+
54
+ We acknowledge the potential of AI in generating talking videos, with applications spanning education, virtual assistants, and entertainment. However, we are equally aware of the ethical, legal, and societal challenges that misuse of this technology could pose.
55
+
56
+ To reduce potential risks, we have only open-sourced a preview model for research purposes. Demos on our website use publicly available materials. We welcome copyright concerns—please contact us if needed, and we will address issues promptly. Users are required to ensure that their actions align with legal regulations, cultural norms, and ethical standards.
57
+
58
+ It is strictly prohibited to use the model for creating malicious, misleading, defamatory, or privacy-infringing content, such as deepfake videos for political misinformation, impersonation, harassment, or fraud. We strongly encourage users to review generated content carefully, ensuring it meets ethical guidelines and respects the rights of all parties involved. Users must also ensure that their inputs (e.g., audio and reference images) and outputs are used with proper authorization. Unauthorized use of third-party intellectual property is strictly forbidden.
59
+
60
+ While users may claim ownership of content generated by the model, they must ensure compliance with copyright laws, particularly when involving public figures' likeness, voice, or other aspects protected under personality rights.
61
+
62
+ ## Citation
63
+
64
+ If you find our work useful, please use the following citation:
65
+
66
+ ```bibtex
67
+ @article{zheng2024memo,
68
+ title={MEMO: Memory-Guided Diffusion for Expressive Talking Video Generation},
69
+ author={Longtao Zheng and Yifan Zhang and Hanzhong Guo and Jiachun Pan and Zhenxiong Tan and Jiahao Lu and Chuanxin Tang and Bo An and Shuicheng Yan},
70
+ journal={arXiv preprint arXiv:2412.04448},
71
+ year={2024}
72
+ }
73
+ ```
assets/demo.gif ADDED

Git LFS Details

  • SHA256: 29524bac983cde4772839769aae1c175cae886d9969488bcd1a19724f62d6b47
  • Pointer size: 132 Bytes
  • Size of remote file: 4.12 MB
assets/examples/dicaprio.jpg ADDED
assets/examples/speech.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16e82c22a2e7104861943b994c40a537271653cb7d0b1b722dda2cda8ab75a7c
3
+ size 2646078
configs/inference.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ resolution: 512
2
+ num_generated_frames_per_clip: 16
3
+ fps: 30
4
+ num_init_past_frames: 2
5
+ num_past_frames: 16
6
+ inference_steps: 20
7
+ cfg_scale: 3.5
8
+ weight_dtype: bf16
9
+ enable_xformers_memory_efficient_attention: true
10
+
11
+ model_name_or_path: memoavatar/memo
12
+ # model_name_or_path: checkpoints
13
+ vae: stabilityai/sd-vae-ft-mse
14
+ wav2vec: facebook/wav2vec2-base-960h
15
+ emotion2vec: iic/emotion2vec_plus_large
16
+ misc_model_dir: checkpoints
inference.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+
5
+ import torch
6
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
7
+ from diffusers.utils.import_utils import is_xformers_available
8
+ from omegaconf import OmegaConf
9
+ from packaging import version
10
+ from tqdm import tqdm
11
+
12
+ from memo.models.audio_proj import AudioProjModel
13
+ from memo.models.image_proj import ImageProjModel
14
+ from memo.models.unet_2d_condition import UNet2DConditionModel
15
+ from memo.models.unet_3d import UNet3DConditionModel
16
+ from memo.pipelines.video_pipeline import VideoPipeline
17
+ from memo.utils.audio_utils import extract_audio_emotion_labels, preprocess_audio, resample_audio
18
+ from memo.utils.vision_utils import preprocess_image, tensor_to_video
19
+
20
+
21
+ logger = logging.getLogger("memo")
22
+ logger.setLevel(logging.INFO)
23
+
24
+
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser(description="Inference script for MEMO")
27
+
28
+ parser.add_argument("--config", type=str, default="configs/inference.yaml")
29
+ parser.add_argument("--input_image", type=str)
30
+ parser.add_argument("--input_audio", type=str)
31
+ parser.add_argument("--output_dir", type=str)
32
+ parser.add_argument("--seed", type=int, default=42)
33
+
34
+ return parser.parse_args()
35
+
36
+
37
+ def main():
38
+ # Parse arguments
39
+ args = parse_args()
40
+ input_image_path = args.input_image
41
+ input_audio_path = args.input_audio
42
+ if "wav" not in input_audio_path:
43
+ logger.warning("MEMO might not generate full-length video for non-wav audio file.")
44
+ output_dir = args.output_dir
45
+ os.makedirs(output_dir, exist_ok=True)
46
+ output_video_path = os.path.join(
47
+ output_dir,
48
+ f"{os.path.basename(input_image_path).split('.')[0]}_{os.path.basename(input_audio_path).split('.')[0]}.mp4",
49
+ )
50
+
51
+ if os.path.exists(output_video_path):
52
+ logger.info(f"Output file {output_video_path} already exists. Skipping inference.")
53
+ return
54
+
55
+ generator = torch.manual_seed(args.seed)
56
+
57
+ logger.info(f"Loading config from {args.config}")
58
+ config = OmegaConf.load(args.config)
59
+
60
+ # Determine model paths
61
+ if config.model_name_or_path == "memoavatar/memo":
62
+ logger.info(
63
+ f"The MEMO model will be downloaded from Hugging Face to the default cache directory. The models for face analysis and vocal separation will be downloaded to {config.misc_model_dir}."
64
+ )
65
+
66
+ face_analysis = os.path.join(config.misc_model_dir, "misc/face_analysis")
67
+ os.makedirs(face_analysis, exist_ok=True)
68
+ for model in [
69
+ "1k3d68.onnx",
70
+ "2d106det.onnx",
71
+ "face_landmarker_v2_with_blendskapes.task",
72
+ "genderage.onnx",
73
+ "glintr100.onnx",
74
+ "scrfd_10g_bnkps.onnx",
75
+ ]:
76
+ if not os.path.exists(os.path.join(face_analysis, model)):
77
+ logger.info(f"Downloading {model} to {face_analysis}")
78
+ os.system(
79
+ f"wget -P {face_analysis} https://huggingface.co/memoavatar/memo/raw/main/misc/face_analysis/models/{model}"
80
+ )
81
+ logger.info(f"Use face analysis models from {face_analysis}")
82
+
83
+ vocal_separator = os.path.join(config.misc_model_dir, "misc/vocal_separator/Kim_Vocal_2.onnx")
84
+ if os.path.exists(vocal_separator):
85
+ logger.info(f"Vocal separator {vocal_separator} already exists. Skipping download.")
86
+ else:
87
+ logger.info(f"Downloading vocal separator to {vocal_separator}")
88
+ os.makedirs(os.path.dirname(vocal_separator), exist_ok=True)
89
+ os.system(
90
+ f"wget -P {os.path.dirname(vocal_separator)} https://huggingface.co/memoavatar/memo/raw/main/misc/vocal_separator/Kim_Vocal_2.onnx"
91
+ )
92
+ else:
93
+ logger.info(f"Loading manually specified model path: {config.model_name_or_path}")
94
+ face_analysis = os.path.join(config.model_name_or_path, "misc/face_analysis")
95
+ vocal_separator = os.path.join(config.model_name_or_path, "misc/vocal_separator/Kim_Vocal_2.onnx")
96
+
97
+ # Set up device and weight dtype
98
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
99
+ if config.weight_dtype == "fp16":
100
+ weight_dtype = torch.float16
101
+ elif config.weight_dtype == "bf16":
102
+ weight_dtype = torch.bfloat16
103
+ elif config.weight_dtype == "fp32":
104
+ weight_dtype = torch.float32
105
+ else:
106
+ weight_dtype = torch.float32
107
+ logger.info(f"Inference dtype: {weight_dtype}")
108
+
109
+ logger.info(f"Processing image {input_image_path}")
110
+ img_size = (config.resolution, config.resolution)
111
+ pixel_values, face_emb = preprocess_image(
112
+ face_analysis_model=face_analysis,
113
+ image_path=input_image_path,
114
+ image_size=config.resolution,
115
+ )
116
+
117
+ logger.info(f"Processing audio {input_audio_path}")
118
+ cache_dir = os.path.join(output_dir, "audio_preprocess")
119
+ os.makedirs(cache_dir, exist_ok=True)
120
+ input_audio_path = resample_audio(
121
+ input_audio_path,
122
+ os.path.join(cache_dir, f"{os.path.basename(input_audio_path).split('.')[0]}-16k.wav"),
123
+ )
124
+ audio_emb, audio_length = preprocess_audio(
125
+ wav_path=input_audio_path,
126
+ num_generated_frames_per_clip=config.num_generated_frames_per_clip,
127
+ fps=config.fps,
128
+ wav2vec_model=config.wav2vec,
129
+ vocal_separator_model=vocal_separator,
130
+ cache_dir=cache_dir,
131
+ device=device,
132
+ )
133
+
134
+ logger.info("Processing audio emotion")
135
+ audio_emotion, num_emotion_classes = extract_audio_emotion_labels(
136
+ model=config.model_name_or_path,
137
+ wav_path=input_audio_path,
138
+ emotion2vec_model=config.emotion2vec,
139
+ audio_length=audio_length,
140
+ device=device,
141
+ )
142
+
143
+ logger.info("Loading models")
144
+ vae = AutoencoderKL.from_pretrained(config.vae).to(device=device, dtype=weight_dtype)
145
+ reference_net = UNet2DConditionModel.from_pretrained(
146
+ config.model_name_or_path, subfolder="reference_net", use_safetensors=True
147
+ )
148
+ diffusion_net = UNet3DConditionModel.from_pretrained(
149
+ config.model_name_or_path, subfolder="diffusion_net", use_safetensors=True
150
+ )
151
+ image_proj = ImageProjModel.from_pretrained(
152
+ config.model_name_or_path, subfolder="image_proj", use_safetensors=True
153
+ )
154
+ audio_proj = AudioProjModel.from_pretrained(
155
+ config.model_name_or_path, subfolder="audio_proj", use_safetensors=True
156
+ )
157
+
158
+ vae.requires_grad_(False).eval()
159
+ reference_net.requires_grad_(False).eval()
160
+ diffusion_net.requires_grad_(False).eval()
161
+ image_proj.requires_grad_(False).eval()
162
+ audio_proj.requires_grad_(False).eval()
163
+
164
+ # Enable memory-efficient attention for xFormers
165
+ if config.enable_xformers_memory_efficient_attention:
166
+ if is_xformers_available():
167
+ import xformers
168
+
169
+ xformers_version = version.parse(xformers.__version__)
170
+ if xformers_version == version.parse("0.0.16"):
171
+ logger.info(
172
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
173
+ )
174
+ reference_net.enable_xformers_memory_efficient_attention()
175
+ diffusion_net.enable_xformers_memory_efficient_attention()
176
+ else:
177
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
178
+
179
+ # Create inference pipeline
180
+ noise_scheduler = FlowMatchEulerDiscreteScheduler()
181
+ pipeline = VideoPipeline(
182
+ vae=vae,
183
+ reference_net=reference_net,
184
+ diffusion_net=diffusion_net,
185
+ scheduler=noise_scheduler,
186
+ image_proj=image_proj,
187
+ )
188
+ pipeline.to(device=device, dtype=weight_dtype)
189
+
190
+ video_frames = []
191
+ num_clips = audio_emb.shape[0] // config.num_generated_frames_per_clip
192
+ for t in tqdm(range(num_clips), desc="Generating video clips"):
193
+ if len(video_frames) == 0:
194
+ # Initialize the first past frames with reference image
195
+ past_frames = pixel_values.repeat(config.num_init_past_frames, 1, 1, 1)
196
+ past_frames = past_frames.to(dtype=pixel_values.dtype, device=pixel_values.device)
197
+ pixel_values_ref_img = torch.cat([pixel_values, past_frames], dim=0)
198
+ else:
199
+ past_frames = video_frames[-1][0]
200
+ past_frames = past_frames.permute(1, 0, 2, 3)
201
+ past_frames = past_frames[0 - config.num_past_frames :]
202
+ past_frames = past_frames * 2.0 - 1.0
203
+ past_frames = past_frames.to(dtype=pixel_values.dtype, device=pixel_values.device)
204
+ pixel_values_ref_img = torch.cat([pixel_values, past_frames], dim=0)
205
+
206
+ pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
207
+
208
+ audio_tensor = (
209
+ audio_emb[
210
+ t
211
+ * config.num_generated_frames_per_clip : min(
212
+ (t + 1) * config.num_generated_frames_per_clip, audio_emb.shape[0]
213
+ )
214
+ ]
215
+ .unsqueeze(0)
216
+ .to(device=audio_proj.device, dtype=audio_proj.dtype)
217
+ )
218
+ audio_tensor = audio_proj(audio_tensor)
219
+
220
+ audio_emotion_tensor = audio_emotion[
221
+ t
222
+ * config.num_generated_frames_per_clip : min(
223
+ (t + 1) * config.num_generated_frames_per_clip, audio_emb.shape[0]
224
+ )
225
+ ]
226
+
227
+ pipeline_output = pipeline(
228
+ ref_image=pixel_values_ref_img,
229
+ audio_tensor=audio_tensor,
230
+ audio_emotion=audio_emotion_tensor,
231
+ emotion_class_num=num_emotion_classes,
232
+ face_emb=face_emb,
233
+ width=img_size[0],
234
+ height=img_size[1],
235
+ video_length=config.num_generated_frames_per_clip,
236
+ num_inference_steps=config.inference_steps,
237
+ guidance_scale=config.cfg_scale,
238
+ generator=generator,
239
+ )
240
+
241
+ video_frames.append(pipeline_output.videos)
242
+
243
+ video_frames = torch.cat(video_frames, dim=2)
244
+ video_frames = video_frames.squeeze(0)
245
+ video_frames = video_frames[:, :audio_length]
246
+
247
+ tensor_to_video(video_frames, output_video_path, input_audio_path, fps=config.fps)
248
+
249
+
250
+ if __name__ == "__main__":
251
+ main()
memo/__init__.py ADDED
File without changes
memo/models/__init__.py ADDED
File without changes
memo/models/attention.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ from diffusers.models.attention import (
5
+ AdaLayerNorm,
6
+ AdaLayerNormZero,
7
+ Attention,
8
+ FeedForward,
9
+ )
10
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
11
+ from einops import rearrange
12
+ from torch import nn
13
+
14
+ from memo.models.attention_processor import Attention as CustomAttention
15
+ from memo.models.attention_processor import JointAttnProcessor2_0
16
+
17
+
18
+ class GatedSelfAttentionDense(nn.Module):
19
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
20
+ super().__init__()
21
+
22
+ self.linear = nn.Linear(context_dim, query_dim)
23
+
24
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
25
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
26
+
27
+ self.norm1 = nn.LayerNorm(query_dim)
28
+ self.norm2 = nn.LayerNorm(query_dim)
29
+
30
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
31
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
32
+
33
+ self.enabled = True
34
+
35
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
36
+ if not self.enabled:
37
+ return x
38
+
39
+ n_visual = x.shape[1]
40
+ objs = self.linear(objs)
41
+
42
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
43
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
44
+
45
+ return x
46
+
47
+
48
+ class BasicTransformerBlock(nn.Module):
49
+ def __init__(
50
+ self,
51
+ dim: int,
52
+ num_attention_heads: int,
53
+ attention_head_dim: int,
54
+ dropout=0.0,
55
+ cross_attention_dim: Optional[int] = None,
56
+ activation_fn: str = "geglu",
57
+ num_embeds_ada_norm: Optional[int] = None,
58
+ attention_bias: bool = False,
59
+ only_cross_attention: bool = False,
60
+ double_self_attention: bool = False,
61
+ upcast_attention: bool = False,
62
+ norm_elementwise_affine: bool = True,
63
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
64
+ norm_eps: float = 1e-5,
65
+ final_dropout: bool = False,
66
+ attention_type: str = "default",
67
+ positional_embeddings: Optional[str] = None,
68
+ num_positional_embeddings: Optional[int] = None,
69
+ is_final_block: bool = False,
70
+ ):
71
+ super().__init__()
72
+ self.only_cross_attention = only_cross_attention
73
+ self.is_final_block = is_final_block
74
+
75
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
76
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
77
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
78
+ self.use_layer_norm = norm_type == "layer_norm"
79
+
80
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
81
+ raise ValueError(
82
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
83
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
84
+ )
85
+
86
+ if positional_embeddings and (num_positional_embeddings is None):
87
+ raise ValueError(
88
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
89
+ )
90
+
91
+ if positional_embeddings == "sinusoidal":
92
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
93
+ else:
94
+ self.pos_embed = None
95
+
96
+ # Define 3 blocks. Each block has its own normalization layer.
97
+ # 1. Self-Attn
98
+ if self.use_ada_layer_norm:
99
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
100
+ elif self.use_ada_layer_norm_zero:
101
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
102
+ else:
103
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
104
+
105
+ if not is_final_block:
106
+ self.attn1 = Attention(
107
+ query_dim=dim,
108
+ heads=num_attention_heads,
109
+ dim_head=attention_head_dim,
110
+ dropout=dropout,
111
+ bias=attention_bias,
112
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
113
+ upcast_attention=upcast_attention,
114
+ )
115
+
116
+ # 2. Cross-Attn
117
+ if cross_attention_dim is not None or double_self_attention:
118
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
119
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
120
+ # the second cross attention block.
121
+ self.norm2 = (
122
+ AdaLayerNorm(dim, num_embeds_ada_norm)
123
+ if self.use_ada_layer_norm
124
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
125
+ )
126
+ self.attn2 = Attention(
127
+ query_dim=dim,
128
+ cross_attention_dim=(cross_attention_dim if not double_self_attention else None),
129
+ heads=num_attention_heads,
130
+ dim_head=attention_head_dim,
131
+ dropout=dropout,
132
+ bias=attention_bias,
133
+ upcast_attention=upcast_attention,
134
+ )
135
+ else:
136
+ self.norm2 = None
137
+ self.attn2 = None
138
+
139
+ # 3. Feed-forward
140
+ if not self.use_ada_layer_norm_single:
141
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
142
+
143
+ self.ff = FeedForward(
144
+ dim,
145
+ dropout=dropout,
146
+ activation_fn=activation_fn,
147
+ final_dropout=final_dropout,
148
+ )
149
+
150
+ # 4. Fuser
151
+ if attention_type in {"gated", "gated-text-image"}: # Updated line
152
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
153
+
154
+ # 5. Scale-shift for PixArt-Alpha.
155
+ if self.use_ada_layer_norm_single:
156
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
157
+
158
+ # let chunk size default to None
159
+ self._chunk_size = None
160
+ self._chunk_dim = 0
161
+
162
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
163
+ self._chunk_size = chunk_size
164
+ self._chunk_dim = dim
165
+
166
+ def forward(
167
+ self,
168
+ hidden_states: torch.FloatTensor,
169
+ attention_mask: Optional[torch.FloatTensor] = None,
170
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
171
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
172
+ timestep: Optional[torch.LongTensor] = None,
173
+ cross_attention_kwargs: Dict[str, Any] = None,
174
+ class_labels: Optional[torch.LongTensor] = None,
175
+ ) -> torch.FloatTensor:
176
+ # Notice that normalization is always applied before the real computation in the following blocks.
177
+ # 0. Self-Attention
178
+ batch_size = hidden_states.shape[0]
179
+
180
+ gate_msa = None
181
+ scale_mlp = None
182
+ shift_mlp = None
183
+ gate_mlp = None
184
+ if self.use_ada_layer_norm:
185
+ norm_hidden_states = self.norm1(hidden_states, timestep)
186
+ elif self.use_ada_layer_norm_zero:
187
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
188
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
189
+ )
190
+ elif self.use_layer_norm:
191
+ norm_hidden_states = self.norm1(hidden_states)
192
+ elif self.use_ada_layer_norm_single:
193
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
194
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
195
+ ).chunk(6, dim=1)
196
+ norm_hidden_states = self.norm1(hidden_states)
197
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
198
+ norm_hidden_states = norm_hidden_states.squeeze(1)
199
+ else:
200
+ raise ValueError("Incorrect norm used")
201
+
202
+ if self.pos_embed is not None:
203
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
204
+
205
+ # 1. Retrieve lora scale.
206
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
207
+
208
+ # 2. Prepare GLIGEN inputs
209
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
210
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
211
+
212
+ ref_feature = norm_hidden_states
213
+ if self.is_final_block:
214
+ return None, ref_feature
215
+ attn_output = self.attn1(
216
+ norm_hidden_states,
217
+ encoder_hidden_states=(encoder_hidden_states if self.only_cross_attention else None),
218
+ attention_mask=attention_mask,
219
+ **cross_attention_kwargs,
220
+ )
221
+ if self.use_ada_layer_norm_zero:
222
+ attn_output = gate_msa.unsqueeze(1) * attn_output
223
+ elif self.use_ada_layer_norm_single:
224
+ attn_output = gate_msa * attn_output
225
+
226
+ hidden_states = attn_output + hidden_states
227
+ if hidden_states.ndim == 4:
228
+ hidden_states = hidden_states.squeeze(1)
229
+
230
+ # 2.5 GLIGEN Control
231
+ if gligen_kwargs is not None:
232
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
233
+
234
+ # 3. Cross-Attention
235
+ if self.attn2 is not None:
236
+ if self.use_ada_layer_norm:
237
+ norm_hidden_states = self.norm2(hidden_states, timestep)
238
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
239
+ norm_hidden_states = self.norm2(hidden_states)
240
+ elif self.use_ada_layer_norm_single:
241
+ # For PixArt norm2 isn't applied here:
242
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
243
+ norm_hidden_states = hidden_states
244
+ else:
245
+ raise ValueError("Incorrect norm")
246
+
247
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
248
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
249
+
250
+ attn_output = self.attn2(
251
+ norm_hidden_states,
252
+ encoder_hidden_states=encoder_hidden_states.repeat(
253
+ norm_hidden_states.shape[0] // encoder_hidden_states.shape[0], 1, 1
254
+ ),
255
+ attention_mask=encoder_attention_mask,
256
+ **cross_attention_kwargs,
257
+ )
258
+ hidden_states = attn_output + hidden_states
259
+
260
+ # 4. Feed-forward
261
+ if not self.use_ada_layer_norm_single:
262
+ norm_hidden_states = self.norm3(hidden_states)
263
+
264
+ if self.use_ada_layer_norm_zero:
265
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
266
+
267
+ if self.use_ada_layer_norm_single:
268
+ norm_hidden_states = self.norm2(hidden_states)
269
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
270
+
271
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
272
+
273
+ if self.use_ada_layer_norm_zero:
274
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
275
+ elif self.use_ada_layer_norm_single:
276
+ ff_output = gate_mlp * ff_output
277
+
278
+ hidden_states = ff_output + hidden_states
279
+ if hidden_states.ndim == 4:
280
+ hidden_states = hidden_states.squeeze(1)
281
+
282
+ return hidden_states, ref_feature
283
+
284
+
285
+ class TemporalBasicTransformerBlock(nn.Module):
286
+ def __init__(
287
+ self,
288
+ dim: int,
289
+ num_attention_heads: int,
290
+ attention_head_dim: int,
291
+ dropout=0.0,
292
+ cross_attention_dim: Optional[int] = None,
293
+ activation_fn: str = "geglu",
294
+ num_embeds_ada_norm: Optional[int] = None,
295
+ attention_bias: bool = False,
296
+ only_cross_attention: bool = False,
297
+ upcast_attention: bool = False,
298
+ unet_use_cross_frame_attention=None,
299
+ unet_use_temporal_attention=None,
300
+ ):
301
+ super().__init__()
302
+ self.only_cross_attention = only_cross_attention
303
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
304
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
305
+ self.unet_use_temporal_attention = unet_use_temporal_attention
306
+
307
+ self.attn1 = Attention(
308
+ query_dim=dim,
309
+ heads=num_attention_heads,
310
+ dim_head=attention_head_dim,
311
+ dropout=dropout,
312
+ bias=attention_bias,
313
+ upcast_attention=upcast_attention,
314
+ )
315
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
316
+
317
+ # Cross-Attn
318
+ if cross_attention_dim is not None:
319
+ self.attn2 = Attention(
320
+ query_dim=dim,
321
+ cross_attention_dim=cross_attention_dim,
322
+ heads=num_attention_heads,
323
+ dim_head=attention_head_dim,
324
+ dropout=dropout,
325
+ bias=attention_bias,
326
+ upcast_attention=upcast_attention,
327
+ )
328
+ else:
329
+ self.attn2 = None
330
+
331
+ if cross_attention_dim is not None:
332
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
333
+ else:
334
+ self.norm2 = None
335
+
336
+ # Feed-forward
337
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
338
+ self.norm3 = nn.LayerNorm(dim)
339
+ self.use_ada_layer_norm_zero = False
340
+
341
+ # Temp-Attn
342
+ if unet_use_temporal_attention is None:
343
+ unet_use_temporal_attention = False
344
+ if unet_use_temporal_attention:
345
+ self.attn_temp = Attention(
346
+ query_dim=dim,
347
+ heads=num_attention_heads,
348
+ dim_head=attention_head_dim,
349
+ dropout=dropout,
350
+ bias=attention_bias,
351
+ upcast_attention=upcast_attention,
352
+ )
353
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
354
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
355
+
356
+ def forward(
357
+ self,
358
+ hidden_states: torch.FloatTensor,
359
+ ref_img_feature: torch.FloatTensor,
360
+ attention_mask: Optional[torch.FloatTensor] = None,
361
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
362
+ timestep: Optional[torch.LongTensor] = None,
363
+ cross_attention_kwargs: Dict[str, Any] = None,
364
+ video_length=None,
365
+ uc_mask=None,
366
+ ):
367
+ norm_hidden_states = self.norm1(hidden_states)
368
+
369
+ # 1. Self-Attention
370
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
371
+ ref_img_feature = ref_img_feature.repeat(video_length, 1, 1)
372
+ modify_norm_hidden_states = torch.cat((norm_hidden_states, ref_img_feature), dim=1).to(
373
+ dtype=norm_hidden_states.dtype
374
+ )
375
+ hidden_states_uc = (
376
+ self.attn1(
377
+ norm_hidden_states,
378
+ encoder_hidden_states=modify_norm_hidden_states,
379
+ attention_mask=attention_mask,
380
+ )
381
+ + hidden_states
382
+ )
383
+ if uc_mask is not None:
384
+ hidden_states_c = hidden_states_uc.clone()
385
+ _uc_mask = uc_mask.clone()
386
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
387
+ _uc_mask = (
388
+ torch.Tensor([1] * (hidden_states.shape[0] // 2) + [0] * (hidden_states.shape[0] // 2))
389
+ .to(hidden_states_uc.device)
390
+ .bool()
391
+ )
392
+ hidden_states_c[_uc_mask] = (
393
+ self.attn1(
394
+ norm_hidden_states[_uc_mask],
395
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
396
+ attention_mask=attention_mask,
397
+ )
398
+ + hidden_states[_uc_mask]
399
+ )
400
+ hidden_states = hidden_states_c.clone()
401
+ else:
402
+ hidden_states = hidden_states_uc
403
+
404
+ if self.attn2 is not None:
405
+ # Cross-Attention
406
+ norm_hidden_states = (
407
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
408
+ )
409
+ hidden_states = (
410
+ self.attn2(
411
+ norm_hidden_states,
412
+ encoder_hidden_states=encoder_hidden_states,
413
+ attention_mask=attention_mask,
414
+ )
415
+ + hidden_states
416
+ )
417
+
418
+ # Feed-forward
419
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
420
+
421
+ # Temporal-Attention
422
+ if self.unet_use_temporal_attention:
423
+ d = hidden_states.shape[1]
424
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
425
+ norm_hidden_states = (
426
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
427
+ )
428
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
429
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
430
+
431
+ return hidden_states
432
+
433
+
434
+ class LabelEmbedding(nn.Module):
435
+ def __init__(self, num_classes, hidden_size, dropout_prob):
436
+ super().__init__()
437
+ use_cfg_embedding = dropout_prob > 0
438
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
439
+ self.num_classes = num_classes
440
+ self.dropout_prob = dropout_prob
441
+
442
+ def token_drop(self, labels, force_drop_ids=None):
443
+ # Drops labels to enable classifier-free guidance.
444
+ if force_drop_ids is None:
445
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
446
+ else:
447
+ drop_ids = torch.tensor(force_drop_ids == 1)
448
+ labels = torch.where(drop_ids, self.num_classes, labels)
449
+
450
+ return labels
451
+
452
+ def forward(self, labels: torch.LongTensor, force_drop_ids=None):
453
+ use_dropout = self.dropout_prob > 0
454
+ if (self.training and use_dropout) or (force_drop_ids is not None):
455
+ labels = self.token_drop(labels, force_drop_ids)
456
+ embeddings = self.embedding_table(labels)
457
+
458
+ return embeddings
459
+
460
+
461
+ class EmoAdaLayerNorm(nn.Module):
462
+ def __init__(
463
+ self,
464
+ embedding_dim,
465
+ num_classes=9,
466
+ norm_elementwise_affine: bool = False,
467
+ norm_eps: float = 1e-5,
468
+ class_dropout_prob=0.3,
469
+ ):
470
+ super().__init__()
471
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
472
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
473
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(embedding_dim, 2 * embedding_dim, bias=True))
474
+
475
+ def forward(self, x, emotion=None):
476
+ emo_embedding = self.class_embedder(emotion)
477
+ shift, scale = self.adaLN_modulation(emo_embedding).chunk(2, dim=1)
478
+ if emotion.shape[0] > 1:
479
+ repeat = x.shape[0] // emo_embedding.shape[0]
480
+ scale = scale.unsqueeze(1)
481
+ scale = torch.repeat_interleave(scale, repeats=repeat, dim=0)
482
+ shift = shift.unsqueeze(1)
483
+ shift = torch.repeat_interleave(shift, repeats=repeat, dim=0)
484
+ else:
485
+ scale = scale.unsqueeze(1)
486
+ shift = shift.unsqueeze(1)
487
+
488
+ x = self.norm(x) * (1 + scale) + shift
489
+
490
+ return x
491
+
492
+
493
+ class JointAudioTemporalBasicTransformerBlock(nn.Module):
494
+ def __init__(
495
+ self,
496
+ dim: int,
497
+ num_attention_heads: int,
498
+ attention_head_dim: int,
499
+ dropout=0.0,
500
+ cross_attention_dim: Optional[int] = None,
501
+ activation_fn: str = "geglu",
502
+ attention_bias: bool = False,
503
+ only_cross_attention: bool = False,
504
+ upcast_attention: bool = False,
505
+ unet_use_cross_frame_attention=None,
506
+ unet_use_temporal_attention=None,
507
+ depth=0,
508
+ unet_block_name=None,
509
+ use_ada_layer_norm=False,
510
+ emo_drop_rate=0.3,
511
+ is_final_block=False,
512
+ ):
513
+ super().__init__()
514
+ self.only_cross_attention = only_cross_attention
515
+ self.use_ada_layer_norm = use_ada_layer_norm
516
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
517
+ self.unet_use_temporal_attention = unet_use_temporal_attention
518
+ self.unet_block_name = unet_block_name
519
+ self.depth = depth
520
+ self.is_final_block = is_final_block
521
+
522
+ self.norm1 = (
523
+ EmoAdaLayerNorm(dim, num_classes=9, class_dropout_prob=emo_drop_rate)
524
+ if self.use_ada_layer_norm
525
+ else nn.LayerNorm(dim)
526
+ )
527
+ self.attn1 = CustomAttention(
528
+ query_dim=dim,
529
+ heads=num_attention_heads,
530
+ dim_head=attention_head_dim,
531
+ dropout=dropout,
532
+ bias=attention_bias,
533
+ upcast_attention=upcast_attention,
534
+ )
535
+
536
+ self.audio_norm1 = (
537
+ EmoAdaLayerNorm(cross_attention_dim, num_classes=9, class_dropout_prob=emo_drop_rate)
538
+ if self.use_ada_layer_norm
539
+ else nn.LayerNorm(cross_attention_dim)
540
+ )
541
+ self.audio_attn1 = CustomAttention(
542
+ query_dim=cross_attention_dim,
543
+ heads=num_attention_heads,
544
+ dim_head=attention_head_dim,
545
+ dropout=dropout,
546
+ bias=attention_bias,
547
+ upcast_attention=upcast_attention,
548
+ )
549
+
550
+ self.norm2 = (
551
+ EmoAdaLayerNorm(dim, num_classes=9, class_dropout_prob=emo_drop_rate)
552
+ if self.use_ada_layer_norm
553
+ else nn.LayerNorm(dim)
554
+ )
555
+ self.audio_norm2 = (
556
+ EmoAdaLayerNorm(cross_attention_dim, num_classes=9, class_dropout_prob=emo_drop_rate)
557
+ if self.use_ada_layer_norm
558
+ else nn.LayerNorm(cross_attention_dim)
559
+ )
560
+
561
+ # Joint Attention
562
+ self.attn2 = CustomAttention(
563
+ query_dim=dim,
564
+ heads=num_attention_heads,
565
+ dim_head=attention_head_dim,
566
+ cross_attention_dim=dim,
567
+ added_kv_proj_dim=cross_attention_dim,
568
+ dropout=dropout,
569
+ bias=attention_bias,
570
+ upcast_attention=upcast_attention,
571
+ only_cross_attention=False,
572
+ out_dim=dim,
573
+ context_out_dim=cross_attention_dim,
574
+ context_pre_only=False,
575
+ processor=JointAttnProcessor2_0(),
576
+ is_final_block=is_final_block,
577
+ )
578
+
579
+ # Feed-forward
580
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
581
+ self.norm3 = nn.LayerNorm(dim)
582
+ if not is_final_block:
583
+ self.audio_ff = FeedForward(cross_attention_dim, dropout=dropout, activation_fn=activation_fn)
584
+ self.audio_norm3 = nn.LayerNorm(cross_attention_dim)
585
+
586
+ def forward(
587
+ self,
588
+ hidden_states,
589
+ encoder_hidden_states=None,
590
+ attention_mask=None,
591
+ emotion=None,
592
+ ):
593
+ norm_hidden_states = (
594
+ self.norm1(hidden_states, emotion) if self.use_ada_layer_norm else self.norm1(hidden_states)
595
+ )
596
+ norm_encoder_hidden_states = (
597
+ self.audio_norm1(encoder_hidden_states, emotion)
598
+ if self.use_ada_layer_norm
599
+ else self.audio_norm1(encoder_hidden_states)
600
+ )
601
+
602
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
603
+
604
+ encoder_hidden_states = (
605
+ self.audio_attn1(norm_encoder_hidden_states, attention_mask=attention_mask) + encoder_hidden_states
606
+ )
607
+
608
+ norm_hidden_states = (
609
+ self.norm2(hidden_states, emotion) if self.use_ada_layer_norm else self.norm2(hidden_states)
610
+ )
611
+ norm_encoder_hidden_states = (
612
+ self.audio_norm2(encoder_hidden_states, emotion)
613
+ if self.use_ada_layer_norm
614
+ else self.audio_norm2(encoder_hidden_states)
615
+ )
616
+
617
+ joint_hidden_states, joint_encoder_hidden_states = self.attn2(
618
+ norm_hidden_states,
619
+ norm_encoder_hidden_states,
620
+ )
621
+
622
+ hidden_states = joint_hidden_states + hidden_states
623
+ if not self.is_final_block:
624
+ encoder_hidden_states = joint_encoder_hidden_states + encoder_hidden_states
625
+
626
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
627
+ if not self.is_final_block:
628
+ encoder_hidden_states = self.audio_ff(self.audio_norm3(encoder_hidden_states)) + encoder_hidden_states
629
+ else:
630
+ encoder_hidden_states = None
631
+
632
+ return hidden_states, encoder_hidden_states
633
+
634
+
635
+ def zero_module(module):
636
+ for p in module.parameters():
637
+ nn.init.zeros_(p)
638
+
639
+ return module
memo/models/attention_processor.py ADDED
@@ -0,0 +1,2299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Callable, Optional, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.utils import deprecate, logging
8
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
9
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
10
+ from torch import nn
11
+
12
+
13
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
14
+
15
+ if is_torch_npu_available():
16
+ import torch_npu
17
+
18
+ if is_xformers_available():
19
+ import xformers
20
+ import xformers.ops
21
+ else:
22
+ xformers = None
23
+
24
+
25
+ @maybe_allow_in_graph
26
+ class Attention(nn.Module):
27
+ r"""
28
+ A cross attention layer.
29
+
30
+ Parameters:
31
+ query_dim (`int`):
32
+ The number of channels in the query.
33
+ cross_attention_dim (`int`, *optional*):
34
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
35
+ heads (`int`, *optional*, defaults to 8):
36
+ The number of heads to use for multi-head attention.
37
+ dim_head (`int`, *optional*, defaults to 64):
38
+ The number of channels in each head.
39
+ dropout (`float`, *optional*, defaults to 0.0):
40
+ The dropout probability to use.
41
+ bias (`bool`, *optional*, defaults to False):
42
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
43
+ upcast_attention (`bool`, *optional*, defaults to False):
44
+ Set to `True` to upcast the attention computation to `float32`.
45
+ upcast_softmax (`bool`, *optional*, defaults to False):
46
+ Set to `True` to upcast the softmax computation to `float32`.
47
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
48
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
49
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
50
+ The number of groups to use for the group norm in the cross attention.
51
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
52
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
53
+ norm_num_groups (`int`, *optional*, defaults to `None`):
54
+ The number of groups to use for the group norm in the attention.
55
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
56
+ The number of channels to use for the spatial normalization.
57
+ out_bias (`bool`, *optional*, defaults to `True`):
58
+ Set to `True` to use a bias in the output linear layer.
59
+ scale_qk (`bool`, *optional*, defaults to `True`):
60
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
61
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
62
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
63
+ `added_kv_proj_dim` is not `None`.
64
+ eps (`float`, *optional*, defaults to 1e-5):
65
+ An additional value added to the denominator in group normalization that is used for numerical stability.
66
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
67
+ A factor to rescale the output by dividing it with this value.
68
+ residual_connection (`bool`, *optional*, defaults to `False`):
69
+ Set to `True` to add the residual connection to the output.
70
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
71
+ Set to `True` if the attention block is loaded from a deprecated state dict.
72
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
73
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
74
+ `AttnProcessor` otherwise.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ query_dim: int,
80
+ cross_attention_dim: Optional[int] = None,
81
+ heads: int = 8,
82
+ kv_heads: Optional[int] = None,
83
+ dim_head: int = 64,
84
+ dropout: float = 0.0,
85
+ bias: bool = False,
86
+ upcast_attention: bool = False,
87
+ upcast_softmax: bool = False,
88
+ cross_attention_norm: Optional[str] = None,
89
+ cross_attention_norm_num_groups: int = 32,
90
+ qk_norm: Optional[str] = None,
91
+ added_kv_proj_dim: Optional[int] = None,
92
+ added_proj_bias: Optional[bool] = True,
93
+ norm_num_groups: Optional[int] = None,
94
+ spatial_norm_dim: Optional[int] = None,
95
+ out_bias: bool = True,
96
+ scale_qk: bool = True,
97
+ only_cross_attention: bool = False,
98
+ eps: float = 1e-5,
99
+ rescale_output_factor: float = 1.0,
100
+ residual_connection: bool = False,
101
+ _from_deprecated_attn_block: bool = False,
102
+ processor: Optional["AttnProcessor"] = None,
103
+ out_dim: int = None,
104
+ context_out_dim: int = None,
105
+ context_pre_only=None,
106
+ is_final_block=False,
107
+ ):
108
+ super().__init__()
109
+
110
+ # To prevent circular import.
111
+ from memo.models.normalization import FP32LayerNorm
112
+
113
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
114
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
115
+ self.query_dim = query_dim
116
+ self.use_bias = bias
117
+ self.is_cross_attention = cross_attention_dim is not None
118
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
119
+ self.upcast_attention = upcast_attention
120
+ self.upcast_softmax = upcast_softmax
121
+ self.rescale_output_factor = rescale_output_factor
122
+ self.residual_connection = residual_connection
123
+ self.dropout = dropout
124
+ self.fused_projections = False
125
+ self.out_dim = out_dim if out_dim is not None else query_dim
126
+ self.context_out_dim = context_out_dim if context_out_dim is not None else self.out_dim
127
+ self.context_pre_only = context_pre_only
128
+ self.is_final_block = is_final_block
129
+
130
+ # we make use of this private variable to know whether this class is loaded
131
+ # with an deprecated state dict so that we can convert it on the fly
132
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
133
+
134
+ self.scale_qk = scale_qk
135
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
136
+
137
+ self.heads = out_dim // dim_head if out_dim is not None else heads
138
+ # for slice_size > 0 the attention score computation
139
+ # is split across the batch axis to save memory
140
+ # You can set slice_size with `set_attention_slice`
141
+ self.sliceable_head_dim = heads
142
+
143
+ self.added_kv_proj_dim = added_kv_proj_dim
144
+ self.only_cross_attention = only_cross_attention
145
+
146
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
147
+ raise ValueError(
148
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
149
+ )
150
+
151
+ if norm_num_groups is not None:
152
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
153
+ else:
154
+ self.group_norm = None
155
+
156
+ if spatial_norm_dim is not None:
157
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
158
+ else:
159
+ self.spatial_norm = None
160
+
161
+ if qk_norm is None:
162
+ self.norm_q = None
163
+ self.norm_k = None
164
+ elif qk_norm == "layer_norm":
165
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps)
166
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps)
167
+ elif qk_norm == "fp32_layer_norm":
168
+ self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
169
+ self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
170
+ elif qk_norm == "layer_norm_across_heads":
171
+ # Lumina applys qk norm across all heads
172
+ self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
173
+ self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
174
+ else:
175
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
176
+
177
+ if cross_attention_norm is None:
178
+ self.norm_cross = None
179
+ elif cross_attention_norm == "layer_norm":
180
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
181
+ elif cross_attention_norm == "group_norm":
182
+ if self.added_kv_proj_dim is not None:
183
+ # The given `encoder_hidden_states` are initially of shape
184
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
185
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
186
+ # before the projection, so we need to use `added_kv_proj_dim` as
187
+ # the number of channels for the group norm.
188
+ norm_cross_num_channels = added_kv_proj_dim
189
+ else:
190
+ norm_cross_num_channels = self.cross_attention_dim
191
+
192
+ self.norm_cross = nn.GroupNorm(
193
+ num_channels=norm_cross_num_channels,
194
+ num_groups=cross_attention_norm_num_groups,
195
+ eps=1e-5,
196
+ affine=True,
197
+ )
198
+ else:
199
+ raise ValueError(
200
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
201
+ )
202
+
203
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
204
+
205
+ if not self.only_cross_attention:
206
+ # only relevant for the `AddedKVProcessor` classes
207
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
208
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
209
+ else:
210
+ self.to_k = None
211
+ self.to_v = None
212
+
213
+ if self.added_kv_proj_dim is not None:
214
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
215
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
216
+ if self.context_pre_only is not None:
217
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
218
+
219
+ self.to_out = nn.ModuleList([])
220
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
221
+ self.to_out.append(nn.Dropout(dropout))
222
+
223
+ if self.context_pre_only is not None and not self.context_pre_only and not is_final_block:
224
+ self.to_add_out = nn.Linear(self.inner_dim, self.context_out_dim, bias=out_bias)
225
+
226
+ if qk_norm is not None and added_kv_proj_dim is not None:
227
+ if qk_norm == "fp32_layer_norm":
228
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
229
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
230
+ else:
231
+ self.norm_added_q = None
232
+ self.norm_added_k = None
233
+
234
+ # set attention processor
235
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
236
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
237
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
238
+ if processor is None:
239
+ processor = (
240
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
241
+ )
242
+ self.set_processor(processor)
243
+
244
+ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
245
+ r"""
246
+ Set whether to use npu flash attention from `torch_npu` or not.
247
+
248
+ """
249
+ if use_npu_flash_attention:
250
+ processor = AttnProcessorNPU()
251
+ else:
252
+ # set attention processor
253
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
254
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
255
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
256
+ processor = (
257
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
258
+ )
259
+ self.set_processor(processor)
260
+
261
+ def set_use_memory_efficient_attention_xformers(
262
+ self,
263
+ use_memory_efficient_attention_xformers: bool,
264
+ attention_op: Optional[Callable] = None,
265
+ ) -> None:
266
+ r"""
267
+ Set whether to use memory efficient attention from `xformers` or not.
268
+
269
+ Args:
270
+ use_memory_efficient_attention_xformers (`bool`):
271
+ Whether to use memory efficient attention from `xformers` or not.
272
+ attention_op (`Callable`, *optional*):
273
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
274
+ `xformers`.
275
+ """
276
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
277
+ self.processor,
278
+ (
279
+ CustomDiffusionAttnProcessor,
280
+ CustomDiffusionXFormersAttnProcessor,
281
+ CustomDiffusionAttnProcessor2_0,
282
+ ),
283
+ )
284
+
285
+ is_joint_diffusion = hasattr(self, "processor") and isinstance(
286
+ self.processor,
287
+ (JointAttnProcessor2_0),
288
+ )
289
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
290
+ self.processor,
291
+ (
292
+ AttnAddedKVProcessor,
293
+ AttnAddedKVProcessor2_0,
294
+ SlicedAttnAddedKVProcessor,
295
+ XFormersAttnAddedKVProcessor,
296
+ ),
297
+ )
298
+
299
+ if use_memory_efficient_attention_xformers:
300
+ if is_added_kv_processor and is_custom_diffusion:
301
+ raise NotImplementedError(
302
+ f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
303
+ )
304
+ if not is_xformers_available():
305
+ raise ModuleNotFoundError(
306
+ (
307
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
308
+ " xformers"
309
+ ),
310
+ name="xformers",
311
+ )
312
+ elif not torch.cuda.is_available():
313
+ raise ValueError(
314
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
315
+ " only available for GPU "
316
+ )
317
+ else:
318
+ try:
319
+ # Make sure we can run the memory efficient attention
320
+ _ = xformers.ops.memory_efficient_attention(
321
+ torch.randn((1, 2, 40), device="cuda"),
322
+ torch.randn((1, 2, 40), device="cuda"),
323
+ torch.randn((1, 2, 40), device="cuda"),
324
+ )
325
+ except Exception as e:
326
+ raise e
327
+
328
+ if is_custom_diffusion:
329
+ processor = CustomDiffusionXFormersAttnProcessor(
330
+ train_kv=self.processor.train_kv,
331
+ train_q_out=self.processor.train_q_out,
332
+ hidden_size=self.processor.hidden_size,
333
+ cross_attention_dim=self.processor.cross_attention_dim,
334
+ attention_op=attention_op,
335
+ )
336
+ processor.load_state_dict(self.processor.state_dict())
337
+ if hasattr(self.processor, "to_k_custom_diffusion"):
338
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
339
+ elif is_added_kv_processor:
340
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
341
+ # which uses this type of cross attention ONLY because the attention mask of format
342
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
343
+ # throw warning
344
+ logger.info(
345
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
346
+ )
347
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
348
+ elif is_joint_diffusion:
349
+ processor = JointAttnProcessor2_0()
350
+ else:
351
+ processor = XFormersAttnProcessor(attention_op=attention_op)
352
+ else:
353
+ if is_custom_diffusion:
354
+ attn_processor_class = (
355
+ CustomDiffusionAttnProcessor2_0
356
+ if hasattr(F, "scaled_dot_product_attention")
357
+ else CustomDiffusionAttnProcessor
358
+ )
359
+ processor = attn_processor_class(
360
+ train_kv=self.processor.train_kv,
361
+ train_q_out=self.processor.train_q_out,
362
+ hidden_size=self.processor.hidden_size,
363
+ cross_attention_dim=self.processor.cross_attention_dim,
364
+ )
365
+ processor.load_state_dict(self.processor.state_dict())
366
+ if hasattr(self.processor, "to_k_custom_diffusion"):
367
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
368
+ else:
369
+ # set attention processor
370
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
371
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
372
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
373
+ processor = (
374
+ AttnProcessor2_0()
375
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
376
+ else AttnProcessor()
377
+ )
378
+
379
+ self.set_processor(processor)
380
+
381
+ def set_attention_slice(self, slice_size: int) -> None:
382
+ r"""
383
+ Set the slice size for attention computation.
384
+
385
+ Args:
386
+ slice_size (`int`):
387
+ The slice size for attention computation.
388
+ """
389
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
390
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
391
+
392
+ if slice_size is not None and self.added_kv_proj_dim is not None:
393
+ processor = SlicedAttnAddedKVProcessor(slice_size)
394
+ elif slice_size is not None:
395
+ processor = SlicedAttnProcessor(slice_size)
396
+ elif self.added_kv_proj_dim is not None:
397
+ processor = AttnAddedKVProcessor()
398
+ else:
399
+ # set attention processor
400
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
401
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
402
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
403
+ processor = (
404
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
405
+ )
406
+
407
+ self.set_processor(processor)
408
+
409
+ def set_processor(self, processor: "AttnProcessor") -> None:
410
+ r"""
411
+ Set the attention processor to use.
412
+
413
+ Args:
414
+ processor (`AttnProcessor`):
415
+ The attention processor to use.
416
+ """
417
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
418
+ # pop `processor` from `self._modules`
419
+ if (
420
+ hasattr(self, "processor")
421
+ and isinstance(self.processor, torch.nn.Module)
422
+ and not isinstance(processor, torch.nn.Module)
423
+ ):
424
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
425
+ self._modules.pop("processor")
426
+ self.processor = processor
427
+
428
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
429
+ r"""
430
+ Get the attention processor in use.
431
+
432
+ Args:
433
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
434
+ Set to `True` to return the deprecated LoRA attention processor.
435
+
436
+ Returns:
437
+ "AttentionProcessor": The attention processor in use.
438
+ """
439
+ if not return_deprecated_lora:
440
+ return self.processor
441
+
442
+ def forward(
443
+ self,
444
+ hidden_states: torch.Tensor,
445
+ encoder_hidden_states: Optional[torch.Tensor] = None,
446
+ attention_mask: Optional[torch.Tensor] = None,
447
+ **cross_attention_kwargs,
448
+ ) -> torch.Tensor:
449
+ r"""
450
+ The forward method of the `Attention` class.
451
+
452
+ Args:
453
+ hidden_states (`torch.Tensor`):
454
+ The hidden states of the query.
455
+ encoder_hidden_states (`torch.Tensor`, *optional*):
456
+ The hidden states of the encoder.
457
+ attention_mask (`torch.Tensor`, *optional*):
458
+ The attention mask to use. If `None`, no mask is applied.
459
+ **cross_attention_kwargs:
460
+ Additional keyword arguments to pass along to the cross attention.
461
+
462
+ Returns:
463
+ `torch.Tensor`: The output of the attention layer.
464
+ """
465
+ # The `Attention` class can call different attention processors / attention functions
466
+ # here we simply pass along all tensors to the selected processor class
467
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
468
+
469
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
470
+ quiet_attn_parameters = {"ip_adapter_masks"}
471
+ unused_kwargs = [
472
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
473
+ ]
474
+ if len(unused_kwargs) > 0:
475
+ logger.warning(
476
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
477
+ )
478
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
479
+
480
+ return self.processor(
481
+ self,
482
+ hidden_states,
483
+ encoder_hidden_states=encoder_hidden_states,
484
+ attention_mask=attention_mask,
485
+ **cross_attention_kwargs,
486
+ )
487
+
488
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
489
+ r"""
490
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
491
+ is the number of heads initialized while constructing the `Attention` class.
492
+
493
+ Args:
494
+ tensor (`torch.Tensor`): The tensor to reshape.
495
+
496
+ Returns:
497
+ `torch.Tensor`: The reshaped tensor.
498
+ """
499
+ head_size = self.heads
500
+ batch_size, seq_len, dim = tensor.shape
501
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
502
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
503
+ return tensor
504
+
505
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
506
+ r"""
507
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
508
+ the number of heads initialized while constructing the `Attention` class.
509
+
510
+ Args:
511
+ tensor (`torch.Tensor`): The tensor to reshape.
512
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
513
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
514
+
515
+ Returns:
516
+ `torch.Tensor`: The reshaped tensor.
517
+ """
518
+ head_size = self.heads
519
+ if tensor.ndim == 3:
520
+ batch_size, seq_len, dim = tensor.shape
521
+ extra_dim = 1
522
+ else:
523
+ batch_size, extra_dim, seq_len, dim = tensor.shape
524
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
525
+ tensor = tensor.permute(0, 2, 1, 3)
526
+
527
+ if out_dim == 3:
528
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
529
+
530
+ return tensor
531
+
532
+ def get_attention_scores(
533
+ self,
534
+ query: torch.Tensor,
535
+ key: torch.Tensor,
536
+ attention_mask: torch.Tensor = None,
537
+ ) -> torch.Tensor:
538
+ r"""
539
+ Compute the attention scores.
540
+
541
+ Args:
542
+ query (`torch.Tensor`): The query tensor.
543
+ key (`torch.Tensor`): The key tensor.
544
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
545
+
546
+ Returns:
547
+ `torch.Tensor`: The attention probabilities/scores.
548
+ """
549
+ dtype = query.dtype
550
+ if self.upcast_attention:
551
+ query = query.float()
552
+ key = key.float()
553
+
554
+ if attention_mask is None:
555
+ baddbmm_input = torch.empty(
556
+ query.shape[0],
557
+ query.shape[1],
558
+ key.shape[1],
559
+ dtype=query.dtype,
560
+ device=query.device,
561
+ )
562
+ beta = 0
563
+ else:
564
+ baddbmm_input = attention_mask
565
+ beta = 1
566
+
567
+ attention_scores = torch.baddbmm(
568
+ baddbmm_input,
569
+ query,
570
+ key.transpose(-1, -2),
571
+ beta=beta,
572
+ alpha=self.scale,
573
+ )
574
+ del baddbmm_input
575
+
576
+ if self.upcast_softmax:
577
+ attention_scores = attention_scores.float()
578
+
579
+ attention_probs = attention_scores.softmax(dim=-1)
580
+ del attention_scores
581
+
582
+ attention_probs = attention_probs.to(dtype)
583
+
584
+ return attention_probs
585
+
586
+ def prepare_attention_mask(
587
+ self,
588
+ attention_mask: torch.Tensor,
589
+ target_length: int,
590
+ batch_size: int,
591
+ out_dim: int = 3,
592
+ ) -> torch.Tensor:
593
+ r"""
594
+ Prepare the attention mask for the attention computation.
595
+
596
+ Args:
597
+ attention_mask (`torch.Tensor`):
598
+ The attention mask to prepare.
599
+ target_length (`int`):
600
+ The target length of the attention mask. This is the length of the attention mask after padding.
601
+ batch_size (`int`):
602
+ The batch size, which is used to repeat the attention mask.
603
+ out_dim (`int`, *optional*, defaults to `3`):
604
+ The output dimension of the attention mask. Can be either `3` or `4`.
605
+
606
+ Returns:
607
+ `torch.Tensor`: The prepared attention mask.
608
+ """
609
+ head_size = self.heads
610
+ if attention_mask is None:
611
+ return attention_mask
612
+
613
+ current_length: int = attention_mask.shape[-1]
614
+ if current_length != target_length:
615
+ if attention_mask.device.type == "mps":
616
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
617
+ # Instead, we can manually construct the padding tensor.
618
+ padding_shape = (
619
+ attention_mask.shape[0],
620
+ attention_mask.shape[1],
621
+ target_length,
622
+ )
623
+ padding = torch.zeros(
624
+ padding_shape,
625
+ dtype=attention_mask.dtype,
626
+ device=attention_mask.device,
627
+ )
628
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
629
+ else:
630
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
631
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
632
+ # remaining_length: int = target_length - current_length
633
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
634
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
635
+
636
+ if out_dim == 3:
637
+ if attention_mask.shape[0] < batch_size * head_size:
638
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
639
+ elif out_dim == 4:
640
+ attention_mask = attention_mask.unsqueeze(1)
641
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
642
+
643
+ return attention_mask
644
+
645
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
646
+ r"""
647
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
648
+ `Attention` class.
649
+
650
+ Args:
651
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
652
+
653
+ Returns:
654
+ `torch.Tensor`: The normalized encoder hidden states.
655
+ """
656
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
657
+
658
+ if isinstance(self.norm_cross, nn.LayerNorm):
659
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
660
+ elif isinstance(self.norm_cross, nn.GroupNorm):
661
+ # Group norm norms along the channels dimension and expects
662
+ # input to be in the shape of (N, C, *). In this case, we want
663
+ # to norm along the hidden dimension, so we need to move
664
+ # (batch_size, sequence_length, hidden_size) ->
665
+ # (batch_size, hidden_size, sequence_length)
666
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
667
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
668
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
669
+ else:
670
+ assert False
671
+
672
+ return encoder_hidden_states
673
+
674
+ @torch.no_grad()
675
+ def fuse_projections(self, fuse=True):
676
+ device = self.to_q.weight.data.device
677
+ dtype = self.to_q.weight.data.dtype
678
+
679
+ if not self.is_cross_attention:
680
+ # fetch weight matrices.
681
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
682
+ in_features = concatenated_weights.shape[1]
683
+ out_features = concatenated_weights.shape[0]
684
+
685
+ # create a new single projection layer and copy over the weights.
686
+ self.to_qkv = nn.Linear(
687
+ in_features,
688
+ out_features,
689
+ bias=self.use_bias,
690
+ device=device,
691
+ dtype=dtype,
692
+ )
693
+ self.to_qkv.weight.copy_(concatenated_weights)
694
+ if self.use_bias:
695
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
696
+ self.to_qkv.bias.copy_(concatenated_bias)
697
+
698
+ else:
699
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
700
+ in_features = concatenated_weights.shape[1]
701
+ out_features = concatenated_weights.shape[0]
702
+
703
+ self.to_kv = nn.Linear(
704
+ in_features,
705
+ out_features,
706
+ bias=self.use_bias,
707
+ device=device,
708
+ dtype=dtype,
709
+ )
710
+ self.to_kv.weight.copy_(concatenated_weights)
711
+ if self.use_bias:
712
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
713
+ self.to_kv.bias.copy_(concatenated_bias)
714
+
715
+ self.fused_projections = fuse
716
+
717
+
718
+ class AttnProcessor:
719
+ r"""
720
+ Default processor for performing attention-related computations.
721
+ """
722
+
723
+ def __call__(
724
+ self,
725
+ attn: Attention,
726
+ hidden_states: torch.Tensor,
727
+ encoder_hidden_states: Optional[torch.Tensor] = None,
728
+ attention_mask: Optional[torch.Tensor] = None,
729
+ temb: Optional[torch.Tensor] = None,
730
+ *args,
731
+ **kwargs,
732
+ ) -> torch.Tensor:
733
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
734
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
735
+ deprecate("scale", "1.0.0", deprecation_message)
736
+
737
+ residual = hidden_states
738
+
739
+ if attn.spatial_norm is not None:
740
+ hidden_states = attn.spatial_norm(hidden_states, temb)
741
+
742
+ input_ndim = hidden_states.ndim
743
+
744
+ if input_ndim == 4:
745
+ batch_size, channel, height, width = hidden_states.shape
746
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
747
+
748
+ batch_size, sequence_length, _ = (
749
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
750
+ )
751
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
752
+
753
+ if attn.group_norm is not None:
754
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
755
+
756
+ query = attn.to_q(hidden_states)
757
+
758
+ if encoder_hidden_states is None:
759
+ encoder_hidden_states = hidden_states
760
+ elif attn.norm_cross:
761
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
762
+
763
+ key = attn.to_k(encoder_hidden_states)
764
+ value = attn.to_v(encoder_hidden_states)
765
+
766
+ query = attn.head_to_batch_dim(query)
767
+ key = attn.head_to_batch_dim(key)
768
+ value = attn.head_to_batch_dim(value)
769
+
770
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
771
+ hidden_states = torch.bmm(attention_probs, value)
772
+ hidden_states = attn.batch_to_head_dim(hidden_states)
773
+
774
+ # linear proj
775
+ hidden_states = attn.to_out[0](hidden_states)
776
+ # dropout
777
+ hidden_states = attn.to_out[1](hidden_states)
778
+
779
+ if input_ndim == 4:
780
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
781
+
782
+ if attn.residual_connection:
783
+ hidden_states = hidden_states + residual
784
+
785
+ hidden_states = hidden_states / attn.rescale_output_factor
786
+
787
+ return hidden_states
788
+
789
+
790
+ class MemoryLinearAttnProcessor:
791
+ r"""
792
+ Processor for performing linear attention-related computations.
793
+ """
794
+
795
+ def __init__(self):
796
+ self.memory = {"KV": None, "Z": None}
797
+ self.decay = 0.9
798
+
799
+ def reset_memory_state(self):
800
+ """Reset memory to the initial state."""
801
+ self.memory = {"KV": None, "Z": None}
802
+
803
+ def __call__(
804
+ self,
805
+ attn: Attention,
806
+ hidden_states: torch.Tensor,
807
+ motion_frames: torch.Tensor,
808
+ encoder_hidden_states: Optional[torch.Tensor] = None,
809
+ attention_mask: Optional[torch.Tensor] = None,
810
+ temb: Optional[torch.Tensor] = None,
811
+ is_new_audio: bool = True,
812
+ update_past_memory: bool = False,
813
+ *args,
814
+ **kwargs,
815
+ ) -> torch.Tensor:
816
+ # Reset memory if it's a new data segment
817
+ # Need to modify inference code
818
+ if is_new_audio:
819
+ self.reset_memory_state()
820
+
821
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
822
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
823
+ deprecate("scale", "1.0.0", deprecation_message)
824
+
825
+ residual = hidden_states
826
+
827
+ if attn.spatial_norm is not None:
828
+ hidden_states = attn.spatial_norm(hidden_states, temb)
829
+
830
+ with torch.no_grad():
831
+ motion_frames = attn.spatial_norm(motion_frames, temb)
832
+
833
+ input_ndim = hidden_states.ndim
834
+
835
+ if input_ndim == 4:
836
+ batch_size, channel, height, width = hidden_states.shape
837
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
838
+ with torch.no_grad():
839
+ (
840
+ motion_frames_batch_size,
841
+ motion_frames_channel,
842
+ motion_frames_height,
843
+ motion_frames_width,
844
+ ) = motion_frames.shape
845
+ motion_frames = motion_frames.view(
846
+ motion_frames_batch_size,
847
+ motion_frames_channel,
848
+ motion_frames_height * motion_frames_width,
849
+ ).transpose(1, 2)
850
+
851
+ batch_size, sequence_length, _ = (
852
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
853
+ )
854
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
855
+
856
+ if attn.group_norm is not None:
857
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
858
+ with torch.no_grad():
859
+ motion_frames = attn.group_norm(motion_frames.transpose(1, 2)).transpose(1, 2)
860
+
861
+ query = attn.to_q(hidden_states)
862
+
863
+ if encoder_hidden_states is None:
864
+ encoder_hidden_states = hidden_states
865
+ elif attn.norm_cross:
866
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
867
+
868
+ key = attn.to_k(encoder_hidden_states)
869
+ value = attn.to_v(encoder_hidden_states)
870
+
871
+ query = attn.head_to_batch_dim(query)
872
+ key = attn.head_to_batch_dim(key)
873
+ value = attn.head_to_batch_dim(value)
874
+
875
+ with torch.no_grad():
876
+ motion_frames_query = attn.to_q(motion_frames)
877
+ motion_frames_key = attn.to_k(motion_frames)
878
+ motion_frames_value = attn.to_v(motion_frames)
879
+
880
+ motion_frames_query = attn.head_to_batch_dim(motion_frames_query)
881
+ motion_frames_key = attn.head_to_batch_dim(motion_frames_key)
882
+ motion_frames_value = attn.head_to_batch_dim(motion_frames_value)
883
+
884
+ query = torch.softmax(query, dim=-1)
885
+ key = torch.softmax(key, dim=-2)
886
+
887
+ with torch.no_grad():
888
+ motion_frames_key = torch.softmax(motion_frames_key, dim=-2)
889
+
890
+ # Compute linear attention using the new formulation
891
+ query = query * attn.scale
892
+
893
+ # Update Memory
894
+ if update_past_memory or is_new_audio:
895
+ with torch.no_grad():
896
+ # frame-level decay for memory update
897
+ seq_length = motion_frames_key.size(1)
898
+ decay_factors = self.decay ** torch.arange(
899
+ seq_length - 1,
900
+ -1,
901
+ -1,
902
+ device=motion_frames_key.device,
903
+ dtype=motion_frames_key.dtype,
904
+ ) # [0, 1, ..., seq_length - 1]
905
+ decay_factors = decay_factors.view(1, seq_length, 1)
906
+ decayed_motion_frames_key = motion_frames_key * decay_factors
907
+ decayed_motion_frames_value = motion_frames_value * decay_factors
908
+
909
+ batch_size, seq_length, _ = decayed_motion_frames_key.shape
910
+ keys_unsqueezed = decayed_motion_frames_key.unsqueeze(3)
911
+ values_unsqueezed = decayed_motion_frames_value.unsqueeze(2)
912
+
913
+ KV_t_all = keys_unsqueezed * values_unsqueezed
914
+ KV_cumsum = KV_t_all.sum(dim=1)
915
+
916
+ Z_cumsum = decayed_motion_frames_key.sum(dim=1)
917
+
918
+ if self.memory["KV"] is None and self.memory["Z"] is None:
919
+ self.memory["KV"] = KV_cumsum
920
+ self.memory["Z"] = Z_cumsum.unsqueeze(1) # [batch_size, 1, d_model]
921
+ else:
922
+ self.memory["KV"] = self.memory["KV"] * (self.decay**seq_length) + KV_cumsum
923
+ self.memory["Z"] = self.memory["Z"] * (self.decay**seq_length) + Z_cumsum.unsqueeze(
924
+ 1
925
+ ) # [batch_size, seq_length, d_model]
926
+
927
+ KV = self.decay * self.memory["KV"] + torch.einsum("bnd,bne->bde", key, value)
928
+ Z = self.decay * self.memory["Z"] + key.sum(dim=-2, keepdim=True)
929
+
930
+ # Compute Linear Attn
931
+ query_KV = torch.einsum("bnd,bde->bne", query, KV)
932
+ query_Z = torch.einsum("bnd,bod->bno", query, Z)
933
+
934
+ hidden_states = query_KV / (query_Z.clamp(min=1e-10))
935
+ hidden_states = attn.batch_to_head_dim(hidden_states)
936
+ hidden_states = attn.to_out[0](hidden_states)
937
+ hidden_states = attn.to_out[1](hidden_states)
938
+
939
+ if input_ndim == 4:
940
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
941
+
942
+ if attn.residual_connection:
943
+ hidden_states = hidden_states + residual
944
+
945
+ hidden_states = hidden_states / attn.rescale_output_factor
946
+
947
+ return hidden_states
948
+
949
+
950
+ class CustomDiffusionAttnProcessor(nn.Module):
951
+ r"""
952
+ Processor for implementing attention for the Custom Diffusion method.
953
+
954
+ Args:
955
+ train_kv (`bool`, defaults to `True`):
956
+ Whether to newly train the key and value matrices corresponding to the text features.
957
+ train_q_out (`bool`, defaults to `True`):
958
+ Whether to newly train query matrices corresponding to the latent image features.
959
+ hidden_size (`int`, *optional*, defaults to `None`):
960
+ The hidden size of the attention layer.
961
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
962
+ The number of channels in the `encoder_hidden_states`.
963
+ out_bias (`bool`, defaults to `True`):
964
+ Whether to include the bias parameter in `train_q_out`.
965
+ dropout (`float`, *optional*, defaults to 0.0):
966
+ The dropout probability to use.
967
+ """
968
+
969
+ def __init__(
970
+ self,
971
+ train_kv: bool = True,
972
+ train_q_out: bool = True,
973
+ hidden_size: Optional[int] = None,
974
+ cross_attention_dim: Optional[int] = None,
975
+ out_bias: bool = True,
976
+ dropout: float = 0.0,
977
+ ):
978
+ super().__init__()
979
+ self.train_kv = train_kv
980
+ self.train_q_out = train_q_out
981
+
982
+ self.hidden_size = hidden_size
983
+ self.cross_attention_dim = cross_attention_dim
984
+
985
+ # `_custom_diffusion` id for easy serialization and loading.
986
+ if self.train_kv:
987
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
988
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
989
+ if self.train_q_out:
990
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
991
+ self.to_out_custom_diffusion = nn.ModuleList([])
992
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
993
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
994
+
995
+ def __call__(
996
+ self,
997
+ attn: Attention,
998
+ hidden_states: torch.Tensor,
999
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1000
+ attention_mask: Optional[torch.Tensor] = None,
1001
+ ) -> torch.Tensor:
1002
+ batch_size, sequence_length, _ = hidden_states.shape
1003
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1004
+ if self.train_q_out:
1005
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
1006
+ else:
1007
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
1008
+
1009
+ if encoder_hidden_states is None:
1010
+ crossattn = False
1011
+ encoder_hidden_states = hidden_states
1012
+ else:
1013
+ crossattn = True
1014
+ if attn.norm_cross:
1015
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1016
+
1017
+ if self.train_kv:
1018
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
1019
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
1020
+ key = key.to(attn.to_q.weight.dtype)
1021
+ value = value.to(attn.to_q.weight.dtype)
1022
+ else:
1023
+ key = attn.to_k(encoder_hidden_states)
1024
+ value = attn.to_v(encoder_hidden_states)
1025
+
1026
+ if crossattn:
1027
+ detach = torch.ones_like(key)
1028
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1029
+ key = detach * key + (1 - detach) * key.detach()
1030
+ value = detach * value + (1 - detach) * value.detach()
1031
+
1032
+ query = attn.head_to_batch_dim(query)
1033
+ key = attn.head_to_batch_dim(key)
1034
+ value = attn.head_to_batch_dim(value)
1035
+
1036
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1037
+ hidden_states = torch.bmm(attention_probs, value)
1038
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1039
+
1040
+ if self.train_q_out:
1041
+ # linear proj
1042
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1043
+ # dropout
1044
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1045
+ else:
1046
+ # linear proj
1047
+ hidden_states = attn.to_out[0](hidden_states)
1048
+ # dropout
1049
+ hidden_states = attn.to_out[1](hidden_states)
1050
+
1051
+ return hidden_states
1052
+
1053
+
1054
+ class AttnAddedKVProcessor:
1055
+ r"""
1056
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
1057
+ encoder.
1058
+ """
1059
+
1060
+ def __call__(
1061
+ self,
1062
+ attn: Attention,
1063
+ hidden_states: torch.Tensor,
1064
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1065
+ attention_mask: Optional[torch.Tensor] = None,
1066
+ *args,
1067
+ **kwargs,
1068
+ ) -> torch.Tensor:
1069
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1070
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1071
+ deprecate("scale", "1.0.0", deprecation_message)
1072
+
1073
+ residual = hidden_states
1074
+
1075
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1076
+ batch_size, sequence_length, _ = hidden_states.shape
1077
+
1078
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1079
+
1080
+ if encoder_hidden_states is None:
1081
+ encoder_hidden_states = hidden_states
1082
+ elif attn.norm_cross:
1083
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1084
+
1085
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1086
+
1087
+ query = attn.to_q(hidden_states)
1088
+ query = attn.head_to_batch_dim(query)
1089
+
1090
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1091
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1092
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1093
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1094
+
1095
+ if not attn.only_cross_attention:
1096
+ key = attn.to_k(hidden_states)
1097
+ value = attn.to_v(hidden_states)
1098
+ key = attn.head_to_batch_dim(key)
1099
+ value = attn.head_to_batch_dim(value)
1100
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1101
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1102
+ else:
1103
+ key = encoder_hidden_states_key_proj
1104
+ value = encoder_hidden_states_value_proj
1105
+
1106
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1107
+ hidden_states = torch.bmm(attention_probs, value)
1108
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1109
+
1110
+ # linear proj
1111
+ hidden_states = attn.to_out[0](hidden_states)
1112
+ # dropout
1113
+ hidden_states = attn.to_out[1](hidden_states)
1114
+
1115
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1116
+ hidden_states = hidden_states + residual
1117
+
1118
+ return hidden_states
1119
+
1120
+
1121
+ class AttnAddedKVProcessor2_0:
1122
+ r"""
1123
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
1124
+ learnable key and value matrices for the text encoder.
1125
+ """
1126
+
1127
+ def __init__(self):
1128
+ if not hasattr(F, "scaled_dot_product_attention"):
1129
+ raise ImportError(
1130
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1131
+ )
1132
+
1133
+ def __call__(
1134
+ self,
1135
+ attn: Attention,
1136
+ hidden_states: torch.Tensor,
1137
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1138
+ attention_mask: Optional[torch.Tensor] = None,
1139
+ *args,
1140
+ **kwargs,
1141
+ ) -> torch.Tensor:
1142
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1143
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1144
+ deprecate("scale", "1.0.0", deprecation_message)
1145
+
1146
+ residual = hidden_states
1147
+
1148
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1149
+ batch_size, sequence_length, _ = hidden_states.shape
1150
+
1151
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
1152
+
1153
+ if encoder_hidden_states is None:
1154
+ encoder_hidden_states = hidden_states
1155
+ elif attn.norm_cross:
1156
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1157
+
1158
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1159
+
1160
+ query = attn.to_q(hidden_states)
1161
+ query = attn.head_to_batch_dim(query, out_dim=4)
1162
+
1163
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1164
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1165
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
1166
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
1167
+
1168
+ if not attn.only_cross_attention:
1169
+ key = attn.to_k(hidden_states)
1170
+ value = attn.to_v(hidden_states)
1171
+ key = attn.head_to_batch_dim(key, out_dim=4)
1172
+ value = attn.head_to_batch_dim(value, out_dim=4)
1173
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1174
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1175
+ else:
1176
+ key = encoder_hidden_states_key_proj
1177
+ value = encoder_hidden_states_value_proj
1178
+
1179
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1180
+ # TODO: add support for attn.scale when we move to Torch 2.1
1181
+ hidden_states = F.scaled_dot_product_attention(
1182
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1183
+ )
1184
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
1185
+
1186
+ # linear proj
1187
+ hidden_states = attn.to_out[0](hidden_states)
1188
+ # dropout
1189
+ hidden_states = attn.to_out[1](hidden_states)
1190
+
1191
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1192
+ hidden_states = hidden_states + residual
1193
+
1194
+ return hidden_states
1195
+
1196
+
1197
+ class JointAttnProcessor2_0:
1198
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1199
+
1200
+ def __init__(self):
1201
+ if not hasattr(F, "scaled_dot_product_attention"):
1202
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1203
+
1204
+ def __call__(
1205
+ self,
1206
+ attn: Attention,
1207
+ hidden_states: torch.FloatTensor,
1208
+ encoder_hidden_states: torch.FloatTensor = None,
1209
+ attention_mask: Optional[torch.FloatTensor] = None,
1210
+ *args,
1211
+ **kwargs,
1212
+ ) -> torch.FloatTensor:
1213
+ residual = hidden_states
1214
+
1215
+ input_ndim = hidden_states.ndim
1216
+ if input_ndim == 4:
1217
+ batch_size, channel, height, width = hidden_states.shape
1218
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1219
+ context_input_ndim = encoder_hidden_states.ndim
1220
+ if context_input_ndim == 4:
1221
+ batch_size, channel, height, width = encoder_hidden_states.shape
1222
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1223
+
1224
+ batch_size = encoder_hidden_states.shape[0]
1225
+
1226
+ query = attn.to_q(hidden_states)
1227
+ key = attn.to_k(hidden_states)
1228
+ value = attn.to_v(hidden_states)
1229
+
1230
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1231
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1232
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1233
+
1234
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1235
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1236
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1237
+
1238
+ inner_dim = key.shape[-1]
1239
+ head_dim = inner_dim // attn.heads
1240
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1241
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1242
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1243
+
1244
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1245
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1246
+ hidden_states = hidden_states.to(query.dtype)
1247
+
1248
+ # Split the attention outputs.
1249
+ hidden_states, encoder_hidden_states = (
1250
+ hidden_states[:, : residual.shape[1]],
1251
+ hidden_states[:, residual.shape[1] :],
1252
+ )
1253
+
1254
+ hidden_states = attn.to_out[0](hidden_states)
1255
+ hidden_states = attn.to_out[1](hidden_states)
1256
+ if not attn.context_pre_only and not attn.is_final_block:
1257
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1258
+ else:
1259
+ encoder_hidden_states = None
1260
+
1261
+ if input_ndim == 4:
1262
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1263
+
1264
+ return hidden_states, encoder_hidden_states
1265
+
1266
+
1267
+ class FusedJointAttnProcessor2_0:
1268
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1269
+
1270
+ def __init__(self):
1271
+ if not hasattr(F, "scaled_dot_product_attention"):
1272
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1273
+
1274
+ def __call__(
1275
+ self,
1276
+ attn: Attention,
1277
+ hidden_states: torch.FloatTensor,
1278
+ encoder_hidden_states: torch.FloatTensor = None,
1279
+ attention_mask: Optional[torch.FloatTensor] = None,
1280
+ *args,
1281
+ **kwargs,
1282
+ ) -> torch.FloatTensor:
1283
+ residual = hidden_states
1284
+
1285
+ input_ndim = hidden_states.ndim
1286
+ if input_ndim == 4:
1287
+ batch_size, channel, height, width = hidden_states.shape
1288
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1289
+ context_input_ndim = encoder_hidden_states.ndim
1290
+ if context_input_ndim == 4:
1291
+ batch_size, channel, height, width = encoder_hidden_states.shape
1292
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1293
+
1294
+ batch_size = encoder_hidden_states.shape[0]
1295
+
1296
+ # `sample` projections.
1297
+ qkv = attn.to_qkv(hidden_states)
1298
+ split_size = qkv.shape[-1] // 3
1299
+ query, key, value = torch.split(qkv, split_size, dim=-1)
1300
+
1301
+ # `context` projections.
1302
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
1303
+ split_size = encoder_qkv.shape[-1] // 3
1304
+ (
1305
+ encoder_hidden_states_query_proj,
1306
+ encoder_hidden_states_key_proj,
1307
+ encoder_hidden_states_value_proj,
1308
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
1309
+
1310
+ # attention
1311
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1312
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1313
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1314
+
1315
+ inner_dim = key.shape[-1]
1316
+ head_dim = inner_dim // attn.heads
1317
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1318
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1319
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1320
+
1321
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1322
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1323
+ hidden_states = hidden_states.to(query.dtype)
1324
+
1325
+ # Split the attention outputs.
1326
+ hidden_states, encoder_hidden_states = (
1327
+ hidden_states[:, : residual.shape[1]],
1328
+ hidden_states[:, residual.shape[1] :],
1329
+ )
1330
+
1331
+ # linear proj
1332
+ hidden_states = attn.to_out[0](hidden_states)
1333
+ # dropout
1334
+ hidden_states = attn.to_out[1](hidden_states)
1335
+ if not attn.context_pre_only:
1336
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1337
+
1338
+ if input_ndim == 4:
1339
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1340
+ if context_input_ndim == 4:
1341
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1342
+
1343
+ return hidden_states, encoder_hidden_states
1344
+
1345
+
1346
+ class XFormersAttnAddedKVProcessor:
1347
+ r"""
1348
+ Processor for implementing memory efficient attention using xFormers.
1349
+
1350
+ Args:
1351
+ attention_op (`Callable`, *optional*, defaults to `None`):
1352
+ The base
1353
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1354
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1355
+ operator.
1356
+ """
1357
+
1358
+ def __init__(self, attention_op: Optional[Callable] = None):
1359
+ self.attention_op = attention_op
1360
+
1361
+ def __call__(
1362
+ self,
1363
+ attn: Attention,
1364
+ hidden_states: torch.Tensor,
1365
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1366
+ attention_mask: Optional[torch.Tensor] = None,
1367
+ ) -> torch.Tensor:
1368
+ residual = hidden_states
1369
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1370
+ batch_size, sequence_length, _ = hidden_states.shape
1371
+
1372
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1373
+
1374
+ if encoder_hidden_states is None:
1375
+ encoder_hidden_states = hidden_states
1376
+ elif attn.norm_cross:
1377
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1378
+
1379
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1380
+
1381
+ query = attn.to_q(hidden_states)
1382
+ query = attn.head_to_batch_dim(query)
1383
+
1384
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1385
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1386
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1387
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1388
+
1389
+ if not attn.only_cross_attention:
1390
+ key = attn.to_k(hidden_states)
1391
+ value = attn.to_v(hidden_states)
1392
+ key = attn.head_to_batch_dim(key)
1393
+ value = attn.head_to_batch_dim(value)
1394
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1395
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1396
+ else:
1397
+ key = encoder_hidden_states_key_proj
1398
+ value = encoder_hidden_states_value_proj
1399
+
1400
+ hidden_states = xformers.ops.memory_efficient_attention(
1401
+ query,
1402
+ key,
1403
+ value,
1404
+ attn_bias=attention_mask,
1405
+ op=self.attention_op,
1406
+ scale=attn.scale,
1407
+ )
1408
+ hidden_states = hidden_states.to(query.dtype)
1409
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1410
+
1411
+ # linear proj
1412
+ hidden_states = attn.to_out[0](hidden_states)
1413
+ # dropout
1414
+ hidden_states = attn.to_out[1](hidden_states)
1415
+
1416
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1417
+ hidden_states = hidden_states + residual
1418
+
1419
+ return hidden_states
1420
+
1421
+
1422
+ class XFormersAttnProcessor:
1423
+ r"""
1424
+ Processor for implementing memory efficient attention using xFormers.
1425
+
1426
+ Args:
1427
+ attention_op (`Callable`, *optional*, defaults to `None`):
1428
+ The base
1429
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1430
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1431
+ operator.
1432
+ """
1433
+
1434
+ def __init__(self, attention_op: Optional[Callable] = None):
1435
+ self.attention_op = attention_op
1436
+
1437
+ def __call__(
1438
+ self,
1439
+ attn: Attention,
1440
+ hidden_states: torch.Tensor,
1441
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1442
+ attention_mask: Optional[torch.Tensor] = None,
1443
+ temb: Optional[torch.Tensor] = None,
1444
+ *args,
1445
+ **kwargs,
1446
+ ) -> torch.Tensor:
1447
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1448
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1449
+ deprecate("scale", "1.0.0", deprecation_message)
1450
+
1451
+ residual = hidden_states
1452
+
1453
+ if attn.spatial_norm is not None:
1454
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1455
+
1456
+ input_ndim = hidden_states.ndim
1457
+
1458
+ if input_ndim == 4:
1459
+ batch_size, channel, height, width = hidden_states.shape
1460
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1461
+
1462
+ batch_size, key_tokens, _ = (
1463
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1464
+ )
1465
+
1466
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1467
+ if attention_mask is not None:
1468
+ # expand our mask's singleton query_tokens dimension:
1469
+ # [batch*heads, 1, key_tokens] ->
1470
+ # [batch*heads, query_tokens, key_tokens]
1471
+ # so that it can be added as a bias onto the attention scores that xformers computes:
1472
+ # [batch*heads, query_tokens, key_tokens]
1473
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1474
+ _, query_tokens, _ = hidden_states.shape
1475
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
1476
+
1477
+ if attn.group_norm is not None:
1478
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1479
+
1480
+ query = attn.to_q(hidden_states)
1481
+
1482
+ if encoder_hidden_states is None:
1483
+ encoder_hidden_states = hidden_states
1484
+ elif attn.norm_cross:
1485
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1486
+
1487
+ key = attn.to_k(encoder_hidden_states)
1488
+ value = attn.to_v(encoder_hidden_states)
1489
+
1490
+ query = attn.head_to_batch_dim(query).contiguous()
1491
+ key = attn.head_to_batch_dim(key).contiguous()
1492
+ value = attn.head_to_batch_dim(value).contiguous()
1493
+
1494
+ hidden_states = xformers.ops.memory_efficient_attention(
1495
+ query,
1496
+ key,
1497
+ value,
1498
+ attn_bias=attention_mask,
1499
+ op=self.attention_op,
1500
+ scale=attn.scale,
1501
+ )
1502
+ hidden_states = hidden_states.to(query.dtype)
1503
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1504
+
1505
+ # linear proj
1506
+ hidden_states = attn.to_out[0](hidden_states)
1507
+ # dropout
1508
+ hidden_states = attn.to_out[1](hidden_states)
1509
+
1510
+ if input_ndim == 4:
1511
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1512
+
1513
+ if attn.residual_connection:
1514
+ hidden_states = hidden_states + residual
1515
+
1516
+ hidden_states = hidden_states / attn.rescale_output_factor
1517
+
1518
+ return hidden_states
1519
+
1520
+
1521
+ class AttnProcessorNPU:
1522
+ r"""
1523
+ Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
1524
+ fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
1525
+ not significant.
1526
+
1527
+ """
1528
+
1529
+ def __init__(self):
1530
+ if not is_torch_npu_available():
1531
+ raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
1532
+
1533
+ def __call__(
1534
+ self,
1535
+ attn: Attention,
1536
+ hidden_states: torch.Tensor,
1537
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1538
+ attention_mask: Optional[torch.Tensor] = None,
1539
+ temb: Optional[torch.Tensor] = None,
1540
+ *args,
1541
+ **kwargs,
1542
+ ) -> torch.Tensor:
1543
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1544
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1545
+ deprecate("scale", "1.0.0", deprecation_message)
1546
+
1547
+ residual = hidden_states
1548
+ if attn.spatial_norm is not None:
1549
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1550
+
1551
+ input_ndim = hidden_states.ndim
1552
+
1553
+ if input_ndim == 4:
1554
+ batch_size, channel, height, width = hidden_states.shape
1555
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1556
+
1557
+ batch_size, sequence_length, _ = (
1558
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1559
+ )
1560
+
1561
+ if attention_mask is not None:
1562
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1563
+ # scaled_dot_product_attention expects attention_mask shape to be
1564
+ # (batch, heads, source_length, target_length)
1565
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1566
+
1567
+ if attn.group_norm is not None:
1568
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1569
+
1570
+ query = attn.to_q(hidden_states)
1571
+
1572
+ if encoder_hidden_states is None:
1573
+ encoder_hidden_states = hidden_states
1574
+ elif attn.norm_cross:
1575
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1576
+
1577
+ key = attn.to_k(encoder_hidden_states)
1578
+ value = attn.to_v(encoder_hidden_states)
1579
+
1580
+ inner_dim = key.shape[-1]
1581
+ head_dim = inner_dim // attn.heads
1582
+
1583
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1584
+
1585
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1586
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1587
+
1588
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1589
+ if query.dtype in (torch.float16, torch.bfloat16):
1590
+ hidden_states = torch_npu.npu_fusion_attention(
1591
+ query,
1592
+ key,
1593
+ value,
1594
+ attn.heads,
1595
+ input_layout="BNSD",
1596
+ pse=None,
1597
+ atten_mask=attention_mask,
1598
+ scale=1.0 / math.sqrt(query.shape[-1]),
1599
+ pre_tockens=65536,
1600
+ next_tockens=65536,
1601
+ keep_prob=1.0,
1602
+ sync=False,
1603
+ inner_precise=0,
1604
+ )[0]
1605
+ else:
1606
+ # TODO: add support for attn.scale when we move to Torch 2.1
1607
+ hidden_states = F.scaled_dot_product_attention(
1608
+ query,
1609
+ key,
1610
+ value,
1611
+ attn_mask=attention_mask,
1612
+ dropout_p=0.0,
1613
+ is_causal=False,
1614
+ )
1615
+
1616
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1617
+ hidden_states = hidden_states.to(query.dtype)
1618
+
1619
+ # linear proj
1620
+ hidden_states = attn.to_out[0](hidden_states)
1621
+ # dropout
1622
+ hidden_states = attn.to_out[1](hidden_states)
1623
+
1624
+ if input_ndim == 4:
1625
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1626
+
1627
+ if attn.residual_connection:
1628
+ hidden_states = hidden_states + residual
1629
+
1630
+ hidden_states = hidden_states / attn.rescale_output_factor
1631
+
1632
+ return hidden_states
1633
+
1634
+
1635
+ class AttnProcessor2_0:
1636
+ r"""
1637
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1638
+ """
1639
+
1640
+ def __init__(self):
1641
+ if not hasattr(F, "scaled_dot_product_attention"):
1642
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1643
+
1644
+ def __call__(
1645
+ self,
1646
+ attn: Attention,
1647
+ hidden_states: torch.Tensor,
1648
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1649
+ attention_mask: Optional[torch.Tensor] = None,
1650
+ temb: Optional[torch.Tensor] = None,
1651
+ *args,
1652
+ **kwargs,
1653
+ ) -> torch.Tensor:
1654
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1655
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1656
+ deprecate("scale", "1.0.0", deprecation_message)
1657
+
1658
+ residual = hidden_states
1659
+ if attn.spatial_norm is not None:
1660
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1661
+
1662
+ input_ndim = hidden_states.ndim
1663
+
1664
+ if input_ndim == 4:
1665
+ batch_size, channel, height, width = hidden_states.shape
1666
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1667
+
1668
+ batch_size, sequence_length, _ = (
1669
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1670
+ )
1671
+
1672
+ if attention_mask is not None:
1673
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1674
+ # scaled_dot_product_attention expects attention_mask shape to be
1675
+ # (batch, heads, source_length, target_length)
1676
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1677
+
1678
+ if attn.group_norm is not None:
1679
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1680
+
1681
+ query = attn.to_q(hidden_states)
1682
+
1683
+ if encoder_hidden_states is None:
1684
+ encoder_hidden_states = hidden_states
1685
+ elif attn.norm_cross:
1686
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1687
+
1688
+ key = attn.to_k(encoder_hidden_states)
1689
+ value = attn.to_v(encoder_hidden_states)
1690
+
1691
+ inner_dim = key.shape[-1]
1692
+ head_dim = inner_dim // attn.heads
1693
+
1694
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1695
+
1696
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1697
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1698
+
1699
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1700
+ # TODO: add support for attn.scale when we move to Torch 2.1
1701
+ hidden_states = F.scaled_dot_product_attention(
1702
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1703
+ )
1704
+
1705
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1706
+ hidden_states = hidden_states.to(query.dtype)
1707
+
1708
+ # linear proj
1709
+ hidden_states = attn.to_out[0](hidden_states)
1710
+ # dropout
1711
+ hidden_states = attn.to_out[1](hidden_states)
1712
+
1713
+ if input_ndim == 4:
1714
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1715
+
1716
+ if attn.residual_connection:
1717
+ hidden_states = hidden_states + residual
1718
+
1719
+ hidden_states = hidden_states / attn.rescale_output_factor
1720
+
1721
+ return hidden_states
1722
+
1723
+
1724
+ class FusedAttnProcessor2_0:
1725
+ r"""
1726
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
1727
+ fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
1728
+ For cross-attention modules, key and value projection matrices are fused.
1729
+
1730
+ <Tip warning={true}>
1731
+
1732
+ This API is currently 🧪 experimental in nature and can change in future.
1733
+
1734
+ </Tip>
1735
+ """
1736
+
1737
+ def __init__(self):
1738
+ if not hasattr(F, "scaled_dot_product_attention"):
1739
+ raise ImportError(
1740
+ "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
1741
+ )
1742
+
1743
+ def __call__(
1744
+ self,
1745
+ attn: Attention,
1746
+ hidden_states: torch.Tensor,
1747
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1748
+ attention_mask: Optional[torch.Tensor] = None,
1749
+ temb: Optional[torch.Tensor] = None,
1750
+ *args,
1751
+ **kwargs,
1752
+ ) -> torch.Tensor:
1753
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1754
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1755
+ deprecate("scale", "1.0.0", deprecation_message)
1756
+
1757
+ residual = hidden_states
1758
+ if attn.spatial_norm is not None:
1759
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1760
+
1761
+ input_ndim = hidden_states.ndim
1762
+
1763
+ if input_ndim == 4:
1764
+ batch_size, channel, height, width = hidden_states.shape
1765
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1766
+
1767
+ batch_size, sequence_length, _ = (
1768
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1769
+ )
1770
+
1771
+ if attention_mask is not None:
1772
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1773
+ # scaled_dot_product_attention expects attention_mask shape to be
1774
+ # (batch, heads, source_length, target_length)
1775
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1776
+
1777
+ if attn.group_norm is not None:
1778
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1779
+
1780
+ if encoder_hidden_states is None:
1781
+ qkv = attn.to_qkv(hidden_states)
1782
+ split_size = qkv.shape[-1] // 3
1783
+ query, key, value = torch.split(qkv, split_size, dim=-1)
1784
+ else:
1785
+ if attn.norm_cross:
1786
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1787
+ query = attn.to_q(hidden_states)
1788
+
1789
+ kv = attn.to_kv(encoder_hidden_states)
1790
+ split_size = kv.shape[-1] // 2
1791
+ key, value = torch.split(kv, split_size, dim=-1)
1792
+
1793
+ inner_dim = key.shape[-1]
1794
+ head_dim = inner_dim // attn.heads
1795
+
1796
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1797
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1798
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1799
+
1800
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1801
+ # TODO: add support for attn.scale when we move to Torch 2.1
1802
+ hidden_states = F.scaled_dot_product_attention(
1803
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1804
+ )
1805
+
1806
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1807
+ hidden_states = hidden_states.to(query.dtype)
1808
+
1809
+ # linear proj
1810
+ hidden_states = attn.to_out[0](hidden_states)
1811
+ # dropout
1812
+ hidden_states = attn.to_out[1](hidden_states)
1813
+
1814
+ if input_ndim == 4:
1815
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1816
+
1817
+ if attn.residual_connection:
1818
+ hidden_states = hidden_states + residual
1819
+
1820
+ hidden_states = hidden_states / attn.rescale_output_factor
1821
+
1822
+ return hidden_states
1823
+
1824
+
1825
+ class CustomDiffusionXFormersAttnProcessor(nn.Module):
1826
+ r"""
1827
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1828
+
1829
+ Args:
1830
+ train_kv (`bool`, defaults to `True`):
1831
+ Whether to newly train the key and value matrices corresponding to the text features.
1832
+ train_q_out (`bool`, defaults to `True`):
1833
+ Whether to newly train query matrices corresponding to the latent image features.
1834
+ hidden_size (`int`, *optional*, defaults to `None`):
1835
+ The hidden size of the attention layer.
1836
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1837
+ The number of channels in the `encoder_hidden_states`.
1838
+ out_bias (`bool`, defaults to `True`):
1839
+ Whether to include the bias parameter in `train_q_out`.
1840
+ dropout (`float`, *optional*, defaults to 0.0):
1841
+ The dropout probability to use.
1842
+ attention_op (`Callable`, *optional*, defaults to `None`):
1843
+ The base
1844
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1845
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1846
+ """
1847
+
1848
+ def __init__(
1849
+ self,
1850
+ train_kv: bool = True,
1851
+ train_q_out: bool = False,
1852
+ hidden_size: Optional[int] = None,
1853
+ cross_attention_dim: Optional[int] = None,
1854
+ out_bias: bool = True,
1855
+ dropout: float = 0.0,
1856
+ attention_op: Optional[Callable] = None,
1857
+ ):
1858
+ super().__init__()
1859
+ self.train_kv = train_kv
1860
+ self.train_q_out = train_q_out
1861
+
1862
+ self.hidden_size = hidden_size
1863
+ self.cross_attention_dim = cross_attention_dim
1864
+ self.attention_op = attention_op
1865
+
1866
+ # `_custom_diffusion` id for easy serialization and loading.
1867
+ if self.train_kv:
1868
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1869
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1870
+ if self.train_q_out:
1871
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1872
+ self.to_out_custom_diffusion = nn.ModuleList([])
1873
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1874
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1875
+
1876
+ def __call__(
1877
+ self,
1878
+ attn: Attention,
1879
+ hidden_states: torch.Tensor,
1880
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1881
+ attention_mask: Optional[torch.Tensor] = None,
1882
+ ) -> torch.Tensor:
1883
+ batch_size, sequence_length, _ = (
1884
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1885
+ )
1886
+
1887
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1888
+
1889
+ if self.train_q_out:
1890
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
1891
+ else:
1892
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
1893
+
1894
+ if encoder_hidden_states is None:
1895
+ crossattn = False
1896
+ encoder_hidden_states = hidden_states
1897
+ else:
1898
+ crossattn = True
1899
+ if attn.norm_cross:
1900
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1901
+
1902
+ if self.train_kv:
1903
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
1904
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
1905
+ key = key.to(attn.to_q.weight.dtype)
1906
+ value = value.to(attn.to_q.weight.dtype)
1907
+ else:
1908
+ key = attn.to_k(encoder_hidden_states)
1909
+ value = attn.to_v(encoder_hidden_states)
1910
+
1911
+ if crossattn:
1912
+ detach = torch.ones_like(key)
1913
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1914
+ key = detach * key + (1 - detach) * key.detach()
1915
+ value = detach * value + (1 - detach) * value.detach()
1916
+
1917
+ query = attn.head_to_batch_dim(query).contiguous()
1918
+ key = attn.head_to_batch_dim(key).contiguous()
1919
+ value = attn.head_to_batch_dim(value).contiguous()
1920
+
1921
+ hidden_states = xformers.ops.memory_efficient_attention(
1922
+ query,
1923
+ key,
1924
+ value,
1925
+ attn_bias=attention_mask,
1926
+ op=self.attention_op,
1927
+ scale=attn.scale,
1928
+ )
1929
+ hidden_states = hidden_states.to(query.dtype)
1930
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1931
+
1932
+ if self.train_q_out:
1933
+ # linear proj
1934
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1935
+ # dropout
1936
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1937
+ else:
1938
+ # linear proj
1939
+ hidden_states = attn.to_out[0](hidden_states)
1940
+ # dropout
1941
+ hidden_states = attn.to_out[1](hidden_states)
1942
+
1943
+ return hidden_states
1944
+
1945
+
1946
+ class CustomDiffusionAttnProcessor2_0(nn.Module):
1947
+ r"""
1948
+ Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
1949
+ dot-product attention.
1950
+
1951
+ Args:
1952
+ train_kv (`bool`, defaults to `True`):
1953
+ Whether to newly train the key and value matrices corresponding to the text features.
1954
+ train_q_out (`bool`, defaults to `True`):
1955
+ Whether to newly train query matrices corresponding to the latent image features.
1956
+ hidden_size (`int`, *optional*, defaults to `None`):
1957
+ The hidden size of the attention layer.
1958
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1959
+ The number of channels in the `encoder_hidden_states`.
1960
+ out_bias (`bool`, defaults to `True`):
1961
+ Whether to include the bias parameter in `train_q_out`.
1962
+ dropout (`float`, *optional*, defaults to 0.0):
1963
+ The dropout probability to use.
1964
+ """
1965
+
1966
+ def __init__(
1967
+ self,
1968
+ train_kv: bool = True,
1969
+ train_q_out: bool = True,
1970
+ hidden_size: Optional[int] = None,
1971
+ cross_attention_dim: Optional[int] = None,
1972
+ out_bias: bool = True,
1973
+ dropout: float = 0.0,
1974
+ ):
1975
+ super().__init__()
1976
+ self.train_kv = train_kv
1977
+ self.train_q_out = train_q_out
1978
+
1979
+ self.hidden_size = hidden_size
1980
+ self.cross_attention_dim = cross_attention_dim
1981
+
1982
+ # `_custom_diffusion` id for easy serialization and loading.
1983
+ if self.train_kv:
1984
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1985
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1986
+ if self.train_q_out:
1987
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1988
+ self.to_out_custom_diffusion = nn.ModuleList([])
1989
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1990
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1991
+
1992
+ def __call__(
1993
+ self,
1994
+ attn: Attention,
1995
+ hidden_states: torch.Tensor,
1996
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1997
+ attention_mask: Optional[torch.Tensor] = None,
1998
+ ) -> torch.Tensor:
1999
+ batch_size, sequence_length, _ = hidden_states.shape
2000
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2001
+ if self.train_q_out:
2002
+ query = self.to_q_custom_diffusion(hidden_states)
2003
+ else:
2004
+ query = attn.to_q(hidden_states)
2005
+
2006
+ if encoder_hidden_states is None:
2007
+ crossattn = False
2008
+ encoder_hidden_states = hidden_states
2009
+ else:
2010
+ crossattn = True
2011
+ if attn.norm_cross:
2012
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2013
+
2014
+ if self.train_kv:
2015
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
2016
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
2017
+ key = key.to(attn.to_q.weight.dtype)
2018
+ value = value.to(attn.to_q.weight.dtype)
2019
+
2020
+ else:
2021
+ key = attn.to_k(encoder_hidden_states)
2022
+ value = attn.to_v(encoder_hidden_states)
2023
+
2024
+ if crossattn:
2025
+ detach = torch.ones_like(key)
2026
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
2027
+ key = detach * key + (1 - detach) * key.detach()
2028
+ value = detach * value + (1 - detach) * value.detach()
2029
+
2030
+ inner_dim = hidden_states.shape[-1]
2031
+
2032
+ head_dim = inner_dim // attn.heads
2033
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2034
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2035
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2036
+
2037
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2038
+ # TODO: add support for attn.scale when we move to Torch 2.1
2039
+ hidden_states = F.scaled_dot_product_attention(
2040
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2041
+ )
2042
+
2043
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2044
+ hidden_states = hidden_states.to(query.dtype)
2045
+
2046
+ if self.train_q_out:
2047
+ # linear proj
2048
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
2049
+ # dropout
2050
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
2051
+ else:
2052
+ # linear proj
2053
+ hidden_states = attn.to_out[0](hidden_states)
2054
+ # dropout
2055
+ hidden_states = attn.to_out[1](hidden_states)
2056
+
2057
+ return hidden_states
2058
+
2059
+
2060
+ class SlicedAttnProcessor:
2061
+ r"""
2062
+ Processor for implementing sliced attention.
2063
+
2064
+ Args:
2065
+ slice_size (`int`, *optional*):
2066
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
2067
+ `attention_head_dim` must be a multiple of the `slice_size`.
2068
+ """
2069
+
2070
+ def __init__(self, slice_size: int):
2071
+ self.slice_size = slice_size
2072
+
2073
+ def __call__(
2074
+ self,
2075
+ attn: Attention,
2076
+ hidden_states: torch.Tensor,
2077
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2078
+ attention_mask: Optional[torch.Tensor] = None,
2079
+ ) -> torch.Tensor:
2080
+ residual = hidden_states
2081
+
2082
+ input_ndim = hidden_states.ndim
2083
+
2084
+ if input_ndim == 4:
2085
+ batch_size, channel, height, width = hidden_states.shape
2086
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2087
+
2088
+ batch_size, sequence_length, _ = (
2089
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2090
+ )
2091
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2092
+
2093
+ if attn.group_norm is not None:
2094
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2095
+
2096
+ query = attn.to_q(hidden_states)
2097
+ dim = query.shape[-1]
2098
+ query = attn.head_to_batch_dim(query)
2099
+
2100
+ if encoder_hidden_states is None:
2101
+ encoder_hidden_states = hidden_states
2102
+ elif attn.norm_cross:
2103
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2104
+
2105
+ key = attn.to_k(encoder_hidden_states)
2106
+ value = attn.to_v(encoder_hidden_states)
2107
+ key = attn.head_to_batch_dim(key)
2108
+ value = attn.head_to_batch_dim(value)
2109
+
2110
+ batch_size_attention, query_tokens, _ = query.shape
2111
+ hidden_states = torch.zeros(
2112
+ (batch_size_attention, query_tokens, dim // attn.heads),
2113
+ device=query.device,
2114
+ dtype=query.dtype,
2115
+ )
2116
+
2117
+ for i in range(batch_size_attention // self.slice_size):
2118
+ start_idx = i * self.slice_size
2119
+ end_idx = (i + 1) * self.slice_size
2120
+
2121
+ query_slice = query[start_idx:end_idx]
2122
+ key_slice = key[start_idx:end_idx]
2123
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
2124
+
2125
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
2126
+
2127
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
2128
+
2129
+ hidden_states[start_idx:end_idx] = attn_slice
2130
+
2131
+ hidden_states = attn.batch_to_head_dim(hidden_states)
2132
+
2133
+ # linear proj
2134
+ hidden_states = attn.to_out[0](hidden_states)
2135
+ # dropout
2136
+ hidden_states = attn.to_out[1](hidden_states)
2137
+
2138
+ if input_ndim == 4:
2139
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2140
+
2141
+ if attn.residual_connection:
2142
+ hidden_states = hidden_states + residual
2143
+
2144
+ hidden_states = hidden_states / attn.rescale_output_factor
2145
+
2146
+ return hidden_states
2147
+
2148
+
2149
+ class SlicedAttnAddedKVProcessor:
2150
+ r"""
2151
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
2152
+
2153
+ Args:
2154
+ slice_size (`int`, *optional*):
2155
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
2156
+ `attention_head_dim` must be a multiple of the `slice_size`.
2157
+ """
2158
+
2159
+ def __init__(self, slice_size):
2160
+ self.slice_size = slice_size
2161
+
2162
+ def __call__(
2163
+ self,
2164
+ attn: "Attention",
2165
+ hidden_states: torch.Tensor,
2166
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2167
+ attention_mask: Optional[torch.Tensor] = None,
2168
+ temb: Optional[torch.Tensor] = None,
2169
+ ) -> torch.Tensor:
2170
+ residual = hidden_states
2171
+
2172
+ if attn.spatial_norm is not None:
2173
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2174
+
2175
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
2176
+
2177
+ batch_size, sequence_length, _ = hidden_states.shape
2178
+
2179
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2180
+
2181
+ if encoder_hidden_states is None:
2182
+ encoder_hidden_states = hidden_states
2183
+ elif attn.norm_cross:
2184
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2185
+
2186
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2187
+
2188
+ query = attn.to_q(hidden_states)
2189
+ dim = query.shape[-1]
2190
+ query = attn.head_to_batch_dim(query)
2191
+
2192
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2193
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2194
+
2195
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
2196
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
2197
+
2198
+ if not attn.only_cross_attention:
2199
+ key = attn.to_k(hidden_states)
2200
+ value = attn.to_v(hidden_states)
2201
+ key = attn.head_to_batch_dim(key)
2202
+ value = attn.head_to_batch_dim(value)
2203
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
2204
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
2205
+ else:
2206
+ key = encoder_hidden_states_key_proj
2207
+ value = encoder_hidden_states_value_proj
2208
+
2209
+ batch_size_attention, query_tokens, _ = query.shape
2210
+ hidden_states = torch.zeros(
2211
+ (batch_size_attention, query_tokens, dim // attn.heads),
2212
+ device=query.device,
2213
+ dtype=query.dtype,
2214
+ )
2215
+
2216
+ for i in range(batch_size_attention // self.slice_size):
2217
+ start_idx = i * self.slice_size
2218
+ end_idx = (i + 1) * self.slice_size
2219
+
2220
+ query_slice = query[start_idx:end_idx]
2221
+ key_slice = key[start_idx:end_idx]
2222
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
2223
+
2224
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
2225
+
2226
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
2227
+
2228
+ hidden_states[start_idx:end_idx] = attn_slice
2229
+
2230
+ hidden_states = attn.batch_to_head_dim(hidden_states)
2231
+
2232
+ # linear proj
2233
+ hidden_states = attn.to_out[0](hidden_states)
2234
+ # dropout
2235
+ hidden_states = attn.to_out[1](hidden_states)
2236
+
2237
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
2238
+ hidden_states = hidden_states + residual
2239
+
2240
+ return hidden_states
2241
+
2242
+
2243
+ class SpatialNorm(nn.Module):
2244
+ """
2245
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
2246
+
2247
+ Args:
2248
+ f_channels (`int`):
2249
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
2250
+ zq_channels (`int`):
2251
+ The number of channels for the quantized vector as described in the paper.
2252
+ """
2253
+
2254
+ def __init__(
2255
+ self,
2256
+ f_channels: int,
2257
+ zq_channels: int,
2258
+ ):
2259
+ super().__init__()
2260
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
2261
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
2262
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
2263
+
2264
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
2265
+ f_size = f.shape[-2:]
2266
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
2267
+ norm_f = self.norm_layer(f)
2268
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
2269
+ return new_f
2270
+
2271
+
2272
+ ADDED_KV_ATTENTION_PROCESSORS = (
2273
+ AttnAddedKVProcessor,
2274
+ SlicedAttnAddedKVProcessor,
2275
+ AttnAddedKVProcessor2_0,
2276
+ XFormersAttnAddedKVProcessor,
2277
+ )
2278
+
2279
+ CROSS_ATTENTION_PROCESSORS = (
2280
+ AttnProcessor,
2281
+ AttnProcessor2_0,
2282
+ XFormersAttnProcessor,
2283
+ SlicedAttnProcessor,
2284
+ )
2285
+
2286
+ AttentionProcessor = Union[
2287
+ AttnProcessor,
2288
+ AttnProcessor2_0,
2289
+ FusedAttnProcessor2_0,
2290
+ XFormersAttnProcessor,
2291
+ SlicedAttnProcessor,
2292
+ AttnAddedKVProcessor,
2293
+ SlicedAttnAddedKVProcessor,
2294
+ AttnAddedKVProcessor2_0,
2295
+ XFormersAttnAddedKVProcessor,
2296
+ CustomDiffusionAttnProcessor,
2297
+ CustomDiffusionXFormersAttnProcessor,
2298
+ CustomDiffusionAttnProcessor2_0,
2299
+ ]
memo/models/audio_proj.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import ConfigMixin, ModelMixin
3
+ from einops import rearrange
4
+ from torch import nn
5
+
6
+
7
+ class AudioProjModel(ModelMixin, ConfigMixin):
8
+ def __init__(
9
+ self,
10
+ seq_len=5,
11
+ blocks=12, # add a new parameter blocks
12
+ channels=768, # add a new parameter channels
13
+ intermediate_dim=512,
14
+ output_dim=768,
15
+ context_tokens=32,
16
+ ):
17
+ super().__init__()
18
+
19
+ self.seq_len = seq_len
20
+ self.blocks = blocks
21
+ self.channels = channels
22
+ self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
23
+ self.intermediate_dim = intermediate_dim
24
+ self.context_tokens = context_tokens
25
+ self.output_dim = output_dim
26
+
27
+ # define multiple linear layers
28
+ self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
29
+ self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
30
+ self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
31
+
32
+ self.norm = nn.LayerNorm(output_dim)
33
+
34
+ def forward(self, audio_embeds):
35
+ video_length = audio_embeds.shape[1]
36
+ audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
37
+ batch_size, window_size, blocks, channels = audio_embeds.shape
38
+ audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
39
+
40
+ audio_embeds = torch.relu(self.proj1(audio_embeds))
41
+ audio_embeds = torch.relu(self.proj2(audio_embeds))
42
+
43
+ context_tokens = self.proj3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
44
+
45
+ context_tokens = self.norm(context_tokens)
46
+ context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
47
+
48
+ return context_tokens
memo/models/emotion_classifier.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ from diffusers import ConfigMixin, ModelMixin
5
+
6
+
7
+ class AudioEmotionClassifierModel(ModelMixin, ConfigMixin):
8
+ num_emotion_classes = 9
9
+
10
+ def __init__(self, num_classifier_layers=5, num_classifier_channels=2048):
11
+ super().__init__()
12
+
13
+ if num_classifier_layers == 1:
14
+ self.layers = torch.nn.Linear(1024, self.num_emotion_classes)
15
+ else:
16
+ layer_list = [
17
+ ("fc1", torch.nn.Linear(1024, num_classifier_channels)),
18
+ ("relu1", torch.nn.ReLU()),
19
+ ]
20
+ for n in range(num_classifier_layers - 2):
21
+ layer_list.append((f"fc{n+2}", torch.nn.Linear(num_classifier_channels, num_classifier_channels)))
22
+ layer_list.append((f"relu{n+2}", torch.nn.ReLU()))
23
+ layer_list.append(
24
+ (f"fc{num_classifier_layers}", torch.nn.Linear(num_classifier_channels, self.num_emotion_classes))
25
+ )
26
+ self.layers = torch.nn.Sequential(OrderedDict(layer_list))
27
+
28
+ def forward(self, x):
29
+ x = self.layers(x)
30
+ x = torch.softmax(x, dim=-1)
31
+ return x
memo/models/image_proj.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import ConfigMixin, ModelMixin
3
+
4
+
5
+ class ImageProjModel(ModelMixin, ConfigMixin):
6
+ def __init__(
7
+ self,
8
+ cross_attention_dim=768,
9
+ clip_embeddings_dim=512,
10
+ clip_extra_context_tokens=4,
11
+ ):
12
+ super().__init__()
13
+
14
+ self.generator = None
15
+ self.cross_attention_dim = cross_attention_dim
16
+ self.clip_extra_context_tokens = clip_extra_context_tokens
17
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
18
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
19
+
20
+ def forward(self, image_embeds):
21
+ embeds = image_embeds
22
+ clip_extra_context_tokens = self.proj(embeds).reshape(
23
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
24
+ )
25
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
26
+ return clip_extra_context_tokens
memo/models/motion_module.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import xformers
5
+ import xformers.ops
6
+ from diffusers.models.attention import FeedForward
7
+ from diffusers.models.attention_processor import Attention
8
+ from diffusers.utils.import_utils import is_xformers_available
9
+ from einops import rearrange, repeat
10
+ from torch import nn
11
+
12
+ from memo.models.attention import zero_module
13
+ from memo.models.attention_processor import (
14
+ MemoryLinearAttnProcessor,
15
+ )
16
+
17
+
18
+ class PositionalEncoding(nn.Module):
19
+ def __init__(self, d_model, dropout=0.0, max_len=24):
20
+ super().__init__()
21
+ self.dropout = nn.Dropout(p=dropout)
22
+ position = torch.arange(max_len).unsqueeze(1)
23
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
24
+ pe = torch.zeros(1, max_len, d_model)
25
+ pe[0, :, 0::2] = torch.sin(position * div_term)
26
+ pe[0, :, 1::2] = torch.cos(position * div_term)
27
+ self.register_buffer("pe", pe)
28
+
29
+ def forward(self, x, offset=0):
30
+ x = x + self.pe[:, offset : offset + x.size(1)]
31
+ return self.dropout(x)
32
+
33
+
34
+ class MemoryLinearAttnTemporalModule(nn.Module):
35
+ def __init__(
36
+ self,
37
+ in_channels,
38
+ num_attention_heads=8,
39
+ num_transformer_block=2,
40
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
41
+ temporal_position_encoding=False,
42
+ temporal_position_encoding_max_len=24,
43
+ temporal_attention_dim_div=1,
44
+ zero_initialize=True,
45
+ ):
46
+ super().__init__()
47
+
48
+ self.temporal_transformer = TemporalLinearAttnTransformer(
49
+ in_channels=in_channels,
50
+ num_attention_heads=num_attention_heads,
51
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
52
+ num_layers=num_transformer_block,
53
+ attention_block_types=attention_block_types,
54
+ temporal_position_encoding=temporal_position_encoding,
55
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
56
+ )
57
+
58
+ if zero_initialize:
59
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
60
+
61
+ def forward(
62
+ self,
63
+ hidden_states,
64
+ motion_frames,
65
+ encoder_hidden_states,
66
+ is_new_audio=True,
67
+ update_past_memory=False,
68
+ ):
69
+ hidden_states = self.temporal_transformer(
70
+ hidden_states,
71
+ motion_frames,
72
+ encoder_hidden_states,
73
+ is_new_audio=is_new_audio,
74
+ update_past_memory=update_past_memory,
75
+ )
76
+
77
+ output = hidden_states
78
+ return output
79
+
80
+
81
+ class TemporalLinearAttnTransformer(nn.Module):
82
+ def __init__(
83
+ self,
84
+ in_channels,
85
+ num_attention_heads,
86
+ attention_head_dim,
87
+ num_layers,
88
+ attention_block_types=(
89
+ "Temporal_Self",
90
+ "Temporal_Self",
91
+ ),
92
+ dropout=0.0,
93
+ norm_num_groups=32,
94
+ cross_attention_dim=768,
95
+ activation_fn="geglu",
96
+ attention_bias=False,
97
+ upcast_attention=False,
98
+ temporal_position_encoding=False,
99
+ temporal_position_encoding_max_len=24,
100
+ ):
101
+ super().__init__()
102
+
103
+ inner_dim = num_attention_heads * attention_head_dim
104
+
105
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
106
+ self.proj_in = nn.Linear(in_channels, inner_dim)
107
+
108
+ self.transformer_blocks = nn.ModuleList(
109
+ [
110
+ TemporalLinearAttnTransformerBlock(
111
+ dim=inner_dim,
112
+ num_attention_heads=num_attention_heads,
113
+ attention_head_dim=attention_head_dim,
114
+ attention_block_types=attention_block_types,
115
+ dropout=dropout,
116
+ cross_attention_dim=cross_attention_dim,
117
+ activation_fn=activation_fn,
118
+ attention_bias=attention_bias,
119
+ upcast_attention=upcast_attention,
120
+ temporal_position_encoding=temporal_position_encoding,
121
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
122
+ )
123
+ for _ in range(num_layers)
124
+ ]
125
+ )
126
+ self.proj_out = nn.Linear(inner_dim, in_channels)
127
+
128
+ def forward(
129
+ self,
130
+ hidden_states,
131
+ motion_frames,
132
+ encoder_hidden_states=None,
133
+ is_new_audio=True,
134
+ update_past_memory=False,
135
+ ):
136
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
137
+ video_length = hidden_states.shape[2]
138
+ n_motion_frames = motion_frames.shape[2]
139
+
140
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
141
+ with torch.no_grad():
142
+ motion_frames = rearrange(motion_frames, "b c f h w -> (b f) c h w")
143
+
144
+ batch, _, height, weight = hidden_states.shape
145
+ residual = hidden_states
146
+
147
+ hidden_states = self.norm(hidden_states)
148
+ with torch.no_grad():
149
+ motion_frames = self.norm(motion_frames)
150
+
151
+ inner_dim = hidden_states.shape[1]
152
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
153
+ hidden_states = self.proj_in(hidden_states)
154
+
155
+ with torch.no_grad():
156
+ (
157
+ motion_frames_batch,
158
+ motion_frames_inner_dim,
159
+ motion_frames_height,
160
+ motion_frames_weight,
161
+ ) = motion_frames.shape
162
+
163
+ motion_frames = motion_frames.permute(0, 2, 3, 1).reshape(
164
+ motion_frames_batch,
165
+ motion_frames_height * motion_frames_weight,
166
+ motion_frames_inner_dim,
167
+ )
168
+ motion_frames = self.proj_in(motion_frames)
169
+
170
+ # Transformer Blocks
171
+ for block in self.transformer_blocks:
172
+ hidden_states = block(
173
+ hidden_states,
174
+ motion_frames,
175
+ encoder_hidden_states=encoder_hidden_states,
176
+ video_length=video_length,
177
+ n_motion_frames=n_motion_frames,
178
+ is_new_audio=is_new_audio,
179
+ update_past_memory=update_past_memory,
180
+ )
181
+
182
+ # output
183
+ hidden_states = self.proj_out(hidden_states)
184
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
185
+
186
+ output = hidden_states + residual
187
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
188
+
189
+ return output
190
+
191
+
192
+ class TemporalLinearAttnTransformerBlock(nn.Module):
193
+ def __init__(
194
+ self,
195
+ dim,
196
+ num_attention_heads,
197
+ attention_head_dim,
198
+ attention_block_types=(
199
+ "Temporal_Self",
200
+ "Temporal_Self",
201
+ ),
202
+ dropout=0.0,
203
+ cross_attention_dim=768,
204
+ activation_fn="geglu",
205
+ attention_bias=False,
206
+ upcast_attention=False,
207
+ temporal_position_encoding=False,
208
+ temporal_position_encoding_max_len=24,
209
+ ):
210
+ super().__init__()
211
+
212
+ attention_blocks = []
213
+ norms = []
214
+
215
+ for block_name in attention_block_types:
216
+ attention_blocks.append(
217
+ MemoryLinearAttention(
218
+ attention_mode=block_name.split("_", maxsplit=1)[0],
219
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
220
+ query_dim=dim,
221
+ heads=num_attention_heads,
222
+ dim_head=attention_head_dim,
223
+ dropout=dropout,
224
+ bias=attention_bias,
225
+ upcast_attention=upcast_attention,
226
+ temporal_position_encoding=temporal_position_encoding,
227
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
228
+ )
229
+ )
230
+ norms.append(nn.LayerNorm(dim))
231
+
232
+ self.attention_blocks = nn.ModuleList(attention_blocks)
233
+ self.norms = nn.ModuleList(norms)
234
+
235
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
236
+ self.ff_norm = nn.LayerNorm(dim)
237
+
238
+ def forward(
239
+ self,
240
+ hidden_states,
241
+ motion_frames,
242
+ encoder_hidden_states=None,
243
+ video_length=None,
244
+ n_motion_frames=None,
245
+ is_new_audio=True,
246
+ update_past_memory=False,
247
+ ):
248
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
249
+ norm_hidden_states = norm(hidden_states)
250
+ with torch.no_grad():
251
+ norm_motion_frames = norm(motion_frames)
252
+ hidden_states = (
253
+ attention_block(
254
+ norm_hidden_states,
255
+ norm_motion_frames,
256
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
257
+ video_length=video_length,
258
+ n_motion_frames=n_motion_frames,
259
+ is_new_audio=is_new_audio,
260
+ update_past_memory=update_past_memory,
261
+ )
262
+ + hidden_states
263
+ )
264
+
265
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
266
+
267
+ output = hidden_states
268
+ return output
269
+
270
+
271
+ class MemoryLinearAttention(Attention):
272
+ def __init__(
273
+ self,
274
+ *args,
275
+ attention_mode=None,
276
+ temporal_position_encoding=False,
277
+ temporal_position_encoding_max_len=24,
278
+ **kwargs,
279
+ ):
280
+ super().__init__(*args, **kwargs)
281
+ assert attention_mode == "Temporal"
282
+
283
+ self.attention_mode = attention_mode
284
+ self.is_cross_attention = kwargs.get("cross_attention_dim") is not None
285
+ self.query_dim = kwargs["query_dim"]
286
+ self.temporal_position_encoding_max_len = temporal_position_encoding_max_len
287
+ self.pos_encoder = (
288
+ PositionalEncoding(
289
+ kwargs["query_dim"],
290
+ dropout=0.0,
291
+ max_len=temporal_position_encoding_max_len,
292
+ )
293
+ if (temporal_position_encoding and attention_mode == "Temporal")
294
+ else None
295
+ )
296
+
297
+ def extra_repr(self):
298
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
299
+
300
+ def set_use_memory_efficient_attention_xformers(
301
+ self,
302
+ use_memory_efficient_attention_xformers: bool,
303
+ attention_op=None,
304
+ ):
305
+ if use_memory_efficient_attention_xformers:
306
+ if not is_xformers_available():
307
+ raise ModuleNotFoundError(
308
+ (
309
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
310
+ " xformers"
311
+ ),
312
+ name="xformers",
313
+ )
314
+
315
+ if not torch.cuda.is_available():
316
+ raise ValueError(
317
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
318
+ " only available for GPU "
319
+ )
320
+
321
+ try:
322
+ # Make sure we can run the memory efficient attention
323
+ _ = xformers.ops.memory_efficient_attention(
324
+ torch.randn((1, 2, 40), device="cuda"),
325
+ torch.randn((1, 2, 40), device="cuda"),
326
+ torch.randn((1, 2, 40), device="cuda"),
327
+ )
328
+ except Exception as e:
329
+ raise e
330
+ processor = MemoryLinearAttnProcessor()
331
+ else:
332
+ processor = MemoryLinearAttnProcessor()
333
+
334
+ self.set_processor(processor)
335
+
336
+ def forward(
337
+ self,
338
+ hidden_states,
339
+ motion_frames,
340
+ encoder_hidden_states=None,
341
+ attention_mask=None,
342
+ video_length=None,
343
+ n_motion_frames=None,
344
+ is_new_audio=True,
345
+ update_past_memory=False,
346
+ **cross_attention_kwargs,
347
+ ):
348
+ if self.attention_mode == "Temporal":
349
+ d = hidden_states.shape[1]
350
+ hidden_states = rearrange(
351
+ hidden_states,
352
+ "(b f) d c -> (b d) f c",
353
+ f=video_length,
354
+ )
355
+
356
+ if self.pos_encoder is not None:
357
+ hidden_states = self.pos_encoder(hidden_states)
358
+
359
+ with torch.no_grad():
360
+ motion_frames = rearrange(motion_frames, "(b f) d c -> (b d) f c", f=n_motion_frames)
361
+
362
+ encoder_hidden_states = (
363
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
364
+ if encoder_hidden_states is not None
365
+ else encoder_hidden_states
366
+ )
367
+
368
+ else:
369
+ raise NotImplementedError
370
+
371
+ hidden_states = self.processor(
372
+ self,
373
+ hidden_states,
374
+ motion_frames,
375
+ encoder_hidden_states=encoder_hidden_states,
376
+ attention_mask=attention_mask,
377
+ n_motion_frames=n_motion_frames,
378
+ is_new_audio=is_new_audio,
379
+ update_past_memory=update_past_memory,
380
+ **cross_attention_kwargs,
381
+ )
382
+
383
+ if self.attention_mode == "Temporal":
384
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
385
+
386
+ return hidden_states
memo/models/normalization.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ class FP32LayerNorm(nn.LayerNorm):
7
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
8
+ origin_dtype = inputs.dtype
9
+ return F.layer_norm(
10
+ inputs.float(),
11
+ self.normalized_shape,
12
+ self.weight.float() if self.weight is not None else None,
13
+ self.bias.float() if self.bias is not None else None,
14
+ self.eps,
15
+ ).to(origin_dtype)
memo/models/resnet.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from einops import rearrange
4
+ from torch import nn
5
+
6
+
7
+ class InflatedConv3d(nn.Conv2d):
8
+ def forward(self, x):
9
+ video_length = x.shape[2]
10
+
11
+ x = rearrange(x, "b c f h w -> (b f) c h w")
12
+ x = super().forward(x)
13
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
14
+
15
+ return x
16
+
17
+
18
+ class InflatedGroupNorm(nn.GroupNorm):
19
+ def forward(self, x):
20
+ video_length = x.shape[2]
21
+
22
+ x = rearrange(x, "b c f h w -> (b f) c h w")
23
+ x = super().forward(x)
24
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
25
+
26
+ return x
27
+
28
+
29
+ class Upsample3D(nn.Module):
30
+ def __init__(
31
+ self,
32
+ channels,
33
+ use_conv=False,
34
+ use_conv_transpose=False,
35
+ out_channels=None,
36
+ name="conv",
37
+ ):
38
+ super().__init__()
39
+ self.channels = channels
40
+ self.out_channels = out_channels or channels
41
+ self.use_conv = use_conv
42
+ self.use_conv_transpose = use_conv_transpose
43
+ self.name = name
44
+
45
+ if use_conv_transpose:
46
+ raise NotImplementedError
47
+ if use_conv:
48
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
49
+
50
+ def forward(self, hidden_states, output_size=None):
51
+ assert hidden_states.shape[1] == self.channels
52
+
53
+ if self.use_conv_transpose:
54
+ raise NotImplementedError
55
+
56
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
57
+ dtype = hidden_states.dtype
58
+ if dtype == torch.bfloat16:
59
+ hidden_states = hidden_states.to(torch.float32)
60
+
61
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
62
+ if hidden_states.shape[0] >= 64:
63
+ hidden_states = hidden_states.contiguous()
64
+
65
+ # if `output_size` is passed we force the interpolation output
66
+ # size and do not make use of `scale_factor=2`
67
+ if output_size is None:
68
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
69
+ else:
70
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
71
+
72
+ # If the input is bfloat16, we cast back to bfloat16
73
+ if dtype == torch.bfloat16:
74
+ hidden_states = hidden_states.to(dtype)
75
+
76
+ hidden_states = self.conv(hidden_states)
77
+
78
+ return hidden_states
79
+
80
+
81
+ class Downsample3D(nn.Module):
82
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
83
+ super().__init__()
84
+ self.channels = channels
85
+ self.out_channels = out_channels or channels
86
+ self.use_conv = use_conv
87
+ self.padding = padding
88
+ stride = 2
89
+ self.name = name
90
+
91
+ if use_conv:
92
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
93
+ else:
94
+ raise NotImplementedError
95
+
96
+ def forward(self, hidden_states):
97
+ assert hidden_states.shape[1] == self.channels
98
+ if self.use_conv and self.padding == 0:
99
+ raise NotImplementedError
100
+
101
+ assert hidden_states.shape[1] == self.channels
102
+ hidden_states = self.conv(hidden_states)
103
+
104
+ return hidden_states
105
+
106
+
107
+ class ResnetBlock3D(nn.Module):
108
+ def __init__(
109
+ self,
110
+ *,
111
+ in_channels,
112
+ out_channels=None,
113
+ conv_shortcut=False,
114
+ dropout=0.0,
115
+ temb_channels=512,
116
+ groups=32,
117
+ groups_out=None,
118
+ pre_norm=True,
119
+ eps=1e-6,
120
+ non_linearity="swish",
121
+ time_embedding_norm="default",
122
+ output_scale_factor=1.0,
123
+ use_in_shortcut=None,
124
+ use_inflated_groupnorm=None,
125
+ ):
126
+ super().__init__()
127
+ self.pre_norm = pre_norm
128
+ self.pre_norm = True
129
+ self.in_channels = in_channels
130
+ out_channels = in_channels if out_channels is None else out_channels
131
+ self.out_channels = out_channels
132
+ self.use_conv_shortcut = conv_shortcut
133
+ self.time_embedding_norm = time_embedding_norm
134
+ self.output_scale_factor = output_scale_factor
135
+
136
+ if groups_out is None:
137
+ groups_out = groups
138
+
139
+ assert use_inflated_groupnorm is not None
140
+ if use_inflated_groupnorm:
141
+ self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
142
+ else:
143
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
144
+
145
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
146
+
147
+ if temb_channels is not None:
148
+ if self.time_embedding_norm == "default":
149
+ time_emb_proj_out_channels = out_channels
150
+ elif self.time_embedding_norm == "scale_shift":
151
+ time_emb_proj_out_channels = out_channels * 2
152
+ else:
153
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
154
+
155
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
156
+ else:
157
+ self.time_emb_proj = None
158
+
159
+ if use_inflated_groupnorm:
160
+ self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
161
+ else:
162
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
163
+ self.dropout = torch.nn.Dropout(dropout)
164
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
165
+
166
+ if non_linearity == "swish":
167
+ self.nonlinearity = F.silu()
168
+ elif non_linearity == "mish":
169
+ self.nonlinearity = Mish()
170
+ elif non_linearity == "silu":
171
+ self.nonlinearity = nn.SiLU()
172
+
173
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
174
+
175
+ self.conv_shortcut = None
176
+ if self.use_in_shortcut:
177
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
178
+
179
+ def forward(self, input_tensor, temb):
180
+ hidden_states = input_tensor
181
+
182
+ hidden_states = self.norm1(hidden_states)
183
+ hidden_states = self.nonlinearity(hidden_states)
184
+
185
+ hidden_states = self.conv1(hidden_states)
186
+
187
+ if temb is not None:
188
+ if temb.dim() == 3:
189
+ temb = self.time_emb_proj(self.nonlinearity(temb))
190
+ temb = temb.transpose(1, 2).unsqueeze(-1).unsqueeze(-1)
191
+ else:
192
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
193
+
194
+ if temb is not None and self.time_embedding_norm == "default":
195
+ hidden_states = hidden_states + temb
196
+
197
+ hidden_states = self.norm2(hidden_states)
198
+
199
+ if temb is not None and self.time_embedding_norm == "scale_shift":
200
+ scale, shift = torch.chunk(temb, 2, dim=1)
201
+ hidden_states = hidden_states * (1 + scale) + shift
202
+
203
+ hidden_states = self.nonlinearity(hidden_states)
204
+
205
+ hidden_states = self.dropout(hidden_states)
206
+ hidden_states = self.conv2(hidden_states)
207
+
208
+ if self.conv_shortcut is not None:
209
+ input_tensor = self.conv_shortcut(input_tensor)
210
+
211
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
212
+
213
+ return output_tensor
214
+
215
+
216
+ class Mish(torch.nn.Module):
217
+ def forward(self, hidden_states):
218
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
memo/models/transformer_2d.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+ from diffusers.models.normalization import AdaLayerNormSingle
9
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
10
+ from torch import nn
11
+
12
+ from memo.models.attention import BasicTransformerBlock
13
+
14
+
15
+ @dataclass
16
+ class Transformer2DModelOutput(BaseOutput):
17
+ sample: torch.FloatTensor
18
+ ref_feature_list: list[torch.FloatTensor]
19
+
20
+
21
+ class Transformer2DModel(ModelMixin, ConfigMixin):
22
+ _supports_gradient_checkpointing = True
23
+
24
+ @register_to_config
25
+ def __init__(
26
+ self,
27
+ num_attention_heads: int = 16,
28
+ attention_head_dim: int = 88,
29
+ in_channels: Optional[int] = None,
30
+ out_channels: Optional[int] = None,
31
+ num_layers: int = 1,
32
+ dropout: float = 0.0,
33
+ norm_num_groups: int = 32,
34
+ cross_attention_dim: Optional[int] = None,
35
+ attention_bias: bool = False,
36
+ num_vector_embeds: Optional[int] = None,
37
+ patch_size: Optional[int] = None,
38
+ activation_fn: str = "geglu",
39
+ num_embeds_ada_norm: Optional[int] = None,
40
+ use_linear_projection: bool = False,
41
+ only_cross_attention: bool = False,
42
+ double_self_attention: bool = False,
43
+ upcast_attention: bool = False,
44
+ norm_type: str = "layer_norm",
45
+ norm_elementwise_affine: bool = True,
46
+ norm_eps: float = 1e-5,
47
+ attention_type: str = "default",
48
+ is_final_block: bool = False,
49
+ ):
50
+ super().__init__()
51
+ self.use_linear_projection = use_linear_projection
52
+ self.num_attention_heads = num_attention_heads
53
+ self.attention_head_dim = attention_head_dim
54
+ self.is_final_block = is_final_block
55
+ inner_dim = num_attention_heads * attention_head_dim
56
+
57
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
58
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
59
+
60
+ # 1. Transformer2DModel can process both standard continuous images of
61
+ # shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of
62
+ # shape `(batch_size, num_image_vectors)`
63
+ # Define whether input is continuous or discrete depending on configuration
64
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
65
+ self.is_input_vectorized = num_vector_embeds is not None
66
+ self.is_input_patches = in_channels is not None and patch_size is not None
67
+
68
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
69
+ deprecation_message = (
70
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
71
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
72
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
73
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
74
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
75
+ )
76
+ deprecate(
77
+ "norm_type!=num_embeds_ada_norm",
78
+ "1.0.0",
79
+ deprecation_message,
80
+ standard_warn=False,
81
+ )
82
+ norm_type = "ada_norm"
83
+
84
+ if self.is_input_continuous and self.is_input_vectorized:
85
+ raise ValueError(
86
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
87
+ " sure that either `in_channels` or `num_vector_embeds` is None."
88
+ )
89
+
90
+ if self.is_input_vectorized and self.is_input_patches:
91
+ raise ValueError(
92
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
93
+ " sure that either `num_vector_embeds` or `num_patches` is None."
94
+ )
95
+
96
+ if not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
97
+ raise ValueError(
98
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
99
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
100
+ )
101
+
102
+ # 2. Define input layers
103
+ self.in_channels = in_channels
104
+
105
+ self.norm = torch.nn.GroupNorm(
106
+ num_groups=norm_num_groups,
107
+ num_channels=in_channels,
108
+ eps=1e-6,
109
+ affine=True,
110
+ )
111
+ if use_linear_projection:
112
+ self.proj_in = linear_cls(in_channels, inner_dim)
113
+ else:
114
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
115
+
116
+ # 3. Define transformers blocks
117
+ self.transformer_blocks = nn.ModuleList(
118
+ [
119
+ BasicTransformerBlock(
120
+ inner_dim,
121
+ num_attention_heads,
122
+ attention_head_dim,
123
+ dropout=dropout,
124
+ cross_attention_dim=cross_attention_dim,
125
+ activation_fn=activation_fn,
126
+ num_embeds_ada_norm=num_embeds_ada_norm,
127
+ attention_bias=attention_bias,
128
+ only_cross_attention=only_cross_attention,
129
+ double_self_attention=double_self_attention,
130
+ upcast_attention=upcast_attention,
131
+ norm_type=norm_type,
132
+ norm_elementwise_affine=norm_elementwise_affine,
133
+ norm_eps=norm_eps,
134
+ attention_type=attention_type,
135
+ is_final_block=(is_final_block and d == num_layers - 1),
136
+ )
137
+ for d in range(num_layers)
138
+ ]
139
+ )
140
+
141
+ # 4. Define output layers
142
+ self.out_channels = in_channels if out_channels is None else out_channels
143
+ # TODO: should use out_channels for continuous projections
144
+ if not is_final_block:
145
+ if use_linear_projection:
146
+ self.proj_out = linear_cls(inner_dim, in_channels)
147
+ else:
148
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
149
+
150
+ # 5. PixArt-Alpha blocks.
151
+ self.adaln_single = None
152
+ self.use_additional_conditions = False
153
+ if norm_type == "ada_norm_single":
154
+ self.use_additional_conditions = self.config.sample_size == 128
155
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
156
+ # additional conditions until we find better name
157
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
158
+
159
+ self.caption_projection = None
160
+
161
+ self.gradient_checkpointing = False
162
+
163
+ def _set_gradient_checkpointing(self, module, value=False):
164
+ if hasattr(module, "gradient_checkpointing"):
165
+ module.gradient_checkpointing = value
166
+
167
+ def forward(
168
+ self,
169
+ hidden_states: torch.Tensor,
170
+ encoder_hidden_states: Optional[torch.Tensor] = None,
171
+ timestep: Optional[torch.LongTensor] = None,
172
+ class_labels: Optional[torch.LongTensor] = None,
173
+ cross_attention_kwargs: Dict[str, Any] = None,
174
+ attention_mask: Optional[torch.Tensor] = None,
175
+ encoder_attention_mask: Optional[torch.Tensor] = None,
176
+ return_dict: bool = True,
177
+ ):
178
+ if attention_mask is not None and attention_mask.ndim == 2:
179
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
180
+ attention_mask = attention_mask.unsqueeze(1)
181
+
182
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
183
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
184
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
185
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
186
+
187
+ # Retrieve lora scale.
188
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
189
+
190
+ # 1. Input
191
+ batch, _, height, width = hidden_states.shape
192
+ residual = hidden_states
193
+
194
+ hidden_states = self.norm(hidden_states)
195
+ if not self.use_linear_projection:
196
+ hidden_states = (
197
+ self.proj_in(hidden_states, scale=lora_scale) if not USE_PEFT_BACKEND else self.proj_in(hidden_states)
198
+ )
199
+ inner_dim = hidden_states.shape[1]
200
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
201
+ else:
202
+ inner_dim = hidden_states.shape[1]
203
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
204
+ hidden_states = (
205
+ self.proj_in(hidden_states, scale=lora_scale) if not USE_PEFT_BACKEND else self.proj_in(hidden_states)
206
+ )
207
+
208
+ # 2. Blocks
209
+ if self.caption_projection is not None:
210
+ batch_size = hidden_states.shape[0]
211
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
212
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
213
+
214
+ ref_feature_list = []
215
+ for block in self.transformer_blocks:
216
+ if self.training and self.gradient_checkpointing:
217
+
218
+ def create_custom_forward(module, return_dict=None):
219
+ def custom_forward(*inputs):
220
+ if return_dict is not None:
221
+ return module(*inputs, return_dict=return_dict)
222
+
223
+ return module(*inputs)
224
+
225
+ return custom_forward
226
+
227
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
228
+ hidden_states, ref_feature = torch.utils.checkpoint.checkpoint(
229
+ create_custom_forward(block),
230
+ hidden_states,
231
+ attention_mask,
232
+ encoder_hidden_states,
233
+ encoder_attention_mask,
234
+ timestep,
235
+ cross_attention_kwargs,
236
+ class_labels,
237
+ **ckpt_kwargs,
238
+ )
239
+ else:
240
+ hidden_states, ref_feature = block(
241
+ hidden_states, # shape [5, 4096, 320]
242
+ attention_mask=attention_mask,
243
+ encoder_hidden_states=encoder_hidden_states, # shape [1,4,768]
244
+ encoder_attention_mask=encoder_attention_mask,
245
+ timestep=timestep,
246
+ cross_attention_kwargs=cross_attention_kwargs,
247
+ class_labels=class_labels,
248
+ )
249
+ ref_feature_list.append(ref_feature)
250
+
251
+ # 3. Output
252
+ output = None
253
+
254
+ if self.is_final_block:
255
+ if not return_dict:
256
+ return (output, ref_feature_list)
257
+
258
+ return Transformer2DModelOutput(sample=output, ref_feature_list=ref_feature_list)
259
+
260
+ if self.is_input_continuous:
261
+ if not self.use_linear_projection:
262
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
263
+ hidden_states = (
264
+ self.proj_out(hidden_states, scale=lora_scale)
265
+ if not USE_PEFT_BACKEND
266
+ else self.proj_out(hidden_states)
267
+ )
268
+ else:
269
+ hidden_states = (
270
+ self.proj_out(hidden_states, scale=lora_scale)
271
+ if not USE_PEFT_BACKEND
272
+ else self.proj_out(hidden_states)
273
+ )
274
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
275
+
276
+ output = hidden_states + residual
277
+ if not return_dict:
278
+ return (output, ref_feature_list)
279
+
280
+ return Transformer2DModelOutput(sample=output, ref_feature_list=ref_feature_list)
memo/models/transformer_3d.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models import ModelMixin
7
+ from diffusers.utils import BaseOutput
8
+ from einops import rearrange, repeat
9
+ from torch import nn
10
+
11
+ from memo.models.attention import JointAudioTemporalBasicTransformerBlock, TemporalBasicTransformerBlock
12
+
13
+
14
+ def create_custom_forward(module, return_dict=None):
15
+ def custom_forward(*inputs):
16
+ if return_dict is not None:
17
+ return module(*inputs, return_dict=return_dict)
18
+
19
+ return module(*inputs)
20
+
21
+ return custom_forward
22
+
23
+
24
+ @dataclass
25
+ class Transformer3DModelOutput(BaseOutput):
26
+ sample: torch.FloatTensor
27
+
28
+
29
+ class Transformer3DModel(ModelMixin, ConfigMixin):
30
+ _supports_gradient_checkpointing = True
31
+
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ num_attention_heads: int = 16,
36
+ attention_head_dim: int = 88,
37
+ in_channels: Optional[int] = None,
38
+ num_layers: int = 1,
39
+ dropout: float = 0.0,
40
+ norm_num_groups: int = 32,
41
+ cross_attention_dim: Optional[int] = None,
42
+ attention_bias: bool = False,
43
+ activation_fn: str = "geglu",
44
+ use_linear_projection: bool = False,
45
+ only_cross_attention: bool = False,
46
+ upcast_attention: bool = False,
47
+ unet_use_cross_frame_attention=None,
48
+ unet_use_temporal_attention=None,
49
+ use_audio_module=False,
50
+ depth=0,
51
+ unet_block_name=None,
52
+ emo_drop_rate=0.3,
53
+ is_final_block=False,
54
+ ):
55
+ super().__init__()
56
+ self.use_linear_projection = use_linear_projection
57
+ self.num_attention_heads = num_attention_heads
58
+ self.attention_head_dim = attention_head_dim
59
+ inner_dim = num_attention_heads * attention_head_dim
60
+ self.use_audio_module = use_audio_module
61
+ # Define input layers
62
+ self.in_channels = in_channels
63
+ self.is_final_block = is_final_block
64
+
65
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
66
+ if use_linear_projection:
67
+ self.proj_in = nn.Linear(in_channels, inner_dim)
68
+ else:
69
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
70
+
71
+ if use_audio_module:
72
+ self.transformer_blocks = nn.ModuleList(
73
+ [
74
+ JointAudioTemporalBasicTransformerBlock(
75
+ dim=inner_dim,
76
+ num_attention_heads=num_attention_heads,
77
+ attention_head_dim=attention_head_dim,
78
+ dropout=dropout,
79
+ cross_attention_dim=cross_attention_dim,
80
+ activation_fn=activation_fn,
81
+ attention_bias=attention_bias,
82
+ only_cross_attention=only_cross_attention,
83
+ upcast_attention=upcast_attention,
84
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85
+ unet_use_temporal_attention=unet_use_temporal_attention,
86
+ depth=depth,
87
+ unet_block_name=unet_block_name,
88
+ use_ada_layer_norm=True,
89
+ emo_drop_rate=emo_drop_rate,
90
+ is_final_block=(is_final_block and d == num_layers - 1),
91
+ )
92
+ for d in range(num_layers)
93
+ ]
94
+ )
95
+ else:
96
+ self.transformer_blocks = nn.ModuleList(
97
+ [
98
+ TemporalBasicTransformerBlock(
99
+ inner_dim,
100
+ num_attention_heads,
101
+ attention_head_dim,
102
+ dropout=dropout,
103
+ cross_attention_dim=cross_attention_dim,
104
+ activation_fn=activation_fn,
105
+ attention_bias=attention_bias,
106
+ only_cross_attention=only_cross_attention,
107
+ upcast_attention=upcast_attention,
108
+ )
109
+ for _ in range(num_layers)
110
+ ]
111
+ )
112
+
113
+ # 4. Define output layers
114
+ if use_linear_projection:
115
+ self.proj_out = nn.Linear(in_channels, inner_dim)
116
+ else:
117
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
118
+
119
+ self.gradient_checkpointing = False
120
+
121
+ def _set_gradient_checkpointing(self, module, value=False):
122
+ if hasattr(module, "gradient_checkpointing"):
123
+ module.gradient_checkpointing = value
124
+
125
+ def forward(
126
+ self,
127
+ hidden_states,
128
+ ref_img_feature=None,
129
+ encoder_hidden_states=None,
130
+ attention_mask=None,
131
+ timestep=None,
132
+ emotion=None,
133
+ uc_mask=None,
134
+ return_dict: bool = True,
135
+ ):
136
+ # Input
137
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
138
+ video_length = hidden_states.shape[2]
139
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
140
+
141
+ if self.use_audio_module:
142
+ if encoder_hidden_states.dim() == 4:
143
+ encoder_hidden_states = rearrange(
144
+ encoder_hidden_states,
145
+ "bs f margin dim -> (bs f) margin dim",
146
+ )
147
+ else:
148
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
149
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b f) n c", f=video_length)
150
+
151
+ batch, _, height, weight = hidden_states.shape
152
+ residual = hidden_states
153
+ if self.use_audio_module:
154
+ residual_audio = encoder_hidden_states
155
+
156
+ hidden_states = self.norm(hidden_states)
157
+ if not self.use_linear_projection:
158
+ hidden_states = self.proj_in(hidden_states)
159
+ inner_dim = hidden_states.shape[1]
160
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
161
+ else:
162
+ inner_dim = hidden_states.shape[1]
163
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
164
+ hidden_states = self.proj_in(hidden_states)
165
+
166
+ # Blocks
167
+ for block in self.transformer_blocks:
168
+ if self.training and self.gradient_checkpointing:
169
+ if isinstance(block, TemporalBasicTransformerBlock):
170
+ hidden_states = torch.utils.checkpoint.checkpoint(
171
+ create_custom_forward(block),
172
+ hidden_states,
173
+ ref_img_feature,
174
+ None, # attention_mask
175
+ encoder_hidden_states,
176
+ timestep,
177
+ None, # cross_attention_kwargs
178
+ video_length,
179
+ uc_mask,
180
+ )
181
+ elif isinstance(block, JointAudioTemporalBasicTransformerBlock):
182
+ (
183
+ hidden_states,
184
+ encoder_hidden_states,
185
+ ) = torch.utils.checkpoint.checkpoint(
186
+ create_custom_forward(block),
187
+ hidden_states,
188
+ encoder_hidden_states,
189
+ attention_mask,
190
+ emotion,
191
+ )
192
+ else:
193
+ hidden_states = torch.utils.checkpoint.checkpoint(
194
+ create_custom_forward(block),
195
+ hidden_states,
196
+ encoder_hidden_states,
197
+ timestep,
198
+ attention_mask,
199
+ video_length,
200
+ )
201
+ else:
202
+ if isinstance(block, TemporalBasicTransformerBlock):
203
+ hidden_states = block(
204
+ hidden_states=hidden_states,
205
+ ref_img_feature=ref_img_feature,
206
+ encoder_hidden_states=encoder_hidden_states,
207
+ timestep=timestep,
208
+ video_length=video_length,
209
+ uc_mask=uc_mask,
210
+ )
211
+ elif isinstance(block, JointAudioTemporalBasicTransformerBlock):
212
+ hidden_states, encoder_hidden_states = block(
213
+ hidden_states, # shape [2, 4096, 320]
214
+ encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640]
215
+ attention_mask=attention_mask,
216
+ emotion=emotion,
217
+ )
218
+ else:
219
+ hidden_states = block(
220
+ hidden_states, # shape [2, 4096, 320]
221
+ encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640]
222
+ attention_mask=attention_mask,
223
+ timestep=timestep,
224
+ video_length=video_length,
225
+ )
226
+
227
+ # Output
228
+ if not self.use_linear_projection:
229
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
230
+ hidden_states = self.proj_out(hidden_states)
231
+ else:
232
+ hidden_states = self.proj_out(hidden_states)
233
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
234
+
235
+ output = hidden_states + residual
236
+
237
+ if self.use_audio_module and not self.is_final_block:
238
+ audio_output = encoder_hidden_states + residual_audio
239
+ else:
240
+ audio_output = None
241
+
242
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
243
+ if not return_dict:
244
+ if self.use_audio_module:
245
+ return output, audio_output
246
+ else:
247
+ return output
248
+
249
+ if self.use_audio_module:
250
+ return output, audio_output
251
+ else:
252
+ return output
memo/models/unet_2d_blocks.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from diffusers.models.activations import get_activation
5
+ from diffusers.models.attention_processor import Attention
6
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
7
+ from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel
8
+ from diffusers.utils import is_torch_version, logging
9
+ from diffusers.utils.torch_utils import apply_freeu
10
+ from torch import nn
11
+
12
+ from memo.models.transformer_2d import Transformer2DModel
13
+
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ def create_custom_forward(module, return_dict=None):
19
+ def custom_forward(*inputs):
20
+ if return_dict is not None:
21
+ return module(*inputs, return_dict=return_dict)
22
+
23
+ return module(*inputs)
24
+
25
+ return custom_forward
26
+
27
+
28
+ def get_down_block(
29
+ down_block_type: str,
30
+ num_layers: int,
31
+ in_channels: int,
32
+ out_channels: int,
33
+ temb_channels: int,
34
+ add_downsample: bool,
35
+ resnet_eps: float,
36
+ resnet_act_fn: str,
37
+ transformer_layers_per_block: int = 1,
38
+ num_attention_heads: Optional[int] = None,
39
+ resnet_groups: Optional[int] = None,
40
+ cross_attention_dim: Optional[int] = None,
41
+ downsample_padding: Optional[int] = None,
42
+ dual_cross_attention: bool = False,
43
+ use_linear_projection: bool = False,
44
+ only_cross_attention: bool = False,
45
+ upcast_attention: bool = False,
46
+ resnet_time_scale_shift: str = "default",
47
+ attention_type: str = "default",
48
+ attention_head_dim: Optional[int] = None,
49
+ dropout: float = 0.0,
50
+ ):
51
+ # If attn head dim is not defined, we default it to the number of heads
52
+ if attention_head_dim is None:
53
+ logger.warning("It is recommended to provide `attention_head_dim` when calling `get_down_block`.")
54
+ logger.warning(f"Defaulting `attention_head_dim` to {num_attention_heads}.")
55
+ attention_head_dim = num_attention_heads
56
+
57
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
58
+ if down_block_type == "DownBlock2D":
59
+ return DownBlock2D(
60
+ num_layers=num_layers,
61
+ in_channels=in_channels,
62
+ out_channels=out_channels,
63
+ temb_channels=temb_channels,
64
+ dropout=dropout,
65
+ add_downsample=add_downsample,
66
+ resnet_eps=resnet_eps,
67
+ resnet_act_fn=resnet_act_fn,
68
+ resnet_groups=resnet_groups,
69
+ downsample_padding=downsample_padding,
70
+ resnet_time_scale_shift=resnet_time_scale_shift,
71
+ )
72
+
73
+ if down_block_type == "CrossAttnDownBlock2D":
74
+ if cross_attention_dim is None:
75
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
76
+ return CrossAttnDownBlock2D(
77
+ num_layers=num_layers,
78
+ transformer_layers_per_block=transformer_layers_per_block,
79
+ in_channels=in_channels,
80
+ out_channels=out_channels,
81
+ temb_channels=temb_channels,
82
+ dropout=dropout,
83
+ add_downsample=add_downsample,
84
+ resnet_eps=resnet_eps,
85
+ resnet_act_fn=resnet_act_fn,
86
+ resnet_groups=resnet_groups,
87
+ downsample_padding=downsample_padding,
88
+ cross_attention_dim=cross_attention_dim,
89
+ num_attention_heads=num_attention_heads,
90
+ dual_cross_attention=dual_cross_attention,
91
+ use_linear_projection=use_linear_projection,
92
+ only_cross_attention=only_cross_attention,
93
+ upcast_attention=upcast_attention,
94
+ resnet_time_scale_shift=resnet_time_scale_shift,
95
+ attention_type=attention_type,
96
+ )
97
+ raise ValueError(f"{down_block_type} does not exist.")
98
+
99
+
100
+ def get_up_block(
101
+ up_block_type: str,
102
+ num_layers: int,
103
+ in_channels: int,
104
+ out_channels: int,
105
+ prev_output_channel: int,
106
+ temb_channels: int,
107
+ add_upsample: bool,
108
+ resnet_eps: float,
109
+ resnet_act_fn: str,
110
+ resolution_idx: Optional[int] = None,
111
+ transformer_layers_per_block: int = 1,
112
+ num_attention_heads: Optional[int] = None,
113
+ resnet_groups: Optional[int] = None,
114
+ cross_attention_dim: Optional[int] = None,
115
+ dual_cross_attention: bool = False,
116
+ use_linear_projection: bool = False,
117
+ only_cross_attention: bool = False,
118
+ upcast_attention: bool = False,
119
+ resnet_time_scale_shift: str = "default",
120
+ attention_type: str = "default",
121
+ attention_head_dim: Optional[int] = None,
122
+ dropout: float = 0.0,
123
+ is_final_block: bool = False,
124
+ ) -> nn.Module:
125
+ # If attn head dim is not defined, we default it to the number of heads
126
+ if attention_head_dim is None:
127
+ logger.warning("It is recommended to provide `attention_head_dim` when calling `get_up_block`.")
128
+ logger.warning(f"Defaulting `attention_head_dim` to {num_attention_heads}.")
129
+ attention_head_dim = num_attention_heads
130
+
131
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
132
+ if up_block_type == "UpBlock2D":
133
+ return UpBlock2D(
134
+ num_layers=num_layers,
135
+ in_channels=in_channels,
136
+ out_channels=out_channels,
137
+ prev_output_channel=prev_output_channel,
138
+ temb_channels=temb_channels,
139
+ resolution_idx=resolution_idx,
140
+ dropout=dropout,
141
+ add_upsample=add_upsample,
142
+ resnet_eps=resnet_eps,
143
+ resnet_act_fn=resnet_act_fn,
144
+ resnet_groups=resnet_groups,
145
+ resnet_time_scale_shift=resnet_time_scale_shift,
146
+ )
147
+ if up_block_type == "CrossAttnUpBlock2D":
148
+ if cross_attention_dim is None:
149
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
150
+ return CrossAttnUpBlock2D(
151
+ num_layers=num_layers,
152
+ transformer_layers_per_block=transformer_layers_per_block,
153
+ in_channels=in_channels,
154
+ out_channels=out_channels,
155
+ prev_output_channel=prev_output_channel,
156
+ temb_channels=temb_channels,
157
+ resolution_idx=resolution_idx,
158
+ dropout=dropout,
159
+ add_upsample=add_upsample,
160
+ resnet_eps=resnet_eps,
161
+ resnet_act_fn=resnet_act_fn,
162
+ resnet_groups=resnet_groups,
163
+ cross_attention_dim=cross_attention_dim,
164
+ num_attention_heads=num_attention_heads,
165
+ dual_cross_attention=dual_cross_attention,
166
+ use_linear_projection=use_linear_projection,
167
+ only_cross_attention=only_cross_attention,
168
+ upcast_attention=upcast_attention,
169
+ resnet_time_scale_shift=resnet_time_scale_shift,
170
+ attention_type=attention_type,
171
+ is_final_block=is_final_block,
172
+ )
173
+
174
+ raise ValueError(f"{up_block_type} does not exist.")
175
+
176
+
177
+ class AutoencoderTinyBlock(nn.Module):
178
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
179
+ super().__init__()
180
+ act_fn = get_activation(act_fn)
181
+ self.conv = nn.Sequential(
182
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
183
+ act_fn,
184
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
185
+ act_fn,
186
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
187
+ )
188
+ self.skip = (
189
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
190
+ if in_channels != out_channels
191
+ else nn.Identity()
192
+ )
193
+ self.fuse = nn.ReLU()
194
+
195
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
196
+ return self.fuse(self.conv(x) + self.skip(x))
197
+
198
+
199
+ class UNetMidBlock2D(nn.Module):
200
+ def __init__(
201
+ self,
202
+ in_channels: int,
203
+ temb_channels: int,
204
+ dropout: float = 0.0,
205
+ num_layers: int = 1,
206
+ resnet_eps: float = 1e-6,
207
+ resnet_time_scale_shift: str = "default", # default, spatial
208
+ resnet_act_fn: str = "swish",
209
+ resnet_groups: int = 32,
210
+ attn_groups: Optional[int] = None,
211
+ resnet_pre_norm: bool = True,
212
+ add_attention: bool = True,
213
+ attention_head_dim: int = 1,
214
+ output_scale_factor: float = 1.0,
215
+ ):
216
+ super().__init__()
217
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
218
+ self.add_attention = add_attention
219
+
220
+ if attn_groups is None:
221
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
222
+
223
+ # there is always at least one resnet
224
+ resnets = [
225
+ ResnetBlock2D(
226
+ in_channels=in_channels,
227
+ out_channels=in_channels,
228
+ temb_channels=temb_channels,
229
+ eps=resnet_eps,
230
+ groups=resnet_groups,
231
+ dropout=dropout,
232
+ time_embedding_norm=resnet_time_scale_shift,
233
+ non_linearity=resnet_act_fn,
234
+ output_scale_factor=output_scale_factor,
235
+ pre_norm=resnet_pre_norm,
236
+ )
237
+ ]
238
+ attentions = []
239
+
240
+ if attention_head_dim is None:
241
+ logger.warning(
242
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
243
+ )
244
+ attention_head_dim = in_channels
245
+
246
+ for _ in range(num_layers):
247
+ if self.add_attention:
248
+ attentions.append(
249
+ Attention(
250
+ in_channels,
251
+ heads=in_channels // attention_head_dim,
252
+ dim_head=attention_head_dim,
253
+ rescale_output_factor=output_scale_factor,
254
+ eps=resnet_eps,
255
+ norm_num_groups=attn_groups,
256
+ spatial_norm_dim=(temb_channels if resnet_time_scale_shift == "spatial" else None),
257
+ residual_connection=True,
258
+ bias=True,
259
+ upcast_softmax=True,
260
+ _from_deprecated_attn_block=True,
261
+ )
262
+ )
263
+ else:
264
+ attentions.append(None)
265
+
266
+ resnets.append(
267
+ ResnetBlock2D(
268
+ in_channels=in_channels,
269
+ out_channels=in_channels,
270
+ temb_channels=temb_channels,
271
+ eps=resnet_eps,
272
+ groups=resnet_groups,
273
+ dropout=dropout,
274
+ time_embedding_norm=resnet_time_scale_shift,
275
+ non_linearity=resnet_act_fn,
276
+ output_scale_factor=output_scale_factor,
277
+ pre_norm=resnet_pre_norm,
278
+ )
279
+ )
280
+
281
+ self.attentions = nn.ModuleList(attentions)
282
+ self.resnets = nn.ModuleList(resnets)
283
+
284
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
285
+ hidden_states = self.resnets[0](hidden_states, temb)
286
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
287
+ if attn is not None:
288
+ hidden_states = attn(hidden_states, temb=temb)
289
+ hidden_states = resnet(hidden_states, temb)
290
+
291
+ return hidden_states
292
+
293
+
294
+ class UNetMidBlock2DCrossAttn(nn.Module):
295
+ def __init__(
296
+ self,
297
+ in_channels: int,
298
+ temb_channels: int,
299
+ dropout: float = 0.0,
300
+ num_layers: int = 1,
301
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
302
+ resnet_eps: float = 1e-6,
303
+ resnet_time_scale_shift: str = "default",
304
+ resnet_act_fn: str = "swish",
305
+ resnet_groups: int = 32,
306
+ resnet_pre_norm: bool = True,
307
+ num_attention_heads: int = 1,
308
+ output_scale_factor: float = 1.0,
309
+ cross_attention_dim: int = 1280,
310
+ dual_cross_attention: bool = False,
311
+ use_linear_projection: bool = False,
312
+ upcast_attention: bool = False,
313
+ attention_type: str = "default",
314
+ ):
315
+ super().__init__()
316
+
317
+ self.has_cross_attention = True
318
+ self.num_attention_heads = num_attention_heads
319
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
320
+
321
+ # support for variable transformer layers per block
322
+ if isinstance(transformer_layers_per_block, int):
323
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
324
+
325
+ # there is always at least one resnet
326
+ resnets = [
327
+ ResnetBlock2D(
328
+ in_channels=in_channels,
329
+ out_channels=in_channels,
330
+ temb_channels=temb_channels,
331
+ eps=resnet_eps,
332
+ groups=resnet_groups,
333
+ dropout=dropout,
334
+ time_embedding_norm=resnet_time_scale_shift,
335
+ non_linearity=resnet_act_fn,
336
+ output_scale_factor=output_scale_factor,
337
+ pre_norm=resnet_pre_norm,
338
+ )
339
+ ]
340
+ attentions = []
341
+
342
+ for i in range(num_layers):
343
+ if not dual_cross_attention:
344
+ attentions.append(
345
+ Transformer2DModel(
346
+ num_attention_heads,
347
+ in_channels // num_attention_heads,
348
+ in_channels=in_channels,
349
+ num_layers=transformer_layers_per_block[i],
350
+ cross_attention_dim=cross_attention_dim,
351
+ norm_num_groups=resnet_groups,
352
+ use_linear_projection=use_linear_projection,
353
+ upcast_attention=upcast_attention,
354
+ attention_type=attention_type,
355
+ )
356
+ )
357
+ else:
358
+ attentions.append(
359
+ DualTransformer2DModel(
360
+ num_attention_heads,
361
+ in_channels // num_attention_heads,
362
+ in_channels=in_channels,
363
+ num_layers=1,
364
+ cross_attention_dim=cross_attention_dim,
365
+ norm_num_groups=resnet_groups,
366
+ )
367
+ )
368
+ resnets.append(
369
+ ResnetBlock2D(
370
+ in_channels=in_channels,
371
+ out_channels=in_channels,
372
+ temb_channels=temb_channels,
373
+ eps=resnet_eps,
374
+ groups=resnet_groups,
375
+ dropout=dropout,
376
+ time_embedding_norm=resnet_time_scale_shift,
377
+ non_linearity=resnet_act_fn,
378
+ output_scale_factor=output_scale_factor,
379
+ pre_norm=resnet_pre_norm,
380
+ )
381
+ )
382
+
383
+ self.attentions = nn.ModuleList(attentions)
384
+ self.resnets = nn.ModuleList(resnets)
385
+
386
+ self.gradient_checkpointing = False
387
+
388
+ def forward(
389
+ self,
390
+ hidden_states: torch.FloatTensor,
391
+ temb: Optional[torch.FloatTensor] = None,
392
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
393
+ attention_mask: Optional[torch.FloatTensor] = None,
394
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
395
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
396
+ ) -> torch.FloatTensor:
397
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
398
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
399
+ ref_feature_list = []
400
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
401
+ hidden_states, ref_feature = attn(
402
+ hidden_states,
403
+ encoder_hidden_states=encoder_hidden_states,
404
+ cross_attention_kwargs=cross_attention_kwargs,
405
+ attention_mask=attention_mask,
406
+ encoder_attention_mask=encoder_attention_mask,
407
+ return_dict=False,
408
+ )
409
+ if self.training and self.gradient_checkpointing:
410
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
411
+ hidden_states = torch.utils.checkpoint.checkpoint(
412
+ create_custom_forward(resnet),
413
+ hidden_states,
414
+ temb,
415
+ **ckpt_kwargs,
416
+ )
417
+ else:
418
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
419
+ ref_feature_list.append(ref_feature)
420
+
421
+ return hidden_states, ref_feature_list
422
+
423
+
424
+ class CrossAttnDownBlock2D(nn.Module):
425
+ def __init__(
426
+ self,
427
+ in_channels: int,
428
+ out_channels: int,
429
+ temb_channels: int,
430
+ dropout: float = 0.0,
431
+ num_layers: int = 1,
432
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
433
+ resnet_eps: float = 1e-6,
434
+ resnet_time_scale_shift: str = "default",
435
+ resnet_act_fn: str = "swish",
436
+ resnet_groups: int = 32,
437
+ resnet_pre_norm: bool = True,
438
+ num_attention_heads: int = 1,
439
+ cross_attention_dim: int = 1280,
440
+ output_scale_factor: float = 1.0,
441
+ downsample_padding: int = 1,
442
+ add_downsample: bool = True,
443
+ dual_cross_attention: bool = False,
444
+ use_linear_projection: bool = False,
445
+ only_cross_attention: bool = False,
446
+ upcast_attention: bool = False,
447
+ attention_type: str = "default",
448
+ ):
449
+ super().__init__()
450
+ resnets = []
451
+ attentions = []
452
+
453
+ self.has_cross_attention = True
454
+ self.num_attention_heads = num_attention_heads
455
+ if isinstance(transformer_layers_per_block, int):
456
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
457
+
458
+ for i in range(num_layers):
459
+ in_channels = in_channels if i == 0 else out_channels
460
+ resnets.append(
461
+ ResnetBlock2D(
462
+ in_channels=in_channels,
463
+ out_channels=out_channels,
464
+ temb_channels=temb_channels,
465
+ eps=resnet_eps,
466
+ groups=resnet_groups,
467
+ dropout=dropout,
468
+ time_embedding_norm=resnet_time_scale_shift,
469
+ non_linearity=resnet_act_fn,
470
+ output_scale_factor=output_scale_factor,
471
+ pre_norm=resnet_pre_norm,
472
+ )
473
+ )
474
+ if not dual_cross_attention:
475
+ attentions.append(
476
+ Transformer2DModel(
477
+ num_attention_heads,
478
+ out_channels // num_attention_heads,
479
+ in_channels=out_channels,
480
+ num_layers=transformer_layers_per_block[i],
481
+ cross_attention_dim=cross_attention_dim,
482
+ norm_num_groups=resnet_groups,
483
+ use_linear_projection=use_linear_projection,
484
+ only_cross_attention=only_cross_attention,
485
+ upcast_attention=upcast_attention,
486
+ attention_type=attention_type,
487
+ )
488
+ )
489
+ else:
490
+ attentions.append(
491
+ DualTransformer2DModel(
492
+ num_attention_heads,
493
+ out_channels // num_attention_heads,
494
+ in_channels=out_channels,
495
+ num_layers=1,
496
+ cross_attention_dim=cross_attention_dim,
497
+ norm_num_groups=resnet_groups,
498
+ )
499
+ )
500
+ self.attentions = nn.ModuleList(attentions)
501
+ self.resnets = nn.ModuleList(resnets)
502
+
503
+ if add_downsample:
504
+ self.downsamplers = nn.ModuleList(
505
+ [
506
+ Downsample2D(
507
+ out_channels,
508
+ use_conv=True,
509
+ out_channels=out_channels,
510
+ padding=downsample_padding,
511
+ name="op",
512
+ )
513
+ ]
514
+ )
515
+ else:
516
+ self.downsamplers = None
517
+
518
+ self.gradient_checkpointing = False
519
+
520
+ def forward(
521
+ self,
522
+ hidden_states: torch.FloatTensor,
523
+ temb: Optional[torch.FloatTensor] = None,
524
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
525
+ attention_mask: Optional[torch.FloatTensor] = None,
526
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
527
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
528
+ additional_residuals: Optional[torch.FloatTensor] = None,
529
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
530
+ output_states = ()
531
+
532
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
533
+
534
+ blocks = list(zip(self.resnets, self.attentions))
535
+ ref_feature_list = []
536
+ for i, (resnet, attn) in enumerate(blocks):
537
+ if self.training and self.gradient_checkpointing:
538
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
539
+ hidden_states = torch.utils.checkpoint.checkpoint(
540
+ create_custom_forward(resnet),
541
+ hidden_states,
542
+ temb,
543
+ **ckpt_kwargs,
544
+ )
545
+ else:
546
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
547
+ hidden_states, ref_feature = attn(
548
+ hidden_states,
549
+ encoder_hidden_states=encoder_hidden_states,
550
+ cross_attention_kwargs=cross_attention_kwargs,
551
+ attention_mask=attention_mask,
552
+ encoder_attention_mask=encoder_attention_mask,
553
+ return_dict=False,
554
+ )
555
+ ref_feature_list.append(ref_feature)
556
+
557
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
558
+ if i == len(blocks) - 1 and additional_residuals is not None:
559
+ hidden_states = hidden_states + additional_residuals
560
+
561
+ output_states = output_states + (hidden_states,)
562
+
563
+ if self.downsamplers is not None:
564
+ for downsampler in self.downsamplers:
565
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
566
+
567
+ output_states = output_states + (hidden_states,)
568
+
569
+ return hidden_states, output_states, ref_feature_list
570
+
571
+
572
+ class DownBlock2D(nn.Module):
573
+ def __init__(
574
+ self,
575
+ in_channels: int,
576
+ out_channels: int,
577
+ temb_channels: int,
578
+ dropout: float = 0.0,
579
+ num_layers: int = 1,
580
+ resnet_eps: float = 1e-6,
581
+ resnet_time_scale_shift: str = "default",
582
+ resnet_act_fn: str = "swish",
583
+ resnet_groups: int = 32,
584
+ resnet_pre_norm: bool = True,
585
+ output_scale_factor: float = 1.0,
586
+ add_downsample: bool = True,
587
+ downsample_padding: int = 1,
588
+ ):
589
+ super().__init__()
590
+ resnets = []
591
+
592
+ for i in range(num_layers):
593
+ in_channels = in_channels if i == 0 else out_channels
594
+ resnets.append(
595
+ ResnetBlock2D(
596
+ in_channels=in_channels,
597
+ out_channels=out_channels,
598
+ temb_channels=temb_channels,
599
+ eps=resnet_eps,
600
+ groups=resnet_groups,
601
+ dropout=dropout,
602
+ time_embedding_norm=resnet_time_scale_shift,
603
+ non_linearity=resnet_act_fn,
604
+ output_scale_factor=output_scale_factor,
605
+ pre_norm=resnet_pre_norm,
606
+ )
607
+ )
608
+
609
+ self.resnets = nn.ModuleList(resnets)
610
+
611
+ if add_downsample:
612
+ self.downsamplers = nn.ModuleList(
613
+ [
614
+ Downsample2D(
615
+ out_channels,
616
+ use_conv=True,
617
+ out_channels=out_channels,
618
+ padding=downsample_padding,
619
+ name="op",
620
+ )
621
+ ]
622
+ )
623
+ else:
624
+ self.downsamplers = None
625
+
626
+ self.gradient_checkpointing = False
627
+
628
+ def forward(
629
+ self,
630
+ hidden_states: torch.FloatTensor,
631
+ temb: Optional[torch.FloatTensor] = None,
632
+ scale: float = 1.0,
633
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
634
+ output_states = ()
635
+
636
+ ref_feature_list = []
637
+ for resnet in self.resnets:
638
+ if self.training and self.gradient_checkpointing:
639
+ if is_torch_version(">=", "1.11.0"):
640
+ hidden_states = torch.utils.checkpoint.checkpoint(
641
+ create_custom_forward(resnet),
642
+ hidden_states,
643
+ temb,
644
+ use_reentrant=False,
645
+ )
646
+ else:
647
+ hidden_states = torch.utils.checkpoint.checkpoint(
648
+ create_custom_forward(resnet), hidden_states, temb
649
+ )
650
+ else:
651
+ hidden_states = resnet(hidden_states, temb, scale=scale)
652
+
653
+ ref_feature_list.append(hidden_states)
654
+
655
+ output_states = output_states + (hidden_states,)
656
+
657
+ if self.downsamplers is not None:
658
+ for downsampler in self.downsamplers:
659
+ hidden_states = downsampler(hidden_states, scale=scale)
660
+
661
+ output_states = output_states + (hidden_states,)
662
+
663
+ return hidden_states, output_states, ref_feature_list
664
+
665
+
666
+ class CrossAttnUpBlock2D(nn.Module):
667
+ def __init__(
668
+ self,
669
+ in_channels: int,
670
+ out_channels: int,
671
+ prev_output_channel: int,
672
+ temb_channels: int,
673
+ resolution_idx: Optional[int] = None,
674
+ dropout: float = 0.0,
675
+ num_layers: int = 1,
676
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
677
+ resnet_eps: float = 1e-6,
678
+ resnet_time_scale_shift: str = "default",
679
+ resnet_act_fn: str = "swish",
680
+ resnet_groups: int = 32,
681
+ resnet_pre_norm: bool = True,
682
+ num_attention_heads: int = 1,
683
+ cross_attention_dim: int = 1280,
684
+ output_scale_factor: float = 1.0,
685
+ add_upsample: bool = True,
686
+ dual_cross_attention: bool = False,
687
+ use_linear_projection: bool = False,
688
+ only_cross_attention: bool = False,
689
+ upcast_attention: bool = False,
690
+ attention_type: str = "default",
691
+ is_final_block: bool = False,
692
+ ):
693
+ super().__init__()
694
+ resnets = []
695
+ attentions = []
696
+
697
+ self.has_cross_attention = True
698
+ self.num_attention_heads = num_attention_heads
699
+ self.is_final_block = is_final_block
700
+
701
+ if isinstance(transformer_layers_per_block, int):
702
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
703
+
704
+ for i in range(num_layers):
705
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
706
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
707
+
708
+ resnets.append(
709
+ ResnetBlock2D(
710
+ in_channels=resnet_in_channels + res_skip_channels,
711
+ out_channels=out_channels,
712
+ temb_channels=temb_channels,
713
+ eps=resnet_eps,
714
+ groups=resnet_groups,
715
+ dropout=dropout,
716
+ time_embedding_norm=resnet_time_scale_shift,
717
+ non_linearity=resnet_act_fn,
718
+ output_scale_factor=output_scale_factor,
719
+ pre_norm=resnet_pre_norm,
720
+ )
721
+ )
722
+ if not dual_cross_attention:
723
+ attentions.append(
724
+ Transformer2DModel(
725
+ num_attention_heads,
726
+ out_channels // num_attention_heads,
727
+ in_channels=out_channels,
728
+ num_layers=transformer_layers_per_block[i],
729
+ cross_attention_dim=cross_attention_dim,
730
+ norm_num_groups=resnet_groups,
731
+ use_linear_projection=use_linear_projection,
732
+ only_cross_attention=only_cross_attention,
733
+ upcast_attention=upcast_attention,
734
+ attention_type=attention_type,
735
+ is_final_block=(is_final_block and i == num_layers - 1),
736
+ )
737
+ )
738
+ else:
739
+ attentions.append(
740
+ DualTransformer2DModel(
741
+ num_attention_heads,
742
+ out_channels // num_attention_heads,
743
+ in_channels=out_channels,
744
+ num_layers=1,
745
+ cross_attention_dim=cross_attention_dim,
746
+ norm_num_groups=resnet_groups,
747
+ )
748
+ )
749
+ self.attentions = nn.ModuleList(attentions)
750
+ self.resnets = nn.ModuleList(resnets)
751
+
752
+ if add_upsample:
753
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
754
+ else:
755
+ self.upsamplers = None
756
+
757
+ self.gradient_checkpointing = False
758
+ self.resolution_idx = resolution_idx
759
+
760
+ def forward(
761
+ self,
762
+ hidden_states: torch.FloatTensor,
763
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
764
+ temb: Optional[torch.FloatTensor] = None,
765
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
766
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
767
+ upsample_size: Optional[int] = None,
768
+ attention_mask: Optional[torch.FloatTensor] = None,
769
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
770
+ ) -> torch.FloatTensor:
771
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
772
+ is_freeu_enabled = (
773
+ getattr(self, "s1", None)
774
+ and getattr(self, "s2", None)
775
+ and getattr(self, "b1", None)
776
+ and getattr(self, "b2", None)
777
+ )
778
+
779
+ ref_feature_list = []
780
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
781
+ # pop res hidden states
782
+ res_hidden_states = res_hidden_states_tuple[-1]
783
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
784
+
785
+ # FreeU: Only operate on the first two stages
786
+ if is_freeu_enabled:
787
+ hidden_states, res_hidden_states = apply_freeu(
788
+ self.resolution_idx,
789
+ hidden_states,
790
+ res_hidden_states,
791
+ s1=self.s1,
792
+ s2=self.s2,
793
+ b1=self.b1,
794
+ b2=self.b2,
795
+ )
796
+
797
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
798
+
799
+ if self.training and self.gradient_checkpointing:
800
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
801
+ hidden_states = torch.utils.checkpoint.checkpoint(
802
+ create_custom_forward(resnet),
803
+ hidden_states,
804
+ temb,
805
+ **ckpt_kwargs,
806
+ )
807
+ else:
808
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
809
+ hidden_states, ref_feature = attn(
810
+ hidden_states,
811
+ encoder_hidden_states=encoder_hidden_states,
812
+ cross_attention_kwargs=cross_attention_kwargs,
813
+ attention_mask=attention_mask,
814
+ encoder_attention_mask=encoder_attention_mask,
815
+ return_dict=False,
816
+ )
817
+ ref_feature_list.append(ref_feature)
818
+
819
+ if self.is_final_block:
820
+ assert hidden_states is None
821
+ else:
822
+ if self.upsamplers is not None:
823
+ for upsampler in self.upsamplers:
824
+ hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
825
+
826
+ return hidden_states, ref_feature_list
827
+
828
+
829
+ class UpBlock2D(nn.Module):
830
+ def __init__(
831
+ self,
832
+ in_channels: int,
833
+ prev_output_channel: int,
834
+ out_channels: int,
835
+ temb_channels: int,
836
+ resolution_idx: Optional[int] = None,
837
+ dropout: float = 0.0,
838
+ num_layers: int = 1,
839
+ resnet_eps: float = 1e-6,
840
+ resnet_time_scale_shift: str = "default",
841
+ resnet_act_fn: str = "swish",
842
+ resnet_groups: int = 32,
843
+ resnet_pre_norm: bool = True,
844
+ output_scale_factor: float = 1.0,
845
+ add_upsample: bool = True,
846
+ ):
847
+ super().__init__()
848
+ resnets = []
849
+
850
+ for i in range(num_layers):
851
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
852
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
853
+
854
+ resnets.append(
855
+ ResnetBlock2D(
856
+ in_channels=resnet_in_channels + res_skip_channels,
857
+ out_channels=out_channels,
858
+ temb_channels=temb_channels,
859
+ eps=resnet_eps,
860
+ groups=resnet_groups,
861
+ dropout=dropout,
862
+ time_embedding_norm=resnet_time_scale_shift,
863
+ non_linearity=resnet_act_fn,
864
+ output_scale_factor=output_scale_factor,
865
+ pre_norm=resnet_pre_norm,
866
+ )
867
+ )
868
+
869
+ self.resnets = nn.ModuleList(resnets)
870
+
871
+ if add_upsample:
872
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
873
+ else:
874
+ self.upsamplers = None
875
+
876
+ self.gradient_checkpointing = False
877
+ self.resolution_idx = resolution_idx
878
+
879
+ def forward(
880
+ self,
881
+ hidden_states: torch.FloatTensor,
882
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
883
+ temb: Optional[torch.FloatTensor] = None,
884
+ upsample_size: Optional[int] = None,
885
+ scale: float = 1.0,
886
+ ) -> torch.FloatTensor:
887
+ is_freeu_enabled = (
888
+ getattr(self, "s1", None)
889
+ and getattr(self, "s2", None)
890
+ and getattr(self, "b1", None)
891
+ and getattr(self, "b2", None)
892
+ )
893
+
894
+ ref_feature_list = []
895
+ for resnet in self.resnets:
896
+ # pop res hidden states
897
+ res_hidden_states = res_hidden_states_tuple[-1]
898
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
899
+
900
+ # FreeU: Only operate on the first two stages
901
+ if is_freeu_enabled:
902
+ hidden_states, res_hidden_states = apply_freeu(
903
+ self.resolution_idx,
904
+ hidden_states,
905
+ res_hidden_states,
906
+ s1=self.s1,
907
+ s2=self.s2,
908
+ b1=self.b1,
909
+ b2=self.b2,
910
+ )
911
+
912
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
913
+
914
+ if self.training and self.gradient_checkpointing:
915
+ if is_torch_version(">=", "1.11.0"):
916
+ hidden_states = torch.utils.checkpoint.checkpoint(
917
+ create_custom_forward(resnet),
918
+ hidden_states,
919
+ temb,
920
+ use_reentrant=False,
921
+ )
922
+ else:
923
+ hidden_states = torch.utils.checkpoint.checkpoint(
924
+ create_custom_forward(resnet), hidden_states, temb
925
+ )
926
+ else:
927
+ hidden_states = resnet(hidden_states, temb, scale=scale)
928
+
929
+ ref_feature_list.append(hidden_states)
930
+
931
+ if self.upsamplers is not None:
932
+ for upsampler in self.upsamplers:
933
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
934
+
935
+ return hidden_states, ref_feature_list
memo/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.loaders import UNet2DConditionLoadersMixin
8
+ from diffusers.models.activations import get_activation
9
+ from diffusers.models.attention_processor import (
10
+ ADDED_KV_ATTENTION_PROCESSORS,
11
+ CROSS_ATTENTION_PROCESSORS,
12
+ AttentionProcessor,
13
+ AttnAddedKVProcessor,
14
+ AttnProcessor,
15
+ )
16
+ from diffusers.models.embeddings import (
17
+ GaussianFourierProjection,
18
+ GLIGENTextBoundingboxProjection,
19
+ ImageHintTimeEmbedding,
20
+ ImageProjection,
21
+ ImageTimeEmbedding,
22
+ TextImageProjection,
23
+ TextImageTimeEmbedding,
24
+ TextTimeEmbedding,
25
+ TimestepEmbedding,
26
+ Timesteps,
27
+ )
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils import (
30
+ USE_PEFT_BACKEND,
31
+ BaseOutput,
32
+ deprecate,
33
+ logging,
34
+ scale_lora_layers,
35
+ unscale_lora_layers,
36
+ )
37
+ from torch import nn
38
+
39
+ from memo.models.unet_2d_blocks import (
40
+ UNetMidBlock2D,
41
+ UNetMidBlock2DCrossAttn,
42
+ get_down_block,
43
+ get_up_block,
44
+ )
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ @dataclass
51
+ class UNet2DConditionOutput(BaseOutput):
52
+ """
53
+ The output of [`UNet2DConditionModel`].
54
+
55
+ Args:
56
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
57
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
58
+ """
59
+
60
+ ref_features: list[torch.FloatTensor] = None
61
+
62
+
63
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
64
+ r"""
65
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
66
+ shaped output.
67
+
68
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
69
+ for all models (such as downloading or saving).
70
+
71
+ Parameters:
72
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
73
+ Height and width of input/output sample.
74
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
75
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
76
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
77
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
78
+ Whether to flip the sin to cos in the time embedding.
79
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
80
+ down_block_types (`Tuple[str]`, *optional*, defaults to
81
+ `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
82
+ The tuple of downsample blocks to use.
83
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
84
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
85
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
86
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
87
+ The tuple of upsample blocks to use.
88
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
89
+ Whether to include self-attention in the basic transformer blocks, see
90
+ [`~models.attention.BasicTransformerBlock`].
91
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
92
+ The tuple of output channels for each block.
93
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
94
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
95
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
96
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
97
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
98
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
99
+ If `None`, normalization and activation layers is skipped in post-processing.
100
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
101
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
102
+ The dimension of the cross attention features.
103
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
104
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
105
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
106
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
107
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
108
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
109
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
110
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
111
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
112
+ encoder_hid_dim (`int`, *optional*, defaults to None):
113
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
114
+ dimension to `cross_attention_dim`.
115
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
116
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
117
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
118
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
119
+ num_attention_heads (`int`, *optional*):
120
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
121
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
122
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
123
+ class_embed_type (`str`, *optional*, defaults to `None`):
124
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
125
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
126
+ addition_embed_type (`str`, *optional*, defaults to `None`):
127
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
128
+ "text". "text" will use the `TextTimeEmbedding` layer.
129
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
130
+ Dimension for the timestep embeddings.
131
+ num_class_embeds (`int`, *optional*, defaults to `None`):
132
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
133
+ class conditioning with `class_embed_type` equal to `None`.
134
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
135
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
136
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
137
+ An optional override for the dimension of the projected time embedding.
138
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
139
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
140
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
141
+ timestep_post_act (`str`, *optional*, defaults to `None`):
142
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
143
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
144
+ The dimension of `cond_proj` layer in the timestep embedding.
145
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
146
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
147
+ *optional*): The dimension of the `class_labels` input when
148
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
149
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
150
+ embeddings with the class embeddings.
151
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
152
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
153
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
154
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
155
+ otherwise.
156
+ """
157
+
158
+ _supports_gradient_checkpointing = True
159
+
160
+ @register_to_config
161
+ def __init__(
162
+ self,
163
+ sample_size: Optional[int] = None,
164
+ in_channels: int = 4,
165
+ _out_channels: int = 4,
166
+ _center_input_sample: bool = False,
167
+ flip_sin_to_cos: bool = True,
168
+ freq_shift: int = 0,
169
+ down_block_types: Tuple[str] = (
170
+ "CrossAttnDownBlock2D",
171
+ "CrossAttnDownBlock2D",
172
+ "CrossAttnDownBlock2D",
173
+ "DownBlock2D",
174
+ ),
175
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
176
+ up_block_types: Tuple[str] = (
177
+ "UpBlock2D",
178
+ "CrossAttnUpBlock2D",
179
+ "CrossAttnUpBlock2D",
180
+ "CrossAttnUpBlock2D",
181
+ ),
182
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
183
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
184
+ layers_per_block: Union[int, Tuple[int]] = 2,
185
+ downsample_padding: int = 1,
186
+ mid_block_scale_factor: float = 1,
187
+ dropout: float = 0.0,
188
+ act_fn: str = "silu",
189
+ norm_num_groups: Optional[int] = 32,
190
+ norm_eps: float = 1e-5,
191
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
192
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
193
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
194
+ encoder_hid_dim: Optional[int] = None,
195
+ encoder_hid_dim_type: Optional[str] = None,
196
+ attention_head_dim: Union[int, Tuple[int]] = 8,
197
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
198
+ dual_cross_attention: bool = False,
199
+ use_linear_projection: bool = False,
200
+ class_embed_type: Optional[str] = None,
201
+ addition_embed_type: Optional[str] = None,
202
+ addition_time_embed_dim: Optional[int] = None,
203
+ num_class_embeds: Optional[int] = None,
204
+ upcast_attention: bool = False,
205
+ resnet_time_scale_shift: str = "default",
206
+ time_embedding_type: str = "positional",
207
+ time_embedding_dim: Optional[int] = None,
208
+ time_embedding_act_fn: Optional[str] = None,
209
+ timestep_post_act: Optional[str] = None,
210
+ time_cond_proj_dim: Optional[int] = None,
211
+ conv_in_kernel: int = 3,
212
+ projection_class_embeddings_input_dim: Optional[int] = None,
213
+ attention_type: str = "default",
214
+ class_embeddings_concat: bool = False,
215
+ mid_block_only_cross_attention: Optional[bool] = None,
216
+ addition_embed_type_num_heads=64,
217
+ ):
218
+ super().__init__()
219
+
220
+ self.sample_size = sample_size
221
+
222
+ if num_attention_heads is not None:
223
+ raise ValueError(
224
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads`"
225
+ "because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131."
226
+ "Passing `num_attention_heads` will only be supported in diffusers v0.19."
227
+ )
228
+
229
+ # If `num_attention_heads` is not defined (which is the case for most models)
230
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
231
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
232
+ # when this library was created. The incorrect naming was only discovered much later in
233
+ # https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
234
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
235
+ # which is why we correct for the naming here.
236
+ num_attention_heads = num_attention_heads or attention_head_dim
237
+
238
+ # Check inputs
239
+ if len(down_block_types) != len(up_block_types):
240
+ raise ValueError(
241
+ "Must provide the same number of `down_block_types` as `up_block_types`."
242
+ f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
243
+ )
244
+
245
+ if len(block_out_channels) != len(down_block_types):
246
+ raise ValueError(
247
+ "Must provide the same number of `block_out_channels` as `down_block_types`."
248
+ f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
249
+ )
250
+
251
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
252
+ raise ValueError(
253
+ "Must provide the same number of `only_cross_attention` as `down_block_types`."
254
+ f"`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
255
+ )
256
+
257
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
258
+ raise ValueError(
259
+ "Must provide the same number of `num_attention_heads` as `down_block_types`."
260
+ f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
261
+ )
262
+
263
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
264
+ raise ValueError(
265
+ "Must provide the same number of `attention_head_dim` as `down_block_types`."
266
+ f"`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
267
+ )
268
+
269
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
270
+ raise ValueError(
271
+ "Must provide the same number of `cross_attention_dim` as `down_block_types`."
272
+ f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
273
+ )
274
+
275
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
276
+ raise ValueError(
277
+ "Must provide the same number of `layers_per_block` as `down_block_types`."
278
+ f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
279
+ )
280
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
281
+ for layer_number_per_block in transformer_layers_per_block:
282
+ if isinstance(layer_number_per_block, list):
283
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
284
+
285
+ # input
286
+ conv_in_padding = (conv_in_kernel - 1) // 2
287
+ self.conv_in = nn.Conv2d(
288
+ in_channels,
289
+ block_out_channels[0],
290
+ kernel_size=conv_in_kernel,
291
+ padding=conv_in_padding,
292
+ )
293
+
294
+ # time
295
+ if time_embedding_type == "fourier":
296
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
297
+ if time_embed_dim % 2 != 0:
298
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
299
+ self.time_proj = GaussianFourierProjection(
300
+ time_embed_dim // 2,
301
+ set_W_to_weight=False,
302
+ log=False,
303
+ flip_sin_to_cos=flip_sin_to_cos,
304
+ )
305
+ timestep_input_dim = time_embed_dim
306
+ elif time_embedding_type == "positional":
307
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
308
+
309
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
310
+ timestep_input_dim = block_out_channels[0]
311
+ else:
312
+ raise ValueError(
313
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
314
+ )
315
+
316
+ self.time_embedding = TimestepEmbedding(
317
+ timestep_input_dim,
318
+ time_embed_dim,
319
+ act_fn=act_fn,
320
+ post_act_fn=timestep_post_act,
321
+ cond_proj_dim=time_cond_proj_dim,
322
+ )
323
+
324
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
325
+ encoder_hid_dim_type = "text_proj"
326
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
327
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
328
+
329
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
330
+ raise ValueError(
331
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
332
+ )
333
+
334
+ if encoder_hid_dim_type == "text_proj":
335
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
336
+ elif encoder_hid_dim_type == "text_image_proj":
337
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
338
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
339
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
340
+ self.encoder_hid_proj = TextImageProjection(
341
+ text_embed_dim=encoder_hid_dim,
342
+ image_embed_dim=cross_attention_dim,
343
+ cross_attention_dim=cross_attention_dim,
344
+ )
345
+ elif encoder_hid_dim_type == "image_proj":
346
+ # Kandinsky 2.2
347
+ self.encoder_hid_proj = ImageProjection(
348
+ image_embed_dim=encoder_hid_dim,
349
+ cross_attention_dim=cross_attention_dim,
350
+ )
351
+ elif encoder_hid_dim_type is not None:
352
+ raise ValueError(
353
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
354
+ )
355
+ else:
356
+ self.encoder_hid_proj = None
357
+
358
+ # class embedding
359
+ if class_embed_type is None and num_class_embeds is not None:
360
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
361
+ elif class_embed_type == "timestep":
362
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
363
+ elif class_embed_type == "identity":
364
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
365
+ elif class_embed_type == "projection":
366
+ if projection_class_embeddings_input_dim is None:
367
+ raise ValueError(
368
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
369
+ )
370
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
371
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
372
+ # 2. it projects from an arbitrary input dimension.
373
+ #
374
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
375
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
376
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
377
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
378
+ elif class_embed_type == "simple_projection":
379
+ if projection_class_embeddings_input_dim is None:
380
+ raise ValueError(
381
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
382
+ )
383
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
384
+ else:
385
+ self.class_embedding = None
386
+
387
+ if addition_embed_type == "text":
388
+ if encoder_hid_dim is not None:
389
+ text_time_embedding_from_dim = encoder_hid_dim
390
+ else:
391
+ text_time_embedding_from_dim = cross_attention_dim
392
+
393
+ self.add_embedding = TextTimeEmbedding(
394
+ text_time_embedding_from_dim,
395
+ time_embed_dim,
396
+ num_heads=addition_embed_type_num_heads,
397
+ )
398
+ elif addition_embed_type == "text_image":
399
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
400
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
401
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
402
+ self.add_embedding = TextImageTimeEmbedding(
403
+ text_embed_dim=cross_attention_dim,
404
+ image_embed_dim=cross_attention_dim,
405
+ time_embed_dim=time_embed_dim,
406
+ )
407
+ elif addition_embed_type == "text_time":
408
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
409
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
410
+ elif addition_embed_type == "image":
411
+ # Kandinsky 2.2
412
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
413
+ elif addition_embed_type == "image_hint":
414
+ # Kandinsky 2.2 ControlNet
415
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
416
+ elif addition_embed_type is not None:
417
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
418
+
419
+ if time_embedding_act_fn is None:
420
+ self.time_embed_act = None
421
+ else:
422
+ self.time_embed_act = get_activation(time_embedding_act_fn)
423
+
424
+ self.down_blocks = nn.ModuleList([])
425
+ self.up_blocks = nn.ModuleList([])
426
+
427
+ if isinstance(only_cross_attention, bool):
428
+ if mid_block_only_cross_attention is None:
429
+ mid_block_only_cross_attention = only_cross_attention
430
+
431
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
432
+
433
+ if mid_block_only_cross_attention is None:
434
+ mid_block_only_cross_attention = False
435
+
436
+ if isinstance(num_attention_heads, int):
437
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
438
+
439
+ if isinstance(attention_head_dim, int):
440
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
441
+
442
+ if isinstance(cross_attention_dim, int):
443
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
444
+
445
+ if isinstance(layers_per_block, int):
446
+ layers_per_block = [layers_per_block] * len(down_block_types)
447
+
448
+ if isinstance(transformer_layers_per_block, int):
449
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
450
+
451
+ if class_embeddings_concat:
452
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
453
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
454
+ # regular time embeddings
455
+ blocks_time_embed_dim = time_embed_dim * 2
456
+ else:
457
+ blocks_time_embed_dim = time_embed_dim
458
+
459
+ # down
460
+ output_channel = block_out_channels[0]
461
+ for i, down_block_type in enumerate(down_block_types):
462
+ input_channel = output_channel
463
+ output_channel = block_out_channels[i]
464
+ is_final_block = i == len(block_out_channels) - 1
465
+
466
+ down_block = get_down_block(
467
+ down_block_type,
468
+ num_layers=layers_per_block[i],
469
+ transformer_layers_per_block=transformer_layers_per_block[i],
470
+ in_channels=input_channel,
471
+ out_channels=output_channel,
472
+ temb_channels=blocks_time_embed_dim,
473
+ add_downsample=not is_final_block,
474
+ resnet_eps=norm_eps,
475
+ resnet_act_fn=act_fn,
476
+ resnet_groups=norm_num_groups,
477
+ cross_attention_dim=cross_attention_dim[i],
478
+ num_attention_heads=num_attention_heads[i],
479
+ downsample_padding=downsample_padding,
480
+ dual_cross_attention=dual_cross_attention,
481
+ use_linear_projection=use_linear_projection,
482
+ only_cross_attention=only_cross_attention[i],
483
+ upcast_attention=upcast_attention,
484
+ resnet_time_scale_shift=resnet_time_scale_shift,
485
+ attention_type=attention_type,
486
+ attention_head_dim=(attention_head_dim[i] if attention_head_dim[i] is not None else output_channel),
487
+ dropout=dropout,
488
+ )
489
+ self.down_blocks.append(down_block)
490
+
491
+ # mid
492
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
493
+ self.mid_block = UNetMidBlock2DCrossAttn(
494
+ transformer_layers_per_block=transformer_layers_per_block[-1],
495
+ in_channels=block_out_channels[-1],
496
+ temb_channels=blocks_time_embed_dim,
497
+ dropout=dropout,
498
+ resnet_eps=norm_eps,
499
+ resnet_act_fn=act_fn,
500
+ output_scale_factor=mid_block_scale_factor,
501
+ resnet_time_scale_shift=resnet_time_scale_shift,
502
+ cross_attention_dim=cross_attention_dim[-1],
503
+ num_attention_heads=num_attention_heads[-1],
504
+ resnet_groups=norm_num_groups,
505
+ dual_cross_attention=dual_cross_attention,
506
+ use_linear_projection=use_linear_projection,
507
+ upcast_attention=upcast_attention,
508
+ attention_type=attention_type,
509
+ )
510
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
511
+ raise NotImplementedError(f"Unsupport mid_block_type: {mid_block_type}")
512
+ elif mid_block_type == "UNetMidBlock2D":
513
+ self.mid_block = UNetMidBlock2D(
514
+ in_channels=block_out_channels[-1],
515
+ temb_channels=blocks_time_embed_dim,
516
+ dropout=dropout,
517
+ num_layers=0,
518
+ resnet_eps=norm_eps,
519
+ resnet_act_fn=act_fn,
520
+ output_scale_factor=mid_block_scale_factor,
521
+ resnet_groups=norm_num_groups,
522
+ resnet_time_scale_shift=resnet_time_scale_shift,
523
+ add_attention=False,
524
+ )
525
+ elif mid_block_type is None:
526
+ self.mid_block = None
527
+ else:
528
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
529
+
530
+ # count how many layers upsample the images
531
+ self.num_upsamplers = 0
532
+
533
+ # up
534
+ reversed_block_out_channels = list(reversed(block_out_channels))
535
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
536
+ reversed_layers_per_block = list(reversed(layers_per_block))
537
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
538
+ reversed_transformer_layers_per_block = (
539
+ list(reversed(transformer_layers_per_block))
540
+ if reverse_transformer_layers_per_block is None
541
+ else reverse_transformer_layers_per_block
542
+ )
543
+ only_cross_attention = list(reversed(only_cross_attention))
544
+
545
+ output_channel = reversed_block_out_channels[0]
546
+ for i, up_block_type in enumerate(up_block_types):
547
+ is_final_block = i == len(block_out_channels) - 1
548
+
549
+ prev_output_channel = output_channel
550
+ output_channel = reversed_block_out_channels[i]
551
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
552
+
553
+ # add upsample block for all BUT final layer
554
+ if not is_final_block:
555
+ add_upsample = True
556
+ self.num_upsamplers += 1
557
+ else:
558
+ add_upsample = False
559
+
560
+ up_block = get_up_block(
561
+ up_block_type,
562
+ num_layers=reversed_layers_per_block[i] + 1,
563
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
564
+ in_channels=input_channel,
565
+ out_channels=output_channel,
566
+ prev_output_channel=prev_output_channel,
567
+ temb_channels=blocks_time_embed_dim,
568
+ add_upsample=add_upsample,
569
+ resnet_eps=norm_eps,
570
+ resnet_act_fn=act_fn,
571
+ resolution_idx=i,
572
+ resnet_groups=norm_num_groups,
573
+ cross_attention_dim=reversed_cross_attention_dim[i],
574
+ num_attention_heads=reversed_num_attention_heads[i],
575
+ dual_cross_attention=dual_cross_attention,
576
+ use_linear_projection=use_linear_projection,
577
+ only_cross_attention=only_cross_attention[i],
578
+ upcast_attention=upcast_attention,
579
+ resnet_time_scale_shift=resnet_time_scale_shift,
580
+ attention_type=attention_type,
581
+ attention_head_dim=(attention_head_dim[i] if attention_head_dim[i] is not None else output_channel),
582
+ dropout=dropout,
583
+ is_final_block=is_final_block,
584
+ )
585
+ self.up_blocks.append(up_block)
586
+ prev_output_channel = output_channel
587
+
588
+ # out
589
+ if norm_num_groups is not None:
590
+ self.conv_norm_out = nn.GroupNorm(
591
+ num_channels=block_out_channels[0],
592
+ num_groups=norm_num_groups,
593
+ eps=norm_eps,
594
+ )
595
+
596
+ self.conv_act = get_activation(act_fn)
597
+
598
+ else:
599
+ self.conv_norm_out = None
600
+ self.conv_act = None
601
+ self.conv_norm_out = None
602
+
603
+ if attention_type in ["gated", "gated-text-image"]:
604
+ positive_len = 768
605
+ if isinstance(cross_attention_dim, int):
606
+ positive_len = cross_attention_dim
607
+ elif isinstance(cross_attention_dim, (tuple, list)):
608
+ positive_len = cross_attention_dim[0]
609
+
610
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
611
+ self.position_net = GLIGENTextBoundingboxProjection(
612
+ positive_len=positive_len,
613
+ out_dim=cross_attention_dim,
614
+ feature_type=feature_type,
615
+ )
616
+
617
+ @property
618
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
619
+ r"""
620
+ Returns:
621
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
622
+ indexed by its weight name.
623
+ """
624
+ # set recursively
625
+ processors = {}
626
+
627
+ def fn_recursive_add_processors(
628
+ name: str,
629
+ module: torch.nn.Module,
630
+ processors: Dict[str, AttentionProcessor],
631
+ ):
632
+ if hasattr(module, "get_processor"):
633
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
634
+
635
+ for sub_name, child in module.named_children():
636
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
637
+
638
+ return processors
639
+
640
+ for name, module in self.named_children():
641
+ fn_recursive_add_processors(name, module, processors)
642
+
643
+ return processors
644
+
645
+ def set_attn_processor(
646
+ self,
647
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
648
+ _remove_lora=False,
649
+ ):
650
+ r"""
651
+ Sets the attention processor to use to compute attention.
652
+
653
+ Parameters:
654
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
655
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
656
+ for **all** `Attention` layers.
657
+
658
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
659
+ processor. This is strongly recommended when setting trainable attention processors.
660
+
661
+ """
662
+ count = len(self.attn_processors.keys())
663
+
664
+ if isinstance(processor, dict) and len(processor) != count:
665
+ raise ValueError(
666
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
667
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
668
+ )
669
+
670
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
671
+ if hasattr(module, "set_processor"):
672
+ if not isinstance(processor, dict):
673
+ module.set_processor(processor, _remove_lora=_remove_lora)
674
+ else:
675
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
676
+
677
+ for sub_name, child in module.named_children():
678
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
679
+
680
+ for name, module in self.named_children():
681
+ fn_recursive_attn_processor(name, module, processor)
682
+
683
+ def set_default_attn_processor(self):
684
+ """
685
+ Disables custom attention processors and sets the default attention implementation.
686
+ """
687
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
688
+ processor = AttnAddedKVProcessor()
689
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
690
+ processor = AttnProcessor()
691
+ else:
692
+ raise ValueError(
693
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
694
+ )
695
+
696
+ self.set_attn_processor(processor, _remove_lora=True)
697
+
698
+ def set_attention_slice(self, slice_size):
699
+ r"""
700
+ Enable sliced attention computation.
701
+
702
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
703
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
704
+
705
+ Args:
706
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
707
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
708
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
709
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
710
+ must be a multiple of `slice_size`.
711
+ """
712
+ sliceable_head_dims = []
713
+
714
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
715
+ if hasattr(module, "set_attention_slice"):
716
+ sliceable_head_dims.append(module.sliceable_head_dim)
717
+
718
+ for child in module.children():
719
+ fn_recursive_retrieve_sliceable_dims(child)
720
+
721
+ # retrieve number of attention layers
722
+ for module in self.children():
723
+ fn_recursive_retrieve_sliceable_dims(module)
724
+
725
+ num_sliceable_layers = len(sliceable_head_dims)
726
+
727
+ if slice_size == "auto":
728
+ # half the attention head size is usually a good trade-off between
729
+ # speed and memory
730
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
731
+ elif slice_size == "max":
732
+ # make smallest slice possible
733
+ slice_size = num_sliceable_layers * [1]
734
+
735
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
736
+
737
+ if len(slice_size) != len(sliceable_head_dims):
738
+ raise ValueError(
739
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
740
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
741
+ )
742
+
743
+ for i, size in enumerate(slice_size):
744
+ dim = sliceable_head_dims[i]
745
+ if size is not None and size > dim:
746
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
747
+
748
+ # Recursively walk through all the children.
749
+ # Any children which exposes the set_attention_slice method
750
+ # gets the message
751
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
752
+ if hasattr(module, "set_attention_slice"):
753
+ module.set_attention_slice(slice_size.pop())
754
+
755
+ for child in module.children():
756
+ fn_recursive_set_attention_slice(child, slice_size)
757
+
758
+ reversed_slice_size = list(reversed(slice_size))
759
+ for module in self.children():
760
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
761
+
762
+ def _set_gradient_checkpointing(self, module, value=False):
763
+ if hasattr(module, "gradient_checkpointing"):
764
+ module.gradient_checkpointing = value
765
+
766
+ def enable_freeu(self, s1, s2, b1, b2):
767
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
768
+
769
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
770
+
771
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
772
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
773
+
774
+ Args:
775
+ s1 (`float`):
776
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
777
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
778
+ s2 (`float`):
779
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
780
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
781
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
782
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
783
+ """
784
+ for _, upsample_block in enumerate(self.up_blocks):
785
+ setattr(upsample_block, "s1", s1)
786
+ setattr(upsample_block, "s2", s2)
787
+ setattr(upsample_block, "b1", b1)
788
+ setattr(upsample_block, "b2", b2)
789
+
790
+ def disable_freeu(self):
791
+ """Disables the FreeU mechanism."""
792
+ freeu_keys = {"s1", "s2", "b1", "b2"}
793
+ for _, upsample_block in enumerate(self.up_blocks):
794
+ for k in freeu_keys:
795
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
796
+ setattr(upsample_block, k, None)
797
+
798
+ def forward(
799
+ self,
800
+ sample: torch.FloatTensor,
801
+ timestep: Union[torch.Tensor, float, int],
802
+ encoder_hidden_states: torch.Tensor,
803
+ cond_tensor: torch.FloatTensor = None,
804
+ class_labels: Optional[torch.Tensor] = None,
805
+ timestep_cond: Optional[torch.Tensor] = None,
806
+ attention_mask: Optional[torch.Tensor] = None,
807
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
808
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
809
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
810
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
811
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
812
+ encoder_attention_mask: Optional[torch.Tensor] = None,
813
+ return_dict: bool = True,
814
+ post_process: bool = False,
815
+ ) -> Union[UNet2DConditionOutput, Tuple]:
816
+ r"""
817
+ The [`UNet2DConditionModel`] forward method.
818
+
819
+ Args:
820
+ sample (`torch.FloatTensor`):
821
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
822
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
823
+ encoder_hidden_states (`torch.FloatTensor`):
824
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
825
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
826
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
827
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
828
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
829
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
830
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
831
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
832
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
833
+ negative values to the attention scores corresponding to "discard" tokens.
834
+ cross_attention_kwargs (`dict`, *optional*):
835
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
836
+ `self.processor` in
837
+ [diffusers.models.attention_processor]
838
+ (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
839
+ added_cond_kwargs: (`dict`, *optional*):
840
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
841
+ are passed along to the UNet blocks.
842
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
843
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
844
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
845
+ A tensor that if specified is added to the residual of the middle unet block.
846
+ encoder_attention_mask (`torch.Tensor`):
847
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
848
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
849
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
850
+ return_dict (`bool`, *optional*, defaults to `True`):
851
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
852
+ tuple.
853
+ cross_attention_kwargs (`dict`, *optional*):
854
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
855
+ added_cond_kwargs: (`dict`, *optional*):
856
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
857
+ are passed along to the UNet blocks.
858
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
859
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
860
+ example from ControlNet side model(s)
861
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
862
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
863
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
864
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
865
+
866
+ Returns:
867
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
868
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
869
+ a `tuple` is returned where the first element is the sample tensor.
870
+ """
871
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
872
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
873
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
874
+ # on the fly if necessary.
875
+ default_overall_up_factor = 2**self.num_upsamplers
876
+
877
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
878
+ forward_upsample_size = False
879
+ upsample_size = None
880
+
881
+ for dim in sample.shape[-2:]:
882
+ if dim % default_overall_up_factor != 0:
883
+ # Forward upsample size to force interpolation output size.
884
+ forward_upsample_size = True
885
+ break
886
+
887
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
888
+ # expects mask of shape:
889
+ # [batch, key_tokens]
890
+ # adds singleton query_tokens dimension:
891
+ # [batch, 1, key_tokens]
892
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
893
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
894
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
895
+ if attention_mask is not None:
896
+ # assume that mask is expressed as:
897
+ # (1 = keep, 0 = discard)
898
+ # convert mask into a bias that can be added to attention scores:
899
+ # (keep = +0, discard = -10000.0)
900
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
901
+ attention_mask = attention_mask.unsqueeze(1)
902
+
903
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
904
+ if encoder_attention_mask is not None:
905
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
906
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
907
+
908
+ # 0. center input if necessary
909
+ if self.config.center_input_sample:
910
+ sample = 2 * sample - 1.0
911
+
912
+ # 1. time
913
+ timesteps = timestep
914
+ if not torch.is_tensor(timesteps):
915
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
916
+ # This would be a good case for the `match` statement (Python 3.10+)
917
+ is_mps = sample.device.type == "mps"
918
+ if isinstance(timestep, float):
919
+ dtype = torch.float32 if is_mps else torch.float64
920
+ else:
921
+ dtype = torch.int32 if is_mps else torch.int64
922
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
923
+ elif len(timesteps.shape) == 0:
924
+ timesteps = timesteps[None].to(sample.device)
925
+
926
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
927
+ timesteps = timesteps.expand(sample.shape[0])
928
+
929
+ t_emb = self.time_proj(timesteps)
930
+
931
+ # `Timesteps` does not contain any weights and will always return f32 tensors
932
+ # but time_embedding might actually be running in fp16. so we need to cast here.
933
+ # there might be better ways to encapsulate this.
934
+ t_emb = t_emb.to(dtype=sample.dtype)
935
+
936
+ emb = self.time_embedding(t_emb, timestep_cond)
937
+ aug_emb = None
938
+
939
+ if self.class_embedding is not None:
940
+ if class_labels is None:
941
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
942
+
943
+ if self.config.class_embed_type == "timestep":
944
+ class_labels = self.time_proj(class_labels)
945
+
946
+ # `Timesteps` does not contain any weights and will always return f32 tensors
947
+ # there might be better ways to encapsulate this.
948
+ class_labels = class_labels.to(dtype=sample.dtype)
949
+
950
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
951
+
952
+ if self.config.class_embeddings_concat:
953
+ emb = torch.cat([emb, class_emb], dim=-1)
954
+ else:
955
+ emb = emb + class_emb
956
+
957
+ if self.config.addition_embed_type == "text":
958
+ aug_emb = self.add_embedding(encoder_hidden_states)
959
+ elif self.config.addition_embed_type == "text_image":
960
+ # Kandinsky 2.1 - style
961
+ if "image_embeds" not in added_cond_kwargs:
962
+ raise ValueError(
963
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image'"
964
+ "which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
965
+ )
966
+
967
+ image_embs = added_cond_kwargs.get("image_embeds")
968
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
969
+ aug_emb = self.add_embedding(text_embs, image_embs)
970
+ elif self.config.addition_embed_type == "text_time":
971
+ # SDXL - style
972
+ if "text_embeds" not in added_cond_kwargs:
973
+ raise ValueError(
974
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time'"
975
+ "which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
976
+ )
977
+ text_embeds = added_cond_kwargs.get("text_embeds")
978
+ if "time_ids" not in added_cond_kwargs:
979
+ raise ValueError(
980
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time'"
981
+ "which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
982
+ )
983
+ time_ids = added_cond_kwargs.get("time_ids")
984
+ time_embeds = self.add_time_proj(time_ids.flatten())
985
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
986
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
987
+ add_embeds = add_embeds.to(emb.dtype)
988
+ aug_emb = self.add_embedding(add_embeds)
989
+ elif self.config.addition_embed_type == "image":
990
+ # Kandinsky 2.2 - style
991
+ if "image_embeds" not in added_cond_kwargs:
992
+ raise ValueError(
993
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image'"
994
+ "which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
995
+ )
996
+ image_embs = added_cond_kwargs.get("image_embeds")
997
+ aug_emb = self.add_embedding(image_embs)
998
+ elif self.config.addition_embed_type == "image_hint":
999
+ # Kandinsky 2.2 - style
1000
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1001
+ raise ValueError(
1002
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint'"
1003
+ "which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1004
+ )
1005
+ image_embs = added_cond_kwargs.get("image_embeds")
1006
+ hint = added_cond_kwargs.get("hint")
1007
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1008
+ sample = torch.cat([sample, hint], dim=1)
1009
+
1010
+ emb = emb + aug_emb if aug_emb is not None else emb
1011
+
1012
+ if self.time_embed_act is not None:
1013
+ emb = self.time_embed_act(emb)
1014
+
1015
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1016
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1017
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1018
+ # Kadinsky 2.1 - style
1019
+ if "image_embeds" not in added_cond_kwargs:
1020
+ raise ValueError(
1021
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj'"
1022
+ "which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1023
+ )
1024
+
1025
+ image_embeds = added_cond_kwargs.get("image_embeds")
1026
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1027
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1028
+ # Kandinsky 2.2 - style
1029
+ if "image_embeds" not in added_cond_kwargs:
1030
+ raise ValueError(
1031
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj'"
1032
+ "which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1033
+ )
1034
+ image_embeds = added_cond_kwargs.get("image_embeds")
1035
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1036
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1037
+ if "image_embeds" not in added_cond_kwargs:
1038
+ raise ValueError(
1039
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj'"
1040
+ "which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1041
+ )
1042
+ image_embeds = added_cond_kwargs.get("image_embeds")
1043
+ image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
1044
+ encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
1045
+
1046
+ # 2. pre-process
1047
+ sample = self.conv_in(sample)
1048
+ if cond_tensor is not None:
1049
+ sample = sample + cond_tensor
1050
+
1051
+ # 2.5 GLIGEN position net
1052
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1053
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1054
+ gligen_args = cross_attention_kwargs.pop("gligen")
1055
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1056
+
1057
+ # 3. down
1058
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1059
+ if USE_PEFT_BACKEND:
1060
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1061
+ scale_lora_layers(self, lora_scale)
1062
+
1063
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1064
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1065
+ is_adapter = down_intrablock_additional_residuals is not None
1066
+ # maintain backward compatibility for legacy usage, where
1067
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1068
+ # but can only use one or the other
1069
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1070
+ deprecate(
1071
+ "T2I should not use down_block_additional_residuals",
1072
+ "1.3.0",
1073
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1074
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1075
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1076
+ standard_warn=False,
1077
+ )
1078
+ down_intrablock_additional_residuals = down_block_additional_residuals
1079
+ is_adapter = True
1080
+
1081
+ ref_features = {"down": [], "mid": [], "up": []}
1082
+ down_block_res_samples = (sample,)
1083
+ for downsample_block in self.down_blocks:
1084
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1085
+ # For t2i-adapter CrossAttnDownBlock2D
1086
+ additional_residuals = {}
1087
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1088
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1089
+
1090
+ sample, res_samples, ref_feature_list = downsample_block(
1091
+ hidden_states=sample,
1092
+ temb=emb,
1093
+ encoder_hidden_states=encoder_hidden_states,
1094
+ attention_mask=attention_mask,
1095
+ cross_attention_kwargs=cross_attention_kwargs,
1096
+ encoder_attention_mask=encoder_attention_mask,
1097
+ **additional_residuals,
1098
+ )
1099
+ else:
1100
+ sample, res_samples, ref_feature_list = downsample_block(
1101
+ hidden_states=sample, temb=emb, scale=lora_scale
1102
+ )
1103
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1104
+ sample += down_intrablock_additional_residuals.pop(0)
1105
+ ref_features["down"].append(ref_feature_list)
1106
+ down_block_res_samples += res_samples
1107
+
1108
+ if is_controlnet:
1109
+ new_down_block_res_samples = ()
1110
+
1111
+ for down_block_res_sample, down_block_additional_residual in zip(
1112
+ down_block_res_samples, down_block_additional_residuals
1113
+ ):
1114
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1115
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1116
+
1117
+ down_block_res_samples = new_down_block_res_samples
1118
+
1119
+ # 4. mid
1120
+ if self.mid_block is not None:
1121
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1122
+ sample, ref_feature_list = self.mid_block(
1123
+ sample,
1124
+ emb,
1125
+ encoder_hidden_states=encoder_hidden_states,
1126
+ attention_mask=attention_mask,
1127
+ cross_attention_kwargs=cross_attention_kwargs,
1128
+ encoder_attention_mask=encoder_attention_mask,
1129
+ )
1130
+ ref_features["mid"].append(ref_feature_list)
1131
+ else:
1132
+ sample = self.mid_block(sample, emb)
1133
+
1134
+ # To support T2I-Adapter-XL
1135
+ if (
1136
+ is_adapter
1137
+ and len(down_intrablock_additional_residuals) > 0
1138
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1139
+ ):
1140
+ sample += down_intrablock_additional_residuals.pop(0)
1141
+
1142
+ if is_controlnet:
1143
+ sample = sample + mid_block_additional_residual
1144
+
1145
+ # 5. up
1146
+ for i, upsample_block in enumerate(self.up_blocks):
1147
+ is_final_block = i == len(self.up_blocks) - 1
1148
+
1149
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1150
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1151
+
1152
+ # if we have not reached the final block and need to forward the
1153
+ # upsample size, we do it here
1154
+ if not is_final_block and forward_upsample_size:
1155
+ upsample_size = down_block_res_samples[-1].shape[2:]
1156
+
1157
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1158
+ sample, ref_feature_list = upsample_block(
1159
+ hidden_states=sample,
1160
+ temb=emb,
1161
+ res_hidden_states_tuple=res_samples,
1162
+ encoder_hidden_states=encoder_hidden_states,
1163
+ cross_attention_kwargs=cross_attention_kwargs,
1164
+ upsample_size=upsample_size,
1165
+ attention_mask=attention_mask,
1166
+ encoder_attention_mask=encoder_attention_mask,
1167
+ )
1168
+ else:
1169
+ sample, ref_feature_list = upsample_block(
1170
+ hidden_states=sample,
1171
+ temb=emb,
1172
+ res_hidden_states_tuple=res_samples,
1173
+ upsample_size=upsample_size,
1174
+ scale=lora_scale,
1175
+ )
1176
+ ref_features["up"].append(ref_feature_list)
1177
+
1178
+ if USE_PEFT_BACKEND:
1179
+ # remove `lora_scale` from each PEFT layer
1180
+ unscale_lora_layers(self, lora_scale)
1181
+
1182
+ if not return_dict:
1183
+ return ref_features
1184
+
1185
+ return UNet2DConditionOutput(ref_features=ref_features)
memo/models/unet_3d.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.utils.checkpoint
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.models.attention_processor import AttentionProcessor
9
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.utils import BaseOutput, logging
12
+
13
+ from memo.models.resnet import InflatedConv3d, InflatedGroupNorm
14
+ from memo.models.unet_3d_blocks import (
15
+ UNetMidBlock3DCrossAttn,
16
+ get_down_block,
17
+ get_up_block,
18
+ )
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class UNet3DConditionOutput(BaseOutput):
26
+ sample: torch.FloatTensor
27
+
28
+
29
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
30
+ _supports_gradient_checkpointing = True
31
+
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ sample_size: Optional[int] = None,
36
+ in_channels: int = 8,
37
+ out_channels: int = 8,
38
+ flip_sin_to_cos: bool = True,
39
+ freq_shift: int = 0,
40
+ down_block_types: Tuple[str] = (
41
+ "CrossAttnDownBlock3D",
42
+ "CrossAttnDownBlock3D",
43
+ "CrossAttnDownBlock3D",
44
+ "DownBlock3D",
45
+ ),
46
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
47
+ up_block_types: Tuple[str] = (
48
+ "UpBlock3D",
49
+ "CrossAttnUpBlock3D",
50
+ "CrossAttnUpBlock3D",
51
+ "CrossAttnUpBlock3D",
52
+ ),
53
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
54
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
55
+ layers_per_block: int = 2,
56
+ downsample_padding: int = 1,
57
+ mid_block_scale_factor: float = 1,
58
+ act_fn: str = "silu",
59
+ norm_num_groups: int = 32,
60
+ norm_eps: float = 1e-5,
61
+ cross_attention_dim: int = 1280,
62
+ attention_head_dim: Union[int, Tuple[int]] = 8,
63
+ dual_cross_attention: bool = False,
64
+ use_linear_projection: bool = False,
65
+ class_embed_type: Optional[str] = None,
66
+ num_class_embeds: Optional[int] = None,
67
+ upcast_attention: bool = False,
68
+ resnet_time_scale_shift: str = "default",
69
+ use_inflated_groupnorm=False,
70
+ # Additional
71
+ motion_module_resolutions=(1, 2, 4, 8),
72
+ motion_module_kwargs=None,
73
+ unet_use_cross_frame_attention=None,
74
+ unet_use_temporal_attention=None,
75
+ # audio
76
+ audio_attention_dim=768,
77
+ emo_drop_rate=0.3,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.sample_size = sample_size
82
+ time_embed_dim = block_out_channels[0] * 4
83
+
84
+ # input
85
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
86
+
87
+ # time
88
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
89
+ timestep_input_dim = block_out_channels[0]
90
+
91
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
92
+
93
+ # class embedding
94
+ if class_embed_type is None and num_class_embeds is not None:
95
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
96
+ elif class_embed_type == "timestep":
97
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
98
+ elif class_embed_type == "identity":
99
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
100
+ else:
101
+ self.class_embedding = None
102
+
103
+ self.down_blocks = nn.ModuleList([])
104
+ self.mid_block = None
105
+ self.up_blocks = nn.ModuleList([])
106
+
107
+ if isinstance(only_cross_attention, bool):
108
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
109
+
110
+ if isinstance(attention_head_dim, int):
111
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
112
+
113
+ # down
114
+ output_channel = block_out_channels[0]
115
+ for i, down_block_type in enumerate(down_block_types):
116
+ res = 2**i
117
+ input_channel = output_channel
118
+ output_channel = block_out_channels[i]
119
+ is_final_block = i == len(block_out_channels) - 1
120
+
121
+ down_block = get_down_block(
122
+ down_block_type,
123
+ num_layers=layers_per_block,
124
+ in_channels=input_channel,
125
+ out_channels=output_channel,
126
+ temb_channels=time_embed_dim,
127
+ add_downsample=not is_final_block,
128
+ resnet_eps=norm_eps,
129
+ resnet_act_fn=act_fn,
130
+ resnet_groups=norm_num_groups,
131
+ cross_attention_dim=cross_attention_dim,
132
+ attn_num_head_channels=attention_head_dim[i],
133
+ downsample_padding=downsample_padding,
134
+ dual_cross_attention=dual_cross_attention,
135
+ use_linear_projection=use_linear_projection,
136
+ only_cross_attention=only_cross_attention[i],
137
+ upcast_attention=upcast_attention,
138
+ resnet_time_scale_shift=resnet_time_scale_shift,
139
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
140
+ unet_use_temporal_attention=unet_use_temporal_attention,
141
+ use_inflated_groupnorm=use_inflated_groupnorm,
142
+ use_motion_module=res in motion_module_resolutions,
143
+ motion_module_kwargs=motion_module_kwargs,
144
+ audio_attention_dim=audio_attention_dim,
145
+ depth=i,
146
+ emo_drop_rate=emo_drop_rate,
147
+ )
148
+ self.down_blocks.append(down_block)
149
+
150
+ # mid
151
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
152
+ self.mid_block = UNetMidBlock3DCrossAttn(
153
+ in_channels=block_out_channels[-1],
154
+ temb_channels=time_embed_dim,
155
+ resnet_eps=norm_eps,
156
+ resnet_act_fn=act_fn,
157
+ output_scale_factor=mid_block_scale_factor,
158
+ resnet_time_scale_shift=resnet_time_scale_shift,
159
+ cross_attention_dim=cross_attention_dim,
160
+ attn_num_head_channels=attention_head_dim[-1],
161
+ resnet_groups=norm_num_groups,
162
+ dual_cross_attention=dual_cross_attention,
163
+ use_linear_projection=use_linear_projection,
164
+ upcast_attention=upcast_attention,
165
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
166
+ unet_use_temporal_attention=unet_use_temporal_attention,
167
+ use_inflated_groupnorm=use_inflated_groupnorm,
168
+ motion_module_kwargs=motion_module_kwargs,
169
+ audio_attention_dim=audio_attention_dim,
170
+ depth=3,
171
+ emo_drop_rate=emo_drop_rate,
172
+ )
173
+ else:
174
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
175
+
176
+ # count how many layers upsample the videos
177
+ self.num_upsamplers = 0
178
+
179
+ # up
180
+ reversed_block_out_channels = list(reversed(block_out_channels))
181
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
182
+ only_cross_attention = list(reversed(only_cross_attention))
183
+ output_channel = reversed_block_out_channels[0]
184
+ for i, up_block_type in enumerate(up_block_types):
185
+ res = 2 ** (3 - i)
186
+ is_final_block = i == len(block_out_channels) - 1
187
+
188
+ prev_output_channel = output_channel
189
+ output_channel = reversed_block_out_channels[i]
190
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
191
+
192
+ # add upsample block for all BUT final layer
193
+ if not is_final_block:
194
+ add_upsample = True
195
+ self.num_upsamplers += 1
196
+ else:
197
+ add_upsample = False
198
+
199
+ up_block = get_up_block(
200
+ up_block_type,
201
+ num_layers=layers_per_block + 1,
202
+ in_channels=input_channel,
203
+ out_channels=output_channel,
204
+ prev_output_channel=prev_output_channel,
205
+ temb_channels=time_embed_dim,
206
+ add_upsample=add_upsample,
207
+ resnet_eps=norm_eps,
208
+ resnet_act_fn=act_fn,
209
+ resnet_groups=norm_num_groups,
210
+ cross_attention_dim=cross_attention_dim,
211
+ attn_num_head_channels=reversed_attention_head_dim[i],
212
+ dual_cross_attention=dual_cross_attention,
213
+ use_linear_projection=use_linear_projection,
214
+ only_cross_attention=only_cross_attention[i],
215
+ upcast_attention=upcast_attention,
216
+ resnet_time_scale_shift=resnet_time_scale_shift,
217
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
218
+ unet_use_temporal_attention=unet_use_temporal_attention,
219
+ use_inflated_groupnorm=use_inflated_groupnorm,
220
+ use_motion_module=res in motion_module_resolutions,
221
+ motion_module_kwargs=motion_module_kwargs,
222
+ audio_attention_dim=audio_attention_dim,
223
+ depth=3 - i,
224
+ emo_drop_rate=emo_drop_rate,
225
+ is_final_block=is_final_block,
226
+ )
227
+ self.up_blocks.append(up_block)
228
+ prev_output_channel = output_channel
229
+
230
+ # out
231
+ if use_inflated_groupnorm:
232
+ self.conv_norm_out = InflatedGroupNorm(
233
+ num_channels=block_out_channels[0],
234
+ num_groups=norm_num_groups,
235
+ eps=norm_eps,
236
+ )
237
+ else:
238
+ self.conv_norm_out = nn.GroupNorm(
239
+ num_channels=block_out_channels[0],
240
+ num_groups=norm_num_groups,
241
+ eps=norm_eps,
242
+ )
243
+ self.conv_act = nn.SiLU()
244
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
245
+
246
+ @property
247
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
248
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
249
+ r"""
250
+ Returns:
251
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
252
+ indexed by its weight name.
253
+ """
254
+ # set recursively
255
+ processors = {}
256
+
257
+ def fn_recursive_add_processors(
258
+ name: str,
259
+ module: torch.nn.Module,
260
+ processors: Dict[str, AttentionProcessor],
261
+ ):
262
+ if hasattr(module, "set_processor"):
263
+ processors[f"{name}.processor"] = module.processor
264
+
265
+ for sub_name, child in module.named_children():
266
+ if "temporal_transformer" not in sub_name:
267
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
268
+
269
+ return processors
270
+
271
+ for name, module in self.named_children():
272
+ if "temporal_transformer" not in name:
273
+ fn_recursive_add_processors(name, module, processors)
274
+
275
+ return processors
276
+
277
+ def set_attention_slice(self, slice_size):
278
+ r"""
279
+ Enable sliced attention computation.
280
+
281
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
282
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
283
+
284
+ Args:
285
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
286
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
287
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
288
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
289
+ must be a multiple of `slice_size`.
290
+ """
291
+ sliceable_head_dims = []
292
+
293
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
294
+ if hasattr(module, "set_attention_slice"):
295
+ sliceable_head_dims.append(module.sliceable_head_dim)
296
+
297
+ for child in module.children():
298
+ fn_recursive_retrieve_slicable_dims(child)
299
+
300
+ # retrieve number of attention layers
301
+ for module in self.children():
302
+ fn_recursive_retrieve_slicable_dims(module)
303
+
304
+ num_slicable_layers = len(sliceable_head_dims)
305
+
306
+ if slice_size == "auto":
307
+ # half the attention head size is usually a good trade-off between
308
+ # speed and memory
309
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
310
+ elif slice_size == "max":
311
+ # make smallest slice possible
312
+ slice_size = num_slicable_layers * [1]
313
+
314
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
315
+
316
+ if len(slice_size) != len(sliceable_head_dims):
317
+ raise ValueError(
318
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
319
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
320
+ )
321
+
322
+ for i, size in enumerate(slice_size):
323
+ dim = sliceable_head_dims[i]
324
+ if size is not None and size > dim:
325
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
326
+
327
+ # Recursively walk through all the children.
328
+ # Any children which exposes the set_attention_slice method
329
+ # gets the message
330
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
331
+ if hasattr(module, "set_attention_slice"):
332
+ module.set_attention_slice(slice_size.pop())
333
+
334
+ for child in module.children():
335
+ fn_recursive_set_attention_slice(child, slice_size)
336
+
337
+ reversed_slice_size = list(reversed(slice_size))
338
+ for module in self.children():
339
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
340
+
341
+ def _set_gradient_checkpointing(self, module, value=False):
342
+ if hasattr(module, "gradient_checkpointing"):
343
+ module.gradient_checkpointing = value
344
+
345
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
346
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
347
+ r"""
348
+ Sets the attention processor to use to compute attention.
349
+
350
+ Parameters:
351
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
352
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
353
+ for **all** `Attention` layers.
354
+
355
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
356
+ processor. This is strongly recommended when setting trainable attention processors.
357
+
358
+ """
359
+ count = len(self.attn_processors.keys())
360
+
361
+ if isinstance(processor, dict) and len(processor) != count:
362
+ raise ValueError(
363
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
364
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
365
+ )
366
+
367
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
368
+ if hasattr(module, "set_processor"):
369
+ if not isinstance(processor, dict):
370
+ module.set_processor(processor)
371
+ else:
372
+ module.set_processor(processor.pop(f"{name}.processor"))
373
+
374
+ for sub_name, child in module.named_children():
375
+ if "temporal_transformer" not in sub_name:
376
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
377
+
378
+ for name, module in self.named_children():
379
+ if "temporal_transformer" not in name:
380
+ fn_recursive_attn_processor(name, module, processor)
381
+
382
+ def forward(
383
+ self,
384
+ sample: torch.FloatTensor,
385
+ ref_features: dict,
386
+ timestep: Union[torch.Tensor, float, int, list],
387
+ encoder_hidden_states: torch.Tensor,
388
+ audio_embedding: Optional[torch.Tensor] = None,
389
+ audio_emotion: Optional[torch.Tensor] = None,
390
+ class_labels: Optional[torch.Tensor] = None,
391
+ mask_cond_fea: Optional[torch.Tensor] = None,
392
+ attention_mask: Optional[torch.Tensor] = None,
393
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
394
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
395
+ uc_mask: Optional[torch.Tensor] = None,
396
+ return_dict: bool = True,
397
+ is_new_audio=True,
398
+ update_past_memory=False,
399
+ ) -> Union[UNet3DConditionOutput, Tuple]:
400
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
401
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
402
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
403
+ # on the fly if necessary.
404
+ default_overall_up_factor = 2**self.num_upsamplers
405
+
406
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
407
+ forward_upsample_size = False
408
+ upsample_size = None
409
+
410
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
411
+ logger.info("Forward upsample size to force interpolation output size.")
412
+ forward_upsample_size = True
413
+
414
+ # prepare attention_mask
415
+ if attention_mask is not None:
416
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
417
+ attention_mask = attention_mask.unsqueeze(1)
418
+
419
+ # center input if necessary
420
+ if self.config.center_input_sample:
421
+ sample = 2 * sample - 1.0
422
+
423
+ # time
424
+ timesteps = timestep
425
+ if isinstance(timesteps, list):
426
+ t_emb_list = []
427
+ for timesteps in timestep:
428
+ if not torch.is_tensor(timesteps):
429
+ # This would be a good case for the `match` statement (Python 3.10+)
430
+ is_mps = sample.device.type == "mps"
431
+ if isinstance(timestep, float):
432
+ dtype = torch.float32 if is_mps else torch.float64
433
+ else:
434
+ dtype = torch.int32 if is_mps else torch.int64
435
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
436
+ elif len(timesteps.shape) == 0:
437
+ timesteps = timesteps[None].to(sample.device)
438
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
439
+ timesteps = timesteps.expand(sample.shape[0])
440
+ t_emb = self.time_proj(timesteps)
441
+ t_emb_list.append(t_emb)
442
+
443
+ t_emb = torch.stack(t_emb_list, dim=1)
444
+ else:
445
+ if not torch.is_tensor(timesteps):
446
+ # This would be a good case for the `match` statement (Python 3.10+)
447
+ is_mps = sample.device.type == "mps"
448
+ if isinstance(timestep, float):
449
+ dtype = torch.float32 if is_mps else torch.float64
450
+ else:
451
+ dtype = torch.int32 if is_mps else torch.int64
452
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
453
+ elif len(timesteps.shape) == 0:
454
+ timesteps = timesteps[None].to(sample.device)
455
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
456
+ timesteps = timesteps.expand(sample.shape[0])
457
+ t_emb = self.time_proj(timesteps)
458
+
459
+ # timesteps does not contain any weights and will always return f32 tensors
460
+ # but time_embedding might actually be running in fp16. so we need to cast here.
461
+ # there might be better ways to encapsulate this.
462
+ t_emb = t_emb.to(dtype=self.dtype)
463
+ emb = self.time_embedding(t_emb)
464
+
465
+ if self.class_embedding is not None:
466
+ if class_labels is None:
467
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
468
+
469
+ if self.config.class_embed_type == "timestep":
470
+ class_labels = self.time_proj(class_labels)
471
+
472
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
473
+ emb = emb + class_emb
474
+
475
+ # pre-process
476
+ sample = self.conv_in(sample)
477
+ if mask_cond_fea is not None:
478
+ sample = sample + mask_cond_fea
479
+
480
+ # down
481
+ down_block_res_samples = (sample,)
482
+ for i, downsample_block in enumerate(self.down_blocks):
483
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
484
+ sample, res_samples, audio_embedding = downsample_block(
485
+ hidden_states=sample,
486
+ ref_feature_list=ref_features["down"][i],
487
+ temb=emb,
488
+ encoder_hidden_states=encoder_hidden_states,
489
+ attention_mask=attention_mask,
490
+ audio_embedding=audio_embedding,
491
+ emotion=audio_emotion,
492
+ uc_mask=uc_mask,
493
+ is_new_audio=is_new_audio,
494
+ update_past_memory=update_past_memory,
495
+ )
496
+ else:
497
+ sample, res_samples = downsample_block(
498
+ hidden_states=sample,
499
+ ref_feature_list=ref_features["down"][i],
500
+ temb=emb,
501
+ encoder_hidden_states=encoder_hidden_states,
502
+ is_new_audio=is_new_audio,
503
+ update_past_memory=update_past_memory,
504
+ )
505
+
506
+ down_block_res_samples += res_samples
507
+
508
+ if down_block_additional_residuals is not None:
509
+ new_down_block_res_samples = ()
510
+
511
+ for down_block_res_sample, down_block_additional_residual in zip(
512
+ down_block_res_samples, down_block_additional_residuals
513
+ ):
514
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
515
+ new_down_block_res_samples += (down_block_res_sample,)
516
+
517
+ down_block_res_samples = new_down_block_res_samples
518
+
519
+ # mid
520
+ sample, audio_embedding = self.mid_block(
521
+ sample,
522
+ ref_feature_list=ref_features["mid"][0],
523
+ temb=emb,
524
+ encoder_hidden_states=encoder_hidden_states,
525
+ attention_mask=attention_mask,
526
+ audio_embedding=audio_embedding,
527
+ emotion=audio_emotion,
528
+ uc_mask=uc_mask,
529
+ is_new_audio=is_new_audio,
530
+ update_past_memory=update_past_memory,
531
+ )
532
+
533
+ if mid_block_additional_residual is not None:
534
+ sample = sample + mid_block_additional_residual
535
+
536
+ # up
537
+ for i, upsample_block in enumerate(self.up_blocks):
538
+ is_final_block = i == len(self.up_blocks) - 1
539
+
540
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
541
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
542
+
543
+ # if we have not reached the final block and need to forward the
544
+ # upsample size, we do it here
545
+ if not is_final_block and forward_upsample_size:
546
+ upsample_size = down_block_res_samples[-1].shape[2:]
547
+
548
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
549
+ sample, audio_embedding = upsample_block(
550
+ hidden_states=sample,
551
+ ref_feature_list=ref_features["up"][i],
552
+ temb=emb,
553
+ res_hidden_states_tuple=res_samples,
554
+ encoder_hidden_states=encoder_hidden_states,
555
+ upsample_size=upsample_size,
556
+ attention_mask=attention_mask,
557
+ audio_embedding=audio_embedding,
558
+ emotion=audio_emotion,
559
+ uc_mask=uc_mask,
560
+ is_new_audio=is_new_audio,
561
+ update_past_memory=update_past_memory,
562
+ )
563
+ else:
564
+ sample = upsample_block(
565
+ hidden_states=sample,
566
+ ref_feature_list=ref_features["up"][i],
567
+ temb=emb,
568
+ res_hidden_states_tuple=res_samples,
569
+ upsample_size=upsample_size,
570
+ encoder_hidden_states=encoder_hidden_states,
571
+ is_new_audio=is_new_audio,
572
+ update_past_memory=update_past_memory,
573
+ )
574
+
575
+ # post-process
576
+ sample = self.conv_norm_out(sample)
577
+ sample = self.conv_act(sample)
578
+ sample = self.conv_out(sample)
579
+
580
+ if not return_dict:
581
+ return (sample,)
582
+
583
+ return UNet3DConditionOutput(sample=sample)
memo/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ import torch
4
+ from diffusers.utils import is_torch_version
5
+ from einops import rearrange
6
+ from torch import nn
7
+
8
+ from memo.models.motion_module import MemoryLinearAttnTemporalModule
9
+ from memo.models.resnet import Downsample3D, ResnetBlock3D, Upsample3D
10
+ from memo.models.transformer_3d import Transformer3DModel
11
+
12
+
13
+ def create_custom_forward(module, return_dict=None):
14
+ def custom_forward(*inputs):
15
+ if return_dict is not None:
16
+ return module(*inputs, return_dict=return_dict)
17
+
18
+ return module(*inputs)
19
+
20
+ return custom_forward
21
+
22
+
23
+ def get_down_block(
24
+ down_block_type,
25
+ num_layers,
26
+ in_channels,
27
+ out_channels,
28
+ temb_channels,
29
+ add_downsample,
30
+ resnet_eps,
31
+ resnet_act_fn,
32
+ attn_num_head_channels,
33
+ resnet_groups=None,
34
+ cross_attention_dim=None,
35
+ audio_attention_dim=None,
36
+ downsample_padding=None,
37
+ dual_cross_attention=False,
38
+ use_linear_projection=False,
39
+ only_cross_attention=False,
40
+ upcast_attention=False,
41
+ resnet_time_scale_shift="default",
42
+ unet_use_cross_frame_attention=None,
43
+ unet_use_temporal_attention=None,
44
+ use_inflated_groupnorm=None,
45
+ use_motion_module=None,
46
+ motion_module_kwargs=None,
47
+ depth=0,
48
+ emo_drop_rate=0.3,
49
+ ):
50
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
51
+ if down_block_type == "DownBlock3D":
52
+ return DownBlock3D(
53
+ num_layers=num_layers,
54
+ in_channels=in_channels,
55
+ out_channels=out_channels,
56
+ temb_channels=temb_channels,
57
+ add_downsample=add_downsample,
58
+ resnet_eps=resnet_eps,
59
+ resnet_act_fn=resnet_act_fn,
60
+ resnet_groups=resnet_groups,
61
+ downsample_padding=downsample_padding,
62
+ resnet_time_scale_shift=resnet_time_scale_shift,
63
+ use_inflated_groupnorm=use_inflated_groupnorm,
64
+ use_motion_module=use_motion_module,
65
+ motion_module_kwargs=motion_module_kwargs,
66
+ )
67
+
68
+ if down_block_type == "CrossAttnDownBlock3D":
69
+ if cross_attention_dim is None:
70
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
71
+ return CrossAttnDownBlock3D(
72
+ num_layers=num_layers,
73
+ in_channels=in_channels,
74
+ out_channels=out_channels,
75
+ temb_channels=temb_channels,
76
+ add_downsample=add_downsample,
77
+ resnet_eps=resnet_eps,
78
+ resnet_act_fn=resnet_act_fn,
79
+ resnet_groups=resnet_groups,
80
+ downsample_padding=downsample_padding,
81
+ cross_attention_dim=cross_attention_dim,
82
+ audio_attention_dim=audio_attention_dim,
83
+ attn_num_head_channels=attn_num_head_channels,
84
+ dual_cross_attention=dual_cross_attention,
85
+ use_linear_projection=use_linear_projection,
86
+ only_cross_attention=only_cross_attention,
87
+ upcast_attention=upcast_attention,
88
+ resnet_time_scale_shift=resnet_time_scale_shift,
89
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
90
+ unet_use_temporal_attention=unet_use_temporal_attention,
91
+ use_inflated_groupnorm=use_inflated_groupnorm,
92
+ use_motion_module=use_motion_module,
93
+ motion_module_kwargs=motion_module_kwargs,
94
+ depth=depth,
95
+ emo_drop_rate=emo_drop_rate,
96
+ )
97
+ raise ValueError(f"{down_block_type} does not exist.")
98
+
99
+
100
+ def get_up_block(
101
+ up_block_type,
102
+ num_layers,
103
+ in_channels,
104
+ out_channels,
105
+ prev_output_channel,
106
+ temb_channels,
107
+ add_upsample,
108
+ resnet_eps,
109
+ resnet_act_fn,
110
+ attn_num_head_channels,
111
+ resnet_groups=None,
112
+ cross_attention_dim=None,
113
+ audio_attention_dim=None,
114
+ dual_cross_attention=False,
115
+ use_linear_projection=False,
116
+ only_cross_attention=False,
117
+ upcast_attention=False,
118
+ resnet_time_scale_shift="default",
119
+ unet_use_cross_frame_attention=None,
120
+ unet_use_temporal_attention=None,
121
+ use_inflated_groupnorm=None,
122
+ use_motion_module=None,
123
+ motion_module_kwargs=None,
124
+ depth=0,
125
+ emo_drop_rate=0.3,
126
+ is_final_block=False,
127
+ ):
128
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
129
+ if up_block_type == "UpBlock3D":
130
+ return UpBlock3D(
131
+ num_layers=num_layers,
132
+ in_channels=in_channels,
133
+ out_channels=out_channels,
134
+ prev_output_channel=prev_output_channel,
135
+ temb_channels=temb_channels,
136
+ add_upsample=add_upsample,
137
+ resnet_eps=resnet_eps,
138
+ resnet_act_fn=resnet_act_fn,
139
+ resnet_groups=resnet_groups,
140
+ resnet_time_scale_shift=resnet_time_scale_shift,
141
+ use_inflated_groupnorm=use_inflated_groupnorm,
142
+ use_motion_module=use_motion_module,
143
+ motion_module_kwargs=motion_module_kwargs,
144
+ )
145
+
146
+ if up_block_type == "CrossAttnUpBlock3D":
147
+ if cross_attention_dim is None:
148
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
149
+ return CrossAttnUpBlock3D(
150
+ num_layers=num_layers,
151
+ in_channels=in_channels,
152
+ out_channels=out_channels,
153
+ prev_output_channel=prev_output_channel,
154
+ temb_channels=temb_channels,
155
+ add_upsample=add_upsample,
156
+ resnet_eps=resnet_eps,
157
+ resnet_act_fn=resnet_act_fn,
158
+ resnet_groups=resnet_groups,
159
+ cross_attention_dim=cross_attention_dim,
160
+ audio_attention_dim=audio_attention_dim,
161
+ attn_num_head_channels=attn_num_head_channels,
162
+ dual_cross_attention=dual_cross_attention,
163
+ use_linear_projection=use_linear_projection,
164
+ only_cross_attention=only_cross_attention,
165
+ upcast_attention=upcast_attention,
166
+ resnet_time_scale_shift=resnet_time_scale_shift,
167
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
168
+ unet_use_temporal_attention=unet_use_temporal_attention,
169
+ use_inflated_groupnorm=use_inflated_groupnorm,
170
+ use_motion_module=use_motion_module,
171
+ motion_module_kwargs=motion_module_kwargs,
172
+ depth=depth,
173
+ emo_drop_rate=emo_drop_rate,
174
+ is_final_block=is_final_block,
175
+ )
176
+ raise ValueError(f"{up_block_type} does not exist.")
177
+
178
+
179
+ class UNetMidBlock3DCrossAttn(nn.Module):
180
+ def __init__(
181
+ self,
182
+ in_channels: int,
183
+ temb_channels: int,
184
+ dropout: float = 0.0,
185
+ num_layers: int = 1,
186
+ resnet_eps: float = 1e-6,
187
+ resnet_time_scale_shift: str = "default",
188
+ resnet_act_fn: str = "swish",
189
+ resnet_groups: int = 32,
190
+ resnet_pre_norm: bool = True,
191
+ attn_num_head_channels=1,
192
+ output_scale_factor=1.0,
193
+ cross_attention_dim=1280,
194
+ audio_attention_dim=1024,
195
+ dual_cross_attention=False,
196
+ use_linear_projection=False,
197
+ upcast_attention=False,
198
+ unet_use_cross_frame_attention=None,
199
+ unet_use_temporal_attention=None,
200
+ use_inflated_groupnorm=None,
201
+ motion_module_kwargs=None,
202
+ depth=0,
203
+ emo_drop_rate=0.3,
204
+ ):
205
+ super().__init__()
206
+
207
+ self.has_cross_attention = True
208
+ self.attn_num_head_channels = attn_num_head_channels
209
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
210
+
211
+ # there is always at least one resnet
212
+ resnets = [
213
+ ResnetBlock3D(
214
+ in_channels=in_channels,
215
+ out_channels=in_channels,
216
+ temb_channels=temb_channels,
217
+ eps=resnet_eps,
218
+ groups=resnet_groups,
219
+ dropout=dropout,
220
+ time_embedding_norm=resnet_time_scale_shift,
221
+ non_linearity=resnet_act_fn,
222
+ output_scale_factor=output_scale_factor,
223
+ pre_norm=resnet_pre_norm,
224
+ use_inflated_groupnorm=use_inflated_groupnorm,
225
+ )
226
+ ]
227
+ attentions = []
228
+ motion_modules = []
229
+ audio_modules = []
230
+
231
+ for _ in range(num_layers):
232
+ if dual_cross_attention:
233
+ raise NotImplementedError
234
+ attentions.append(
235
+ Transformer3DModel(
236
+ attn_num_head_channels,
237
+ in_channels // attn_num_head_channels,
238
+ in_channels=in_channels,
239
+ num_layers=1,
240
+ cross_attention_dim=cross_attention_dim,
241
+ norm_num_groups=resnet_groups,
242
+ use_linear_projection=use_linear_projection,
243
+ upcast_attention=upcast_attention,
244
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
245
+ unet_use_temporal_attention=unet_use_temporal_attention,
246
+ )
247
+ )
248
+ audio_modules.append(
249
+ Transformer3DModel(
250
+ attn_num_head_channels,
251
+ in_channels // attn_num_head_channels,
252
+ in_channels=in_channels,
253
+ num_layers=1,
254
+ cross_attention_dim=audio_attention_dim,
255
+ norm_num_groups=resnet_groups,
256
+ use_linear_projection=use_linear_projection,
257
+ only_cross_attention=False,
258
+ upcast_attention=upcast_attention,
259
+ use_audio_module=True,
260
+ depth=depth,
261
+ unet_block_name="mid",
262
+ emo_drop_rate=emo_drop_rate,
263
+ )
264
+ )
265
+
266
+ motion_modules.append(
267
+ MemoryLinearAttnTemporalModule(
268
+ in_channels=in_channels,
269
+ **motion_module_kwargs,
270
+ )
271
+ )
272
+ resnets.append(
273
+ ResnetBlock3D(
274
+ in_channels=in_channels,
275
+ out_channels=in_channels,
276
+ temb_channels=temb_channels,
277
+ eps=resnet_eps,
278
+ groups=resnet_groups,
279
+ dropout=dropout,
280
+ time_embedding_norm=resnet_time_scale_shift,
281
+ non_linearity=resnet_act_fn,
282
+ output_scale_factor=output_scale_factor,
283
+ pre_norm=resnet_pre_norm,
284
+ use_inflated_groupnorm=use_inflated_groupnorm,
285
+ )
286
+ )
287
+
288
+ self.attentions = nn.ModuleList(attentions)
289
+ self.resnets = nn.ModuleList(resnets)
290
+ self.audio_modules = nn.ModuleList(audio_modules)
291
+ self.motion_modules = nn.ModuleList(motion_modules)
292
+
293
+ self.gradient_checkpointing = False
294
+
295
+ def forward(
296
+ self,
297
+ hidden_states,
298
+ ref_feature_list,
299
+ temb=None,
300
+ encoder_hidden_states=None,
301
+ attention_mask=None,
302
+ audio_embedding=None,
303
+ emotion=None,
304
+ uc_mask=None,
305
+ is_new_audio=True,
306
+ update_past_memory=False,
307
+ ):
308
+ hidden_states = self.resnets[0](hidden_states, temb)
309
+ for i, (attn, resnet, audio_module, motion_module) in enumerate(
310
+ zip(
311
+ self.attentions,
312
+ self.resnets[1:],
313
+ self.audio_modules,
314
+ self.motion_modules,
315
+ )
316
+ ):
317
+ ref_feature = ref_feature_list[i]
318
+ ref_feature = ref_feature[0]
319
+ ref_feature = rearrange(
320
+ ref_feature,
321
+ "(b f) (h w) c -> b c f h w",
322
+ b=hidden_states.shape[0],
323
+ w=hidden_states.shape[-1],
324
+ )
325
+ ref_img_feature = ref_feature[:, :, :1, :, :]
326
+ ref_img_feature = rearrange(
327
+ ref_img_feature,
328
+ "b c f h w -> (b f) (h w) c",
329
+ )
330
+ motion_frames = ref_feature[:, :, 1:, :, :]
331
+
332
+ hidden_states = attn(
333
+ hidden_states,
334
+ ref_img_feature,
335
+ encoder_hidden_states=encoder_hidden_states,
336
+ uc_mask=uc_mask,
337
+ return_dict=False,
338
+ )
339
+ if audio_module is not None:
340
+ hidden_states, audio_embedding = audio_module(
341
+ hidden_states,
342
+ ref_img_feature=None,
343
+ encoder_hidden_states=audio_embedding,
344
+ attention_mask=attention_mask,
345
+ return_dict=False,
346
+ emotion=emotion,
347
+ )
348
+ if motion_module is not None:
349
+ motion_frames = motion_frames.to(device=hidden_states.device, dtype=hidden_states.dtype)
350
+ hidden_states = motion_module(
351
+ hidden_states=hidden_states,
352
+ motion_frames=motion_frames,
353
+ encoder_hidden_states=encoder_hidden_states,
354
+ is_new_audio=is_new_audio,
355
+ update_past_memory=update_past_memory,
356
+ )
357
+
358
+ if self.training and self.gradient_checkpointing:
359
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
360
+ hidden_states = torch.utils.checkpoint.checkpoint(
361
+ create_custom_forward(resnet),
362
+ hidden_states,
363
+ temb,
364
+ **ckpt_kwargs,
365
+ )
366
+ else:
367
+ hidden_states = resnet(hidden_states, temb)
368
+
369
+ if audio_module is not None:
370
+ return hidden_states, audio_embedding
371
+ else:
372
+ return hidden_states
373
+
374
+
375
+ class CrossAttnDownBlock3D(nn.Module):
376
+ def __init__(
377
+ self,
378
+ in_channels: int,
379
+ out_channels: int,
380
+ temb_channels: int,
381
+ dropout: float = 0.0,
382
+ num_layers: int = 1,
383
+ resnet_eps: float = 1e-6,
384
+ resnet_time_scale_shift: str = "default",
385
+ resnet_act_fn: str = "swish",
386
+ resnet_groups: int = 32,
387
+ resnet_pre_norm: bool = True,
388
+ attn_num_head_channels=1,
389
+ cross_attention_dim=1280,
390
+ audio_attention_dim=1024,
391
+ output_scale_factor=1.0,
392
+ downsample_padding=1,
393
+ add_downsample=True,
394
+ dual_cross_attention=False,
395
+ use_linear_projection=False,
396
+ only_cross_attention=False,
397
+ upcast_attention=False,
398
+ unet_use_cross_frame_attention=None,
399
+ unet_use_temporal_attention=None,
400
+ use_inflated_groupnorm=None,
401
+ use_motion_module=None,
402
+ motion_module_kwargs=None,
403
+ depth=0,
404
+ emo_drop_rate=0.3,
405
+ ):
406
+ super().__init__()
407
+ resnets = []
408
+ attentions = []
409
+ audio_modules = []
410
+ motion_modules = []
411
+
412
+ self.has_cross_attention = True
413
+ self.attn_num_head_channels = attn_num_head_channels
414
+
415
+ for i in range(num_layers):
416
+ in_channels = in_channels if i == 0 else out_channels
417
+ resnets.append(
418
+ ResnetBlock3D(
419
+ in_channels=in_channels,
420
+ out_channels=out_channels,
421
+ temb_channels=temb_channels,
422
+ eps=resnet_eps,
423
+ groups=resnet_groups,
424
+ dropout=dropout,
425
+ time_embedding_norm=resnet_time_scale_shift,
426
+ non_linearity=resnet_act_fn,
427
+ output_scale_factor=output_scale_factor,
428
+ pre_norm=resnet_pre_norm,
429
+ use_inflated_groupnorm=use_inflated_groupnorm,
430
+ )
431
+ )
432
+ if dual_cross_attention:
433
+ raise NotImplementedError
434
+ attentions.append(
435
+ Transformer3DModel(
436
+ attn_num_head_channels,
437
+ out_channels // attn_num_head_channels,
438
+ in_channels=out_channels,
439
+ num_layers=1,
440
+ cross_attention_dim=cross_attention_dim,
441
+ norm_num_groups=resnet_groups,
442
+ use_linear_projection=use_linear_projection,
443
+ only_cross_attention=only_cross_attention,
444
+ upcast_attention=upcast_attention,
445
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
446
+ unet_use_temporal_attention=unet_use_temporal_attention,
447
+ )
448
+ )
449
+ audio_modules.append(
450
+ Transformer3DModel(
451
+ attn_num_head_channels,
452
+ in_channels // attn_num_head_channels,
453
+ in_channels=out_channels,
454
+ num_layers=1,
455
+ cross_attention_dim=audio_attention_dim,
456
+ norm_num_groups=resnet_groups,
457
+ use_linear_projection=use_linear_projection,
458
+ only_cross_attention=only_cross_attention,
459
+ upcast_attention=upcast_attention,
460
+ use_audio_module=True,
461
+ depth=depth,
462
+ unet_block_name="down",
463
+ emo_drop_rate=emo_drop_rate,
464
+ )
465
+ )
466
+ motion_modules.append(
467
+ MemoryLinearAttnTemporalModule(
468
+ in_channels=out_channels,
469
+ **motion_module_kwargs,
470
+ )
471
+ if use_motion_module
472
+ else None
473
+ )
474
+
475
+ self.attentions = nn.ModuleList(attentions)
476
+ self.resnets = nn.ModuleList(resnets)
477
+ self.audio_modules = nn.ModuleList(audio_modules)
478
+ self.motion_modules = nn.ModuleList(motion_modules)
479
+
480
+ if add_downsample:
481
+ self.downsamplers = nn.ModuleList(
482
+ [
483
+ Downsample3D(
484
+ out_channels,
485
+ use_conv=True,
486
+ out_channels=out_channels,
487
+ padding=downsample_padding,
488
+ name="op",
489
+ )
490
+ ]
491
+ )
492
+ else:
493
+ self.downsamplers = None
494
+
495
+ self.gradient_checkpointing = False
496
+
497
+ def forward(
498
+ self,
499
+ hidden_states,
500
+ ref_feature_list,
501
+ temb=None,
502
+ encoder_hidden_states=None,
503
+ attention_mask=None,
504
+ audio_embedding=None,
505
+ emotion=None,
506
+ uc_mask=None,
507
+ is_new_audio=True,
508
+ update_past_memory=False,
509
+ ):
510
+ output_states = ()
511
+
512
+ for i, (resnet, attn, audio_module, motion_module) in enumerate(
513
+ zip(self.resnets, self.attentions, self.audio_modules, self.motion_modules)
514
+ ):
515
+ ref_feature = ref_feature_list[i]
516
+ ref_feature = ref_feature[0]
517
+ ref_feature = rearrange(
518
+ ref_feature,
519
+ "(b f) (h w) c -> b c f h w",
520
+ b=hidden_states.shape[0],
521
+ w=hidden_states.shape[-1],
522
+ )
523
+ ref_img_feature = ref_feature[:, :, :1, :, :]
524
+ ref_img_feature = rearrange(
525
+ ref_img_feature,
526
+ "b c f h w -> (b f) (h w) c",
527
+ )
528
+ motion_frames = ref_feature[:, :, 1:, :, :]
529
+
530
+ if self.training and self.gradient_checkpointing:
531
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
532
+ hidden_states = torch.utils.checkpoint.checkpoint(
533
+ create_custom_forward(resnet),
534
+ hidden_states,
535
+ temb,
536
+ **ckpt_kwargs,
537
+ )
538
+ else:
539
+ hidden_states = resnet(hidden_states, temb)
540
+
541
+ hidden_states = attn(
542
+ hidden_states,
543
+ ref_img_feature,
544
+ encoder_hidden_states=encoder_hidden_states,
545
+ uc_mask=uc_mask,
546
+ return_dict=False,
547
+ )
548
+
549
+ if audio_module is not None:
550
+ hidden_states, audio_embedding = audio_module(
551
+ hidden_states,
552
+ ref_img_feature=None,
553
+ encoder_hidden_states=audio_embedding,
554
+ attention_mask=attention_mask,
555
+ return_dict=False,
556
+ emotion=emotion,
557
+ )
558
+
559
+ # add motion module
560
+ if motion_module is not None:
561
+ motion_frames = motion_frames.to(device=hidden_states.device, dtype=hidden_states.dtype)
562
+ hidden_states = motion_module(
563
+ hidden_states=hidden_states,
564
+ motion_frames=motion_frames,
565
+ encoder_hidden_states=encoder_hidden_states,
566
+ is_new_audio=is_new_audio,
567
+ update_past_memory=update_past_memory,
568
+ )
569
+
570
+ output_states += (hidden_states,)
571
+
572
+ if self.downsamplers is not None:
573
+ for downsampler in self.downsamplers:
574
+ hidden_states = downsampler(hidden_states)
575
+
576
+ output_states += (hidden_states,)
577
+
578
+ if audio_module is not None:
579
+ return hidden_states, output_states, audio_embedding
580
+ else:
581
+ return hidden_states, output_states
582
+
583
+
584
+ class DownBlock3D(nn.Module):
585
+ def __init__(
586
+ self,
587
+ in_channels: int,
588
+ out_channels: int,
589
+ temb_channels: int,
590
+ dropout: float = 0.0,
591
+ num_layers: int = 1,
592
+ resnet_eps: float = 1e-6,
593
+ resnet_time_scale_shift: str = "default",
594
+ resnet_act_fn: str = "swish",
595
+ resnet_groups: int = 32,
596
+ resnet_pre_norm: bool = True,
597
+ output_scale_factor=1.0,
598
+ add_downsample=True,
599
+ downsample_padding=1,
600
+ use_inflated_groupnorm=None,
601
+ use_motion_module=None,
602
+ motion_module_kwargs=None,
603
+ ):
604
+ super().__init__()
605
+ resnets = []
606
+ motion_modules = []
607
+
608
+ for i in range(num_layers):
609
+ in_channels = in_channels if i == 0 else out_channels
610
+ resnets.append(
611
+ ResnetBlock3D(
612
+ in_channels=in_channels,
613
+ out_channels=out_channels,
614
+ temb_channels=temb_channels,
615
+ eps=resnet_eps,
616
+ groups=resnet_groups,
617
+ dropout=dropout,
618
+ time_embedding_norm=resnet_time_scale_shift,
619
+ non_linearity=resnet_act_fn,
620
+ output_scale_factor=output_scale_factor,
621
+ pre_norm=resnet_pre_norm,
622
+ use_inflated_groupnorm=use_inflated_groupnorm,
623
+ )
624
+ )
625
+ motion_modules.append(
626
+ MemoryLinearAttnTemporalModule(
627
+ in_channels=out_channels,
628
+ **motion_module_kwargs,
629
+ )
630
+ if use_motion_module
631
+ else None
632
+ )
633
+
634
+ self.resnets = nn.ModuleList(resnets)
635
+ self.motion_modules = nn.ModuleList(motion_modules)
636
+
637
+ if add_downsample:
638
+ self.downsamplers = nn.ModuleList(
639
+ [
640
+ Downsample3D(
641
+ out_channels,
642
+ use_conv=True,
643
+ out_channels=out_channels,
644
+ padding=downsample_padding,
645
+ name="op",
646
+ )
647
+ ]
648
+ )
649
+ else:
650
+ self.downsamplers = None
651
+
652
+ self.gradient_checkpointing = False
653
+
654
+ def forward(
655
+ self,
656
+ hidden_states,
657
+ ref_feature_list,
658
+ temb=None,
659
+ encoder_hidden_states=None,
660
+ is_new_audio=True,
661
+ update_past_memory=False,
662
+ ):
663
+ output_states = ()
664
+
665
+ for i, (resnet, motion_module) in enumerate(zip(self.resnets, self.motion_modules)):
666
+ ref_feature = ref_feature_list[i]
667
+ ref_feature = rearrange(
668
+ ref_feature,
669
+ "(b f) c h w -> b c f h w",
670
+ b=hidden_states.shape[0],
671
+ w=hidden_states.shape[-1],
672
+ )
673
+ motion_frames = ref_feature[:, :, 1:, :, :]
674
+
675
+ if self.training and self.gradient_checkpointing:
676
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
677
+ hidden_states = torch.utils.checkpoint.checkpoint(
678
+ create_custom_forward(resnet),
679
+ hidden_states,
680
+ temb,
681
+ **ckpt_kwargs,
682
+ )
683
+ else:
684
+ hidden_states = resnet(hidden_states, temb)
685
+
686
+ if motion_module is not None:
687
+ hidden_states = motion_module(
688
+ hidden_states=hidden_states,
689
+ motion_frames=motion_frames,
690
+ encoder_hidden_states=encoder_hidden_states,
691
+ is_new_audio=is_new_audio,
692
+ update_past_memory=update_past_memory,
693
+ )
694
+
695
+ output_states += (hidden_states,)
696
+
697
+ if self.downsamplers is not None:
698
+ for downsampler in self.downsamplers:
699
+ hidden_states = downsampler(hidden_states)
700
+
701
+ output_states += (hidden_states,)
702
+
703
+ return hidden_states, output_states
704
+
705
+
706
+ class CrossAttnUpBlock3D(nn.Module):
707
+ def __init__(
708
+ self,
709
+ in_channels: int,
710
+ out_channels: int,
711
+ prev_output_channel: int,
712
+ temb_channels: int,
713
+ dropout: float = 0.0,
714
+ num_layers: int = 1,
715
+ resnet_eps: float = 1e-6,
716
+ resnet_time_scale_shift: str = "default",
717
+ resnet_act_fn: str = "swish",
718
+ resnet_groups: int = 32,
719
+ resnet_pre_norm: bool = True,
720
+ attn_num_head_channels=1,
721
+ cross_attention_dim=1280,
722
+ audio_attention_dim=1024,
723
+ output_scale_factor=1.0,
724
+ add_upsample=True,
725
+ dual_cross_attention=False,
726
+ use_linear_projection=False,
727
+ only_cross_attention=False,
728
+ upcast_attention=False,
729
+ unet_use_cross_frame_attention=None,
730
+ unet_use_temporal_attention=None,
731
+ use_motion_module=None,
732
+ use_inflated_groupnorm=None,
733
+ motion_module_kwargs=None,
734
+ depth=0,
735
+ emo_drop_rate=0.3,
736
+ is_final_block=False,
737
+ ):
738
+ super().__init__()
739
+ resnets = []
740
+ attentions = []
741
+ audio_modules = []
742
+ motion_modules = []
743
+
744
+ self.has_cross_attention = True
745
+ self.attn_num_head_channels = attn_num_head_channels
746
+ self.is_final_block = is_final_block
747
+
748
+ for i in range(num_layers):
749
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
750
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
751
+
752
+ resnets.append(
753
+ ResnetBlock3D(
754
+ in_channels=resnet_in_channels + res_skip_channels,
755
+ out_channels=out_channels,
756
+ temb_channels=temb_channels,
757
+ eps=resnet_eps,
758
+ groups=resnet_groups,
759
+ dropout=dropout,
760
+ time_embedding_norm=resnet_time_scale_shift,
761
+ non_linearity=resnet_act_fn,
762
+ output_scale_factor=output_scale_factor,
763
+ pre_norm=resnet_pre_norm,
764
+ use_inflated_groupnorm=use_inflated_groupnorm,
765
+ )
766
+ )
767
+
768
+ if dual_cross_attention:
769
+ raise NotImplementedError
770
+ attentions.append(
771
+ Transformer3DModel(
772
+ attn_num_head_channels,
773
+ out_channels // attn_num_head_channels,
774
+ in_channels=out_channels,
775
+ num_layers=1,
776
+ cross_attention_dim=cross_attention_dim,
777
+ norm_num_groups=resnet_groups,
778
+ use_linear_projection=use_linear_projection,
779
+ only_cross_attention=only_cross_attention,
780
+ upcast_attention=upcast_attention,
781
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
782
+ unet_use_temporal_attention=unet_use_temporal_attention,
783
+ )
784
+ )
785
+ audio_modules.append(
786
+ Transformer3DModel(
787
+ attn_num_head_channels,
788
+ in_channels // attn_num_head_channels,
789
+ in_channels=out_channels,
790
+ num_layers=1,
791
+ cross_attention_dim=audio_attention_dim,
792
+ norm_num_groups=resnet_groups,
793
+ use_linear_projection=use_linear_projection,
794
+ only_cross_attention=only_cross_attention,
795
+ upcast_attention=upcast_attention,
796
+ use_audio_module=True,
797
+ depth=depth,
798
+ unet_block_name="up",
799
+ emo_drop_rate=emo_drop_rate,
800
+ is_final_block=(is_final_block and i == num_layers - 1),
801
+ )
802
+ )
803
+ motion_modules.append(
804
+ MemoryLinearAttnTemporalModule(
805
+ in_channels=out_channels,
806
+ **motion_module_kwargs,
807
+ )
808
+ if use_motion_module
809
+ else None
810
+ )
811
+
812
+ self.attentions = nn.ModuleList(attentions)
813
+ self.resnets = nn.ModuleList(resnets)
814
+ self.audio_modules = nn.ModuleList(audio_modules)
815
+ self.motion_modules = nn.ModuleList(motion_modules)
816
+
817
+ if add_upsample:
818
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
819
+ else:
820
+ self.upsamplers = None
821
+
822
+ self.gradient_checkpointing = False
823
+
824
+ def forward(
825
+ self,
826
+ hidden_states,
827
+ ref_feature_list,
828
+ res_hidden_states_tuple,
829
+ temb=None,
830
+ encoder_hidden_states=None,
831
+ upsample_size=None,
832
+ attention_mask=None,
833
+ audio_embedding=None,
834
+ emotion=None,
835
+ uc_mask=None,
836
+ is_new_audio=True,
837
+ update_past_memory=False,
838
+ ):
839
+ for i, (resnet, attn, audio_module, motion_module) in enumerate(
840
+ zip(self.resnets, self.attentions, self.audio_modules, self.motion_modules)
841
+ ):
842
+ # pop res hidden states
843
+ res_hidden_states = res_hidden_states_tuple[-1]
844
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
845
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
846
+
847
+ ref_feature = ref_feature_list[i]
848
+ ref_feature = ref_feature[0]
849
+ ref_feature = rearrange(
850
+ ref_feature,
851
+ "(b f) (h w) c -> b c f h w",
852
+ b=hidden_states.shape[0],
853
+ w=hidden_states.shape[-1],
854
+ )
855
+ ref_img_feature = ref_feature[:, :, :1, :, :]
856
+ ref_img_feature = rearrange(
857
+ ref_img_feature,
858
+ "b c f h w -> (b f) (h w) c",
859
+ )
860
+ motion_frames = ref_feature[:, :, 1:, :, :]
861
+
862
+ if self.training and self.gradient_checkpointing:
863
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
864
+ hidden_states = torch.utils.checkpoint.checkpoint(
865
+ create_custom_forward(resnet),
866
+ hidden_states,
867
+ temb,
868
+ **ckpt_kwargs,
869
+ )
870
+ else:
871
+ hidden_states = resnet(hidden_states, temb)
872
+
873
+ hidden_states = attn(
874
+ hidden_states,
875
+ ref_img_feature,
876
+ encoder_hidden_states=encoder_hidden_states,
877
+ uc_mask=uc_mask,
878
+ return_dict=False,
879
+ )
880
+
881
+ if audio_module is not None:
882
+ hidden_states, audio_embedding = audio_module(
883
+ hidden_states,
884
+ ref_img_feature=None,
885
+ encoder_hidden_states=audio_embedding,
886
+ attention_mask=attention_mask,
887
+ return_dict=False,
888
+ emotion=emotion,
889
+ )
890
+
891
+ # add motion module
892
+ if motion_module is not None:
893
+ motion_frames = motion_frames.to(device=hidden_states.device, dtype=hidden_states.dtype)
894
+ hidden_states = motion_module(
895
+ hidden_states,
896
+ motion_frames,
897
+ encoder_hidden_states,
898
+ is_new_audio=is_new_audio,
899
+ update_past_memory=update_past_memory,
900
+ )
901
+
902
+ if self.upsamplers is not None:
903
+ for upsampler in self.upsamplers:
904
+ hidden_states = upsampler(hidden_states, upsample_size)
905
+
906
+ if audio_module is not None:
907
+ return hidden_states, audio_embedding
908
+ else:
909
+ return hidden_states
910
+
911
+
912
+ class UpBlock3D(nn.Module):
913
+ def __init__(
914
+ self,
915
+ in_channels: int,
916
+ prev_output_channel: int,
917
+ out_channels: int,
918
+ temb_channels: int,
919
+ dropout: float = 0.0,
920
+ num_layers: int = 1,
921
+ resnet_eps: float = 1e-6,
922
+ resnet_time_scale_shift: str = "default",
923
+ resnet_act_fn: str = "swish",
924
+ resnet_groups: int = 32,
925
+ resnet_pre_norm: bool = True,
926
+ output_scale_factor=1.0,
927
+ add_upsample=True,
928
+ use_inflated_groupnorm=None,
929
+ use_motion_module=None,
930
+ motion_module_kwargs=None,
931
+ ):
932
+ super().__init__()
933
+ resnets = []
934
+ motion_modules = []
935
+
936
+ for i in range(num_layers):
937
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
938
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
939
+
940
+ resnets.append(
941
+ ResnetBlock3D(
942
+ in_channels=resnet_in_channels + res_skip_channels,
943
+ out_channels=out_channels,
944
+ temb_channels=temb_channels,
945
+ eps=resnet_eps,
946
+ groups=resnet_groups,
947
+ dropout=dropout,
948
+ time_embedding_norm=resnet_time_scale_shift,
949
+ non_linearity=resnet_act_fn,
950
+ output_scale_factor=output_scale_factor,
951
+ pre_norm=resnet_pre_norm,
952
+ use_inflated_groupnorm=use_inflated_groupnorm,
953
+ )
954
+ )
955
+ motion_modules.append(
956
+ MemoryLinearAttnTemporalModule(
957
+ in_channels=out_channels,
958
+ **motion_module_kwargs,
959
+ )
960
+ if use_motion_module
961
+ else None
962
+ )
963
+
964
+ self.resnets = nn.ModuleList(resnets)
965
+ self.motion_modules = nn.ModuleList(motion_modules)
966
+
967
+ if add_upsample:
968
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
969
+ else:
970
+ self.upsamplers = None
971
+
972
+ self.gradient_checkpointing = False
973
+
974
+ def forward(
975
+ self,
976
+ hidden_states,
977
+ ref_feature_list,
978
+ res_hidden_states_tuple,
979
+ temb=None,
980
+ upsample_size=None,
981
+ encoder_hidden_states=None,
982
+ is_new_audio=True,
983
+ update_past_memory=False,
984
+ ):
985
+ for i, (resnet, motion_module) in enumerate(zip(self.resnets, self.motion_modules)):
986
+ # pop res hidden states
987
+ res_hidden_states = res_hidden_states_tuple[-1]
988
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
989
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
990
+
991
+ ref_feature = ref_feature_list[i]
992
+ ref_feature = rearrange(
993
+ ref_feature,
994
+ "(b f) c h w -> b c f h w",
995
+ b=hidden_states.shape[0],
996
+ w=hidden_states.shape[-1],
997
+ )
998
+ motion_frames = ref_feature[:, :, 1:, :, :]
999
+
1000
+ if self.training and self.gradient_checkpointing:
1001
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1002
+ hidden_states = torch.utils.checkpoint.checkpoint(
1003
+ create_custom_forward(resnet),
1004
+ hidden_states,
1005
+ temb,
1006
+ **ckpt_kwargs,
1007
+ )
1008
+ else:
1009
+ hidden_states = resnet(hidden_states, temb)
1010
+
1011
+ if motion_module is not None:
1012
+ hidden_states = motion_module(
1013
+ hidden_states=hidden_states,
1014
+ motion_frames=motion_frames,
1015
+ encoder_hidden_states=encoder_hidden_states,
1016
+ is_new_audio=is_new_audio,
1017
+ update_past_memory=update_past_memory,
1018
+ )
1019
+
1020
+ if self.upsamplers is not None:
1021
+ for upsampler in self.upsamplers:
1022
+ hidden_states = upsampler(hidden_states, upsample_size)
1023
+
1024
+ return hidden_states
memo/models/wav2vec.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from transformers import Wav2Vec2Model
3
+ from transformers.modeling_outputs import BaseModelOutput
4
+
5
+
6
+ class Wav2VecModel(Wav2Vec2Model):
7
+ def forward(
8
+ self,
9
+ input_values,
10
+ seq_len,
11
+ attention_mask=None,
12
+ mask_time_indices=None,
13
+ output_attentions=None,
14
+ output_hidden_states=None,
15
+ return_dict=None,
16
+ ):
17
+ self.config.output_attentions = True
18
+
19
+ output_hidden_states = (
20
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
21
+ )
22
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
23
+
24
+ extract_features = self.feature_extractor(input_values)
25
+ extract_features = extract_features.transpose(1, 2)
26
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
27
+
28
+ if attention_mask is not None:
29
+ # compute reduced attention_mask corresponding to feature vectors
30
+ attention_mask = self._get_feature_vector_attention_mask(
31
+ extract_features.shape[1], attention_mask, add_adapter=False
32
+ )
33
+
34
+ hidden_states, extract_features = self.feature_projection(extract_features)
35
+ hidden_states = self._mask_hidden_states(
36
+ hidden_states,
37
+ mask_time_indices=mask_time_indices,
38
+ attention_mask=attention_mask,
39
+ )
40
+
41
+ encoder_outputs = self.encoder(
42
+ hidden_states,
43
+ attention_mask=attention_mask,
44
+ output_attentions=output_attentions,
45
+ output_hidden_states=output_hidden_states,
46
+ return_dict=return_dict,
47
+ )
48
+
49
+ hidden_states = encoder_outputs[0]
50
+
51
+ if self.adapter is not None:
52
+ hidden_states = self.adapter(hidden_states)
53
+
54
+ if not return_dict:
55
+ return (hidden_states,) + encoder_outputs[1:]
56
+ return BaseModelOutput(
57
+ last_hidden_state=hidden_states,
58
+ hidden_states=encoder_outputs.hidden_states,
59
+ attentions=encoder_outputs.attentions,
60
+ )
61
+
62
+ def feature_extract(
63
+ self,
64
+ input_values,
65
+ seq_len,
66
+ ):
67
+ extract_features = self.feature_extractor(input_values)
68
+ extract_features = extract_features.transpose(1, 2)
69
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
70
+
71
+ return extract_features
72
+
73
+ def encode(
74
+ self,
75
+ extract_features,
76
+ attention_mask=None,
77
+ mask_time_indices=None,
78
+ output_attentions=None,
79
+ output_hidden_states=None,
80
+ return_dict=None,
81
+ ):
82
+ self.config.output_attentions = True
83
+
84
+ output_hidden_states = (
85
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
86
+ )
87
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
+
89
+ if attention_mask is not None:
90
+ # compute reduced attention_mask corresponding to feature vectors
91
+ attention_mask = self._get_feature_vector_attention_mask(
92
+ extract_features.shape[1], attention_mask, add_adapter=False
93
+ )
94
+
95
+ hidden_states, extract_features = self.feature_projection(extract_features)
96
+ hidden_states = self._mask_hidden_states(
97
+ hidden_states,
98
+ mask_time_indices=mask_time_indices,
99
+ attention_mask=attention_mask,
100
+ )
101
+
102
+ encoder_outputs = self.encoder(
103
+ hidden_states,
104
+ attention_mask=attention_mask,
105
+ output_attentions=output_attentions,
106
+ output_hidden_states=output_hidden_states,
107
+ return_dict=return_dict,
108
+ )
109
+
110
+ hidden_states = encoder_outputs[0]
111
+
112
+ if self.adapter is not None:
113
+ hidden_states = self.adapter(hidden_states)
114
+
115
+ if not return_dict:
116
+ return (hidden_states,) + encoder_outputs[1:]
117
+ return BaseModelOutput(
118
+ last_hidden_state=hidden_states,
119
+ hidden_states=encoder_outputs.hidden_states,
120
+ attentions=encoder_outputs.attentions,
121
+ )
122
+
123
+
124
+ def linear_interpolation(features, seq_len):
125
+ features = features.transpose(1, 2)
126
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode="linear")
127
+ return output_features.transpose(1, 2)
memo/pipelines/__init__.py ADDED
File without changes
memo/pipelines/video_pipeline.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import (
8
+ DDIMScheduler,
9
+ DiffusionPipeline,
10
+ DPMSolverMultistepScheduler,
11
+ EulerAncestralDiscreteScheduler,
12
+ EulerDiscreteScheduler,
13
+ LMSDiscreteScheduler,
14
+ PNDMScheduler,
15
+ )
16
+ from diffusers.image_processor import VaeImageProcessor
17
+ from diffusers.utils import BaseOutput
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from einops import rearrange
20
+
21
+
22
+ @dataclass
23
+ class VideoPipelineOutput(BaseOutput):
24
+ videos: Union[torch.Tensor, np.ndarray]
25
+
26
+
27
+ class VideoPipeline(DiffusionPipeline):
28
+ def __init__(
29
+ self,
30
+ vae,
31
+ reference_net,
32
+ diffusion_net,
33
+ image_proj,
34
+ scheduler: Union[
35
+ DDIMScheduler,
36
+ PNDMScheduler,
37
+ LMSDiscreteScheduler,
38
+ EulerDiscreteScheduler,
39
+ EulerAncestralDiscreteScheduler,
40
+ DPMSolverMultistepScheduler,
41
+ ],
42
+ ) -> None:
43
+ super().__init__()
44
+
45
+ self.register_modules(
46
+ vae=vae,
47
+ reference_net=reference_net,
48
+ diffusion_net=diffusion_net,
49
+ scheduler=scheduler,
50
+ image_proj=image_proj,
51
+ )
52
+
53
+ self.vae_scale_factor: int = 2 ** (len(self.vae.config.block_out_channels) - 1)
54
+
55
+ self.ref_image_processor = VaeImageProcessor(
56
+ vae_scale_factor=self.vae_scale_factor,
57
+ do_convert_rgb=True,
58
+ )
59
+
60
+ @property
61
+ def _execution_device(self):
62
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
63
+ return self.device
64
+ for module in self.unet.modules():
65
+ if (
66
+ hasattr(module, "_hf_hook")
67
+ and hasattr(module._hf_hook, "execution_device")
68
+ and module._hf_hook.execution_device is not None
69
+ ):
70
+ return torch.device(module._hf_hook.execution_device)
71
+ return self.device
72
+
73
+ def prepare_latents(
74
+ self,
75
+ batch_size: int, # Number of videos to generate in parallel
76
+ num_channels_latents: int, # Number of channels in the latents
77
+ width: int, # Width of the video frame
78
+ height: int, # Height of the video frame
79
+ video_length: int, # Length of the video in frames
80
+ dtype: torch.dtype, # Data type of the latents
81
+ device: torch.device, # Device to store the latents on
82
+ generator: Optional[torch.Generator] = None, # Random number generator for reproducibility
83
+ latents: Optional[torch.Tensor] = None, # Pre-generated latents (optional)
84
+ ):
85
+ shape = (
86
+ batch_size,
87
+ num_channels_latents,
88
+ video_length,
89
+ height // self.vae_scale_factor,
90
+ width // self.vae_scale_factor,
91
+ )
92
+ if isinstance(generator, list) and len(generator) != batch_size:
93
+ raise ValueError(
94
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
95
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
96
+ )
97
+
98
+ if latents is None:
99
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
100
+ else:
101
+ latents = latents.to(device)
102
+
103
+ # scale the initial noise by the standard deviation required by the scheduler
104
+ if hasattr(self.scheduler, "init_noise_sigma"):
105
+ latents = latents * self.scheduler.init_noise_sigma
106
+ return latents
107
+
108
+ def prepare_extra_step_kwargs(self, generator, eta):
109
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
110
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
111
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
112
+ # and should be between [0, 1]
113
+
114
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
115
+ extra_step_kwargs = {}
116
+ if accepts_eta:
117
+ extra_step_kwargs["eta"] = eta
118
+
119
+ # check if the scheduler accepts generator
120
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
121
+ if accepts_generator:
122
+ extra_step_kwargs["generator"] = generator
123
+ return extra_step_kwargs
124
+
125
+ def decode_latents(self, latents):
126
+ video_length = latents.shape[2]
127
+ latents = 1 / 0.18215 * latents
128
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
129
+ video = []
130
+ for frame_idx in range(latents.shape[0]):
131
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
132
+ video = torch.cat(video)
133
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
134
+ video = (video / 2 + 0.5).clamp(0, 1)
135
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
136
+ video = video.cpu().float().numpy()
137
+ return video
138
+
139
+ @torch.no_grad()
140
+ def __call__(
141
+ self,
142
+ ref_image,
143
+ face_emb,
144
+ audio_tensor,
145
+ width,
146
+ height,
147
+ video_length,
148
+ num_inference_steps,
149
+ guidance_scale,
150
+ num_images_per_prompt=1,
151
+ eta: float = 0.0,
152
+ audio_emotion=None,
153
+ emotion_class_num=None,
154
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
155
+ output_type: Optional[str] = "tensor",
156
+ return_dict: bool = True,
157
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
158
+ callback_steps: Optional[int] = 1,
159
+ ):
160
+ # Default height and width to unet
161
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
162
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
163
+
164
+ device = self._execution_device
165
+
166
+ do_classifier_free_guidance = guidance_scale > 1.0
167
+
168
+ # Prepare timesteps
169
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
170
+ timesteps = self.scheduler.timesteps
171
+
172
+ batch_size = 1
173
+
174
+ # prepare clip image embeddings
175
+ clip_image_embeds = face_emb
176
+ clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype)
177
+
178
+ encoder_hidden_states = self.image_proj(clip_image_embeds)
179
+ uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds))
180
+
181
+ if do_classifier_free_guidance:
182
+ encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0)
183
+
184
+ num_channels_latents = self.diffusion_net.in_channels
185
+
186
+ latents = self.prepare_latents(
187
+ batch_size * num_images_per_prompt,
188
+ num_channels_latents,
189
+ width,
190
+ height,
191
+ video_length,
192
+ clip_image_embeds.dtype,
193
+ device,
194
+ generator,
195
+ )
196
+
197
+ # Prepare extra step kwargs.
198
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
199
+
200
+ # Prepare ref image latents
201
+ ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w")
202
+ ref_image_tensor = self.ref_image_processor.preprocess(
203
+ ref_image_tensor, height=height, width=width
204
+ ) # (bs, c, width, height)
205
+ ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device)
206
+ # To save memory on GPUs like RTX 4090, we encode each frame separately
207
+ # ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
208
+ ref_image_latents = []
209
+ for frame_idx in range(ref_image_tensor.shape[0]):
210
+ ref_image_latents.append(self.vae.encode(ref_image_tensor[frame_idx : frame_idx + 1]).latent_dist.mean)
211
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
212
+
213
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
214
+
215
+ if do_classifier_free_guidance:
216
+ uncond_audio_tensor = torch.zeros_like(audio_tensor)
217
+ audio_tensor = torch.cat([uncond_audio_tensor, audio_tensor], dim=0)
218
+ audio_tensor = audio_tensor.to(dtype=self.diffusion_net.dtype, device=self.diffusion_net.device)
219
+
220
+ # denoising loop
221
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
222
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
223
+ for i in range(len(timesteps)):
224
+ t = timesteps[i]
225
+ # Forward reference image
226
+ if i == 0:
227
+ ref_features = self.reference_net(
228
+ ref_image_latents.repeat((2 if do_classifier_free_guidance else 1), 1, 1, 1),
229
+ torch.zeros_like(t),
230
+ encoder_hidden_states=encoder_hidden_states,
231
+ return_dict=False,
232
+ )
233
+
234
+ # expand the latents if we are doing classifier free guidance
235
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
236
+ if hasattr(self.scheduler, "scale_model_input"):
237
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
238
+
239
+ audio_emotion = torch.tensor(torch.mode(audio_emotion).values.item()).to(
240
+ dtype=torch.int, device=self.diffusion_net.device
241
+ )
242
+ if do_classifier_free_guidance:
243
+ uncond_audio_emotion = torch.full_like(audio_emotion, emotion_class_num)
244
+ audio_emotion = torch.cat(
245
+ [uncond_audio_emotion.unsqueeze(0), audio_emotion.unsqueeze(0)],
246
+ dim=0,
247
+ )
248
+
249
+ uc_mask = (
250
+ torch.Tensor(
251
+ [1] * batch_size * num_images_per_prompt * 16
252
+ + [0] * batch_size * num_images_per_prompt * 16
253
+ )
254
+ .to(device)
255
+ .bool()
256
+ )
257
+ else:
258
+ uc_mask = None
259
+
260
+ noise_pred = self.diffusion_net(
261
+ latent_model_input,
262
+ ref_features,
263
+ t,
264
+ encoder_hidden_states=encoder_hidden_states,
265
+ audio_embedding=audio_tensor,
266
+ audio_emotion=audio_emotion,
267
+ uc_mask=uc_mask,
268
+ ).sample
269
+
270
+ # perform guidance
271
+ if do_classifier_free_guidance:
272
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
273
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
274
+
275
+ # compute the previous noisy sample x_t -> x_t-1
276
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
277
+
278
+ # call the callback, if provided
279
+ if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
280
+ progress_bar.update()
281
+ if callback is not None and i % callback_steps == 0:
282
+ step_idx = i // getattr(self.scheduler, "order", 1)
283
+ callback(step_idx, t, latents)
284
+
285
+ # Post-processing
286
+ images = self.decode_latents(latents) # (b, c, f, h, w)
287
+
288
+ # Convert to tensor
289
+ if output_type == "tensor":
290
+ images = torch.from_numpy(images)
291
+
292
+ if not return_dict:
293
+ return images
294
+
295
+ return VideoPipelineOutput(videos=images)
memo/utils/__init__.py ADDED
File without changes
memo/utils/audio_utils.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import os
4
+ import subprocess
5
+ from io import BytesIO
6
+
7
+ import librosa
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchaudio
12
+ from audio_separator.separator import Separator
13
+ from einops import rearrange
14
+ from funasr.download.download_from_hub import download_model
15
+ from funasr.models.emotion2vec.model import Emotion2vec
16
+ from transformers import Wav2Vec2FeatureExtractor
17
+
18
+ from memo.models.emotion_classifier import AudioEmotionClassifierModel
19
+ from memo.models.wav2vec import Wav2VecModel
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int = 16000):
26
+ p = subprocess.Popen(
27
+ [
28
+ "ffmpeg",
29
+ "-y",
30
+ "-v",
31
+ "error",
32
+ "-i",
33
+ input_audio_file,
34
+ "-ar",
35
+ str(sample_rate),
36
+ output_audio_file,
37
+ ]
38
+ )
39
+ ret = p.wait()
40
+ assert ret == 0, f"Resample audio failed! Input: {input_audio_file}, Output: {output_audio_file}"
41
+ return output_audio_file
42
+
43
+
44
+ @torch.no_grad()
45
+ def preprocess_audio(
46
+ wav_path: str,
47
+ fps: int,
48
+ wav2vec_model: str,
49
+ vocal_separator_model: str = None,
50
+ cache_dir: str = "",
51
+ device: str = "cuda",
52
+ sample_rate: int = 16000,
53
+ num_generated_frames_per_clip: int = -1,
54
+ ):
55
+ """
56
+ Preprocess the audio file and extract audio embeddings.
57
+
58
+ Args:
59
+ wav_path (str): Path to the input audio file.
60
+ fps (int): Frames per second for the audio processing.
61
+ wav2vec_model (str): Path to the pretrained Wav2Vec model.
62
+ vocal_separator_model (str, optional): Path to the vocal separator model. Defaults to None.
63
+ cache_dir (str, optional): Directory for cached files. Defaults to "".
64
+ device (str, optional): Device to use ('cuda' or 'cpu'). Defaults to "cuda".
65
+ sample_rate (int, optional): Sampling rate for audio processing. Defaults to 16000.
66
+ num_generated_frames_per_clip (int, optional): Number of generated frames per clip for padding. Defaults to -1.
67
+
68
+ Returns:
69
+ tuple: A tuple containing:
70
+ - audio_emb (torch.Tensor): The processed audio embeddings.
71
+ - audio_length (int): The length of the audio in frames.
72
+ """
73
+ # Initialize Wav2Vec model
74
+ audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model).to(device=device)
75
+ audio_encoder.feature_extractor._freeze_parameters()
76
+
77
+ # Initialize Wav2Vec feature extractor
78
+ wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model)
79
+
80
+ # Initialize vocal separator if provided
81
+ vocal_separator = None
82
+ if vocal_separator_model is not None:
83
+ os.makedirs(cache_dir, exist_ok=True)
84
+ vocal_separator = Separator(
85
+ output_dir=cache_dir,
86
+ output_single_stem="vocals",
87
+ model_file_dir=os.path.dirname(vocal_separator_model),
88
+ )
89
+ vocal_separator.load_model(os.path.basename(vocal_separator_model))
90
+ assert vocal_separator.model_instance is not None, "Failed to load audio separation model."
91
+
92
+ # Perform vocal separation if applicable
93
+ if vocal_separator is not None:
94
+ outputs = vocal_separator.separate(wav_path)
95
+ assert len(outputs) > 0, "Audio separation failed."
96
+ vocal_audio_file = outputs[0]
97
+ vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
98
+ vocal_audio_file = os.path.join(vocal_separator.output_dir, vocal_audio_file)
99
+ vocal_audio_file = resample_audio(
100
+ vocal_audio_file,
101
+ os.path.join(vocal_separator.output_dir, f"{vocal_audio_name}-16k.wav"),
102
+ sample_rate,
103
+ )
104
+ else:
105
+ vocal_audio_file = wav_path
106
+
107
+ # Load audio and extract Wav2Vec features
108
+ speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=sample_rate)
109
+ audio_feature = np.squeeze(wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values)
110
+ audio_length = math.ceil(len(audio_feature) / sample_rate * fps)
111
+ audio_feature = torch.from_numpy(audio_feature).float().to(device=device)
112
+
113
+ # Pad audio features to match the required length
114
+ if num_generated_frames_per_clip > 0 and audio_length % num_generated_frames_per_clip != 0:
115
+ audio_feature = torch.nn.functional.pad(
116
+ audio_feature,
117
+ (
118
+ 0,
119
+ (num_generated_frames_per_clip - audio_length % num_generated_frames_per_clip) * (sample_rate // fps),
120
+ ),
121
+ "constant",
122
+ 0.0,
123
+ )
124
+ audio_length += num_generated_frames_per_clip - audio_length % num_generated_frames_per_clip
125
+ audio_feature = audio_feature.unsqueeze(0)
126
+
127
+ # Extract audio embeddings
128
+ with torch.no_grad():
129
+ embeddings = audio_encoder(audio_feature, seq_len=audio_length, output_hidden_states=True)
130
+ assert len(embeddings) > 0, "Failed to extract audio embeddings."
131
+ audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
132
+ audio_emb = rearrange(audio_emb, "b s d -> s b d")
133
+
134
+ # Concatenate embeddings with surrounding frames
135
+ audio_emb = audio_emb.cpu().detach()
136
+ concatenated_tensors = []
137
+ for i in range(audio_emb.shape[0]):
138
+ vectors_to_concat = [audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)] for j in range(-2, 3)]
139
+ concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
140
+ audio_emb = torch.stack(concatenated_tensors, dim=0)
141
+
142
+ if vocal_separator is not None:
143
+ del vocal_separator
144
+ del audio_encoder
145
+
146
+ return audio_emb, audio_length
147
+
148
+
149
+ @torch.no_grad()
150
+ def extract_audio_emotion_labels(
151
+ model: str,
152
+ wav_path: str,
153
+ emotion2vec_model: str,
154
+ audio_length: int,
155
+ sample_rate: int = 16000,
156
+ device: str = "cuda",
157
+ ):
158
+ """
159
+ Extract audio emotion labels from an audio file.
160
+
161
+ Args:
162
+ model (str): Path to the MEMO model.
163
+ wav_path (str): Path to the input audio file.
164
+ emotion2vec_model (str): Path to the Emotion2vec model.
165
+ audio_length (int): Target length for interpolated emotion labels.
166
+ sample_rate (int, optional): Sample rate of the input audio. Default is 16000.
167
+ device (str, optional): Device to use ('cuda' or 'cpu'). Default is "cuda".
168
+
169
+ Returns:
170
+ torch.Tensor: Processed emotion labels with shape matching the target audio length.
171
+ """
172
+ # Load models
173
+ logger.info("Downloading emotion2vec models from modelscope")
174
+ kwargs = download_model(model=emotion2vec_model)
175
+ kwargs["tokenizer"] = None
176
+ kwargs["input_size"] = None
177
+ kwargs["frontend"] = None
178
+ emotion_model = Emotion2vec(**kwargs, vocab_size=-1).to(device)
179
+ init_param = kwargs.get("init_param", None)
180
+ load_emotion2vec_model(
181
+ model=emotion_model,
182
+ path=init_param,
183
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
184
+ oss_bucket=kwargs.get("oss_bucket", None),
185
+ scope_map=kwargs.get("scope_map", []),
186
+ )
187
+ emotion_model.eval()
188
+
189
+ classifier = AudioEmotionClassifierModel.from_pretrained(
190
+ model,
191
+ subfolder="misc/audio_emotion_classifier",
192
+ use_safetensors=True,
193
+ ).to(device=device)
194
+ classifier.eval()
195
+
196
+ # Load audio
197
+ wav, sr = torchaudio.load(wav_path)
198
+ if sr != sample_rate:
199
+ wav = torchaudio.functional.resample(wav, sr, sample_rate)
200
+ wav = wav.view(-1) if wav.dim() == 1 else wav[0].view(-1)
201
+
202
+ emotion_labels = torch.full_like(wav, -1, dtype=torch.int32)
203
+
204
+ def extract_emotion(x):
205
+ """
206
+ Extract emotion for a given audio segment.
207
+ """
208
+ x = x.to(device=device)
209
+ x = F.layer_norm(x, x.shape).view(1, -1)
210
+ feats = emotion_model.extract_features(x)
211
+ x = feats["x"].mean(dim=1) # average across frames
212
+ x = classifier(x)
213
+ x = torch.softmax(x, dim=-1)
214
+ return torch.argmax(x, dim=-1)
215
+
216
+ # Process start, middle, and end segments
217
+ start_label = extract_emotion(wav[: sample_rate * 2]).item()
218
+ emotion_labels[:sample_rate] = start_label
219
+
220
+ for i in range(sample_rate, len(wav) - sample_rate, sample_rate):
221
+ mid_wav = wav[i - sample_rate : i - sample_rate + sample_rate * 3]
222
+ mid_label = extract_emotion(mid_wav).item()
223
+ emotion_labels[i : i + sample_rate] = mid_label
224
+
225
+ end_label = extract_emotion(wav[-sample_rate * 2 :]).item()
226
+ emotion_labels[-sample_rate:] = end_label
227
+
228
+ # Interpolate to match the target audio length
229
+ emotion_labels = emotion_labels.unsqueeze(0).unsqueeze(0).float()
230
+ emotion_labels = F.interpolate(emotion_labels, size=audio_length, mode="nearest").squeeze(0).squeeze(0).int()
231
+ num_emotion_classes = classifier.num_emotion_classes
232
+
233
+ del emotion_model
234
+ del classifier
235
+
236
+ return emotion_labels, num_emotion_classes
237
+
238
+
239
+ def load_emotion2vec_model(
240
+ path: str,
241
+ model: torch.nn.Module,
242
+ ignore_init_mismatch: bool = True,
243
+ map_location: str = "cpu",
244
+ oss_bucket=None,
245
+ scope_map=[],
246
+ ):
247
+ obj = model
248
+ dst_state = obj.state_dict()
249
+ logger.debug(f"Emotion2vec checkpoint: {path}")
250
+ if oss_bucket is None:
251
+ src_state = torch.load(path, map_location=map_location)
252
+ else:
253
+ buffer = BytesIO(oss_bucket.get_object(path).read())
254
+ src_state = torch.load(buffer, map_location=map_location)
255
+
256
+ src_state = src_state["state_dict"] if "state_dict" in src_state else src_state
257
+ src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state
258
+ src_state = src_state["model"] if "model" in src_state else src_state
259
+
260
+ if isinstance(scope_map, str):
261
+ scope_map = scope_map.split(",")
262
+ scope_map += ["module.", "None"]
263
+
264
+ for k in dst_state.keys():
265
+ k_src = k
266
+ if scope_map is not None:
267
+ src_prefix = ""
268
+ dst_prefix = ""
269
+ for i in range(0, len(scope_map), 2):
270
+ src_prefix = scope_map[i] if scope_map[i].lower() != "none" else ""
271
+ dst_prefix = scope_map[i + 1] if scope_map[i + 1].lower() != "none" else ""
272
+
273
+ if dst_prefix == "" and (src_prefix + k) in src_state.keys():
274
+ k_src = src_prefix + k
275
+ if not k_src.startswith("module."):
276
+ logger.debug(f"init param, map: {k} from {k_src} in ckpt")
277
+ elif k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix, 1) in src_state.keys():
278
+ k_src = k.replace(dst_prefix, src_prefix, 1)
279
+ if not k_src.startswith("module."):
280
+ logger.debug(f"init param, map: {k} from {k_src} in ckpt")
281
+
282
+ if k_src in src_state.keys():
283
+ if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
284
+ logger.debug(
285
+ f"ignore_init_mismatch:{ignore_init_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}"
286
+ )
287
+ else:
288
+ dst_state[k] = src_state[k_src]
289
+
290
+ else:
291
+ logger.debug(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
292
+
293
+ obj.load_state_dict(dst_state, strict=True)
memo/utils/vision_utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from insightface.app import FaceAnalysis
7
+ from moviepy.editor import AudioFileClip, VideoClip
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def tensor_to_video(tensor, output_video_path, input_audio_path, fps=30):
16
+ """
17
+ Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file.
18
+
19
+ Args:
20
+ tensor (Tensor): The Tensor to be converted, shaped [c, f, h, w].
21
+ output_video_path (str): The file path where the output video will be saved.
22
+ input_audio_path (str): The path to the audio file (WAV file) that contains the audio track to be added.
23
+ fps (int): The frame rate of the output video. Default is 30 fps.
24
+ """
25
+ tensor = tensor.permute(1, 2, 3, 0).cpu().numpy() # convert to [f, h, w, c]
26
+ tensor = np.clip(tensor * 255, 0, 255).astype(np.uint8) # to [0, 255]
27
+
28
+ def make_frame(t):
29
+ frame_index = min(int(t * fps), tensor.shape[0] - 1)
30
+ return tensor[frame_index]
31
+
32
+ video_duration = tensor.shape[0] / fps
33
+ audio_clip = AudioFileClip(input_audio_path)
34
+ audio_duration = audio_clip.duration
35
+ final_duration = min(video_duration, audio_duration)
36
+ audio_clip = audio_clip.subclip(0, final_duration)
37
+ new_video_clip = VideoClip(make_frame, duration=final_duration)
38
+ new_video_clip = new_video_clip.set_audio(audio_clip)
39
+ new_video_clip.write_videofile(output_video_path, fps=fps, audio_codec="aac")
40
+
41
+
42
+ @torch.no_grad()
43
+ def preprocess_image(face_analysis_model: str, image_path: str, image_size: int = 512):
44
+ """
45
+ Preprocess the image and extract face embedding.
46
+
47
+ Args:
48
+ face_analysis_model (str): Path to the FaceAnalysis model directory.
49
+ image_path (str): Path to the image file.
50
+ image_size (int, optional): Target size for resizing the image. Default is 512.
51
+
52
+ Returns:
53
+ tuple: A tuple containing:
54
+ - pixel_values (torch.Tensor): Tensor of the preprocessed image.
55
+ - face_emb (torch.Tensor): Tensor of the face embedding.
56
+ """
57
+ # Define the image transformation
58
+ transform = transforms.Compose(
59
+ [
60
+ transforms.Resize((image_size, image_size)),
61
+ transforms.ToTensor(),
62
+ transforms.Normalize([0.5], [0.5]),
63
+ ]
64
+ )
65
+
66
+ # Initialize the FaceAnalysis model
67
+ face_analysis = FaceAnalysis(
68
+ name="",
69
+ root=face_analysis_model,
70
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
71
+ )
72
+ face_analysis.prepare(ctx_id=0, det_size=(640, 640))
73
+
74
+ # Load and preprocess the image
75
+ image = Image.open(image_path).convert("RGB")
76
+ pixel_values = transform(image)
77
+ pixel_values = pixel_values.unsqueeze(0)
78
+
79
+ # Detect faces and extract the face embedding
80
+ image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
81
+ faces = face_analysis.get(image_bgr)
82
+ if not faces:
83
+ logger.warning("No faces detected in the image. Using a zero vector as the face embedding.")
84
+ face_emb = np.zeros(512)
85
+ else:
86
+ # Sort faces by size and select the largest one
87
+ faces_sorted = sorted(
88
+ faces,
89
+ key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]),
90
+ reverse=True,
91
+ )
92
+ face_emb = faces_sorted[0]["embedding"]
93
+
94
+ # Convert face embedding to a PyTorch tensor
95
+ face_emb = face_emb.reshape(1, -1)
96
+ face_emb = torch.tensor(face_emb)
97
+
98
+ del face_analysis
99
+
100
+ return pixel_values, face_emb
pyproject.toml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "memo"
3
+ version = "0.1.0"
4
+ description = "MEMO: Memory-Guided Diffusion for Expressive Talking Video Generation"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ license = {file = "LICENSE"}
8
+ keywords = [
9
+ 'artificial intelligence',
10
+ 'computer vision',
11
+ 'diffusion models',
12
+ 'video generation',
13
+ 'talking head',
14
+ ]
15
+
16
+ dependencies = [
17
+ 'accelerate==1.1.1',
18
+ 'albumentations==1.4.21',
19
+ 'audio-separator==0.24.1',
20
+ 'black==23.12.1',
21
+ 'diffusers==0.31.0',
22
+ 'einops==0.8.0',
23
+ 'ffmpeg-python==0.2.0',
24
+ 'funasr==1.0.27',
25
+ 'huggingface-hub==0.26.2',
26
+ 'imageio==2.36.0',
27
+ 'imageio-ffmpeg==0.5.1',
28
+ 'insightface==0.7.3',
29
+ 'hydra-core==1.3.2',
30
+ 'jax==0.4.35',
31
+ 'mediapipe==0.10.18',
32
+ 'modelscope==1.20.1',
33
+ 'moviepy==1.0.3',
34
+ 'numpy==1.26.4',
35
+ 'omegaconf==2.3.0',
36
+ 'onnxruntime-gpu>=1.20.1',
37
+ 'opencv-python-headless==4.10.0.84',
38
+ 'pillow>=10.4.0',
39
+ 'scikit-learn>=1.5.2',
40
+ 'scipy>=1.14.1',
41
+ 'torch==2.5.1',
42
+ 'torchaudio==2.5.1',
43
+ 'torchvision==0.20.1',
44
+ 'transformers==4.46.3',
45
+ 'tqdm>=4.67.1',
46
+ 'xformers==0.0.28.post3',
47
+ ]
48
+
49
+ [build-system]
50
+ requires = ["setuptools", "wheel"]
51
+ build-backend = "setuptools.build_meta"
52
+
53
+ [tool.setuptools]
54
+ packages = ["memo"]
55
+
56
+ [tool.ruff]
57
+ line-length = 119
58
+
59
+ [tool.ruff.lint]
60
+ # Never enforce `E501` (line length violations).
61
+ ignore = ["C901", "E501", "E741", "F402", "F823" ]
62
+ select = ["C", "E", "F", "I", "W"]
63
+
64
+ # Ignore import violations in all `__init__.py` files.
65
+ [tool.ruff.lint.per-file-ignores]
66
+ "__init__.py" = ["E402", "F401", "F403", "F811"]
67
+
68
+ [tool.ruff.lint.isort]
69
+ lines-after-imports = 2
70
+ known-first-party = ["vpt_x"]
71
+
72
+ [tool.ruff.format]
73
+ # Like Black, use double quotes for strings.
74
+ quote-style = "double"
75
+
76
+ # Like Black, indent with spaces, rather than tabs.
77
+ indent-style = "space"
78
+
79
+ # Like Black, respect magic trailing commas.
80
+ skip-magic-trailing-comma = false
81
+
82
+ # Like Black, automatically detect the appropriate line ending.
83
+ line-ending = "auto"