svjack commited on
Commit
3b3fc8d
·
verified ·
1 Parent(s): a1b5dd1

Upload 129 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. LICENSE +201 -0
  3. README.md +87 -0
  4. assets/00.gif +0 -0
  5. assets/01.gif +0 -0
  6. assets/02.gif +0 -0
  7. assets/03.gif +0 -0
  8. assets/04.gif +0 -0
  9. assets/05.gif +0 -0
  10. assets/06.gif +0 -0
  11. assets/07.gif +0 -0
  12. assets/08.gif +0 -0
  13. assets/09.gif +0 -0
  14. assets/10.gif +0 -0
  15. assets/11.gif +0 -0
  16. assets/12.gif +0 -0
  17. assets/13.gif +3 -0
  18. assets/72105_388.mp4_00-00.png +0 -0
  19. assets/72105_388.mp4_00-01.png +0 -0
  20. assets/72109_125.mp4_00-00.png +0 -0
  21. assets/72109_125.mp4_00-01.png +0 -0
  22. assets/72110_255.mp4_00-00.png +0 -0
  23. assets/72110_255.mp4_00-01.png +0 -0
  24. assets/74302_1349_frame1.png +0 -0
  25. assets/74302_1349_frame3.png +0 -0
  26. assets/Japan_v2_1_070321_s3_frame1.png +0 -0
  27. assets/Japan_v2_1_070321_s3_frame3.png +0 -0
  28. assets/Japan_v2_2_062266_s2_frame1.png +0 -0
  29. assets/Japan_v2_2_062266_s2_frame3.png +0 -0
  30. assets/frame0001_05.png +0 -0
  31. assets/frame0001_09.png +0 -0
  32. assets/frame0001_10.png +0 -0
  33. assets/frame0001_11.png +0 -0
  34. assets/frame0016_10.png +0 -0
  35. assets/frame0016_11.png +0 -0
  36. assets/sketch_sample/frame_1.png +0 -0
  37. assets/sketch_sample/frame_2.png +0 -0
  38. assets/sketch_sample/sample.mov +0 -0
  39. cldm/cldm.py +478 -0
  40. cldm/ddim_hacked.py +317 -0
  41. cldm/hack.py +111 -0
  42. cldm/logger.py +76 -0
  43. cldm/model.py +28 -0
  44. configs/cldm_v21.yaml +17 -0
  45. configs/inference_512_v1.0.yaml +103 -0
  46. configs/training_1024_v1.0/config.yaml +166 -0
  47. configs/training_1024_v1.0/run.sh +37 -0
  48. configs/training_512_v1.0/config.yaml +166 -0
  49. configs/training_512_v1.0/run.sh +37 -0
  50. control_models/_put_your_control_models_.txt +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/13.gif filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright Tencent
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## ___***ToonCrafter_with_SketchGuidance***___
2
+ This repository is an implementation that recreates the SketchGuidance feature of "ToonCrafter".
3
+
4
+ - https://github.com/ToonCrafter/ToonCrafter
5
+ - https://arxiv.org/pdf/2405.17933
6
+
7
+ https://github.com/user-attachments/assets/f72f287d-f848-4982-8f91-43c49d037007
8
+
9
+ # ToonCrafter with Sketch Guidance
10
+
11
+ ## Overview
12
+ This repository contains the ToonCrafter model with additional sketch guidance functionality. The model is hosted on GitHub and can be accessed via the following link:
13
+
14
+ - **Repository:** [https://github.com/svjack/ToonCrafter_with_SketchGuidance](https://github.com/svjack/ToonCrafter_with_SketchGuidance)
15
+
16
+ ## Installation
17
+
18
+ ### Clone the Repository
19
+ To clone the repository and set up the environment, follow these steps:
20
+
21
+ ```bash
22
+ git clone https://github.com/svjack/ToonCrafter_with_SketchGuidance && cd ToonCrafter_with_SketchGuidance
23
+ ```
24
+
25
+ ### List Control Models
26
+ After cloning the repository, you can list the control models available:
27
+
28
+ ```bash
29
+ ls control_models
30
+ ```
31
+
32
+ ### Download Sketch Encoder
33
+ Download the `sketch_encoder.ckpt` file from the Hugging Face repository:
34
+
35
+ ```bash
36
+ wget https://huggingface.co/Doubiiu/ToonCrafter/resolve/main/sketch_encoder.ckpt
37
+ ```
38
+
39
+ ### Copy Sketch Encoder to Control Models
40
+ Move the downloaded `sketch_encoder.ckpt` file to the `control_models` directory:
41
+
42
+ ```bash
43
+ cp sketch_encoder.ckpt control_models
44
+ ```
45
+
46
+ ## Usage
47
+
48
+ ### Run the Gradio App
49
+ Once the repository is cloned and the necessary files are in place, you can run the Gradio app using the following command:
50
+
51
+ ```bash
52
+ python gradio_app.py
53
+ ```
54
+
55
+ This will start the Gradio interface, allowing you to interact with the ToonCrafter model with sketch guidance.
56
+
57
+ ## 🧰 Models
58
+
59
+ |Model|Resolution|GPU Mem. & Inference Time (A100, ddim 50steps)|Checkpoint|
60
+ |:---------|:---------|:--------|:--------|
61
+ |ToonCrafter_512|320x512| TBD (`perframe_ae=True`)|[Hugging Face](https://huggingface.co/Doubiiu/ToonCrafter/blob/main/model.ckpt)|
62
+ |SketchEncoder|TBD| TBD |[Hugging Face](https://huggingface.co/Doubiiu/ToonCrafter/blob/main/sketch_encoder.ckpt)|
63
+
64
+
65
+ Currently, ToonCrafter can support generating videos of up to 16 frames with a resolution of 512x320. The inference time can be reduced by using fewer DDIM steps.
66
+
67
+
68
+
69
+ ## ⚙️ Setup
70
+
71
+ ### Install Environment via Anaconda (Recommended)
72
+ ```bash
73
+ conda create -n tooncrafter python=3.8.5
74
+ conda activate tooncrafter
75
+ pip install -r requirements.txt
76
+ ```
77
+
78
+
79
+ ## 💫 Inference
80
+
81
+ ### 1. Local Gradio demo
82
+ 1. Download pretrained ToonCrafter_512 and put the model.ckpt in checkpoints/tooncrafter_512_interp_v1/model.ckpt.
83
+ 2. Download pretrained SketchEncoder and put the model.ckpt in control_models/sketch_encoder.ckpt.
84
+
85
+ ```bash
86
+ python gradio_app.py
87
+ ```
assets/00.gif ADDED
assets/01.gif ADDED
assets/02.gif ADDED
assets/03.gif ADDED
assets/04.gif ADDED
assets/05.gif ADDED
assets/06.gif ADDED
assets/07.gif ADDED
assets/08.gif ADDED
assets/09.gif ADDED
assets/10.gif ADDED
assets/11.gif ADDED
assets/12.gif ADDED
assets/13.gif ADDED

Git LFS Details

  • SHA256: 179af7d265d8790c0ca31a5898f870961b0a738b02d9fd0c991a3a75651cbb56
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
assets/72105_388.mp4_00-00.png ADDED
assets/72105_388.mp4_00-01.png ADDED
assets/72109_125.mp4_00-00.png ADDED
assets/72109_125.mp4_00-01.png ADDED
assets/72110_255.mp4_00-00.png ADDED
assets/72110_255.mp4_00-01.png ADDED
assets/74302_1349_frame1.png ADDED
assets/74302_1349_frame3.png ADDED
assets/Japan_v2_1_070321_s3_frame1.png ADDED
assets/Japan_v2_1_070321_s3_frame3.png ADDED
assets/Japan_v2_2_062266_s2_frame1.png ADDED
assets/Japan_v2_2_062266_s2_frame3.png ADDED
assets/frame0001_05.png ADDED
assets/frame0001_09.png ADDED
assets/frame0001_10.png ADDED
assets/frame0001_11.png ADDED
assets/frame0016_10.png ADDED
assets/frame0016_11.png ADDED
assets/sketch_sample/frame_1.png ADDED
assets/sketch_sample/frame_2.png ADDED
assets/sketch_sample/sample.mov ADDED
Binary file (228 kB). View file
 
cldm/cldm.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import torch
3
+ import torch as th
4
+ import torch.nn as nn
5
+
6
+ from ldm.modules.diffusionmodules.util import (
7
+ conv_nd,
8
+ linear,
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from einops import rearrange, repeat
14
+ from torchvision.utils import make_grid
15
+ from ldm.modules.attention import SpatialTransformer
16
+ from ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
17
+ from lvdm.modules.networks.openaimodel3d import UNetModel
18
+ from ldm.models.diffusion.ddpm import LatentDiffusion
19
+ from ldm.util import log_txt_as_img, exists, instantiate_from_config
20
+ from ldm.models.diffusion.ddim import DDIMSampler
21
+
22
+
23
+ class ControlledUnetModel(UNetModel):
24
+ def forward(self, x, timesteps, context=None, features_adapter=None, fs=None, control = None, **kwargs):
25
+ b,_,t,_,_ = x.shape
26
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).type(x.dtype)
27
+ emb = self.time_embed(t_emb)
28
+ ## repeat t times for context [(b t) 77 768] & time embedding
29
+ ## check if we use per-frame image conditioning
30
+ _, l_context, _ = context.shape
31
+ if l_context == 77 + t*16: ## !!! HARD CODE here
32
+ context_text, context_img = context[:,:77,:], context[:,77:,:]
33
+ context_text = context_text.repeat_interleave(repeats=t, dim=0)
34
+ context_img = rearrange(context_img, 'b (t l) c -> (b t) l c', t=t)
35
+ context = torch.cat([context_text, context_img], dim=1)
36
+ else:
37
+ context = context.repeat_interleave(repeats=t, dim=0)
38
+ emb = emb.repeat_interleave(repeats=t, dim=0)
39
+
40
+ ## always in shape (b t) c h w, except for temporal layer
41
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
42
+
43
+ ## combine emb
44
+ if self.fs_condition:
45
+ if fs is None:
46
+ fs = torch.tensor(
47
+ [self.default_fs] * b, dtype=torch.long, device=x.device)
48
+ fs_emb = timestep_embedding(fs, self.model_channels, repeat_only=False).type(x.dtype)
49
+
50
+ fs_embed = self.fps_embedding(fs_emb)
51
+ fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
52
+ emb = emb + fs_embed
53
+
54
+ h = x.type(self.dtype)
55
+ adapter_idx = 0
56
+ hs = []
57
+ with torch.no_grad():
58
+ for id, module in enumerate(self.input_blocks):
59
+ h = module(h, emb, context=context, batch_size=b)
60
+ if id ==0 and self.addition_attention:
61
+ h = self.init_attn(h, emb, context=context, batch_size=b)
62
+ ## plug-in adapter features
63
+ if ((id+1)%3 == 0) and features_adapter is not None:
64
+ h = h + features_adapter[adapter_idx]
65
+ adapter_idx += 1
66
+ hs.append(h)
67
+ if features_adapter is not None:
68
+ assert len(features_adapter)==adapter_idx, 'Wrong features_adapter'
69
+
70
+ h = self.middle_block(h, emb, context=context, batch_size=b)
71
+
72
+ if control is not None:
73
+ h += control.pop()
74
+
75
+ for module in self.output_blocks:
76
+ if control is None:
77
+ h = torch.cat([h, hs.pop()], dim=1)
78
+ else:
79
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
80
+ h = module(h, emb, context=context, batch_size=b)
81
+
82
+ h = h.type(x.dtype)
83
+ y = self.out(h)
84
+
85
+ # reshape back to (b c t h w)
86
+ y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
87
+ return y
88
+
89
+
90
+ class ControlNet(nn.Module):
91
+ def __init__(
92
+ self,
93
+ image_size,
94
+ in_channels,
95
+ model_channels,
96
+ hint_channels,
97
+ num_res_blocks,
98
+ attention_resolutions,
99
+ dropout=0,
100
+ channel_mult=(1, 2, 4, 8),
101
+ conv_resample=True,
102
+ dims=2,
103
+ use_checkpoint=False,
104
+ use_fp16=False,
105
+ num_heads=-1,
106
+ num_head_channels=-1,
107
+ num_heads_upsample=-1,
108
+ use_scale_shift_norm=False,
109
+ resblock_updown=False,
110
+ use_new_attention_order=False,
111
+ use_spatial_transformer=False, # custom transformer support
112
+ transformer_depth=1, # custom transformer support
113
+ context_dim=None, # custom transformer support
114
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
115
+ legacy=True,
116
+ disable_self_attentions=None,
117
+ num_attention_blocks=None,
118
+ disable_middle_self_attn=False,
119
+ use_linear_in_transformer=False,
120
+ ):
121
+ super().__init__()
122
+ if use_spatial_transformer:
123
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
124
+
125
+ if context_dim is not None:
126
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
127
+ from omegaconf.listconfig import ListConfig
128
+ if type(context_dim) == ListConfig:
129
+ context_dim = list(context_dim)
130
+
131
+ if num_heads_upsample == -1:
132
+ num_heads_upsample = num_heads
133
+
134
+ if num_heads == -1:
135
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
136
+
137
+ if num_head_channels == -1:
138
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
139
+
140
+ self.dims = dims
141
+ self.image_size = image_size
142
+ self.in_channels = in_channels
143
+ self.model_channels = model_channels
144
+ if isinstance(num_res_blocks, int):
145
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
146
+ else:
147
+ if len(num_res_blocks) != len(channel_mult):
148
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
149
+ "as a list/tuple (per-level) with the same length as channel_mult")
150
+ self.num_res_blocks = num_res_blocks
151
+ if disable_self_attentions is not None:
152
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
153
+ assert len(disable_self_attentions) == len(channel_mult)
154
+ if num_attention_blocks is not None:
155
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
156
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
157
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
158
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
159
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
160
+ f"attention will still not be set.")
161
+
162
+ self.attention_resolutions = attention_resolutions
163
+ self.dropout = dropout
164
+ self.channel_mult = channel_mult
165
+ self.conv_resample = conv_resample
166
+ self.use_checkpoint = use_checkpoint
167
+ self.dtype = th.float16 if use_fp16 else th.float32
168
+ self.num_heads = num_heads
169
+ self.num_head_channels = num_head_channels
170
+ self.num_heads_upsample = num_heads_upsample
171
+ self.predict_codebook_ids = n_embed is not None
172
+
173
+ time_embed_dim = model_channels * 4
174
+ self.time_embed = nn.Sequential(
175
+ linear(model_channels, time_embed_dim),
176
+ nn.SiLU(),
177
+ linear(time_embed_dim, time_embed_dim),
178
+ )
179
+
180
+ self.input_blocks = nn.ModuleList(
181
+ [
182
+ TimestepEmbedSequential(
183
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
184
+ )
185
+ ]
186
+ )
187
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
188
+
189
+ self.input_hint_block = TimestepEmbedSequential(
190
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
191
+ nn.SiLU(),
192
+ conv_nd(dims, 16, 16, 3, padding=1),
193
+ nn.SiLU(),
194
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
195
+ nn.SiLU(),
196
+ conv_nd(dims, 32, 32, 3, padding=1),
197
+ nn.SiLU(),
198
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
199
+ nn.SiLU(),
200
+ conv_nd(dims, 96, 96, 3, padding=1),
201
+ nn.SiLU(),
202
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
203
+ nn.SiLU(),
204
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
205
+ )
206
+
207
+ self._feature_size = model_channels
208
+ input_block_chans = [model_channels]
209
+ ch = model_channels
210
+ ds = 1
211
+ for level, mult in enumerate(channel_mult):
212
+ for nr in range(self.num_res_blocks[level]):
213
+ layers = [
214
+ ResBlock(
215
+ ch,
216
+ time_embed_dim,
217
+ dropout,
218
+ out_channels=mult * model_channels,
219
+ dims=dims,
220
+ use_checkpoint=use_checkpoint,
221
+ use_scale_shift_norm=use_scale_shift_norm,
222
+ )
223
+ ]
224
+ ch = mult * model_channels
225
+ if ds in attention_resolutions:
226
+ if num_head_channels == -1:
227
+ dim_head = ch // num_heads
228
+ else:
229
+ num_heads = ch // num_head_channels
230
+ dim_head = num_head_channels
231
+ if legacy:
232
+ # num_heads = 1
233
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
234
+ if exists(disable_self_attentions):
235
+ disabled_sa = disable_self_attentions[level]
236
+ else:
237
+ disabled_sa = False
238
+
239
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
240
+ layers.append(
241
+ AttentionBlock(
242
+ ch,
243
+ use_checkpoint=use_checkpoint,
244
+ num_heads=num_heads,
245
+ num_head_channels=dim_head,
246
+ use_new_attention_order=use_new_attention_order,
247
+ ) if not use_spatial_transformer else SpatialTransformer(
248
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
249
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
250
+ use_checkpoint=use_checkpoint
251
+ )
252
+ )
253
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
254
+ self.zero_convs.append(self.make_zero_conv(ch))
255
+ self._feature_size += ch
256
+ input_block_chans.append(ch)
257
+ if level != len(channel_mult) - 1:
258
+ out_ch = ch
259
+ self.input_blocks.append(
260
+ TimestepEmbedSequential(
261
+ ResBlock(
262
+ ch,
263
+ time_embed_dim,
264
+ dropout,
265
+ out_channels=out_ch,
266
+ dims=dims,
267
+ use_checkpoint=use_checkpoint,
268
+ use_scale_shift_norm=use_scale_shift_norm,
269
+ down=True,
270
+ )
271
+ if resblock_updown
272
+ else Downsample(
273
+ ch, conv_resample, dims=dims, out_channels=out_ch
274
+ )
275
+ )
276
+ )
277
+ ch = out_ch
278
+ input_block_chans.append(ch)
279
+ self.zero_convs.append(self.make_zero_conv(ch))
280
+ ds *= 2
281
+ self._feature_size += ch
282
+
283
+ if num_head_channels == -1:
284
+ dim_head = ch // num_heads
285
+ else:
286
+ num_heads = ch // num_head_channels
287
+ dim_head = num_head_channels
288
+ if legacy:
289
+ # num_heads = 1
290
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
291
+ self.middle_block = TimestepEmbedSequential(
292
+ ResBlock(
293
+ ch,
294
+ time_embed_dim,
295
+ dropout,
296
+ dims=dims,
297
+ use_checkpoint=use_checkpoint,
298
+ use_scale_shift_norm=use_scale_shift_norm,
299
+ ),
300
+ AttentionBlock(
301
+ ch,
302
+ use_checkpoint=use_checkpoint,
303
+ num_heads=num_heads,
304
+ num_head_channels=dim_head,
305
+ use_new_attention_order=use_new_attention_order,
306
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
307
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
308
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
309
+ use_checkpoint=use_checkpoint
310
+ ),
311
+ ResBlock(
312
+ ch,
313
+ time_embed_dim,
314
+ dropout,
315
+ dims=dims,
316
+ use_checkpoint=use_checkpoint,
317
+ use_scale_shift_norm=use_scale_shift_norm,
318
+ ),
319
+ )
320
+ self.middle_block_out = self.make_zero_conv(ch)
321
+ self._feature_size += ch
322
+
323
+ def make_zero_conv(self, channels):
324
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
325
+
326
+ def forward(self, x, hint, timesteps, context, **kwargs):
327
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
328
+ emb = self.time_embed(t_emb)
329
+
330
+ guided_hint = self.input_hint_block(hint, emb, context)
331
+
332
+ outs = []
333
+
334
+ h = x.type(self.dtype)
335
+
336
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
337
+ if guided_hint is not None:
338
+ h = module(h, emb, context)
339
+ h += guided_hint
340
+ guided_hint = None
341
+ else:
342
+ h = module(h, emb, context)
343
+ outs.append(zero_conv(h, emb, context, True))
344
+
345
+ h = self.middle_block(h, emb, context)
346
+ outs.append(self.middle_block_out(h, emb, context))
347
+
348
+ return outs
349
+
350
+
351
+ class ControlLDM(LatentDiffusion):
352
+
353
+ def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
354
+ super().__init__(*args, **kwargs)
355
+ self.control_model = instantiate_from_config(control_stage_config)
356
+ self.control_key = control_key
357
+ self.only_mid_control = only_mid_control
358
+ self.control_scales = [1.0] * 13
359
+
360
+ @torch.no_grad()
361
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
362
+ x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
363
+ control = batch[self.control_key]
364
+ if bs is not None:
365
+ control = control[:bs]
366
+ control = control.to(self.device)
367
+ control = einops.rearrange(control, 'b h w c -> b c h w')
368
+ control = control.to(memory_format=torch.contiguous_format).float()
369
+ return x, dict(c_crossattn=[c], c_concat=[control])
370
+
371
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
372
+ assert isinstance(cond, dict)
373
+ diffusion_model = self.model.diffusion_model
374
+
375
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
376
+
377
+ if cond['c_concat'] is None:
378
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
379
+ else:
380
+ control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
381
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
382
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
383
+
384
+ return eps
385
+
386
+ @torch.no_grad()
387
+ def get_unconditional_conditioning(self, N):
388
+ return self.get_learned_conditioning([""] * N)
389
+
390
+ @torch.no_grad()
391
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
392
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
393
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
394
+ use_ema_scope=True,
395
+ **kwargs):
396
+ use_ddim = ddim_steps is not None
397
+
398
+ log = dict()
399
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
400
+ c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
401
+ N = min(z.shape[0], N)
402
+ n_row = min(z.shape[0], n_row)
403
+ log["reconstruction"] = self.decode_first_stage(z)
404
+ log["control"] = c_cat * 2.0 - 1.0
405
+ log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
406
+
407
+ if plot_diffusion_rows:
408
+ # get diffusion row
409
+ diffusion_row = list()
410
+ z_start = z[:n_row]
411
+ for t in range(self.num_timesteps):
412
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
413
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
414
+ t = t.to(self.device).long()
415
+ noise = torch.randn_like(z_start)
416
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
417
+ diffusion_row.append(self.decode_first_stage(z_noisy))
418
+
419
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
420
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
421
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
422
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
423
+ log["diffusion_row"] = diffusion_grid
424
+
425
+ if sample:
426
+ # get denoise row
427
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
428
+ batch_size=N, ddim=use_ddim,
429
+ ddim_steps=ddim_steps, eta=ddim_eta)
430
+ x_samples = self.decode_first_stage(samples)
431
+ log["samples"] = x_samples
432
+ if plot_denoise_rows:
433
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
434
+ log["denoise_row"] = denoise_grid
435
+
436
+ if unconditional_guidance_scale > 1.0:
437
+ uc_cross = self.get_unconditional_conditioning(N)
438
+ uc_cat = c_cat # torch.zeros_like(c_cat)
439
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
440
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
441
+ batch_size=N, ddim=use_ddim,
442
+ ddim_steps=ddim_steps, eta=ddim_eta,
443
+ unconditional_guidance_scale=unconditional_guidance_scale,
444
+ unconditional_conditioning=uc_full,
445
+ )
446
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
447
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
448
+
449
+ return log
450
+
451
+ @torch.no_grad()
452
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
453
+ ddim_sampler = DDIMSampler(self)
454
+ b, c, h, w = cond["c_concat"][0].shape
455
+ shape = (self.channels, h // 8, w // 8)
456
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
457
+ return samples, intermediates
458
+
459
+ def configure_optimizers(self):
460
+ lr = self.learning_rate
461
+ params = list(self.control_model.parameters())
462
+ if not self.sd_locked:
463
+ params += list(self.model.diffusion_model.output_blocks.parameters())
464
+ params += list(self.model.diffusion_model.out.parameters())
465
+ opt = torch.optim.AdamW(params, lr=lr)
466
+ return opt
467
+
468
+ def low_vram_shift(self, is_diffusing):
469
+ if is_diffusing:
470
+ self.model = self.model.cuda()
471
+ self.control_model = self.control_model.cuda()
472
+ self.first_stage_model = self.first_stage_model.cpu()
473
+ self.cond_stage_model = self.cond_stage_model.cpu()
474
+ else:
475
+ self.model = self.model.cpu()
476
+ self.control_model = self.control_model.cpu()
477
+ self.first_stage_model = self.first_stage_model.cuda()
478
+ self.cond_stage_model = self.cond_stage_model.cuda()
cldm/ddim_hacked.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear", **kwargs):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ if type(attr) == torch.Tensor:
19
+ if attr.device != torch.device("cuda"):
20
+ attr = attr.to(torch.device("cuda"))
21
+ setattr(self, name, attr)
22
+
23
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
+ alphas_cumprod = self.model.alphas_cumprod
27
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
+
30
+ self.register_buffer('betas', to_torch(self.model.betas))
31
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
+
34
+ # calculations for diffusion q(x_t | x_{t-1}) and others
35
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
+
41
+ # ddim sampling parameters
42
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
+ ddim_timesteps=self.ddim_timesteps,
44
+ eta=ddim_eta,verbose=verbose)
45
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
46
+ self.register_buffer('ddim_alphas', ddim_alphas)
47
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
49
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
50
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
51
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
+
54
+ @torch.no_grad()
55
+ def sample(self,
56
+ S,
57
+ batch_size,
58
+ shape,
59
+ conditioning=None,
60
+ callback=None,
61
+ normals_sequence=None,
62
+ img_callback=None,
63
+ quantize_x0=False,
64
+ eta=0.,
65
+ mask=None,
66
+ x0=None,
67
+ temperature=1.,
68
+ noise_dropout=0.,
69
+ score_corrector=None,
70
+ corrector_kwargs=None,
71
+ verbose=True,
72
+ x_T=None,
73
+ log_every_t=100,
74
+ unconditional_guidance_scale=1.,
75
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
76
+ dynamic_threshold=None,
77
+ ucg_schedule=None,
78
+ **kwargs
79
+ ):
80
+ if conditioning is not None:
81
+ if isinstance(conditioning, dict):
82
+ ctmp = conditioning[list(conditioning.keys())[0]]
83
+ while isinstance(ctmp, list): ctmp = ctmp[0]
84
+ cbs = ctmp.shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+
88
+ elif isinstance(conditioning, list):
89
+ for ctmp in conditioning:
90
+ if ctmp.shape[0] != batch_size:
91
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
92
+
93
+ else:
94
+ if conditioning.shape[0] != batch_size:
95
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
96
+
97
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
98
+ # sampling
99
+ C, H, W = shape
100
+ size = (batch_size, C, H, W)
101
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
102
+
103
+ samples, intermediates = self.ddim_sampling(conditioning, size,
104
+ callback=callback,
105
+ img_callback=img_callback,
106
+ quantize_denoised=quantize_x0,
107
+ mask=mask, x0=x0,
108
+ ddim_use_original_steps=False,
109
+ noise_dropout=noise_dropout,
110
+ temperature=temperature,
111
+ score_corrector=score_corrector,
112
+ corrector_kwargs=corrector_kwargs,
113
+ x_T=x_T,
114
+ log_every_t=log_every_t,
115
+ unconditional_guidance_scale=unconditional_guidance_scale,
116
+ unconditional_conditioning=unconditional_conditioning,
117
+ dynamic_threshold=dynamic_threshold,
118
+ ucg_schedule=ucg_schedule
119
+ )
120
+ return samples, intermediates
121
+
122
+ @torch.no_grad()
123
+ def ddim_sampling(self, cond, shape,
124
+ x_T=None, ddim_use_original_steps=False,
125
+ callback=None, timesteps=None, quantize_denoised=False,
126
+ mask=None, x0=None, img_callback=None, log_every_t=100,
127
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
128
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
129
+ ucg_schedule=None):
130
+ device = self.model.betas.device
131
+ b = shape[0]
132
+ if x_T is None:
133
+ img = torch.randn(shape, device=device)
134
+ else:
135
+ img = x_T
136
+
137
+ if timesteps is None:
138
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
139
+ elif timesteps is not None and not ddim_use_original_steps:
140
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
141
+ timesteps = self.ddim_timesteps[:subset_end]
142
+
143
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
144
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
145
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
146
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
147
+
148
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
149
+
150
+ for i, step in enumerate(iterator):
151
+ index = total_steps - i - 1
152
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
153
+
154
+ if mask is not None:
155
+ assert x0 is not None
156
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
157
+ img = img_orig * mask + (1. - mask) * img
158
+
159
+ if ucg_schedule is not None:
160
+ assert len(ucg_schedule) == len(time_range)
161
+ unconditional_guidance_scale = ucg_schedule[i]
162
+
163
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
164
+ quantize_denoised=quantize_denoised, temperature=temperature,
165
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
166
+ corrector_kwargs=corrector_kwargs,
167
+ unconditional_guidance_scale=unconditional_guidance_scale,
168
+ unconditional_conditioning=unconditional_conditioning,
169
+ dynamic_threshold=dynamic_threshold)
170
+ img, pred_x0 = outs
171
+ if callback: callback(i)
172
+ if img_callback: img_callback(pred_x0, i)
173
+
174
+ if index % log_every_t == 0 or index == total_steps - 1:
175
+ intermediates['x_inter'].append(img)
176
+ intermediates['pred_x0'].append(pred_x0)
177
+
178
+ return img, intermediates
179
+
180
+ @torch.no_grad()
181
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
182
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
183
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
184
+ dynamic_threshold=None):
185
+ b, *_, device = *x.shape, x.device
186
+
187
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
188
+ model_output = self.model.apply_model(x, t, c)
189
+ else:
190
+ model_t = self.model.apply_model(x, t, c)
191
+ model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
192
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
193
+
194
+ if self.model.parameterization == "v":
195
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
196
+ else:
197
+ e_t = model_output
198
+
199
+ if score_corrector is not None:
200
+ assert self.model.parameterization == "eps", 'not implemented'
201
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
202
+
203
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
204
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
205
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
206
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
207
+ # select parameters corresponding to the currently considered timestep
208
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
209
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
210
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
211
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
212
+
213
+ # current prediction for x_0
214
+ if self.model.parameterization != "v":
215
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
216
+ else:
217
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
218
+
219
+ if quantize_denoised:
220
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
221
+
222
+ if dynamic_threshold is not None:
223
+ raise NotImplementedError()
224
+
225
+ # direction pointing to x_t
226
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
227
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
228
+ if noise_dropout > 0.:
229
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
230
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
231
+ return x_prev, pred_x0
232
+
233
+ @torch.no_grad()
234
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
235
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
236
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
237
+ num_reference_steps = timesteps.shape[0]
238
+
239
+ assert t_enc <= num_reference_steps
240
+ num_steps = t_enc
241
+
242
+ if use_original_steps:
243
+ alphas_next = self.alphas_cumprod[:num_steps]
244
+ alphas = self.alphas_cumprod_prev[:num_steps]
245
+ else:
246
+ alphas_next = self.ddim_alphas[:num_steps]
247
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
248
+
249
+ x_next = x0
250
+ intermediates = []
251
+ inter_steps = []
252
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
253
+ t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
254
+ if unconditional_guidance_scale == 1.:
255
+ noise_pred = self.model.apply_model(x_next, t, c)
256
+ else:
257
+ assert unconditional_conditioning is not None
258
+ e_t_uncond, noise_pred = torch.chunk(
259
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
260
+ torch.cat((unconditional_conditioning, c))), 2)
261
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
262
+
263
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
264
+ weighted_noise_pred = alphas_next[i].sqrt() * (
265
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
266
+ x_next = xt_weighted + weighted_noise_pred
267
+ if return_intermediates and i % (
268
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
269
+ intermediates.append(x_next)
270
+ inter_steps.append(i)
271
+ elif return_intermediates and i >= num_steps - 2:
272
+ intermediates.append(x_next)
273
+ inter_steps.append(i)
274
+ if callback: callback(i)
275
+
276
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
277
+ if return_intermediates:
278
+ out.update({'intermediates': intermediates})
279
+ return x_next, out
280
+
281
+ @torch.no_grad()
282
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
283
+ # fast, but does not allow for exact reconstruction
284
+ # t serves as an index to gather the correct alphas
285
+ if use_original_steps:
286
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
287
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
288
+ else:
289
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
290
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
291
+
292
+ if noise is None:
293
+ noise = torch.randn_like(x0)
294
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
295
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
296
+
297
+ @torch.no_grad()
298
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
299
+ use_original_steps=False, callback=None):
300
+
301
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
302
+ timesteps = timesteps[:t_start]
303
+
304
+ time_range = np.flip(timesteps)
305
+ total_steps = timesteps.shape[0]
306
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
307
+
308
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
309
+ x_dec = x_latent
310
+ for i, step in enumerate(iterator):
311
+ index = total_steps - i - 1
312
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
313
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
314
+ unconditional_guidance_scale=unconditional_guidance_scale,
315
+ unconditional_conditioning=unconditional_conditioning)
316
+ if callback: callback(i)
317
+ return x_dec
cldm/hack.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+
4
+ import ldm.modules.encoders.modules
5
+ import ldm.modules.attention
6
+
7
+ from transformers import logging
8
+ from ldm.modules.attention import default
9
+
10
+
11
+ def disable_verbosity():
12
+ logging.set_verbosity_error()
13
+ print('logging improved.')
14
+ return
15
+
16
+
17
+ def enable_sliced_attention():
18
+ ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
19
+ print('Enabled sliced_attention.')
20
+ return
21
+
22
+
23
+ def hack_everything(clip_skip=0):
24
+ disable_verbosity()
25
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
26
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
27
+ print('Enabled clip hacks.')
28
+ return
29
+
30
+
31
+ # Written by Lvmin
32
+ def _hacked_clip_forward(self, text):
33
+ PAD = self.tokenizer.pad_token_id
34
+ EOS = self.tokenizer.eos_token_id
35
+ BOS = self.tokenizer.bos_token_id
36
+
37
+ def tokenize(t):
38
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
39
+
40
+ def transformer_encode(t):
41
+ if self.clip_skip > 1:
42
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
43
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
44
+ else:
45
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
46
+
47
+ def split(x):
48
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
49
+
50
+ def pad(x, p, i):
51
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
52
+
53
+ raw_tokens_list = tokenize(text)
54
+ tokens_list = []
55
+
56
+ for raw_tokens in raw_tokens_list:
57
+ raw_tokens_123 = split(raw_tokens)
58
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
59
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
60
+ tokens_list.append(raw_tokens_123)
61
+
62
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
63
+
64
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
65
+ y = transformer_encode(feed)
66
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
67
+
68
+ return z
69
+
70
+
71
+ # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
72
+ def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
73
+ h = self.heads
74
+
75
+ q = self.to_q(x)
76
+ context = default(context, x)
77
+ k = self.to_k(context)
78
+ v = self.to_v(context)
79
+ del context, x
80
+
81
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
82
+
83
+ limit = k.shape[0]
84
+ att_step = 1
85
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
86
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
87
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
88
+
89
+ q_chunks.reverse()
90
+ k_chunks.reverse()
91
+ v_chunks.reverse()
92
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
93
+ del k, q, v
94
+ for i in range(0, limit, att_step):
95
+ q_buffer = q_chunks.pop()
96
+ k_buffer = k_chunks.pop()
97
+ v_buffer = v_chunks.pop()
98
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
99
+
100
+ del k_buffer, q_buffer
101
+ # attention, what we cannot get enough of, by chunks
102
+
103
+ sim_buffer = sim_buffer.softmax(dim=-1)
104
+
105
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
106
+ del v_buffer
107
+ sim[i:i + att_step, :, :] = sim_buffer
108
+
109
+ del sim_buffer
110
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
111
+ return self.to_out(sim)
cldm/logger.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchvision
6
+ from PIL import Image
7
+ from pytorch_lightning.callbacks import Callback
8
+ from pytorch_lightning.utilities.distributed import rank_zero_only
9
+
10
+
11
+ class ImageLogger(Callback):
12
+ def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
13
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
14
+ log_images_kwargs=None):
15
+ super().__init__()
16
+ self.rescale = rescale
17
+ self.batch_freq = batch_frequency
18
+ self.max_images = max_images
19
+ if not increase_log_steps:
20
+ self.log_steps = [self.batch_freq]
21
+ self.clamp = clamp
22
+ self.disabled = disabled
23
+ self.log_on_batch_idx = log_on_batch_idx
24
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
25
+ self.log_first_step = log_first_step
26
+
27
+ @rank_zero_only
28
+ def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
29
+ root = os.path.join(save_dir, "image_log", split)
30
+ for k in images:
31
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
32
+ if self.rescale:
33
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
34
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
35
+ grid = grid.numpy()
36
+ grid = (grid * 255).astype(np.uint8)
37
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
38
+ path = os.path.join(root, filename)
39
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
40
+ Image.fromarray(grid).save(path)
41
+
42
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
43
+ check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
44
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
45
+ hasattr(pl_module, "log_images") and
46
+ callable(pl_module.log_images) and
47
+ self.max_images > 0):
48
+ logger = type(pl_module.logger)
49
+
50
+ is_train = pl_module.training
51
+ if is_train:
52
+ pl_module.eval()
53
+
54
+ with torch.no_grad():
55
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
56
+
57
+ for k in images:
58
+ N = min(images[k].shape[0], self.max_images)
59
+ images[k] = images[k][:N]
60
+ if isinstance(images[k], torch.Tensor):
61
+ images[k] = images[k].detach().cpu()
62
+ if self.clamp:
63
+ images[k] = torch.clamp(images[k], -1., 1.)
64
+
65
+ self.log_local(pl_module.logger.save_dir, split, images,
66
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
67
+
68
+ if is_train:
69
+ pl_module.train()
70
+
71
+ def check_frequency(self, check_idx):
72
+ return check_idx % self.batch_freq == 0
73
+
74
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
75
+ if not self.disabled:
76
+ self.log_img(pl_module, batch, batch_idx, split="train")
cldm/model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from omegaconf import OmegaConf
5
+ from ldm.util import instantiate_from_config
6
+
7
+
8
+ def get_state_dict(d):
9
+ return d.get('state_dict', d)
10
+
11
+
12
+ def load_state_dict(ckpt_path, location='cpu'):
13
+ _, extension = os.path.splitext(ckpt_path)
14
+ if extension.lower() == ".safetensors":
15
+ import safetensors.torch
16
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
17
+ else:
18
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
19
+ state_dict = get_state_dict(state_dict)
20
+ print(f'Loaded state_dict from [{ckpt_path}]')
21
+ return state_dict
22
+
23
+
24
+ def create_model(config_path):
25
+ config = OmegaConf.load(config_path)
26
+ model = instantiate_from_config(config.model).cpu()
27
+ print(f'Loaded model config from [{config_path}]')
28
+ return model
configs/cldm_v21.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ control_stage_config:
2
+ target: cldm.cldm.ControlNet
3
+ params:
4
+ use_checkpoint: True
5
+ image_size: 32 # unused
6
+ in_channels: 4
7
+ hint_channels: 1
8
+ model_channels: 320
9
+ attention_resolutions: [ 4, 2, 1 ]
10
+ num_res_blocks: 2
11
+ channel_mult: [ 1, 2, 4, 4 ]
12
+ num_head_channels: 64 # need to fix for flash-attn
13
+ use_spatial_transformer: True
14
+ use_linear_in_transformer: True
15
+ transformer_depth: 1
16
+ context_dim: 1024
17
+ legacy: False
configs/inference_512_v1.0.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: lvdm.models.ddpm3d.LatentVisualDiffusion
3
+ params:
4
+ rescale_betas_zero_snr: True
5
+ parameterization: "v"
6
+ linear_start: 0.00085
7
+ linear_end: 0.012
8
+ num_timesteps_cond: 1
9
+ timesteps: 1000
10
+ first_stage_key: video
11
+ cond_stage_key: caption
12
+ cond_stage_trainable: False
13
+ conditioning_key: hybrid
14
+ image_size: [40, 64]
15
+ channels: 4
16
+ scale_by_std: False
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+ uncond_type: 'empty_seq'
20
+ use_dynamic_rescale: true
21
+ base_scale: 0.7
22
+ fps_condition_type: 'fps'
23
+ perframe_ae: True
24
+ loop_video: true
25
+ unet_config:
26
+ target: cldm.cldm.ControlledUnetModel
27
+ params:
28
+ in_channels: 8
29
+ out_channels: 4
30
+ model_channels: 320
31
+ attention_resolutions:
32
+ - 4
33
+ - 2
34
+ - 1
35
+ num_res_blocks: 2
36
+ channel_mult:
37
+ - 1
38
+ - 2
39
+ - 4
40
+ - 4
41
+ dropout: 0.1
42
+ num_head_channels: 64
43
+ transformer_depth: 1
44
+ context_dim: 1024
45
+ use_linear: true
46
+ use_checkpoint: True
47
+ temporal_conv: True
48
+ temporal_attention: True
49
+ temporal_selfatt_only: true
50
+ use_relative_position: false
51
+ use_causal_attention: False
52
+ temporal_length: 16
53
+ addition_attention: true
54
+ image_cross_attention: true
55
+ default_fs: 24
56
+ fs_condition: true
57
+
58
+ first_stage_config:
59
+ target: lvdm.models.autoencoder.AutoencoderKL_Dualref
60
+ params:
61
+ embed_dim: 4
62
+ monitor: val/rec_loss
63
+ ddconfig:
64
+ double_z: True
65
+ z_channels: 4
66
+ resolution: 256
67
+ in_channels: 3
68
+ out_ch: 3
69
+ ch: 128
70
+ ch_mult:
71
+ - 1
72
+ - 2
73
+ - 4
74
+ - 4
75
+ num_res_blocks: 2
76
+ attn_resolutions: []
77
+ dropout: 0.0
78
+ lossconfig:
79
+ target: torch.nn.Identity
80
+
81
+ cond_stage_config:
82
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
83
+ params:
84
+ freeze: true
85
+ layer: "penultimate"
86
+
87
+ img_cond_stage_config:
88
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
89
+ params:
90
+ freeze: true
91
+
92
+ image_proj_stage_config:
93
+ target: lvdm.modules.encoders.resampler.Resampler
94
+ params:
95
+ dim: 1024
96
+ depth: 4
97
+ dim_head: 64
98
+ heads: 12
99
+ num_queries: 16
100
+ embedding_dim: 1280
101
+ output_dim: 1024
102
+ ff_mult: 4
103
+ video_length: 16
configs/training_1024_v1.0/config.yaml ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pretrained_checkpoint: checkpoints/dynamicrafter_1024_v1/model.ckpt
3
+ base_learning_rate: 1.0e-05
4
+ scale_lr: False
5
+ target: lvdm.models.ddpm3d.LatentVisualDiffusion
6
+ params:
7
+ rescale_betas_zero_snr: True
8
+ parameterization: "v"
9
+ linear_start: 0.00085
10
+ linear_end: 0.012
11
+ num_timesteps_cond: 1
12
+ log_every_t: 200
13
+ timesteps: 1000
14
+ first_stage_key: video
15
+ cond_stage_key: caption
16
+ cond_stage_trainable: False
17
+ image_proj_model_trainable: True
18
+ conditioning_key: hybrid
19
+ image_size: [72, 128]
20
+ channels: 4
21
+ scale_by_std: False
22
+ scale_factor: 0.18215
23
+ use_ema: False
24
+ uncond_prob: 0.05
25
+ uncond_type: 'empty_seq'
26
+ rand_cond_frame: true
27
+ use_dynamic_rescale: true
28
+ base_scale: 0.3
29
+ fps_condition_type: 'fps'
30
+ perframe_ae: True
31
+
32
+ unet_config:
33
+ target: lvdm.modules.networks.openaimodel3d.UNetModel
34
+ params:
35
+ in_channels: 8
36
+ out_channels: 4
37
+ model_channels: 320
38
+ attention_resolutions:
39
+ - 4
40
+ - 2
41
+ - 1
42
+ num_res_blocks: 2
43
+ channel_mult:
44
+ - 1
45
+ - 2
46
+ - 4
47
+ - 4
48
+ dropout: 0.1
49
+ num_head_channels: 64
50
+ transformer_depth: 1
51
+ context_dim: 1024
52
+ use_linear: true
53
+ use_checkpoint: True
54
+ temporal_conv: True
55
+ temporal_attention: True
56
+ temporal_selfatt_only: true
57
+ use_relative_position: false
58
+ use_causal_attention: False
59
+ temporal_length: 16
60
+ addition_attention: true
61
+ image_cross_attention: true
62
+ default_fs: 10
63
+ fs_condition: true
64
+
65
+ first_stage_config:
66
+ target: lvdm.models.autoencoder.AutoencoderKL
67
+ params:
68
+ embed_dim: 4
69
+ monitor: val/rec_loss
70
+ ddconfig:
71
+ double_z: True
72
+ z_channels: 4
73
+ resolution: 256
74
+ in_channels: 3
75
+ out_ch: 3
76
+ ch: 128
77
+ ch_mult:
78
+ - 1
79
+ - 2
80
+ - 4
81
+ - 4
82
+ num_res_blocks: 2
83
+ attn_resolutions: []
84
+ dropout: 0.0
85
+ lossconfig:
86
+ target: torch.nn.Identity
87
+
88
+ cond_stage_config:
89
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
90
+ params:
91
+ freeze: true
92
+ layer: "penultimate"
93
+
94
+ img_cond_stage_config:
95
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
96
+ params:
97
+ freeze: true
98
+
99
+ image_proj_stage_config:
100
+ target: lvdm.modules.encoders.resampler.Resampler
101
+ params:
102
+ dim: 1024
103
+ depth: 4
104
+ dim_head: 64
105
+ heads: 12
106
+ num_queries: 16
107
+ embedding_dim: 1280
108
+ output_dim: 1024
109
+ ff_mult: 4
110
+ video_length: 16
111
+
112
+ data:
113
+ target: utils_data.DataModuleFromConfig
114
+ params:
115
+ batch_size: 1
116
+ num_workers: 12
117
+ wrap: false
118
+ train:
119
+ target: lvdm.data.webvid.WebVid
120
+ params:
121
+ data_dir: <WebVid10M DATA>
122
+ meta_path: <.csv FILE>
123
+ video_length: 16
124
+ frame_stride: 6
125
+ load_raw_resolution: true
126
+ resolution: [576, 1024]
127
+ spatial_transform: resize_center_crop
128
+ random_fs: true ## if true, we uniformly sample fs with max_fs=frame_stride (above)
129
+
130
+ lightning:
131
+ precision: 16
132
+ # strategy: deepspeed_stage_2
133
+ trainer:
134
+ benchmark: True
135
+ accumulate_grad_batches: 2
136
+ max_steps: 100000
137
+ # logger
138
+ log_every_n_steps: 50
139
+ # val
140
+ val_check_interval: 0.5
141
+ gradient_clip_algorithm: 'norm'
142
+ gradient_clip_val: 0.5
143
+ callbacks:
144
+ model_checkpoint:
145
+ target: pytorch_lightning.callbacks.ModelCheckpoint
146
+ params:
147
+ every_n_train_steps: 9000 #1000
148
+ filename: "{epoch}-{step}"
149
+ save_weights_only: True
150
+ metrics_over_trainsteps_checkpoint:
151
+ target: pytorch_lightning.callbacks.ModelCheckpoint
152
+ params:
153
+ filename: '{epoch}-{step}'
154
+ save_weights_only: True
155
+ every_n_train_steps: 10000 #20000 # 3s/step*2w=
156
+ batch_logger:
157
+ target: callbacks.ImageLogger
158
+ params:
159
+ batch_frequency: 500
160
+ to_local: False
161
+ max_images: 8
162
+ log_images_kwargs:
163
+ ddim_steps: 50
164
+ unconditional_guidance_scale: 7.5
165
+ timestep_spacing: uniform_trailing
166
+ guidance_rescale: 0.7
configs/training_1024_v1.0/run.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NCCL configuration
2
+ # export NCCL_DEBUG=INFO
3
+ # export NCCL_IB_DISABLE=0
4
+ # export NCCL_IB_GID_INDEX=3
5
+ # export NCCL_NET_GDR_LEVEL=3
6
+ # export NCCL_TOPO_FILE=/tmp/topo.txt
7
+
8
+ # args
9
+ name="training_1024_v1.0"
10
+ config_file=configs/${name}/config.yaml
11
+
12
+ # save root dir for logs, checkpoints, tensorboard record, etc.
13
+ save_root="<YOUR_SAVE_ROOT_DIR>"
14
+
15
+ mkdir -p $save_root/$name
16
+
17
+ ## run
18
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
19
+ --nproc_per_node=$HOST_GPU_NUM --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
20
+ ./main/trainer.py \
21
+ --base $config_file \
22
+ --train \
23
+ --name $name \
24
+ --logdir $save_root \
25
+ --devices $HOST_GPU_NUM \
26
+ lightning.trainer.num_nodes=1
27
+
28
+ ## debugging
29
+ # CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch \
30
+ # --nproc_per_node=4 --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
31
+ # ./main/trainer.py \
32
+ # --base $config_file \
33
+ # --train \
34
+ # --name $name \
35
+ # --logdir $save_root \
36
+ # --devices 4 \
37
+ # lightning.trainer.num_nodes=1
configs/training_512_v1.0/config.yaml ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pretrained_checkpoint: checkpoints/dynamicrafter_512_v1/model.ckpt
3
+ base_learning_rate: 1.0e-05
4
+ scale_lr: False
5
+ target: lvdm.models.ddpm3d.LatentVisualDiffusion
6
+ params:
7
+ rescale_betas_zero_snr: True
8
+ parameterization: "v"
9
+ linear_start: 0.00085
10
+ linear_end: 0.012
11
+ num_timesteps_cond: 1
12
+ log_every_t: 200
13
+ timesteps: 1000
14
+ first_stage_key: video
15
+ cond_stage_key: caption
16
+ cond_stage_trainable: False
17
+ image_proj_model_trainable: True
18
+ conditioning_key: hybrid
19
+ image_size: [40, 64]
20
+ channels: 4
21
+ scale_by_std: False
22
+ scale_factor: 0.18215
23
+ use_ema: False
24
+ uncond_prob: 0.05
25
+ uncond_type: 'empty_seq'
26
+ rand_cond_frame: true
27
+ use_dynamic_rescale: true
28
+ base_scale: 0.7
29
+ fps_condition_type: 'fps'
30
+ perframe_ae: True
31
+
32
+ unet_config:
33
+ target: lvdm.modules.networks.openaimodel3d.UNetModel
34
+ params:
35
+ in_channels: 8
36
+ out_channels: 4
37
+ model_channels: 320
38
+ attention_resolutions:
39
+ - 4
40
+ - 2
41
+ - 1
42
+ num_res_blocks: 2
43
+ channel_mult:
44
+ - 1
45
+ - 2
46
+ - 4
47
+ - 4
48
+ dropout: 0.1
49
+ num_head_channels: 64
50
+ transformer_depth: 1
51
+ context_dim: 1024
52
+ use_linear: true
53
+ use_checkpoint: True
54
+ temporal_conv: True
55
+ temporal_attention: True
56
+ temporal_selfatt_only: true
57
+ use_relative_position: false
58
+ use_causal_attention: False
59
+ temporal_length: 16
60
+ addition_attention: true
61
+ image_cross_attention: true
62
+ default_fs: 10
63
+ fs_condition: true
64
+
65
+ first_stage_config:
66
+ target: lvdm.models.autoencoder.AutoencoderKL
67
+ params:
68
+ embed_dim: 4
69
+ monitor: val/rec_loss
70
+ ddconfig:
71
+ double_z: True
72
+ z_channels: 4
73
+ resolution: 256
74
+ in_channels: 3
75
+ out_ch: 3
76
+ ch: 128
77
+ ch_mult:
78
+ - 1
79
+ - 2
80
+ - 4
81
+ - 4
82
+ num_res_blocks: 2
83
+ attn_resolutions: []
84
+ dropout: 0.0
85
+ lossconfig:
86
+ target: torch.nn.Identity
87
+
88
+ cond_stage_config:
89
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
90
+ params:
91
+ freeze: true
92
+ layer: "penultimate"
93
+
94
+ img_cond_stage_config:
95
+ target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
96
+ params:
97
+ freeze: true
98
+
99
+ image_proj_stage_config:
100
+ target: lvdm.modules.encoders.resampler.Resampler
101
+ params:
102
+ dim: 1024
103
+ depth: 4
104
+ dim_head: 64
105
+ heads: 12
106
+ num_queries: 16
107
+ embedding_dim: 1280
108
+ output_dim: 1024
109
+ ff_mult: 4
110
+ video_length: 16
111
+
112
+ data:
113
+ target: utils_data.DataModuleFromConfig
114
+ params:
115
+ batch_size: 2
116
+ num_workers: 12
117
+ wrap: false
118
+ train:
119
+ target: lvdm.data.webvid.WebVid
120
+ params:
121
+ data_dir: <WebVid10M DATA>
122
+ meta_path: <.csv FILE>
123
+ video_length: 16
124
+ frame_stride: 6
125
+ load_raw_resolution: true
126
+ resolution: [320, 512]
127
+ spatial_transform: resize_center_crop
128
+ random_fs: true ## if true, we uniformly sample fs with max_fs=frame_stride (above)
129
+
130
+ lightning:
131
+ precision: 16
132
+ # strategy: deepspeed_stage_2
133
+ trainer:
134
+ benchmark: True
135
+ accumulate_grad_batches: 2
136
+ max_steps: 100000
137
+ # logger
138
+ log_every_n_steps: 50
139
+ # val
140
+ val_check_interval: 0.5
141
+ gradient_clip_algorithm: 'norm'
142
+ gradient_clip_val: 0.5
143
+ callbacks:
144
+ model_checkpoint:
145
+ target: pytorch_lightning.callbacks.ModelCheckpoint
146
+ params:
147
+ every_n_train_steps: 9000 #1000
148
+ filename: "{epoch}-{step}"
149
+ save_weights_only: True
150
+ metrics_over_trainsteps_checkpoint:
151
+ target: pytorch_lightning.callbacks.ModelCheckpoint
152
+ params:
153
+ filename: '{epoch}-{step}'
154
+ save_weights_only: True
155
+ every_n_train_steps: 10000 #20000 # 3s/step*2w=
156
+ batch_logger:
157
+ target: callbacks.ImageLogger
158
+ params:
159
+ batch_frequency: 500
160
+ to_local: False
161
+ max_images: 8
162
+ log_images_kwargs:
163
+ ddim_steps: 50
164
+ unconditional_guidance_scale: 7.5
165
+ timestep_spacing: uniform_trailing
166
+ guidance_rescale: 0.7
configs/training_512_v1.0/run.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NCCL configuration
2
+ # export NCCL_DEBUG=INFO
3
+ # export NCCL_IB_DISABLE=0
4
+ # export NCCL_IB_GID_INDEX=3
5
+ # export NCCL_NET_GDR_LEVEL=3
6
+ # export NCCL_TOPO_FILE=/tmp/topo.txt
7
+
8
+ # args
9
+ name="training_512_v1.0"
10
+ config_file=configs/${name}/config.yaml
11
+
12
+ # save root dir for logs, checkpoints, tensorboard record, etc.
13
+ save_root="<YOUR_SAVE_ROOT_DIR>"
14
+
15
+ mkdir -p $save_root/$name
16
+
17
+ ## run
18
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
19
+ --nproc_per_node=$HOST_GPU_NUM --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
20
+ ./main/trainer.py \
21
+ --base $config_file \
22
+ --train \
23
+ --name $name \
24
+ --logdir $save_root \
25
+ --devices $HOST_GPU_NUM \
26
+ lightning.trainer.num_nodes=1
27
+
28
+ ## debugging
29
+ # CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch \
30
+ # --nproc_per_node=4 --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
31
+ # ./main/trainer.py \
32
+ # --base $config_file \
33
+ # --train \
34
+ # --name $name \
35
+ # --logdir $save_root \
36
+ # --devices 4 \
37
+ # lightning.trainer.num_nodes=1
control_models/_put_your_control_models_.txt ADDED
File without changes