Spaces:
Runtime error
Runtime error
this isn't very nice.
Browse files- makeavid_sd/LICENSE +661 -0
- makeavid_sd/README.md +1 -0
- makeavid_sd/makeavid_sd/__init__.py +1 -0
- makeavid_sd/makeavid_sd/flax_impl/__init__.py +0 -0
- makeavid_sd/makeavid_sd/flax_impl/dataset.py +159 -0
- makeavid_sd/makeavid_sd/flax_impl/flax_attention_pseudo3d.py +212 -0
- makeavid_sd/makeavid_sd/flax_impl/flax_embeddings.py +62 -0
- makeavid_sd/makeavid_sd/flax_impl/flax_resnet_pseudo3d.py +175 -0
- makeavid_sd/makeavid_sd/flax_impl/flax_trainer.py +608 -0
- makeavid_sd/makeavid_sd/flax_impl/flax_unet_pseudo3d_blocks.py +254 -0
- makeavid_sd/makeavid_sd/flax_impl/flax_unet_pseudo3d_condition.py +251 -0
- makeavid_sd/makeavid_sd/flax_impl/train.py +143 -0
- makeavid_sd/makeavid_sd/flax_impl/train.sh +34 -0
- makeavid_sd/makeavid_sd/inference.py +486 -0
- makeavid_sd/makeavid_sd/torch_impl/__init__.py +0 -0
- makeavid_sd/makeavid_sd/torch_impl/torch_attention_pseudo3d.py +294 -0
- makeavid_sd/makeavid_sd/torch_impl/torch_cross_attention.py +171 -0
- makeavid_sd/makeavid_sd/torch_impl/torch_embeddings.py +92 -0
- makeavid_sd/makeavid_sd/torch_impl/torch_resnet_pseudo3d.py +295 -0
- makeavid_sd/makeavid_sd/torch_impl/torch_unet_pseudo3d_blocks.py +493 -0
- makeavid_sd/makeavid_sd/torch_impl/torch_unet_pseudo3d_condition.py +235 -0
- makeavid_sd/requirements.txt +2 -0
- makeavid_sd/setup.py +11 -0
- makeavid_sd/trainer_xla.py +104 -0
makeavid_sd/LICENSE
ADDED
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU AFFERO GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 19 November 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU Affero General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works, specifically designed to ensure
|
12 |
+
cooperation with the community in the case of network server software.
|
13 |
+
|
14 |
+
The licenses for most software and other practical works are designed
|
15 |
+
to take away your freedom to share and change the works. By contrast,
|
16 |
+
our General Public Licenses are intended to guarantee your freedom to
|
17 |
+
share and change all versions of a program--to make sure it remains free
|
18 |
+
software for all its users.
|
19 |
+
|
20 |
+
When we speak of free software, we are referring to freedom, not
|
21 |
+
price. Our General Public Licenses are designed to make sure that you
|
22 |
+
have the freedom to distribute copies of free software (and charge for
|
23 |
+
them if you wish), that you receive source code or can get it if you
|
24 |
+
want it, that you can change the software or use pieces of it in new
|
25 |
+
free programs, and that you know you can do these things.
|
26 |
+
|
27 |
+
Developers that use our General Public Licenses protect your rights
|
28 |
+
with two steps: (1) assert copyright on the software, and (2) offer
|
29 |
+
you this License which gives you legal permission to copy, distribute
|
30 |
+
and/or modify the software.
|
31 |
+
|
32 |
+
A secondary benefit of defending all users' freedom is that
|
33 |
+
improvements made in alternate versions of the program, if they
|
34 |
+
receive widespread use, become available for other developers to
|
35 |
+
incorporate. Many developers of free software are heartened and
|
36 |
+
encouraged by the resulting cooperation. However, in the case of
|
37 |
+
software used on network servers, this result may fail to come about.
|
38 |
+
The GNU General Public License permits making a modified version and
|
39 |
+
letting the public access it on a server without ever releasing its
|
40 |
+
source code to the public.
|
41 |
+
|
42 |
+
The GNU Affero General Public License is designed specifically to
|
43 |
+
ensure that, in such cases, the modified source code becomes available
|
44 |
+
to the community. It requires the operator of a network server to
|
45 |
+
provide the source code of the modified version running there to the
|
46 |
+
users of that server. Therefore, public use of a modified version, on
|
47 |
+
a publicly accessible server, gives the public access to the source
|
48 |
+
code of the modified version.
|
49 |
+
|
50 |
+
An older license, called the Affero General Public License and
|
51 |
+
published by Affero, was designed to accomplish similar goals. This is
|
52 |
+
a different license, not a version of the Affero GPL, but Affero has
|
53 |
+
released a new version of the Affero GPL which permits relicensing under
|
54 |
+
this license.
|
55 |
+
|
56 |
+
The precise terms and conditions for copying, distribution and
|
57 |
+
modification follow.
|
58 |
+
|
59 |
+
TERMS AND CONDITIONS
|
60 |
+
|
61 |
+
0. Definitions.
|
62 |
+
|
63 |
+
"This License" refers to version 3 of the GNU Affero General Public License.
|
64 |
+
|
65 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
66 |
+
works, such as semiconductor masks.
|
67 |
+
|
68 |
+
"The Program" refers to any copyrightable work licensed under this
|
69 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
70 |
+
"recipients" may be individuals or organizations.
|
71 |
+
|
72 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
73 |
+
in a fashion requiring copyright permission, other than the making of an
|
74 |
+
exact copy. The resulting work is called a "modified version" of the
|
75 |
+
earlier work or a work "based on" the earlier work.
|
76 |
+
|
77 |
+
A "covered work" means either the unmodified Program or a work based
|
78 |
+
on the Program.
|
79 |
+
|
80 |
+
To "propagate" a work means to do anything with it that, without
|
81 |
+
permission, would make you directly or secondarily liable for
|
82 |
+
infringement under applicable copyright law, except executing it on a
|
83 |
+
computer or modifying a private copy. Propagation includes copying,
|
84 |
+
distribution (with or without modification), making available to the
|
85 |
+
public, and in some countries other activities as well.
|
86 |
+
|
87 |
+
To "convey" a work means any kind of propagation that enables other
|
88 |
+
parties to make or receive copies. Mere interaction with a user through
|
89 |
+
a computer network, with no transfer of a copy, is not conveying.
|
90 |
+
|
91 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
92 |
+
to the extent that it includes a convenient and prominently visible
|
93 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
94 |
+
tells the user that there is no warranty for the work (except to the
|
95 |
+
extent that warranties are provided), that licensees may convey the
|
96 |
+
work under this License, and how to view a copy of this License. If
|
97 |
+
the interface presents a list of user commands or options, such as a
|
98 |
+
menu, a prominent item in the list meets this criterion.
|
99 |
+
|
100 |
+
1. Source Code.
|
101 |
+
|
102 |
+
The "source code" for a work means the preferred form of the work
|
103 |
+
for making modifications to it. "Object code" means any non-source
|
104 |
+
form of a work.
|
105 |
+
|
106 |
+
A "Standard Interface" means an interface that either is an official
|
107 |
+
standard defined by a recognized standards body, or, in the case of
|
108 |
+
interfaces specified for a particular programming language, one that
|
109 |
+
is widely used among developers working in that language.
|
110 |
+
|
111 |
+
The "System Libraries" of an executable work include anything, other
|
112 |
+
than the work as a whole, that (a) is included in the normal form of
|
113 |
+
packaging a Major Component, but which is not part of that Major
|
114 |
+
Component, and (b) serves only to enable use of the work with that
|
115 |
+
Major Component, or to implement a Standard Interface for which an
|
116 |
+
implementation is available to the public in source code form. A
|
117 |
+
"Major Component", in this context, means a major essential component
|
118 |
+
(kernel, window system, and so on) of the specific operating system
|
119 |
+
(if any) on which the executable work runs, or a compiler used to
|
120 |
+
produce the work, or an object code interpreter used to run it.
|
121 |
+
|
122 |
+
The "Corresponding Source" for a work in object code form means all
|
123 |
+
the source code needed to generate, install, and (for an executable
|
124 |
+
work) run the object code and to modify the work, including scripts to
|
125 |
+
control those activities. However, it does not include the work's
|
126 |
+
System Libraries, or general-purpose tools or generally available free
|
127 |
+
programs which are used unmodified in performing those activities but
|
128 |
+
which are not part of the work. For example, Corresponding Source
|
129 |
+
includes interface definition files associated with source files for
|
130 |
+
the work, and the source code for shared libraries and dynamically
|
131 |
+
linked subprograms that the work is specifically designed to require,
|
132 |
+
such as by intimate data communication or control flow between those
|
133 |
+
subprograms and other parts of the work.
|
134 |
+
|
135 |
+
The Corresponding Source need not include anything that users
|
136 |
+
can regenerate automatically from other parts of the Corresponding
|
137 |
+
Source.
|
138 |
+
|
139 |
+
The Corresponding Source for a work in source code form is that
|
140 |
+
same work.
|
141 |
+
|
142 |
+
2. Basic Permissions.
|
143 |
+
|
144 |
+
All rights granted under this License are granted for the term of
|
145 |
+
copyright on the Program, and are irrevocable provided the stated
|
146 |
+
conditions are met. This License explicitly affirms your unlimited
|
147 |
+
permission to run the unmodified Program. The output from running a
|
148 |
+
covered work is covered by this License only if the output, given its
|
149 |
+
content, constitutes a covered work. This License acknowledges your
|
150 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
151 |
+
|
152 |
+
You may make, run and propagate covered works that you do not
|
153 |
+
convey, without conditions so long as your license otherwise remains
|
154 |
+
in force. You may convey covered works to others for the sole purpose
|
155 |
+
of having them make modifications exclusively for you, or provide you
|
156 |
+
with facilities for running those works, provided that you comply with
|
157 |
+
the terms of this License in conveying all material for which you do
|
158 |
+
not control copyright. Those thus making or running the covered works
|
159 |
+
for you must do so exclusively on your behalf, under your direction
|
160 |
+
and control, on terms that prohibit them from making any copies of
|
161 |
+
your copyrighted material outside their relationship with you.
|
162 |
+
|
163 |
+
Conveying under any other circumstances is permitted solely under
|
164 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
165 |
+
makes it unnecessary.
|
166 |
+
|
167 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
168 |
+
|
169 |
+
No covered work shall be deemed part of an effective technological
|
170 |
+
measure under any applicable law fulfilling obligations under article
|
171 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
172 |
+
similar laws prohibiting or restricting circumvention of such
|
173 |
+
measures.
|
174 |
+
|
175 |
+
When you convey a covered work, you waive any legal power to forbid
|
176 |
+
circumvention of technological measures to the extent such circumvention
|
177 |
+
is effected by exercising rights under this License with respect to
|
178 |
+
the covered work, and you disclaim any intention to limit operation or
|
179 |
+
modification of the work as a means of enforcing, against the work's
|
180 |
+
users, your or third parties' legal rights to forbid circumvention of
|
181 |
+
technological measures.
|
182 |
+
|
183 |
+
4. Conveying Verbatim Copies.
|
184 |
+
|
185 |
+
You may convey verbatim copies of the Program's source code as you
|
186 |
+
receive it, in any medium, provided that you conspicuously and
|
187 |
+
appropriately publish on each copy an appropriate copyright notice;
|
188 |
+
keep intact all notices stating that this License and any
|
189 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
190 |
+
keep intact all notices of the absence of any warranty; and give all
|
191 |
+
recipients a copy of this License along with the Program.
|
192 |
+
|
193 |
+
You may charge any price or no price for each copy that you convey,
|
194 |
+
and you may offer support or warranty protection for a fee.
|
195 |
+
|
196 |
+
5. Conveying Modified Source Versions.
|
197 |
+
|
198 |
+
You may convey a work based on the Program, or the modifications to
|
199 |
+
produce it from the Program, in the form of source code under the
|
200 |
+
terms of section 4, provided that you also meet all of these conditions:
|
201 |
+
|
202 |
+
a) The work must carry prominent notices stating that you modified
|
203 |
+
it, and giving a relevant date.
|
204 |
+
|
205 |
+
b) The work must carry prominent notices stating that it is
|
206 |
+
released under this License and any conditions added under section
|
207 |
+
7. This requirement modifies the requirement in section 4 to
|
208 |
+
"keep intact all notices".
|
209 |
+
|
210 |
+
c) You must license the entire work, as a whole, under this
|
211 |
+
License to anyone who comes into possession of a copy. This
|
212 |
+
License will therefore apply, along with any applicable section 7
|
213 |
+
additional terms, to the whole of the work, and all its parts,
|
214 |
+
regardless of how they are packaged. This License gives no
|
215 |
+
permission to license the work in any other way, but it does not
|
216 |
+
invalidate such permission if you have separately received it.
|
217 |
+
|
218 |
+
d) If the work has interactive user interfaces, each must display
|
219 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
220 |
+
interfaces that do not display Appropriate Legal Notices, your
|
221 |
+
work need not make them do so.
|
222 |
+
|
223 |
+
A compilation of a covered work with other separate and independent
|
224 |
+
works, which are not by their nature extensions of the covered work,
|
225 |
+
and which are not combined with it such as to form a larger program,
|
226 |
+
in or on a volume of a storage or distribution medium, is called an
|
227 |
+
"aggregate" if the compilation and its resulting copyright are not
|
228 |
+
used to limit the access or legal rights of the compilation's users
|
229 |
+
beyond what the individual works permit. Inclusion of a covered work
|
230 |
+
in an aggregate does not cause this License to apply to the other
|
231 |
+
parts of the aggregate.
|
232 |
+
|
233 |
+
6. Conveying Non-Source Forms.
|
234 |
+
|
235 |
+
You may convey a covered work in object code form under the terms
|
236 |
+
of sections 4 and 5, provided that you also convey the
|
237 |
+
machine-readable Corresponding Source under the terms of this License,
|
238 |
+
in one of these ways:
|
239 |
+
|
240 |
+
a) Convey the object code in, or embodied in, a physical product
|
241 |
+
(including a physical distribution medium), accompanied by the
|
242 |
+
Corresponding Source fixed on a durable physical medium
|
243 |
+
customarily used for software interchange.
|
244 |
+
|
245 |
+
b) Convey the object code in, or embodied in, a physical product
|
246 |
+
(including a physical distribution medium), accompanied by a
|
247 |
+
written offer, valid for at least three years and valid for as
|
248 |
+
long as you offer spare parts or customer support for that product
|
249 |
+
model, to give anyone who possesses the object code either (1) a
|
250 |
+
copy of the Corresponding Source for all the software in the
|
251 |
+
product that is covered by this License, on a durable physical
|
252 |
+
medium customarily used for software interchange, for a price no
|
253 |
+
more than your reasonable cost of physically performing this
|
254 |
+
conveying of source, or (2) access to copy the
|
255 |
+
Corresponding Source from a network server at no charge.
|
256 |
+
|
257 |
+
c) Convey individual copies of the object code with a copy of the
|
258 |
+
written offer to provide the Corresponding Source. This
|
259 |
+
alternative is allowed only occasionally and noncommercially, and
|
260 |
+
only if you received the object code with such an offer, in accord
|
261 |
+
with subsection 6b.
|
262 |
+
|
263 |
+
d) Convey the object code by offering access from a designated
|
264 |
+
place (gratis or for a charge), and offer equivalent access to the
|
265 |
+
Corresponding Source in the same way through the same place at no
|
266 |
+
further charge. You need not require recipients to copy the
|
267 |
+
Corresponding Source along with the object code. If the place to
|
268 |
+
copy the object code is a network server, the Corresponding Source
|
269 |
+
may be on a different server (operated by you or a third party)
|
270 |
+
that supports equivalent copying facilities, provided you maintain
|
271 |
+
clear directions next to the object code saying where to find the
|
272 |
+
Corresponding Source. Regardless of what server hosts the
|
273 |
+
Corresponding Source, you remain obligated to ensure that it is
|
274 |
+
available for as long as needed to satisfy these requirements.
|
275 |
+
|
276 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
277 |
+
you inform other peers where the object code and Corresponding
|
278 |
+
Source of the work are being offered to the general public at no
|
279 |
+
charge under subsection 6d.
|
280 |
+
|
281 |
+
A separable portion of the object code, whose source code is excluded
|
282 |
+
from the Corresponding Source as a System Library, need not be
|
283 |
+
included in conveying the object code work.
|
284 |
+
|
285 |
+
A "User Product" is either (1) a "consumer product", which means any
|
286 |
+
tangible personal property which is normally used for personal, family,
|
287 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
288 |
+
into a dwelling. In determining whether a product is a consumer product,
|
289 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
290 |
+
product received by a particular user, "normally used" refers to a
|
291 |
+
typical or common use of that class of product, regardless of the status
|
292 |
+
of the particular user or of the way in which the particular user
|
293 |
+
actually uses, or expects or is expected to use, the product. A product
|
294 |
+
is a consumer product regardless of whether the product has substantial
|
295 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
296 |
+
the only significant mode of use of the product.
|
297 |
+
|
298 |
+
"Installation Information" for a User Product means any methods,
|
299 |
+
procedures, authorization keys, or other information required to install
|
300 |
+
and execute modified versions of a covered work in that User Product from
|
301 |
+
a modified version of its Corresponding Source. The information must
|
302 |
+
suffice to ensure that the continued functioning of the modified object
|
303 |
+
code is in no case prevented or interfered with solely because
|
304 |
+
modification has been made.
|
305 |
+
|
306 |
+
If you convey an object code work under this section in, or with, or
|
307 |
+
specifically for use in, a User Product, and the conveying occurs as
|
308 |
+
part of a transaction in which the right of possession and use of the
|
309 |
+
User Product is transferred to the recipient in perpetuity or for a
|
310 |
+
fixed term (regardless of how the transaction is characterized), the
|
311 |
+
Corresponding Source conveyed under this section must be accompanied
|
312 |
+
by the Installation Information. But this requirement does not apply
|
313 |
+
if neither you nor any third party retains the ability to install
|
314 |
+
modified object code on the User Product (for example, the work has
|
315 |
+
been installed in ROM).
|
316 |
+
|
317 |
+
The requirement to provide Installation Information does not include a
|
318 |
+
requirement to continue to provide support service, warranty, or updates
|
319 |
+
for a work that has been modified or installed by the recipient, or for
|
320 |
+
the User Product in which it has been modified or installed. Access to a
|
321 |
+
network may be denied when the modification itself materially and
|
322 |
+
adversely affects the operation of the network or violates the rules and
|
323 |
+
protocols for communication across the network.
|
324 |
+
|
325 |
+
Corresponding Source conveyed, and Installation Information provided,
|
326 |
+
in accord with this section must be in a format that is publicly
|
327 |
+
documented (and with an implementation available to the public in
|
328 |
+
source code form), and must require no special password or key for
|
329 |
+
unpacking, reading or copying.
|
330 |
+
|
331 |
+
7. Additional Terms.
|
332 |
+
|
333 |
+
"Additional permissions" are terms that supplement the terms of this
|
334 |
+
License by making exceptions from one or more of its conditions.
|
335 |
+
Additional permissions that are applicable to the entire Program shall
|
336 |
+
be treated as though they were included in this License, to the extent
|
337 |
+
that they are valid under applicable law. If additional permissions
|
338 |
+
apply only to part of the Program, that part may be used separately
|
339 |
+
under those permissions, but the entire Program remains governed by
|
340 |
+
this License without regard to the additional permissions.
|
341 |
+
|
342 |
+
When you convey a copy of a covered work, you may at your option
|
343 |
+
remove any additional permissions from that copy, or from any part of
|
344 |
+
it. (Additional permissions may be written to require their own
|
345 |
+
removal in certain cases when you modify the work.) You may place
|
346 |
+
additional permissions on material, added by you to a covered work,
|
347 |
+
for which you have or can give appropriate copyright permission.
|
348 |
+
|
349 |
+
Notwithstanding any other provision of this License, for material you
|
350 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
351 |
+
that material) supplement the terms of this License with terms:
|
352 |
+
|
353 |
+
a) Disclaiming warranty or limiting liability differently from the
|
354 |
+
terms of sections 15 and 16 of this License; or
|
355 |
+
|
356 |
+
b) Requiring preservation of specified reasonable legal notices or
|
357 |
+
author attributions in that material or in the Appropriate Legal
|
358 |
+
Notices displayed by works containing it; or
|
359 |
+
|
360 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
361 |
+
requiring that modified versions of such material be marked in
|
362 |
+
reasonable ways as different from the original version; or
|
363 |
+
|
364 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
365 |
+
authors of the material; or
|
366 |
+
|
367 |
+
e) Declining to grant rights under trademark law for use of some
|
368 |
+
trade names, trademarks, or service marks; or
|
369 |
+
|
370 |
+
f) Requiring indemnification of licensors and authors of that
|
371 |
+
material by anyone who conveys the material (or modified versions of
|
372 |
+
it) with contractual assumptions of liability to the recipient, for
|
373 |
+
any liability that these contractual assumptions directly impose on
|
374 |
+
those licensors and authors.
|
375 |
+
|
376 |
+
All other non-permissive additional terms are considered "further
|
377 |
+
restrictions" within the meaning of section 10. If the Program as you
|
378 |
+
received it, or any part of it, contains a notice stating that it is
|
379 |
+
governed by this License along with a term that is a further
|
380 |
+
restriction, you may remove that term. If a license document contains
|
381 |
+
a further restriction but permits relicensing or conveying under this
|
382 |
+
License, you may add to a covered work material governed by the terms
|
383 |
+
of that license document, provided that the further restriction does
|
384 |
+
not survive such relicensing or conveying.
|
385 |
+
|
386 |
+
If you add terms to a covered work in accord with this section, you
|
387 |
+
must place, in the relevant source files, a statement of the
|
388 |
+
additional terms that apply to those files, or a notice indicating
|
389 |
+
where to find the applicable terms.
|
390 |
+
|
391 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
392 |
+
form of a separately written license, or stated as exceptions;
|
393 |
+
the above requirements apply either way.
|
394 |
+
|
395 |
+
8. Termination.
|
396 |
+
|
397 |
+
You may not propagate or modify a covered work except as expressly
|
398 |
+
provided under this License. Any attempt otherwise to propagate or
|
399 |
+
modify it is void, and will automatically terminate your rights under
|
400 |
+
this License (including any patent licenses granted under the third
|
401 |
+
paragraph of section 11).
|
402 |
+
|
403 |
+
However, if you cease all violation of this License, then your
|
404 |
+
license from a particular copyright holder is reinstated (a)
|
405 |
+
provisionally, unless and until the copyright holder explicitly and
|
406 |
+
finally terminates your license, and (b) permanently, if the copyright
|
407 |
+
holder fails to notify you of the violation by some reasonable means
|
408 |
+
prior to 60 days after the cessation.
|
409 |
+
|
410 |
+
Moreover, your license from a particular copyright holder is
|
411 |
+
reinstated permanently if the copyright holder notifies you of the
|
412 |
+
violation by some reasonable means, this is the first time you have
|
413 |
+
received notice of violation of this License (for any work) from that
|
414 |
+
copyright holder, and you cure the violation prior to 30 days after
|
415 |
+
your receipt of the notice.
|
416 |
+
|
417 |
+
Termination of your rights under this section does not terminate the
|
418 |
+
licenses of parties who have received copies or rights from you under
|
419 |
+
this License. If your rights have been terminated and not permanently
|
420 |
+
reinstated, you do not qualify to receive new licenses for the same
|
421 |
+
material under section 10.
|
422 |
+
|
423 |
+
9. Acceptance Not Required for Having Copies.
|
424 |
+
|
425 |
+
You are not required to accept this License in order to receive or
|
426 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
427 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
428 |
+
to receive a copy likewise does not require acceptance. However,
|
429 |
+
nothing other than this License grants you permission to propagate or
|
430 |
+
modify any covered work. These actions infringe copyright if you do
|
431 |
+
not accept this License. Therefore, by modifying or propagating a
|
432 |
+
covered work, you indicate your acceptance of this License to do so.
|
433 |
+
|
434 |
+
10. Automatic Licensing of Downstream Recipients.
|
435 |
+
|
436 |
+
Each time you convey a covered work, the recipient automatically
|
437 |
+
receives a license from the original licensors, to run, modify and
|
438 |
+
propagate that work, subject to this License. You are not responsible
|
439 |
+
for enforcing compliance by third parties with this License.
|
440 |
+
|
441 |
+
An "entity transaction" is a transaction transferring control of an
|
442 |
+
organization, or substantially all assets of one, or subdividing an
|
443 |
+
organization, or merging organizations. If propagation of a covered
|
444 |
+
work results from an entity transaction, each party to that
|
445 |
+
transaction who receives a copy of the work also receives whatever
|
446 |
+
licenses to the work the party's predecessor in interest had or could
|
447 |
+
give under the previous paragraph, plus a right to possession of the
|
448 |
+
Corresponding Source of the work from the predecessor in interest, if
|
449 |
+
the predecessor has it or can get it with reasonable efforts.
|
450 |
+
|
451 |
+
You may not impose any further restrictions on the exercise of the
|
452 |
+
rights granted or affirmed under this License. For example, you may
|
453 |
+
not impose a license fee, royalty, or other charge for exercise of
|
454 |
+
rights granted under this License, and you may not initiate litigation
|
455 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
456 |
+
any patent claim is infringed by making, using, selling, offering for
|
457 |
+
sale, or importing the Program or any portion of it.
|
458 |
+
|
459 |
+
11. Patents.
|
460 |
+
|
461 |
+
A "contributor" is a copyright holder who authorizes use under this
|
462 |
+
License of the Program or a work on which the Program is based. The
|
463 |
+
work thus licensed is called the contributor's "contributor version".
|
464 |
+
|
465 |
+
A contributor's "essential patent claims" are all patent claims
|
466 |
+
owned or controlled by the contributor, whether already acquired or
|
467 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
468 |
+
by this License, of making, using, or selling its contributor version,
|
469 |
+
but do not include claims that would be infringed only as a
|
470 |
+
consequence of further modification of the contributor version. For
|
471 |
+
purposes of this definition, "control" includes the right to grant
|
472 |
+
patent sublicenses in a manner consistent with the requirements of
|
473 |
+
this License.
|
474 |
+
|
475 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
476 |
+
patent license under the contributor's essential patent claims, to
|
477 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
478 |
+
propagate the contents of its contributor version.
|
479 |
+
|
480 |
+
In the following three paragraphs, a "patent license" is any express
|
481 |
+
agreement or commitment, however denominated, not to enforce a patent
|
482 |
+
(such as an express permission to practice a patent or covenant not to
|
483 |
+
sue for patent infringement). To "grant" such a patent license to a
|
484 |
+
party means to make such an agreement or commitment not to enforce a
|
485 |
+
patent against the party.
|
486 |
+
|
487 |
+
If you convey a covered work, knowingly relying on a patent license,
|
488 |
+
and the Corresponding Source of the work is not available for anyone
|
489 |
+
to copy, free of charge and under the terms of this License, through a
|
490 |
+
publicly available network server or other readily accessible means,
|
491 |
+
then you must either (1) cause the Corresponding Source to be so
|
492 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
493 |
+
patent license for this particular work, or (3) arrange, in a manner
|
494 |
+
consistent with the requirements of this License, to extend the patent
|
495 |
+
license to downstream recipients. "Knowingly relying" means you have
|
496 |
+
actual knowledge that, but for the patent license, your conveying the
|
497 |
+
covered work in a country, or your recipient's use of the covered work
|
498 |
+
in a country, would infringe one or more identifiable patents in that
|
499 |
+
country that you have reason to believe are valid.
|
500 |
+
|
501 |
+
If, pursuant to or in connection with a single transaction or
|
502 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
503 |
+
covered work, and grant a patent license to some of the parties
|
504 |
+
receiving the covered work authorizing them to use, propagate, modify
|
505 |
+
or convey a specific copy of the covered work, then the patent license
|
506 |
+
you grant is automatically extended to all recipients of the covered
|
507 |
+
work and works based on it.
|
508 |
+
|
509 |
+
A patent license is "discriminatory" if it does not include within
|
510 |
+
the scope of its coverage, prohibits the exercise of, or is
|
511 |
+
conditioned on the non-exercise of one or more of the rights that are
|
512 |
+
specifically granted under this License. You may not convey a covered
|
513 |
+
work if you are a party to an arrangement with a third party that is
|
514 |
+
in the business of distributing software, under which you make payment
|
515 |
+
to the third party based on the extent of your activity of conveying
|
516 |
+
the work, and under which the third party grants, to any of the
|
517 |
+
parties who would receive the covered work from you, a discriminatory
|
518 |
+
patent license (a) in connection with copies of the covered work
|
519 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
520 |
+
for and in connection with specific products or compilations that
|
521 |
+
contain the covered work, unless you entered into that arrangement,
|
522 |
+
or that patent license was granted, prior to 28 March 2007.
|
523 |
+
|
524 |
+
Nothing in this License shall be construed as excluding or limiting
|
525 |
+
any implied license or other defenses to infringement that may
|
526 |
+
otherwise be available to you under applicable patent law.
|
527 |
+
|
528 |
+
12. No Surrender of Others' Freedom.
|
529 |
+
|
530 |
+
If conditions are imposed on you (whether by court order, agreement or
|
531 |
+
otherwise) that contradict the conditions of this License, they do not
|
532 |
+
excuse you from the conditions of this License. If you cannot convey a
|
533 |
+
covered work so as to satisfy simultaneously your obligations under this
|
534 |
+
License and any other pertinent obligations, then as a consequence you may
|
535 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
536 |
+
to collect a royalty for further conveying from those to whom you convey
|
537 |
+
the Program, the only way you could satisfy both those terms and this
|
538 |
+
License would be to refrain entirely from conveying the Program.
|
539 |
+
|
540 |
+
13. Remote Network Interaction; Use with the GNU General Public License.
|
541 |
+
|
542 |
+
Notwithstanding any other provision of this License, if you modify the
|
543 |
+
Program, your modified version must prominently offer all users
|
544 |
+
interacting with it remotely through a computer network (if your version
|
545 |
+
supports such interaction) an opportunity to receive the Corresponding
|
546 |
+
Source of your version by providing access to the Corresponding Source
|
547 |
+
from a network server at no charge, through some standard or customary
|
548 |
+
means of facilitating copying of software. This Corresponding Source
|
549 |
+
shall include the Corresponding Source for any work covered by version 3
|
550 |
+
of the GNU General Public License that is incorporated pursuant to the
|
551 |
+
following paragraph.
|
552 |
+
|
553 |
+
Notwithstanding any other provision of this License, you have
|
554 |
+
permission to link or combine any covered work with a work licensed
|
555 |
+
under version 3 of the GNU General Public License into a single
|
556 |
+
combined work, and to convey the resulting work. The terms of this
|
557 |
+
License will continue to apply to the part which is the covered work,
|
558 |
+
but the work with which it is combined will remain governed by version
|
559 |
+
3 of the GNU General Public License.
|
560 |
+
|
561 |
+
14. Revised Versions of this License.
|
562 |
+
|
563 |
+
The Free Software Foundation may publish revised and/or new versions of
|
564 |
+
the GNU Affero General Public License from time to time. Such new versions
|
565 |
+
will be similar in spirit to the present version, but may differ in detail to
|
566 |
+
address new problems or concerns.
|
567 |
+
|
568 |
+
Each version is given a distinguishing version number. If the
|
569 |
+
Program specifies that a certain numbered version of the GNU Affero General
|
570 |
+
Public License "or any later version" applies to it, you have the
|
571 |
+
option of following the terms and conditions either of that numbered
|
572 |
+
version or of any later version published by the Free Software
|
573 |
+
Foundation. If the Program does not specify a version number of the
|
574 |
+
GNU Affero General Public License, you may choose any version ever published
|
575 |
+
by the Free Software Foundation.
|
576 |
+
|
577 |
+
If the Program specifies that a proxy can decide which future
|
578 |
+
versions of the GNU Affero General Public License can be used, that proxy's
|
579 |
+
public statement of acceptance of a version permanently authorizes you
|
580 |
+
to choose that version for the Program.
|
581 |
+
|
582 |
+
Later license versions may give you additional or different
|
583 |
+
permissions. However, no additional obligations are imposed on any
|
584 |
+
author or copyright holder as a result of your choosing to follow a
|
585 |
+
later version.
|
586 |
+
|
587 |
+
15. Disclaimer of Warranty.
|
588 |
+
|
589 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
590 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
591 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
592 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
593 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
594 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
595 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
596 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
597 |
+
|
598 |
+
16. Limitation of Liability.
|
599 |
+
|
600 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
601 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
602 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
603 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
604 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
605 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
606 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
607 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
608 |
+
SUCH DAMAGES.
|
609 |
+
|
610 |
+
17. Interpretation of Sections 15 and 16.
|
611 |
+
|
612 |
+
If the disclaimer of warranty and limitation of liability provided
|
613 |
+
above cannot be given local legal effect according to their terms,
|
614 |
+
reviewing courts shall apply local law that most closely approximates
|
615 |
+
an absolute waiver of all civil liability in connection with the
|
616 |
+
Program, unless a warranty or assumption of liability accompanies a
|
617 |
+
copy of the Program in return for a fee.
|
618 |
+
|
619 |
+
END OF TERMS AND CONDITIONS
|
620 |
+
|
621 |
+
How to Apply These Terms to Your New Programs
|
622 |
+
|
623 |
+
If you develop a new program, and you want it to be of the greatest
|
624 |
+
possible use to the public, the best way to achieve this is to make it
|
625 |
+
free software which everyone can redistribute and change under these terms.
|
626 |
+
|
627 |
+
To do so, attach the following notices to the program. It is safest
|
628 |
+
to attach them to the start of each source file to most effectively
|
629 |
+
state the exclusion of warranty; and each file should have at least
|
630 |
+
the "copyright" line and a pointer to where the full notice is found.
|
631 |
+
|
632 |
+
<one line to give the program's name and a brief idea of what it does.>
|
633 |
+
Copyright (C) <year> <name of author>
|
634 |
+
|
635 |
+
This program is free software: you can redistribute it and/or modify
|
636 |
+
it under the terms of the GNU Affero General Public License as published
|
637 |
+
by the Free Software Foundation, either version 3 of the License, or
|
638 |
+
(at your option) any later version.
|
639 |
+
|
640 |
+
This program is distributed in the hope that it will be useful,
|
641 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
642 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
643 |
+
GNU Affero General Public License for more details.
|
644 |
+
|
645 |
+
You should have received a copy of the GNU Affero General Public License
|
646 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
647 |
+
|
648 |
+
Also add information on how to contact you by electronic and paper mail.
|
649 |
+
|
650 |
+
If your software can interact with users remotely through a computer
|
651 |
+
network, you should also make sure that it provides a way for users to
|
652 |
+
get its source. For example, if your program is a web application, its
|
653 |
+
interface could display a "Source" link that leads users to an archive
|
654 |
+
of the code. There are many ways you could offer source, and different
|
655 |
+
solutions will be better for different programs; see section 13 for the
|
656 |
+
specific requirements.
|
657 |
+
|
658 |
+
You should also get your employer (if you work as a programmer) or school,
|
659 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
660 |
+
For more information on this, and how to apply and follow the GNU AGPL, see
|
661 |
+
<https://www.gnu.org/licenses/>.
|
makeavid_sd/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# makeavid-sd-tpu
|
makeavid_sd/makeavid_sd/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = '0.1.0'
|
makeavid_sd/makeavid_sd/flax_impl/__init__.py
ADDED
File without changes
|
makeavid_sd/makeavid_sd/flax_impl/dataset.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List, Dict, Any, Union, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import DataLoader, ConcatDataset
|
6 |
+
import datasets
|
7 |
+
from diffusers import DDPMScheduler
|
8 |
+
from functools import partial
|
9 |
+
import random
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
@torch.no_grad()
|
15 |
+
def collate_fn(
|
16 |
+
batch: List[Dict[str, Any]],
|
17 |
+
noise_scheduler: DDPMScheduler,
|
18 |
+
num_frames: int,
|
19 |
+
hint_spacing: Optional[int] = None,
|
20 |
+
as_numpy: bool = True
|
21 |
+
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
|
22 |
+
if hint_spacing is None or hint_spacing < 1:
|
23 |
+
hint_spacing = num_frames
|
24 |
+
if as_numpy:
|
25 |
+
dtype = np.float32
|
26 |
+
else:
|
27 |
+
dtype = torch.float32
|
28 |
+
prompts = []
|
29 |
+
videos = []
|
30 |
+
for s in batch:
|
31 |
+
# prompt
|
32 |
+
prompts.append(torch.tensor(s['prompt']).to(dtype = torch.float32))
|
33 |
+
# frames
|
34 |
+
frames = torch.tensor(s['video']).to(dtype = torch.float32)
|
35 |
+
max_frames = len(frames)
|
36 |
+
assert max_frames >= num_frames
|
37 |
+
video_slice = random.randint(0, max_frames - num_frames)
|
38 |
+
frames = frames[video_slice:video_slice + num_frames]
|
39 |
+
frames = frames.permute(1, 0, 2, 3) # f, c, h, w -> c, f, h, w
|
40 |
+
videos.append(frames)
|
41 |
+
|
42 |
+
encoder_hidden_states = torch.cat(prompts) # b, 77, 768
|
43 |
+
|
44 |
+
latents = torch.stack(videos) # b, c, f, h, w
|
45 |
+
latents = latents * 0.18215
|
46 |
+
hint_latents = latents[:, :, ::hint_spacing, :, :]
|
47 |
+
hint_latents = hint_latents.repeat_interleave(hint_spacing, 2)
|
48 |
+
#hint_latents = hint_latents[:, :, :num_frames-1, :, :]
|
49 |
+
#input_latents = latents[:, :, 1:, :, :]
|
50 |
+
input_latents = latents
|
51 |
+
noise = torch.randn_like(input_latents)
|
52 |
+
bsz = input_latents.shape[0]
|
53 |
+
timesteps = torch.randint(
|
54 |
+
0,
|
55 |
+
noise_scheduler.config.num_train_timesteps,
|
56 |
+
(bsz,),
|
57 |
+
dtype = torch.int64
|
58 |
+
)
|
59 |
+
noisy_latents = noise_scheduler.add_noise(input_latents, noise, timesteps)
|
60 |
+
mask = torch.zeros([
|
61 |
+
noisy_latents.shape[0],
|
62 |
+
1,
|
63 |
+
noisy_latents.shape[2],
|
64 |
+
noisy_latents.shape[3],
|
65 |
+
noisy_latents.shape[4]
|
66 |
+
])
|
67 |
+
latent_model_input = torch.cat([noisy_latents, mask, hint_latents], dim = 1)
|
68 |
+
|
69 |
+
latent_model_input = latent_model_input.to(memory_format = torch.contiguous_format)
|
70 |
+
encoder_hidden_states = encoder_hidden_states.to(memory_format = torch.contiguous_format)
|
71 |
+
timesteps = timesteps.to(memory_format = torch.contiguous_format)
|
72 |
+
noise = noise.to(memory_format = torch.contiguous_format)
|
73 |
+
|
74 |
+
if as_numpy:
|
75 |
+
latent_model_input = latent_model_input.numpy().astype(dtype)
|
76 |
+
encoder_hidden_states = encoder_hidden_states.numpy().astype(dtype)
|
77 |
+
timesteps = timesteps.numpy().astype(np.int32)
|
78 |
+
noise = noise.numpy().astype(dtype)
|
79 |
+
else:
|
80 |
+
latent_model_input = latent_model_input.to(dtype = dtype)
|
81 |
+
encoder_hidden_states = encoder_hidden_states.to(dtype = dtype)
|
82 |
+
noise = noise.to(dtype = dtype)
|
83 |
+
|
84 |
+
return {
|
85 |
+
'latent_model_input': latent_model_input,
|
86 |
+
'encoder_hidden_states': encoder_hidden_states,
|
87 |
+
'timesteps': timesteps,
|
88 |
+
'noise': noise
|
89 |
+
}
|
90 |
+
|
91 |
+
def worker_init_fn(worker_id: int):
|
92 |
+
wseed = torch.initial_seed() % 4294967294 # max val for random 2**32 - 1
|
93 |
+
random.seed(wseed)
|
94 |
+
np.random.seed(wseed)
|
95 |
+
|
96 |
+
|
97 |
+
def load_dataset(
|
98 |
+
dataset_path: str,
|
99 |
+
model_path: str,
|
100 |
+
cache_dir: Optional[str] = None,
|
101 |
+
batch_size: int = 1,
|
102 |
+
num_frames: int = 24,
|
103 |
+
hint_spacing: Optional[int] = None,
|
104 |
+
num_workers: int = 0,
|
105 |
+
shuffle: bool = False,
|
106 |
+
as_numpy: bool = True,
|
107 |
+
pin_memory: bool = False,
|
108 |
+
pin_memory_device: str = ''
|
109 |
+
) -> DataLoader:
|
110 |
+
noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(
|
111 |
+
model_path,
|
112 |
+
subfolder = 'scheduler'
|
113 |
+
)
|
114 |
+
dataset = datasets.load_dataset(
|
115 |
+
dataset_path,
|
116 |
+
streaming = False,
|
117 |
+
cache_dir = cache_dir
|
118 |
+
)
|
119 |
+
merged_dataset = ConcatDataset([ dataset[s] for s in dataset ])
|
120 |
+
dataloader = DataLoader(
|
121 |
+
merged_dataset,
|
122 |
+
batch_size = batch_size,
|
123 |
+
num_workers = num_workers,
|
124 |
+
persistent_workers = num_workers > 0,
|
125 |
+
drop_last = True,
|
126 |
+
shuffle = shuffle,
|
127 |
+
worker_init_fn = worker_init_fn,
|
128 |
+
collate_fn = partial(collate_fn,
|
129 |
+
noise_scheduler = noise_scheduler,
|
130 |
+
num_frames = num_frames,
|
131 |
+
hint_spacing = hint_spacing,
|
132 |
+
as_numpy = as_numpy
|
133 |
+
),
|
134 |
+
pin_memory = pin_memory,
|
135 |
+
pin_memory_device = pin_memory_device
|
136 |
+
)
|
137 |
+
return dataloader
|
138 |
+
|
139 |
+
|
140 |
+
def validate_dataset(
|
141 |
+
dataset_path: str
|
142 |
+
) -> List[int]:
|
143 |
+
import os
|
144 |
+
import json
|
145 |
+
data_path = os.path.join(dataset_path, 'data')
|
146 |
+
meta = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'metadata')))
|
147 |
+
prompts = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'prompts')))
|
148 |
+
videos = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'videos')))
|
149 |
+
ok = meta.intersection(prompts).intersection(videos)
|
150 |
+
all_of_em = meta.union(prompts).union(videos)
|
151 |
+
not_ok = []
|
152 |
+
for a in all_of_em:
|
153 |
+
if a not in ok:
|
154 |
+
not_ok.append(a)
|
155 |
+
ok = list(ok)
|
156 |
+
ok.sort()
|
157 |
+
with open(os.path.join(data_path, 'id_list.json'), 'w') as f:
|
158 |
+
json.dump(ok, f)
|
159 |
+
|
makeavid_sd/makeavid_sd/flax_impl/flax_attention_pseudo3d.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import flax.linen as nn
|
7 |
+
|
8 |
+
import einops
|
9 |
+
|
10 |
+
#from flax_memory_efficient_attention import jax_memory_efficient_attention
|
11 |
+
#from flax_attention import FlaxAttention
|
12 |
+
from diffusers.models.attention_flax import FlaxAttention
|
13 |
+
|
14 |
+
|
15 |
+
class TransformerPseudo3DModel(nn.Module):
|
16 |
+
in_channels: int
|
17 |
+
num_attention_heads: int
|
18 |
+
attention_head_dim: int
|
19 |
+
num_layers: int = 1
|
20 |
+
use_memory_efficient_attention: bool = False
|
21 |
+
dtype: jnp.dtype = jnp.float32
|
22 |
+
|
23 |
+
def setup(self) -> None:
|
24 |
+
inner_dim = self.num_attention_heads * self.attention_head_dim
|
25 |
+
self.norm = nn.GroupNorm(
|
26 |
+
num_groups = 32,
|
27 |
+
epsilon = 1e-5
|
28 |
+
)
|
29 |
+
self.proj_in = nn.Conv(
|
30 |
+
inner_dim,
|
31 |
+
kernel_size = (1, 1),
|
32 |
+
strides = (1, 1),
|
33 |
+
padding = 'VALID',
|
34 |
+
dtype = self.dtype
|
35 |
+
)
|
36 |
+
transformer_blocks = []
|
37 |
+
#CheckpointTransformerBlock = nn.checkpoint(
|
38 |
+
# BasicTransformerBlockPseudo3D,
|
39 |
+
# static_argnums = (2,3,4)
|
40 |
+
# #prevent_cse = False
|
41 |
+
#)
|
42 |
+
CheckpointTransformerBlock = BasicTransformerBlockPseudo3D
|
43 |
+
for _ in range(self.num_layers):
|
44 |
+
transformer_blocks.append(CheckpointTransformerBlock(
|
45 |
+
dim = inner_dim,
|
46 |
+
num_attention_heads = self.num_attention_heads,
|
47 |
+
attention_head_dim = self.attention_head_dim,
|
48 |
+
use_memory_efficient_attention = self.use_memory_efficient_attention,
|
49 |
+
dtype = self.dtype
|
50 |
+
))
|
51 |
+
self.transformer_blocks = transformer_blocks
|
52 |
+
self.proj_out = nn.Conv(
|
53 |
+
inner_dim,
|
54 |
+
kernel_size = (1, 1),
|
55 |
+
strides = (1, 1),
|
56 |
+
padding = 'VALID',
|
57 |
+
dtype = self.dtype
|
58 |
+
)
|
59 |
+
|
60 |
+
def __call__(self,
|
61 |
+
hidden_states: jax.Array,
|
62 |
+
encoder_hidden_states: Optional[jax.Array] = None
|
63 |
+
) -> jax.Array:
|
64 |
+
is_video = hidden_states.ndim == 5
|
65 |
+
f: Optional[int] = None
|
66 |
+
if is_video:
|
67 |
+
# jax is channels last
|
68 |
+
# b,c,f,h,w WRONG
|
69 |
+
# b,f,h,w,c CORRECT
|
70 |
+
# b, c, f, h, w = hidden_states.shape
|
71 |
+
#hidden_states = einops.rearrange(hidden_states, 'b c f h w -> (b f) c h w')
|
72 |
+
b, f, h, w, c = hidden_states.shape
|
73 |
+
hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
|
74 |
+
|
75 |
+
batch, height, width, channels = hidden_states.shape
|
76 |
+
residual = hidden_states
|
77 |
+
hidden_states = self.norm(hidden_states)
|
78 |
+
hidden_states = self.proj_in(hidden_states)
|
79 |
+
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
80 |
+
for block in self.transformer_blocks:
|
81 |
+
hidden_states = block(
|
82 |
+
hidden_states,
|
83 |
+
encoder_hidden_states,
|
84 |
+
f,
|
85 |
+
height,
|
86 |
+
width
|
87 |
+
)
|
88 |
+
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
89 |
+
hidden_states = self.proj_out(hidden_states)
|
90 |
+
hidden_states = hidden_states + residual
|
91 |
+
if is_video:
|
92 |
+
hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
|
93 |
+
return hidden_states
|
94 |
+
|
95 |
+
|
96 |
+
class BasicTransformerBlockPseudo3D(nn.Module):
|
97 |
+
dim: int
|
98 |
+
num_attention_heads: int
|
99 |
+
attention_head_dim: int
|
100 |
+
use_memory_efficient_attention: bool = False
|
101 |
+
dtype: jnp.dtype = jnp.float32
|
102 |
+
|
103 |
+
def setup(self) -> None:
|
104 |
+
self.attn1 = FlaxAttention(
|
105 |
+
query_dim = self.dim,
|
106 |
+
heads = self.num_attention_heads,
|
107 |
+
dim_head = self.attention_head_dim,
|
108 |
+
use_memory_efficient_attention = self.use_memory_efficient_attention,
|
109 |
+
dtype = self.dtype
|
110 |
+
)
|
111 |
+
self.ff = FeedForward(dim = self.dim, dtype = self.dtype)
|
112 |
+
self.attn2 = FlaxAttention(
|
113 |
+
query_dim = self.dim,
|
114 |
+
heads = self.num_attention_heads,
|
115 |
+
dim_head = self.attention_head_dim,
|
116 |
+
use_memory_efficient_attention = self.use_memory_efficient_attention,
|
117 |
+
dtype = self.dtype
|
118 |
+
)
|
119 |
+
self.attn_temporal = FlaxAttention(
|
120 |
+
query_dim = self.dim,
|
121 |
+
heads = self.num_attention_heads,
|
122 |
+
dim_head = self.attention_head_dim,
|
123 |
+
use_memory_efficient_attention = self.use_memory_efficient_attention,
|
124 |
+
dtype = self.dtype
|
125 |
+
)
|
126 |
+
self.norm1 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
|
127 |
+
self.norm2 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
|
128 |
+
self.norm_temporal = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
|
129 |
+
self.norm3 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
|
130 |
+
|
131 |
+
def __call__(self,
|
132 |
+
hidden_states: jax.Array,
|
133 |
+
context: Optional[jax.Array] = None,
|
134 |
+
frames_length: Optional[int] = None,
|
135 |
+
height: Optional[int] = None,
|
136 |
+
width: Optional[int] = None
|
137 |
+
) -> jax.Array:
|
138 |
+
if context is not None and frames_length is not None:
|
139 |
+
context = context.repeat(frames_length, axis = 0)
|
140 |
+
# self attention
|
141 |
+
norm_hidden_states = self.norm1(hidden_states)
|
142 |
+
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
143 |
+
# cross attention
|
144 |
+
norm_hidden_states = self.norm2(hidden_states)
|
145 |
+
hidden_states = self.attn2(
|
146 |
+
norm_hidden_states,
|
147 |
+
context = context
|
148 |
+
) + hidden_states
|
149 |
+
# temporal attention
|
150 |
+
if frames_length is not None:
|
151 |
+
#bf, hw, c = hidden_states.shape
|
152 |
+
# (b f) (h w) c -> b f (h w) c
|
153 |
+
#hidden_states = hidden_states.reshape(bf // frames_length, frames_length, hw, c)
|
154 |
+
#b, f, hw, c = hidden_states.shape
|
155 |
+
# b f (h w) c -> b (h w) f c
|
156 |
+
#hidden_states = hidden_states.transpose(0, 2, 1, 3)
|
157 |
+
# b (h w) f c -> (b h w) f c
|
158 |
+
#hidden_states = hidden_states.reshape(b * hw, frames_length, c)
|
159 |
+
hidden_states = einops.rearrange(
|
160 |
+
hidden_states,
|
161 |
+
'(b f) (h w) c -> (b h w) f c',
|
162 |
+
f = frames_length,
|
163 |
+
h = height,
|
164 |
+
w = width
|
165 |
+
)
|
166 |
+
norm_hidden_states = self.norm_temporal(hidden_states)
|
167 |
+
hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
|
168 |
+
# (b h w) f c -> b (h w) f c
|
169 |
+
#hidden_states = hidden_states.reshape(b, hw, f, c)
|
170 |
+
# b (h w) f c -> b f (h w) c
|
171 |
+
#hidden_states = hidden_states.transpose(0, 2, 1, 3)
|
172 |
+
# b f h w c -> (b f) (h w) c
|
173 |
+
#hidden_states = hidden_states.reshape(bf, hw, c)
|
174 |
+
hidden_states = einops.rearrange(
|
175 |
+
hidden_states,
|
176 |
+
'(b h w) f c -> (b f) (h w) c',
|
177 |
+
f = frames_length,
|
178 |
+
h = height,
|
179 |
+
w = width
|
180 |
+
)
|
181 |
+
norm_hidden_states = self.norm3(hidden_states)
|
182 |
+
hidden_states = self.ff(norm_hidden_states) + hidden_states
|
183 |
+
return hidden_states
|
184 |
+
|
185 |
+
|
186 |
+
class FeedForward(nn.Module):
|
187 |
+
dim: int
|
188 |
+
dtype: jnp.dtype = jnp.float32
|
189 |
+
|
190 |
+
def setup(self) -> None:
|
191 |
+
self.net_0 = GEGLU(self.dim, self.dtype)
|
192 |
+
self.net_2 = nn.Dense(self.dim, dtype = self.dtype)
|
193 |
+
|
194 |
+
def __call__(self, hidden_states: jax.Array) -> jax.Array:
|
195 |
+
hidden_states = self.net_0(hidden_states)
|
196 |
+
hidden_states = self.net_2(hidden_states)
|
197 |
+
return hidden_states
|
198 |
+
|
199 |
+
|
200 |
+
class GEGLU(nn.Module):
|
201 |
+
dim: int
|
202 |
+
dtype: jnp.dtype = jnp.float32
|
203 |
+
|
204 |
+
def setup(self) -> None:
|
205 |
+
inner_dim = self.dim * 4
|
206 |
+
self.proj = nn.Dense(inner_dim * 2, dtype = self.dtype)
|
207 |
+
|
208 |
+
def __call__(self, hidden_states: jax.Array) -> jax.Array:
|
209 |
+
hidden_states = self.proj(hidden_states)
|
210 |
+
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis = 2)
|
211 |
+
return hidden_linear * nn.gelu(hidden_gelu)
|
212 |
+
|
makeavid_sd/makeavid_sd/flax_impl/flax_embeddings.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import jax
|
3 |
+
import jax.numpy as jnp
|
4 |
+
import flax.linen as nn
|
5 |
+
|
6 |
+
|
7 |
+
def get_sinusoidal_embeddings(
|
8 |
+
timesteps: jax.Array,
|
9 |
+
embedding_dim: int,
|
10 |
+
freq_shift: float = 1,
|
11 |
+
min_timescale: float = 1,
|
12 |
+
max_timescale: float = 1.0e4,
|
13 |
+
flip_sin_to_cos: bool = False,
|
14 |
+
scale: float = 1.0,
|
15 |
+
dtype: jnp.dtype = jnp.float32
|
16 |
+
) -> jax.Array:
|
17 |
+
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
|
18 |
+
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
|
19 |
+
num_timescales = float(embedding_dim // 2)
|
20 |
+
log_timescale_increment = jnp.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
|
21 |
+
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype = dtype) * -log_timescale_increment)
|
22 |
+
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
|
23 |
+
|
24 |
+
# scale embeddings
|
25 |
+
scaled_time = scale * emb
|
26 |
+
|
27 |
+
if flip_sin_to_cos:
|
28 |
+
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis = 1)
|
29 |
+
else:
|
30 |
+
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis = 1)
|
31 |
+
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
|
32 |
+
return signal
|
33 |
+
|
34 |
+
|
35 |
+
class TimestepEmbedding(nn.Module):
|
36 |
+
time_embed_dim: int = 32
|
37 |
+
dtype: jnp.dtype = jnp.float32
|
38 |
+
|
39 |
+
@nn.compact
|
40 |
+
def __call__(self, temb: jax.Array) -> jax.Array:
|
41 |
+
temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_1")(temb)
|
42 |
+
temb = nn.silu(temb)
|
43 |
+
temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_2")(temb)
|
44 |
+
return temb
|
45 |
+
|
46 |
+
|
47 |
+
class Timesteps(nn.Module):
|
48 |
+
dim: int = 32
|
49 |
+
flip_sin_to_cos: bool = False
|
50 |
+
freq_shift: float = 1
|
51 |
+
dtype: jnp.dtype = jnp.float32
|
52 |
+
|
53 |
+
@nn.compact
|
54 |
+
def __call__(self, timesteps: jax.Array) -> jax.Array:
|
55 |
+
return get_sinusoidal_embeddings(
|
56 |
+
timesteps = timesteps,
|
57 |
+
embedding_dim = self.dim,
|
58 |
+
flip_sin_to_cos = self.flip_sin_to_cos,
|
59 |
+
freq_shift = self.freq_shift,
|
60 |
+
dtype = self.dtype
|
61 |
+
)
|
62 |
+
|
makeavid_sd/makeavid_sd/flax_impl/flax_resnet_pseudo3d.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Optional, Union, Sequence
|
3 |
+
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import flax.linen as nn
|
7 |
+
|
8 |
+
import einops
|
9 |
+
|
10 |
+
|
11 |
+
class ConvPseudo3D(nn.Module):
|
12 |
+
features: int
|
13 |
+
kernel_size: Sequence[int]
|
14 |
+
strides: Union[None, int, Sequence[int]] = 1
|
15 |
+
padding: nn.linear.PaddingLike = 'SAME'
|
16 |
+
dtype: jnp.dtype = jnp.float32
|
17 |
+
|
18 |
+
def setup(self) -> None:
|
19 |
+
self.spatial_conv = nn.Conv(
|
20 |
+
features = self.features,
|
21 |
+
kernel_size = self.kernel_size,
|
22 |
+
strides = self.strides,
|
23 |
+
padding = self.padding,
|
24 |
+
dtype = self.dtype
|
25 |
+
)
|
26 |
+
self.temporal_conv = nn.Conv(
|
27 |
+
features = self.features,
|
28 |
+
kernel_size = (3,),
|
29 |
+
padding = 'SAME',
|
30 |
+
dtype = self.dtype,
|
31 |
+
bias_init = nn.initializers.zeros_init()
|
32 |
+
# TODO dirac delta (identity) initialization impl
|
33 |
+
# kernel_init = torch.nn.init.dirac_ <-> jax/lax
|
34 |
+
)
|
35 |
+
|
36 |
+
def __call__(self, x: jax.Array, convolve_across_time: bool = True) -> jax.Array:
|
37 |
+
is_video = x.ndim == 5
|
38 |
+
convolve_across_time = convolve_across_time and is_video
|
39 |
+
if is_video:
|
40 |
+
b, f, h, w, c = x.shape
|
41 |
+
x = einops.rearrange(x, 'b f h w c -> (b f) h w c')
|
42 |
+
x = self.spatial_conv(x)
|
43 |
+
if is_video:
|
44 |
+
x = einops.rearrange(x, '(b f) h w c -> b f h w c', b = b)
|
45 |
+
b, f, h, w, c = x.shape
|
46 |
+
if not convolve_across_time:
|
47 |
+
return x
|
48 |
+
if is_video:
|
49 |
+
x = einops.rearrange(x, 'b f h w c -> (b h w) f c')
|
50 |
+
x = self.temporal_conv(x)
|
51 |
+
x = einops.rearrange(x, '(b h w) f c -> b f h w c', h = h, w = w)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class UpsamplePseudo3D(nn.Module):
|
56 |
+
out_channels: int
|
57 |
+
dtype: jnp.dtype = jnp.float32
|
58 |
+
|
59 |
+
def setup(self) -> None:
|
60 |
+
self.conv = ConvPseudo3D(
|
61 |
+
features = self.out_channels,
|
62 |
+
kernel_size = (3, 3),
|
63 |
+
strides = (1, 1),
|
64 |
+
padding = ((1, 1), (1, 1)),
|
65 |
+
dtype = self.dtype
|
66 |
+
)
|
67 |
+
|
68 |
+
def __call__(self, hidden_states: jax.Array) -> jax.Array:
|
69 |
+
is_video = hidden_states.ndim == 5
|
70 |
+
if is_video:
|
71 |
+
b, *_ = hidden_states.shape
|
72 |
+
hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
|
73 |
+
batch, h, w, c = hidden_states.shape
|
74 |
+
hidden_states = jax.image.resize(
|
75 |
+
image = hidden_states,
|
76 |
+
shape = (batch, h * 2, w * 2, c),
|
77 |
+
method = 'nearest'
|
78 |
+
)
|
79 |
+
if is_video:
|
80 |
+
hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
|
81 |
+
hidden_states = self.conv(hidden_states)
|
82 |
+
return hidden_states
|
83 |
+
|
84 |
+
|
85 |
+
class DownsamplePseudo3D(nn.Module):
|
86 |
+
out_channels: int
|
87 |
+
dtype: jnp.dtype = jnp.float32
|
88 |
+
|
89 |
+
def setup(self) -> None:
|
90 |
+
self.conv = ConvPseudo3D(
|
91 |
+
features = self.out_channels,
|
92 |
+
kernel_size = (3, 3),
|
93 |
+
strides = (2, 2),
|
94 |
+
padding = ((1, 1), (1, 1)),
|
95 |
+
dtype = self.dtype
|
96 |
+
)
|
97 |
+
|
98 |
+
def __call__(self, hidden_states: jax.Array) -> jax.Array:
|
99 |
+
hidden_states = self.conv(hidden_states)
|
100 |
+
return hidden_states
|
101 |
+
|
102 |
+
|
103 |
+
class ResnetBlockPseudo3D(nn.Module):
|
104 |
+
in_channels: int
|
105 |
+
out_channels: Optional[int] = None
|
106 |
+
use_nin_shortcut: Optional[bool] = None
|
107 |
+
dtype: jnp.dtype = jnp.float32
|
108 |
+
|
109 |
+
def setup(self) -> None:
|
110 |
+
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
111 |
+
self.norm1 = nn.GroupNorm(
|
112 |
+
num_groups = 32,
|
113 |
+
epsilon = 1e-5
|
114 |
+
)
|
115 |
+
self.conv1 = ConvPseudo3D(
|
116 |
+
features = out_channels,
|
117 |
+
kernel_size = (3, 3),
|
118 |
+
strides = (1, 1),
|
119 |
+
padding = ((1, 1), (1, 1)),
|
120 |
+
dtype = self.dtype
|
121 |
+
)
|
122 |
+
self.time_emb_proj = nn.Dense(
|
123 |
+
out_channels,
|
124 |
+
dtype = self.dtype
|
125 |
+
)
|
126 |
+
self.norm2 = nn.GroupNorm(
|
127 |
+
num_groups = 32,
|
128 |
+
epsilon = 1e-5
|
129 |
+
)
|
130 |
+
self.conv2 = ConvPseudo3D(
|
131 |
+
features = out_channels,
|
132 |
+
kernel_size = (3, 3),
|
133 |
+
strides = (1, 1),
|
134 |
+
padding = ((1, 1), (1, 1)),
|
135 |
+
dtype = self.dtype
|
136 |
+
)
|
137 |
+
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
|
138 |
+
self.conv_shortcut = None
|
139 |
+
if use_nin_shortcut:
|
140 |
+
self.conv_shortcut = ConvPseudo3D(
|
141 |
+
features = self.out_channels,
|
142 |
+
kernel_size = (1, 1),
|
143 |
+
strides = (1, 1),
|
144 |
+
padding = 'VALID',
|
145 |
+
dtype = self.dtype
|
146 |
+
)
|
147 |
+
|
148 |
+
def __call__(self,
|
149 |
+
hidden_states: jax.Array,
|
150 |
+
temb: jax.Array
|
151 |
+
) -> jax.Array:
|
152 |
+
is_video = hidden_states.ndim == 5
|
153 |
+
residual = hidden_states
|
154 |
+
hidden_states = self.norm1(hidden_states)
|
155 |
+
hidden_states = nn.silu(hidden_states)
|
156 |
+
hidden_states = self.conv1(hidden_states)
|
157 |
+
temb = nn.silu(temb)
|
158 |
+
temb = self.time_emb_proj(temb)
|
159 |
+
temb = jnp.expand_dims(temb, 1)
|
160 |
+
temb = jnp.expand_dims(temb, 1)
|
161 |
+
if is_video:
|
162 |
+
b, f, *_ = hidden_states.shape
|
163 |
+
hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
|
164 |
+
hidden_states = hidden_states + temb.repeat(f, 0)
|
165 |
+
hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
|
166 |
+
else:
|
167 |
+
hidden_states = hidden_states + temb
|
168 |
+
hidden_states = self.norm2(hidden_states)
|
169 |
+
hidden_states = nn.silu(hidden_states)
|
170 |
+
hidden_states = self.conv2(hidden_states)
|
171 |
+
if self.conv_shortcut is not None:
|
172 |
+
residual = self.conv_shortcut(residual)
|
173 |
+
hidden_states = hidden_states + residual
|
174 |
+
return hidden_states
|
175 |
+
|
makeavid_sd/makeavid_sd/flax_impl/flax_trainer.py
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Any, Optional, Union, Tuple, Dict, List
|
3 |
+
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import math
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
from tqdm.auto import tqdm, trange
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
import jax
|
15 |
+
import jax.numpy as jnp
|
16 |
+
import optax
|
17 |
+
from flax import jax_utils, traverse_util
|
18 |
+
from flax.core.frozen_dict import FrozenDict
|
19 |
+
from flax.training.train_state import TrainState
|
20 |
+
from flax.training.common_utils import shard
|
21 |
+
|
22 |
+
# convert 2D -> 3D
|
23 |
+
from diffusers import FlaxUNet2DConditionModel
|
24 |
+
|
25 |
+
# inference test, run on these on cpu
|
26 |
+
from diffusers import AutoencoderKL
|
27 |
+
from diffusers.schedulers.scheduling_ddim_flax import FlaxDDIMScheduler, DDIMSchedulerState
|
28 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
29 |
+
from PIL import Image
|
30 |
+
|
31 |
+
|
32 |
+
from .flax_unet_pseudo3d_condition import UNetPseudo3DConditionModel
|
33 |
+
|
34 |
+
|
35 |
+
def seed_all(seed: int) -> jax.random.PRNGKeyArray:
|
36 |
+
random.seed(seed)
|
37 |
+
np.random.seed(seed)
|
38 |
+
torch.manual_seed(seed)
|
39 |
+
rng = jax.random.PRNGKey(seed)
|
40 |
+
return rng
|
41 |
+
|
42 |
+
def count_params(
|
43 |
+
params: Union[Dict[str, Any],
|
44 |
+
FrozenDict[str, Any]],
|
45 |
+
filter_name: Optional[str] = None
|
46 |
+
) -> int:
|
47 |
+
p: Dict[Tuple[str], jax.Array] = traverse_util.flatten_dict(params)
|
48 |
+
cc = 0
|
49 |
+
for k in p:
|
50 |
+
if filter_name is not None:
|
51 |
+
if filter_name in ' '.join(k):
|
52 |
+
cc += len(p[k].flatten())
|
53 |
+
else:
|
54 |
+
cc += len(p[k].flatten())
|
55 |
+
return cc
|
56 |
+
|
57 |
+
def map_2d_to_pseudo3d(
|
58 |
+
params2d: Dict[str, Any],
|
59 |
+
params3d: Dict[str, Any],
|
60 |
+
verbose: bool = True
|
61 |
+
) -> Dict[str, Any]:
|
62 |
+
params2d = traverse_util.flatten_dict(params2d)
|
63 |
+
params3d = traverse_util.flatten_dict(params3d)
|
64 |
+
new_params = dict()
|
65 |
+
for k in params3d:
|
66 |
+
if 'spatial_conv' in k:
|
67 |
+
k2d = list(k)
|
68 |
+
k2d.remove('spatial_conv')
|
69 |
+
k2d = tuple(k2d)
|
70 |
+
if verbose:
|
71 |
+
tqdm.write(f'Spatial: {k} <- {k2d}')
|
72 |
+
p = params2d[k2d]
|
73 |
+
elif k not in params2d:
|
74 |
+
if verbose:
|
75 |
+
tqdm.write(f'Missing: {k}')
|
76 |
+
p = params3d[k]
|
77 |
+
else:
|
78 |
+
p = params2d[k]
|
79 |
+
assert p.shape == params3d[k].shape, f'shape mismatch: {k}: {p.shape} != {params3d[k].shape}'
|
80 |
+
new_params[k] = p
|
81 |
+
new_params = traverse_util.unflatten_dict(new_params)
|
82 |
+
return new_params
|
83 |
+
|
84 |
+
|
85 |
+
class FlaxTrainerUNetPseudo3D:
|
86 |
+
def __init__(self,
|
87 |
+
model_path: str,
|
88 |
+
from_pt: bool = True,
|
89 |
+
convert2d: bool = False,
|
90 |
+
sample_size: Tuple[int, int] = (64, 64),
|
91 |
+
seed: int = 0,
|
92 |
+
dtype: str = 'float32',
|
93 |
+
param_dtype: str = 'float32',
|
94 |
+
only_temporal: bool = True,
|
95 |
+
use_memory_efficient_attention = False,
|
96 |
+
verbose: bool = True
|
97 |
+
) -> None:
|
98 |
+
self.verbose = verbose
|
99 |
+
self.tracker: Optional['wandb.sdk.wandb_run.Run'] = None
|
100 |
+
self._use_wandb: bool = False
|
101 |
+
self._tracker_meta: Dict[str, Union[float, int]] = {
|
102 |
+
't00': 0.0,
|
103 |
+
't0': 0.0,
|
104 |
+
'step0': 0
|
105 |
+
}
|
106 |
+
|
107 |
+
self.log('Init JAX')
|
108 |
+
self.num_devices = jax.device_count()
|
109 |
+
self.log(f'Device count: {self.num_devices}')
|
110 |
+
|
111 |
+
self.seed = seed
|
112 |
+
self.rng: jax.random.PRNGKeyArray = seed_all(self.seed)
|
113 |
+
|
114 |
+
self.sample_size = sample_size
|
115 |
+
if dtype == 'float32':
|
116 |
+
self.dtype = jnp.float32
|
117 |
+
elif dtype == 'bfloat16':
|
118 |
+
self.dtype = jnp.bfloat16
|
119 |
+
elif dtype == 'float16':
|
120 |
+
self.dtype = jnp.float16
|
121 |
+
else:
|
122 |
+
raise ValueError(f'unknown type: {dtype}')
|
123 |
+
self.dtype_str: str = dtype
|
124 |
+
if param_dtype not in ['float32', 'bfloat16', 'float16']:
|
125 |
+
raise ValueError(f'unknown parameter type: {param_dtype}')
|
126 |
+
self.param_dtype = param_dtype
|
127 |
+
self._load_models(
|
128 |
+
model_path = model_path,
|
129 |
+
convert2d = convert2d,
|
130 |
+
from_pt = from_pt,
|
131 |
+
use_memory_efficient_attention = use_memory_efficient_attention
|
132 |
+
)
|
133 |
+
self._mark_parameters(only_temporal = only_temporal)
|
134 |
+
# optionally for validation + sampling
|
135 |
+
self.tokenizer: Optional[CLIPTokenizer] = None
|
136 |
+
self.text_encoder: Optional[CLIPTextModel] = None
|
137 |
+
self.vae: Optional[AutoencoderKL] = None
|
138 |
+
self.ddim: Optional[Tuple[FlaxDDIMScheduler, DDIMSchedulerState]] = None
|
139 |
+
|
140 |
+
def log(self, message: Any) -> None:
|
141 |
+
if self.verbose and jax.process_index() == 0:
|
142 |
+
tqdm.write(str(message))
|
143 |
+
|
144 |
+
def log_metrics(self, metrics: dict, step: int, epoch: int) -> None:
|
145 |
+
if jax.process_index() > 0 or (not self.verbose and self.tracker is None):
|
146 |
+
return
|
147 |
+
now = time.monotonic()
|
148 |
+
log_data = {
|
149 |
+
'train/step': step,
|
150 |
+
'train/epoch': epoch,
|
151 |
+
'train/steps_per_sec': (step - self._tracker_meta['step0']) / (now - self._tracker_meta['t0']),
|
152 |
+
**{ f'train/{k}': v for k, v in metrics.items() }
|
153 |
+
}
|
154 |
+
self._tracker_meta['t0'] = now
|
155 |
+
self._tracker_meta['step0'] = step
|
156 |
+
self.log(log_data)
|
157 |
+
if self.tracker is not None:
|
158 |
+
self.tracker.log(log_data, step = step)
|
159 |
+
|
160 |
+
|
161 |
+
def enable_wandb(self, enable: bool = True) -> None:
|
162 |
+
self._use_wandb = enable
|
163 |
+
|
164 |
+
def _setup_wandb(self, config: Dict[str, Any] = dict()) -> None:
|
165 |
+
import wandb
|
166 |
+
import wandb.sdk
|
167 |
+
self.tracker: wandb.sdk.wandb_run.Run = wandb.init(
|
168 |
+
config = config,
|
169 |
+
settings = wandb.sdk.Settings(
|
170 |
+
username = 'anon',
|
171 |
+
host = 'anon',
|
172 |
+
email = 'anon',
|
173 |
+
root_dir = 'anon',
|
174 |
+
_executable = 'anon',
|
175 |
+
_disable_stats = True,
|
176 |
+
_disable_meta = True,
|
177 |
+
disable_code = True,
|
178 |
+
disable_git = True
|
179 |
+
) # pls don't log sensitive data like system user names. also, fuck you for even trying.
|
180 |
+
)
|
181 |
+
|
182 |
+
def _init_tracker_meta(self) -> None:
|
183 |
+
now = time.monotonic()
|
184 |
+
self._tracker_meta = {
|
185 |
+
't00': now,
|
186 |
+
't0': now,
|
187 |
+
'step0': 0
|
188 |
+
}
|
189 |
+
|
190 |
+
def _load_models(self,
|
191 |
+
model_path: str,
|
192 |
+
convert2d: bool,
|
193 |
+
from_pt: bool,
|
194 |
+
use_memory_efficient_attention: bool
|
195 |
+
) -> None:
|
196 |
+
self.log(f'Load pretrained from {model_path}')
|
197 |
+
if convert2d:
|
198 |
+
self.log(' Convert 2D model to Pseudo3D')
|
199 |
+
self.log(' Initiate Pseudo3D model')
|
200 |
+
config = UNetPseudo3DConditionModel.load_config(model_path, subfolder = 'unet')
|
201 |
+
model = UNetPseudo3DConditionModel.from_config(
|
202 |
+
config,
|
203 |
+
sample_size = self.sample_size,
|
204 |
+
dtype = self.dtype,
|
205 |
+
param_dtype = self.param_dtype,
|
206 |
+
use_memory_efficient_attention = use_memory_efficient_attention
|
207 |
+
)
|
208 |
+
params: Dict[str, Any] = model.init_weights(self.rng).unfreeze()
|
209 |
+
self.log(' Load 2D model')
|
210 |
+
model2d, params2d = FlaxUNet2DConditionModel.from_pretrained(
|
211 |
+
model_path,
|
212 |
+
subfolder = 'unet',
|
213 |
+
dtype = self.dtype,
|
214 |
+
from_pt = from_pt
|
215 |
+
)
|
216 |
+
self.log(' Map 2D -> 3D')
|
217 |
+
params = map_2d_to_pseudo3d(params2d, params, verbose = self.verbose)
|
218 |
+
del params2d
|
219 |
+
del model2d
|
220 |
+
del config
|
221 |
+
else:
|
222 |
+
model, params = UNetPseudo3DConditionModel.from_pretrained(
|
223 |
+
model_path,
|
224 |
+
subfolder = 'unet',
|
225 |
+
from_pt = from_pt,
|
226 |
+
sample_size = self.sample_size,
|
227 |
+
dtype = self.dtype,
|
228 |
+
param_dtype = self.param_dtype,
|
229 |
+
use_memory_efficient_attention = use_memory_efficient_attention
|
230 |
+
)
|
231 |
+
self.log(f'Cast parameters to {model.param_dtype}')
|
232 |
+
if model.param_dtype == 'float32':
|
233 |
+
params = model.to_fp32(params)
|
234 |
+
elif model.param_dtype == 'float16':
|
235 |
+
params = model.to_fp16(params)
|
236 |
+
elif model.param_dtype == 'bfloat16':
|
237 |
+
params = model.to_bf16(params)
|
238 |
+
self.pretrained_model = model_path
|
239 |
+
self.model: UNetPseudo3DConditionModel = model
|
240 |
+
self.params: FrozenDict[str, Any] = FrozenDict(params)
|
241 |
+
|
242 |
+
def _mark_parameters(self, only_temporal: bool) -> None:
|
243 |
+
self.log('Mark training parameters')
|
244 |
+
if only_temporal:
|
245 |
+
self.log('Only training temporal layers')
|
246 |
+
if only_temporal:
|
247 |
+
param_partitions = traverse_util.path_aware_map(
|
248 |
+
lambda path, _: 'trainable' if 'temporal' in ' '.join(path) else 'frozen', self.params
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
param_partitions = traverse_util.path_aware_map(
|
252 |
+
lambda *_: 'trainable', self.params
|
253 |
+
)
|
254 |
+
self.only_temporal = only_temporal
|
255 |
+
self.param_partitions: FrozenDict[str, Any] = FrozenDict(param_partitions)
|
256 |
+
self.log(f'Total parameters: {count_params(self.params)}')
|
257 |
+
self.log(f'Temporal parameters: {count_params(self.params, "temporal")}')
|
258 |
+
|
259 |
+
def _load_inference_models(self) -> None:
|
260 |
+
assert jax.process_index() == 0, 'not main process'
|
261 |
+
if self.text_encoder is None:
|
262 |
+
self.log('Load text encoder')
|
263 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
264 |
+
self.pretrained_model,
|
265 |
+
subfolder = 'text_encoder'
|
266 |
+
)
|
267 |
+
if self.tokenizer is None:
|
268 |
+
self.log('Load tokenizer')
|
269 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
270 |
+
self.pretrained_model,
|
271 |
+
subfolder = 'tokenizer'
|
272 |
+
)
|
273 |
+
if self.vae is None:
|
274 |
+
self.log('Load vae')
|
275 |
+
self.vae = AutoencoderKL.from_pretrained(
|
276 |
+
self.pretrained_model,
|
277 |
+
subfolder = 'vae'
|
278 |
+
)
|
279 |
+
if self.ddim is None:
|
280 |
+
self.log('Load ddim scheduler')
|
281 |
+
# tuple(scheduler , scheduler state)
|
282 |
+
self.ddim = FlaxDDIMScheduler.from_pretrained(
|
283 |
+
self.pretrained_model,
|
284 |
+
subfolder = 'scheduler',
|
285 |
+
from_pt = True
|
286 |
+
)
|
287 |
+
|
288 |
+
def _unload_inference_models(self) -> None:
|
289 |
+
self.text_encoder = None
|
290 |
+
self.tokenizer = None
|
291 |
+
self.vae = None
|
292 |
+
self.ddim = None
|
293 |
+
|
294 |
+
def sample(self,
|
295 |
+
params: Union[Dict[str, Any], FrozenDict[str, Any]],
|
296 |
+
prompt: str,
|
297 |
+
image_path: str,
|
298 |
+
num_frames: int,
|
299 |
+
replicate_params: bool = True,
|
300 |
+
neg_prompt: str = '',
|
301 |
+
steps: int = 50,
|
302 |
+
cfg: float = 9.0,
|
303 |
+
unload_after_usage: bool = False
|
304 |
+
) -> List[Image.Image]:
|
305 |
+
assert jax.process_index() == 0, 'not main process'
|
306 |
+
self.log('Sample')
|
307 |
+
self._load_inference_models()
|
308 |
+
with torch.no_grad():
|
309 |
+
tokens = self.tokenizer(
|
310 |
+
[ prompt ],
|
311 |
+
truncation = True,
|
312 |
+
return_overflowing_tokens = False,
|
313 |
+
padding = 'max_length',
|
314 |
+
return_tensors = 'pt'
|
315 |
+
).input_ids
|
316 |
+
neg_tokens = self.tokenizer(
|
317 |
+
[ neg_prompt ],
|
318 |
+
truncation = True,
|
319 |
+
return_overflowing_tokens = False,
|
320 |
+
padding = 'max_length',
|
321 |
+
return_tensors = 'pt'
|
322 |
+
).input_ids
|
323 |
+
encoded_prompt = self.text_encoder(input_ids = tokens).last_hidden_state
|
324 |
+
encoded_neg_prompt = self.text_encoder(input_ids = neg_tokens).last_hidden_state
|
325 |
+
hint_latent = torch.tensor(np.asarray(Image.open(image_path))).permute(2,0,1).to(torch.float32).div(255).mul(2).sub(1).unsqueeze(0)
|
326 |
+
hint_latent = self.vae.encode(hint_latent).latent_dist.mean * self.vae.config.scaling_factor #0.18215 # deterministic
|
327 |
+
hint_latent = hint_latent.unsqueeze(2).repeat_interleave(num_frames, 2)
|
328 |
+
mask = torch.zeros_like(hint_latent[:,0:1,:,:,:]) # zero mask, e.g. skip masking for now
|
329 |
+
init_latent = torch.randn_like(hint_latent)
|
330 |
+
# move to devices
|
331 |
+
encoded_prompt = jnp.array(encoded_prompt.numpy())
|
332 |
+
encoded_neg_prompt = jnp.array(encoded_neg_prompt.numpy())
|
333 |
+
hint_latent = jnp.array(hint_latent.numpy())
|
334 |
+
mask = jnp.array(mask.numpy())
|
335 |
+
init_latent = init_latent.repeat(jax.device_count(), 1, 1, 1, 1)
|
336 |
+
init_latent = jnp.array(init_latent.numpy())
|
337 |
+
self.ddim = (self.ddim[0], self.ddim[0].set_timesteps(self.ddim[1], steps))
|
338 |
+
timesteps = self.ddim[1].timesteps
|
339 |
+
if replicate_params:
|
340 |
+
params = jax_utils.replicate(params)
|
341 |
+
ddim_state = jax_utils.replicate(self.ddim[1])
|
342 |
+
encoded_prompt = jax_utils.replicate(encoded_prompt)
|
343 |
+
encoded_neg_prompt = jax_utils.replicate(encoded_neg_prompt)
|
344 |
+
hint_latent = jax_utils.replicate(hint_latent)
|
345 |
+
mask = jax_utils.replicate(mask)
|
346 |
+
# sampling fun
|
347 |
+
def sample_loop(init_latent, ddim_state, t, params, encoded_prompt, encoded_neg_prompt, hint_latent, mask):
|
348 |
+
latent_model_input = jnp.concatenate([init_latent, mask, hint_latent], axis = 1)
|
349 |
+
pred = self.model.apply(
|
350 |
+
{ 'params': params },
|
351 |
+
latent_model_input,
|
352 |
+
t,
|
353 |
+
encoded_prompt
|
354 |
+
).sample
|
355 |
+
if cfg != 1.0:
|
356 |
+
neg_pred = self.model.apply(
|
357 |
+
{ 'params': params },
|
358 |
+
latent_model_input,
|
359 |
+
t,
|
360 |
+
encoded_neg_prompt
|
361 |
+
).sample
|
362 |
+
pred = neg_pred + cfg * (pred - neg_pred)
|
363 |
+
# TODO check if noise is added at the right dimension
|
364 |
+
init_latent, ddim_state = self.ddim[0].step(ddim_state, pred, t, init_latent).to_tuple()
|
365 |
+
return init_latent, ddim_state
|
366 |
+
p_sample_loop = jax.pmap(sample_loop, 'sample', donate_argnums = ())
|
367 |
+
pbar_sample = trange(len(timesteps), desc = 'Sample', dynamic_ncols = True, smoothing = 0.1, disable = not self.verbose)
|
368 |
+
init_latent = shard(init_latent)
|
369 |
+
for i in pbar_sample:
|
370 |
+
t = timesteps[i].repeat(self.num_devices)
|
371 |
+
t = shard(t)
|
372 |
+
init_latent, ddim_state = p_sample_loop(init_latent, ddim_state, t, params, encoded_prompt, encoded_neg_prompt, hint_latent, mask)
|
373 |
+
# decode
|
374 |
+
self.log('Decode')
|
375 |
+
init_latent = torch.tensor(np.array(init_latent))
|
376 |
+
init_latent = init_latent / self.vae.config.scaling_factor
|
377 |
+
# d:0 b:1 c:2 f:3 h:4 w:5 -> d b f c h w
|
378 |
+
init_latent = init_latent.permute(0, 1, 3, 2, 4, 5)
|
379 |
+
images = []
|
380 |
+
pbar_decode = trange(len(init_latent), desc = 'Decode', dynamic_ncols = True)
|
381 |
+
for sample in init_latent:
|
382 |
+
ims = self.vae.decode(sample.squeeze()).sample
|
383 |
+
ims = ims.add(1).div(2).mul(255).round().clamp(0, 255).to(torch.uint8).permute(0,2,3,1).numpy()
|
384 |
+
ims = [ Image.fromarray(x) for x in ims ]
|
385 |
+
for im in ims:
|
386 |
+
images.append(im)
|
387 |
+
pbar_decode.update(1)
|
388 |
+
if unload_after_usage:
|
389 |
+
self._unload_inference_models()
|
390 |
+
return images
|
391 |
+
|
392 |
+
def get_params_from_state(self, state: TrainState) -> FrozenDict[Any, str]:
|
393 |
+
return FrozenDict(jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)))
|
394 |
+
|
395 |
+
def train(self,
|
396 |
+
dataloader: DataLoader,
|
397 |
+
lr: float,
|
398 |
+
num_frames: int,
|
399 |
+
log_every_step: int = 10,
|
400 |
+
save_every_epoch: int = 1,
|
401 |
+
sample_every_epoch: int = 1,
|
402 |
+
output_dir: str = 'output',
|
403 |
+
warmup: float = 0,
|
404 |
+
decay: float = 0,
|
405 |
+
epochs: int = 10,
|
406 |
+
weight_decay: float = 1e-2
|
407 |
+
) -> None:
|
408 |
+
eps = 1e-8
|
409 |
+
total_steps = len(dataloader) * epochs
|
410 |
+
warmup_steps = math.ceil(warmup * total_steps) if warmup > 0 else 0
|
411 |
+
decay_steps = math.ceil(decay * total_steps) + warmup_steps if decay > 0 else warmup_steps + 1
|
412 |
+
self.log(f'Total steps: {total_steps}')
|
413 |
+
self.log(f'Warmup steps: {warmup_steps}')
|
414 |
+
self.log(f'Decay steps: {decay_steps - warmup_steps}')
|
415 |
+
if warmup > 0 or decay > 0:
|
416 |
+
if not decay > 0:
|
417 |
+
# only warmup, keep peak lr until end
|
418 |
+
self.log('Warmup schedule')
|
419 |
+
end_lr = lr
|
420 |
+
else:
|
421 |
+
# warmup + annealing to end lr
|
422 |
+
self.log('Warmup + cosine annealing schedule')
|
423 |
+
end_lr = eps
|
424 |
+
lr_schedule = optax.warmup_cosine_decay_schedule(
|
425 |
+
init_value = 0.0,
|
426 |
+
peak_value = lr,
|
427 |
+
warmup_steps = warmup_steps,
|
428 |
+
decay_steps = decay_steps,
|
429 |
+
end_value = end_lr
|
430 |
+
)
|
431 |
+
else:
|
432 |
+
# no warmup or decay -> constant lr
|
433 |
+
self.log('constant schedule')
|
434 |
+
lr_schedule = optax.constant_schedule(value = lr)
|
435 |
+
adamw = optax.adamw(
|
436 |
+
learning_rate = lr_schedule,
|
437 |
+
b1 = 0.9,
|
438 |
+
b2 = 0.999,
|
439 |
+
eps = eps,
|
440 |
+
weight_decay = weight_decay #0.01 # 0.0001
|
441 |
+
)
|
442 |
+
optim = optax.chain(
|
443 |
+
optax.clip_by_global_norm(max_norm = 1.0),
|
444 |
+
adamw
|
445 |
+
)
|
446 |
+
partition_optimizers = {
|
447 |
+
'trainable': optim,
|
448 |
+
'frozen': optax.set_to_zero()
|
449 |
+
}
|
450 |
+
tx = optax.multi_transform(partition_optimizers, self.param_partitions)
|
451 |
+
state = TrainState.create(
|
452 |
+
apply_fn = self.model.__call__,
|
453 |
+
params = self.params,
|
454 |
+
tx = tx
|
455 |
+
)
|
456 |
+
validation_rng, train_rngs = jax.random.split(self.rng)
|
457 |
+
train_rngs = jax.random.split(train_rngs, jax.local_device_count())
|
458 |
+
|
459 |
+
def train_step(state: TrainState, batch: Dict[str, jax.Array], train_rng: jax.random.PRNGKeyArray):
|
460 |
+
def compute_loss(
|
461 |
+
params: Dict[str, Any],
|
462 |
+
batch: Dict[str, jax.Array],
|
463 |
+
sample_rng: jax.random.PRNGKeyArray # unused, dataloader provides everything
|
464 |
+
) -> jax.Array:
|
465 |
+
# 'latent_model_input': latent_model_input
|
466 |
+
# 'encoder_hidden_states': encoder_hidden_states
|
467 |
+
# 'timesteps': timesteps
|
468 |
+
# 'noise': noise
|
469 |
+
latent_model_input = batch['latent_model_input']
|
470 |
+
encoder_hidden_states = batch['encoder_hidden_states']
|
471 |
+
timesteps = batch['timesteps']
|
472 |
+
noise = batch['noise']
|
473 |
+
model_pred = self.model.apply(
|
474 |
+
{ 'params': params },
|
475 |
+
latent_model_input,
|
476 |
+
timesteps,
|
477 |
+
encoder_hidden_states
|
478 |
+
).sample
|
479 |
+
loss = (noise - model_pred) ** 2
|
480 |
+
loss = loss.mean()
|
481 |
+
return loss
|
482 |
+
grad_fn = jax.value_and_grad(compute_loss)
|
483 |
+
|
484 |
+
def loss_and_grad(
|
485 |
+
train_rng: jax.random.PRNGKeyArray
|
486 |
+
) -> Tuple[jax.Array, Any, jax.random.PRNGKeyArray]:
|
487 |
+
sample_rng, train_rng = jax.random.split(train_rng, 2)
|
488 |
+
loss, grad = grad_fn(state.params, batch, sample_rng)
|
489 |
+
return loss, grad, train_rng
|
490 |
+
|
491 |
+
loss, grad, new_train_rng = loss_and_grad(train_rng)
|
492 |
+
# self.log(grad) # NOTE uncomment to visualize gradient
|
493 |
+
grad = jax.lax.pmean(grad, axis_name = 'batch')
|
494 |
+
new_state = state.apply_gradients(grads = grad)
|
495 |
+
metrics: Dict[str, Any] = { 'loss': loss }
|
496 |
+
metrics = jax.lax.pmean(metrics, axis_name = 'batch')
|
497 |
+
def l2(xs) -> jax.Array:
|
498 |
+
return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))
|
499 |
+
metrics['l2_grads'] = l2(jax.tree_util.tree_leaves(grad))
|
500 |
+
|
501 |
+
return new_state, metrics, new_train_rng
|
502 |
+
|
503 |
+
p_train_step = jax.pmap(fun = train_step, axis_name = 'batch', donate_argnums = (0, ))
|
504 |
+
state = jax_utils.replicate(state)
|
505 |
+
|
506 |
+
train_metrics = []
|
507 |
+
train_metric = None
|
508 |
+
|
509 |
+
global_step: int = 0
|
510 |
+
|
511 |
+
if jax.process_index() == 0:
|
512 |
+
self._init_tracker_meta()
|
513 |
+
hyper_params = {
|
514 |
+
'lr': lr,
|
515 |
+
'lr_warmup': warmup,
|
516 |
+
'lr_decay': decay,
|
517 |
+
'weight_decay': weight_decay,
|
518 |
+
'total_steps': total_steps,
|
519 |
+
'batch_size': dataloader.batch_size // self.num_devices,
|
520 |
+
'num_frames': num_frames,
|
521 |
+
'sample_size': self.sample_size,
|
522 |
+
'num_devices': self.num_devices,
|
523 |
+
'seed': self.seed,
|
524 |
+
'use_memory_efficient_attention': self.model.use_memory_efficient_attention,
|
525 |
+
'only_temporal': self.only_temporal,
|
526 |
+
'dtype': self.dtype_str,
|
527 |
+
'param_dtype': self.param_dtype,
|
528 |
+
'pretrained_model': self.pretrained_model,
|
529 |
+
'model_config': self.model.config
|
530 |
+
}
|
531 |
+
if self._use_wandb:
|
532 |
+
self.log('Setting up wandb')
|
533 |
+
self._setup_wandb(hyper_params)
|
534 |
+
self.log(hyper_params)
|
535 |
+
output_path = os.path.join(output_dir, str(global_step), 'unet')
|
536 |
+
self.log(f'saving checkpoint to {output_path}')
|
537 |
+
self.model.save_pretrained(
|
538 |
+
save_directory = output_path,
|
539 |
+
params = self.get_params_from_state(state),#jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)),
|
540 |
+
is_main_process = True
|
541 |
+
)
|
542 |
+
|
543 |
+
pbar_epoch = tqdm(
|
544 |
+
total = epochs,
|
545 |
+
desc = 'Epochs',
|
546 |
+
smoothing = 1,
|
547 |
+
position = 0,
|
548 |
+
dynamic_ncols = True,
|
549 |
+
leave = True,
|
550 |
+
disable = jax.process_index() > 0
|
551 |
+
)
|
552 |
+
steps_per_epoch = len(dataloader) # TODO dataloader
|
553 |
+
for epoch in range(epochs):
|
554 |
+
pbar_steps = tqdm(
|
555 |
+
total = steps_per_epoch,
|
556 |
+
desc = 'Steps',
|
557 |
+
position = 1,
|
558 |
+
smoothing = 0.1,
|
559 |
+
dynamic_ncols = True,
|
560 |
+
leave = True,
|
561 |
+
disable = jax.process_index() > 0
|
562 |
+
)
|
563 |
+
for batch in dataloader:
|
564 |
+
# keep input + gt as float32, results in fp32 loss and grad
|
565 |
+
# otherwise uncomment the following to cast to the model dtype
|
566 |
+
# batch = { k: (v.astype(self.dtype) if v.dtype == np.float32 else v) for k,v in batch.items() }
|
567 |
+
batch = shard(batch)
|
568 |
+
state, train_metric, train_rngs = p_train_step(
|
569 |
+
state, batch, train_rngs
|
570 |
+
)
|
571 |
+
train_metrics.append(train_metric)
|
572 |
+
if global_step % log_every_step == 0 and jax.process_index() == 0:
|
573 |
+
train_metrics = jax_utils.unreplicate(train_metrics)
|
574 |
+
train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
|
575 |
+
if global_step == 0:
|
576 |
+
self.log(f'grad dtype: {train_metrics["l2_grads"].dtype}')
|
577 |
+
self.log(f'loss dtype: {train_metrics["loss"].dtype}')
|
578 |
+
train_metrics_dict = { k: v.item() for k, v in train_metrics.items() }
|
579 |
+
train_metrics_dict['lr'] = lr_schedule(global_step).item()
|
580 |
+
self.log_metrics(train_metrics_dict, step = global_step, epoch = epoch)
|
581 |
+
train_metrics = []
|
582 |
+
pbar_steps.update(1)
|
583 |
+
global_step += 1
|
584 |
+
if epoch % save_every_epoch == 0 and jax.process_index() == 0:
|
585 |
+
output_path = os.path.join(output_dir, str(global_step), 'unet')
|
586 |
+
self.log(f'saving checkpoint to {output_path}')
|
587 |
+
self.model.save_pretrained(
|
588 |
+
save_directory = output_path,
|
589 |
+
params = self.get_params_from_state(state),#jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)),
|
590 |
+
is_main_process = True
|
591 |
+
)
|
592 |
+
self.log(f'checkpoint saved ')
|
593 |
+
if epoch % sample_every_epoch == 0 and jax.process_index() == 0:
|
594 |
+
images = self.sample(
|
595 |
+
params = state.params,
|
596 |
+
replicate_params = False,
|
597 |
+
prompt = 'dancing person',
|
598 |
+
image_path = 'testimage.png',
|
599 |
+
num_frames = num_frames,
|
600 |
+
steps = 50,
|
601 |
+
cfg = 9.0,
|
602 |
+
unload_after_usage = False
|
603 |
+
)
|
604 |
+
os.makedirs(os.path.join('image_output', str(epoch)), exist_ok = True)
|
605 |
+
for i, im in enumerate(images):
|
606 |
+
im.save(os.path.join('image_output', str(epoch), str(i).zfill(5) + '.png'), optimize = True)
|
607 |
+
pbar_epoch.update(1)
|
608 |
+
|
makeavid_sd/makeavid_sd/flax_impl/flax_unet_pseudo3d_blocks.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import flax.linen as nn
|
7 |
+
|
8 |
+
from .flax_attention_pseudo3d import TransformerPseudo3DModel
|
9 |
+
from .flax_resnet_pseudo3d import ResnetBlockPseudo3D, DownsamplePseudo3D, UpsamplePseudo3D
|
10 |
+
|
11 |
+
|
12 |
+
class UNetMidBlockPseudo3DCrossAttn(nn.Module):
|
13 |
+
in_channels: int
|
14 |
+
num_layers: int = 1
|
15 |
+
attn_num_head_channels: int = 1
|
16 |
+
use_memory_efficient_attention: bool = False
|
17 |
+
dtype: jnp.dtype = jnp.float32
|
18 |
+
|
19 |
+
def setup(self) -> None:
|
20 |
+
resnets = [
|
21 |
+
ResnetBlockPseudo3D(
|
22 |
+
in_channels = self.in_channels,
|
23 |
+
out_channels = self.in_channels,
|
24 |
+
dtype = self.dtype
|
25 |
+
)
|
26 |
+
]
|
27 |
+
attentions = []
|
28 |
+
for _ in range(self.num_layers):
|
29 |
+
attn_block = TransformerPseudo3DModel(
|
30 |
+
in_channels = self.in_channels,
|
31 |
+
num_attention_heads = self.attn_num_head_channels,
|
32 |
+
attention_head_dim = self.in_channels // self.attn_num_head_channels,
|
33 |
+
num_layers = 1,
|
34 |
+
use_memory_efficient_attention = self.use_memory_efficient_attention,
|
35 |
+
dtype = self.dtype
|
36 |
+
)
|
37 |
+
attentions.append(attn_block)
|
38 |
+
res_block = ResnetBlockPseudo3D(
|
39 |
+
in_channels = self.in_channels,
|
40 |
+
out_channels = self.in_channels,
|
41 |
+
dtype = self.dtype
|
42 |
+
)
|
43 |
+
resnets.append(res_block)
|
44 |
+
self.attentions = attentions
|
45 |
+
self.resnets = resnets
|
46 |
+
|
47 |
+
def __call__(self,
|
48 |
+
hidden_states: jax.Array,
|
49 |
+
temb: jax.Array,
|
50 |
+
encoder_hidden_states = jax.Array
|
51 |
+
) -> jax.Array:
|
52 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
53 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
54 |
+
hidden_states = attn(hidden_states, encoder_hidden_states)
|
55 |
+
hidden_states = resnet(hidden_states, temb)
|
56 |
+
return hidden_states
|
57 |
+
|
58 |
+
|
59 |
+
class CrossAttnDownBlockPseudo3D(nn.Module):
|
60 |
+
in_channels: int
|
61 |
+
out_channels: int
|
62 |
+
num_layers: int = 1
|
63 |
+
attn_num_head_channels: int = 1
|
64 |
+
add_downsample: bool = True
|
65 |
+
use_memory_efficient_attention: bool = False
|
66 |
+
dtype: jnp.dtype = jnp.float32
|
67 |
+
|
68 |
+
def setup(self) -> None:
|
69 |
+
attentions = []
|
70 |
+
resnets = []
|
71 |
+
for i in range(self.num_layers):
|
72 |
+
in_channels = self.in_channels if i == 0 else self.out_channels
|
73 |
+
res_block = ResnetBlockPseudo3D(
|
74 |
+
in_channels = in_channels,
|
75 |
+
out_channels = self.out_channels,
|
76 |
+
dtype = self.dtype
|
77 |
+
)
|
78 |
+
resnets.append(res_block)
|
79 |
+
attn_block = TransformerPseudo3DModel(
|
80 |
+
in_channels = self.out_channels,
|
81 |
+
num_attention_heads = self.attn_num_head_channels,
|
82 |
+
attention_head_dim = self.out_channels // self.attn_num_head_channels,
|
83 |
+
num_layers = 1,
|
84 |
+
use_memory_efficient_attention = self.use_memory_efficient_attention,
|
85 |
+
dtype = self.dtype
|
86 |
+
)
|
87 |
+
attentions.append(attn_block)
|
88 |
+
self.resnets = resnets
|
89 |
+
self.attentions = attentions
|
90 |
+
|
91 |
+
if self.add_downsample:
|
92 |
+
self.downsamplers_0 = DownsamplePseudo3D(
|
93 |
+
out_channels = self.out_channels,
|
94 |
+
dtype = self.dtype
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
self.downsamplers_0 = None
|
98 |
+
|
99 |
+
def __call__(self,
|
100 |
+
hidden_states: jax.Array,
|
101 |
+
temb: jax.Array,
|
102 |
+
encoder_hidden_states: jax.Array
|
103 |
+
) -> Tuple[jax.Array, jax.Array]:
|
104 |
+
output_states = ()
|
105 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
106 |
+
hidden_states = resnet(hidden_states, temb)
|
107 |
+
hidden_states = attn(hidden_states, encoder_hidden_states)
|
108 |
+
output_states += (hidden_states, )
|
109 |
+
if self.add_downsample:
|
110 |
+
hidden_states = self.downsamplers_0(hidden_states)
|
111 |
+
output_states += (hidden_states, )
|
112 |
+
return hidden_states, output_states
|
113 |
+
|
114 |
+
|
115 |
+
class DownBlockPseudo3D(nn.Module):
|
116 |
+
in_channels: int
|
117 |
+
out_channels: int
|
118 |
+
num_layers: int = 1
|
119 |
+
add_downsample: bool = True
|
120 |
+
dtype: jnp.dtype = jnp.float32
|
121 |
+
|
122 |
+
def setup(self) -> None:
|
123 |
+
resnets = []
|
124 |
+
for i in range(self.num_layers):
|
125 |
+
in_channels = self.in_channels if i == 0 else self.out_channels
|
126 |
+
res_block = ResnetBlockPseudo3D(
|
127 |
+
in_channels = in_channels,
|
128 |
+
out_channels = self.out_channels,
|
129 |
+
dtype = self.dtype
|
130 |
+
)
|
131 |
+
resnets.append(res_block)
|
132 |
+
self.resnets = resnets
|
133 |
+
if self.add_downsample:
|
134 |
+
self.downsamplers_0 = DownsamplePseudo3D(
|
135 |
+
out_channels = self.out_channels,
|
136 |
+
dtype = self.dtype
|
137 |
+
)
|
138 |
+
else:
|
139 |
+
self.downsamplers_0 = None
|
140 |
+
|
141 |
+
def __call__(self,
|
142 |
+
hidden_states: jax.Array,
|
143 |
+
temb: jax.Array
|
144 |
+
) -> Tuple[jax.Array, jax.Array]:
|
145 |
+
output_states = ()
|
146 |
+
for resnet in self.resnets:
|
147 |
+
hidden_states = resnet(hidden_states, temb)
|
148 |
+
output_states += (hidden_states, )
|
149 |
+
if self.add_downsample:
|
150 |
+
hidden_states = self.downsamplers_0(hidden_states)
|
151 |
+
output_states += (hidden_states, )
|
152 |
+
return hidden_states, output_states
|
153 |
+
|
154 |
+
|
155 |
+
class CrossAttnUpBlockPseudo3D(nn.Module):
|
156 |
+
in_channels: int
|
157 |
+
out_channels: int
|
158 |
+
prev_output_channels: int
|
159 |
+
num_layers: int = 1
|
160 |
+
attn_num_head_channels: int = 1
|
161 |
+
add_upsample: bool = True
|
162 |
+
use_memory_efficient_attention: bool = False
|
163 |
+
dtype: jnp.dtype = jnp.float32
|
164 |
+
|
165 |
+
def setup(self) -> None:
|
166 |
+
resnets = []
|
167 |
+
attentions = []
|
168 |
+
for i in range(self.num_layers):
|
169 |
+
res_skip_channels = self.in_channels if (i == self.num_layers -1) else self.out_channels
|
170 |
+
resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels
|
171 |
+
res_block = ResnetBlockPseudo3D(
|
172 |
+
in_channels = resnet_in_channels + res_skip_channels,
|
173 |
+
out_channels = self.out_channels,
|
174 |
+
dtype = self.dtype
|
175 |
+
)
|
176 |
+
resnets.append(res_block)
|
177 |
+
attn_block = TransformerPseudo3DModel(
|
178 |
+
in_channels = self.out_channels,
|
179 |
+
num_attention_heads = self.attn_num_head_channels,
|
180 |
+
attention_head_dim = self.out_channels // self.attn_num_head_channels,
|
181 |
+
num_layers = 1,
|
182 |
+
use_memory_efficient_attention = self.use_memory_efficient_attention,
|
183 |
+
dtype = self.dtype
|
184 |
+
)
|
185 |
+
attentions.append(attn_block)
|
186 |
+
self.resnets = resnets
|
187 |
+
self.attentions = attentions
|
188 |
+
if self.add_upsample:
|
189 |
+
self.upsamplers_0 = UpsamplePseudo3D(
|
190 |
+
out_channels = self.out_channels,
|
191 |
+
dtype = self.dtype
|
192 |
+
)
|
193 |
+
else:
|
194 |
+
self.upsamplers_0 = None
|
195 |
+
|
196 |
+
def __call__(self,
|
197 |
+
hidden_states: jax.Array,
|
198 |
+
res_hidden_states_tuple: Tuple[jax.Array, ...],
|
199 |
+
temb: jax.Array,
|
200 |
+
encoder_hidden_states: jax.Array
|
201 |
+
) -> jax.Array:
|
202 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
203 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
204 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
205 |
+
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis = -1)
|
206 |
+
hidden_states = resnet(hidden_states, temb)
|
207 |
+
hidden_states = attn(hidden_states, encoder_hidden_states)
|
208 |
+
if self.add_upsample:
|
209 |
+
hidden_states = self.upsamplers_0(hidden_states)
|
210 |
+
return hidden_states
|
211 |
+
|
212 |
+
|
213 |
+
class UpBlockPseudo3D(nn.Module):
|
214 |
+
in_channels: int
|
215 |
+
out_channels: int
|
216 |
+
prev_output_channels: int
|
217 |
+
num_layers: int = 1
|
218 |
+
add_upsample: bool = True
|
219 |
+
dtype: jnp.dtype = jnp.float32
|
220 |
+
|
221 |
+
def setup(self) -> None:
|
222 |
+
resnets = []
|
223 |
+
for i in range(self.num_layers):
|
224 |
+
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
|
225 |
+
resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels
|
226 |
+
res_block = ResnetBlockPseudo3D(
|
227 |
+
in_channels = resnet_in_channels + res_skip_channels,
|
228 |
+
out_channels = self.out_channels,
|
229 |
+
dtype = self.dtype
|
230 |
+
)
|
231 |
+
resnets.append(res_block)
|
232 |
+
self.resnets = resnets
|
233 |
+
if self.add_upsample:
|
234 |
+
self.upsamplers_0 = UpsamplePseudo3D(
|
235 |
+
out_channels = self.out_channels,
|
236 |
+
dtype = self.dtype
|
237 |
+
)
|
238 |
+
else:
|
239 |
+
self.upsamplers_0 = None
|
240 |
+
|
241 |
+
def __call__(self,
|
242 |
+
hidden_states: jax.Array,
|
243 |
+
res_hidden_states_tuple: Tuple[jax.Array, ...],
|
244 |
+
temb: jax.Array
|
245 |
+
) -> jax.Array:
|
246 |
+
for resnet in self.resnets:
|
247 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
248 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
249 |
+
hidden_states = jnp.concatenate([hidden_states, res_hidden_states], axis = -1)
|
250 |
+
hidden_states = resnet(hidden_states, temb)
|
251 |
+
if self.add_upsample:
|
252 |
+
hidden_states = self.upsamplers_0(hidden_states)
|
253 |
+
return hidden_states
|
254 |
+
|
makeavid_sd/makeavid_sd/flax_impl/flax_unet_pseudo3d_condition.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import flax.linen as nn
|
7 |
+
from flax.core.frozen_dict import FrozenDict
|
8 |
+
|
9 |
+
from diffusers.configuration_utils import ConfigMixin, flax_register_to_config
|
10 |
+
from diffusers.models.modeling_flax_utils import FlaxModelMixin
|
11 |
+
from diffusers.utils import BaseOutput
|
12 |
+
|
13 |
+
from .flax_unet_pseudo3d_blocks import (
|
14 |
+
CrossAttnDownBlockPseudo3D,
|
15 |
+
CrossAttnUpBlockPseudo3D,
|
16 |
+
DownBlockPseudo3D,
|
17 |
+
UpBlockPseudo3D,
|
18 |
+
UNetMidBlockPseudo3DCrossAttn
|
19 |
+
)
|
20 |
+
#from flax_embeddings import (
|
21 |
+
# TimestepEmbedding,
|
22 |
+
# Timesteps
|
23 |
+
#)
|
24 |
+
from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
25 |
+
from .flax_resnet_pseudo3d import ConvPseudo3D
|
26 |
+
|
27 |
+
|
28 |
+
class UNetPseudo3DConditionOutput(BaseOutput):
|
29 |
+
sample: jax.Array
|
30 |
+
|
31 |
+
|
32 |
+
@flax_register_to_config
|
33 |
+
class UNetPseudo3DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
34 |
+
sample_size: Union[int, Tuple[int, int]] = (64, 64)
|
35 |
+
in_channels: int = 4
|
36 |
+
out_channels: int = 4
|
37 |
+
down_block_types: Tuple[str] = (
|
38 |
+
"CrossAttnDownBlockPseudo3D",
|
39 |
+
"CrossAttnDownBlockPseudo3D",
|
40 |
+
"CrossAttnDownBlockPseudo3D",
|
41 |
+
"DownBlockPseudo3D"
|
42 |
+
)
|
43 |
+
up_block_types: Tuple[str] = (
|
44 |
+
"UpBlockPseudo3D",
|
45 |
+
"CrossAttnUpBlockPseudo3D",
|
46 |
+
"CrossAttnUpBlockPseudo3D",
|
47 |
+
"CrossAttnUpBlockPseudo3D"
|
48 |
+
)
|
49 |
+
block_out_channels: Tuple[int] = (
|
50 |
+
320,
|
51 |
+
640,
|
52 |
+
1280,
|
53 |
+
1280
|
54 |
+
)
|
55 |
+
layers_per_block: int = 2
|
56 |
+
attention_head_dim: Union[int, Tuple[int]] = 8
|
57 |
+
cross_attention_dim: int = 768
|
58 |
+
flip_sin_to_cos: bool = True
|
59 |
+
freq_shift: int = 0
|
60 |
+
use_memory_efficient_attention: bool = False
|
61 |
+
dtype: jnp.dtype = jnp.float32
|
62 |
+
param_dtype: str = 'float32'
|
63 |
+
|
64 |
+
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
65 |
+
if self.param_dtype == 'bfloat16':
|
66 |
+
param_dtype = jnp.bfloat16
|
67 |
+
elif self.param_dtype == 'float16':
|
68 |
+
param_dtype = jnp.float16
|
69 |
+
elif self.param_dtype == 'float32':
|
70 |
+
param_dtype = jnp.float32
|
71 |
+
else:
|
72 |
+
raise ValueError(f'unknown parameter type: {self.param_dtype}')
|
73 |
+
sample_size = self.sample_size
|
74 |
+
if isinstance(sample_size, int):
|
75 |
+
sample_size = (sample_size, sample_size)
|
76 |
+
sample_shape = (1, self.in_channels, 1, *sample_size)
|
77 |
+
sample = jnp.zeros(sample_shape, dtype = param_dtype)
|
78 |
+
timesteps = jnp.ones((1, ), dtype = jnp.int32)
|
79 |
+
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype = param_dtype)
|
80 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
81 |
+
rngs = { "params": params_rng, "dropout": dropout_rng }
|
82 |
+
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
|
83 |
+
|
84 |
+
def setup(self) -> None:
|
85 |
+
if isinstance(self.attention_head_dim, int):
|
86 |
+
attention_head_dim = (self.attention_head_dim, ) * len(self.down_block_types)
|
87 |
+
else:
|
88 |
+
attention_head_dim = self.attention_head_dim
|
89 |
+
time_embed_dim = self.block_out_channels[0] * 4
|
90 |
+
self.conv_in = ConvPseudo3D(
|
91 |
+
features = self.block_out_channels[0],
|
92 |
+
kernel_size = (3, 3),
|
93 |
+
strides = (1, 1),
|
94 |
+
padding = ((1, 1), (1, 1)),
|
95 |
+
dtype = self.dtype
|
96 |
+
)
|
97 |
+
self.time_proj = FlaxTimesteps(
|
98 |
+
dim = self.block_out_channels[0],
|
99 |
+
flip_sin_to_cos = self.flip_sin_to_cos,
|
100 |
+
freq_shift = self.freq_shift
|
101 |
+
)
|
102 |
+
self.time_embedding = FlaxTimestepEmbedding(
|
103 |
+
time_embed_dim = time_embed_dim,
|
104 |
+
dtype = self.dtype
|
105 |
+
)
|
106 |
+
down_blocks = []
|
107 |
+
output_channels = self.block_out_channels[0]
|
108 |
+
for i, down_block_type in enumerate(self.down_block_types):
|
109 |
+
input_channels = output_channels
|
110 |
+
output_channels = self.block_out_channels[i]
|
111 |
+
is_final_block = i == len(self.block_out_channels) - 1
|
112 |
+
# allows loading 3d models with old layer type names in their configs
|
113 |
+
# eg. 2D instead of Pseudo3D, like lxj's timelapse model
|
114 |
+
if down_block_type in ['CrossAttnDownBlockPseudo3D', 'CrossAttnDownBlock2D']:
|
115 |
+
down_block = CrossAttnDownBlockPseudo3D(
|
116 |
+
in_channels = input_channels,
|
117 |
+
out_channels = output_channels,
|
118 |
+
num_layers = self.layers_per_block,
|
119 |
+
attn_num_head_channels = attention_head_dim[i],
|
120 |
+
add_downsample = not is_final_block,
|
121 |
+
use_memory_efficient_attention = self.use_memory_efficient_attention,
|
122 |
+
dtype = self.dtype
|
123 |
+
)
|
124 |
+
elif down_block_type in ['DownBlockPseudo3D', 'DownBlock2D']:
|
125 |
+
down_block = DownBlockPseudo3D(
|
126 |
+
in_channels = input_channels,
|
127 |
+
out_channels = output_channels,
|
128 |
+
num_layers = self.layers_per_block,
|
129 |
+
add_downsample = not is_final_block,
|
130 |
+
dtype = self.dtype
|
131 |
+
)
|
132 |
+
else:
|
133 |
+
raise NotImplementedError(f'Unimplemented down block type: {down_block_type}')
|
134 |
+
down_blocks.append(down_block)
|
135 |
+
self.down_blocks = down_blocks
|
136 |
+
self.mid_block = UNetMidBlockPseudo3DCrossAttn(
|
137 |
+
in_channels = self.block_out_channels[-1],
|
138 |
+
attn_num_head_channels = attention_head_dim[-1],
|
139 |
+
use_memory_efficient_attention = self.use_memory_efficient_attention,
|
140 |
+
dtype = self.dtype
|
141 |
+
)
|
142 |
+
up_blocks = []
|
143 |
+
reversed_block_out_channels = list(reversed(self.block_out_channels))
|
144 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
145 |
+
output_channels = reversed_block_out_channels[0]
|
146 |
+
for i, up_block_type in enumerate(self.up_block_types):
|
147 |
+
prev_output_channels = output_channels
|
148 |
+
output_channels = reversed_block_out_channels[i]
|
149 |
+
input_channels = reversed_block_out_channels[min(i + 1, len(self.block_out_channels) - 1)]
|
150 |
+
is_final_block = i == len(self.block_out_channels) - 1
|
151 |
+
if up_block_type in ['CrossAttnUpBlockPseudo3D', 'CrossAttnUpBlock2D']:
|
152 |
+
up_block = CrossAttnUpBlockPseudo3D(
|
153 |
+
in_channels = input_channels,
|
154 |
+
out_channels = output_channels,
|
155 |
+
prev_output_channels = prev_output_channels,
|
156 |
+
num_layers = self.layers_per_block + 1,
|
157 |
+
attn_num_head_channels = reversed_attention_head_dim[i],
|
158 |
+
add_upsample = not is_final_block,
|
159 |
+
use_memory_efficient_attention = self.use_memory_efficient_attention,
|
160 |
+
dtype = self.dtype
|
161 |
+
)
|
162 |
+
elif up_block_type in ['UpBlockPseudo3D', 'UpBlock2D']:
|
163 |
+
up_block = UpBlockPseudo3D(
|
164 |
+
in_channels = input_channels,
|
165 |
+
out_channels = output_channels,
|
166 |
+
prev_output_channels = prev_output_channels,
|
167 |
+
num_layers = self.layers_per_block + 1,
|
168 |
+
add_upsample = not is_final_block,
|
169 |
+
dtype = self.dtype
|
170 |
+
)
|
171 |
+
else:
|
172 |
+
raise NotImplementedError(f'Unimplemented up block type: {up_block_type}')
|
173 |
+
up_blocks.append(up_block)
|
174 |
+
self.up_blocks = up_blocks
|
175 |
+
self.conv_norm_out = nn.GroupNorm(
|
176 |
+
num_groups = 32,
|
177 |
+
epsilon = 1e-5
|
178 |
+
)
|
179 |
+
self.conv_out = ConvPseudo3D(
|
180 |
+
features = self.out_channels,
|
181 |
+
kernel_size = (3, 3),
|
182 |
+
strides = (1, 1),
|
183 |
+
padding = ((1, 1), (1, 1)),
|
184 |
+
dtype = self.dtype
|
185 |
+
)
|
186 |
+
|
187 |
+
def __call__(self,
|
188 |
+
sample: jax.Array,
|
189 |
+
timesteps: jax.Array,
|
190 |
+
encoder_hidden_states: jax.Array,
|
191 |
+
return_dict: bool = True
|
192 |
+
) -> Union[UNetPseudo3DConditionOutput, Tuple[jax.Array]]:
|
193 |
+
if timesteps.dtype != jnp.float32:
|
194 |
+
timesteps = timesteps.astype(dtype = jnp.float32)
|
195 |
+
if len(timesteps.shape) == 0:
|
196 |
+
timesteps = jnp.expand_dims(timesteps, 0)
|
197 |
+
# b,c,f,h,w -> b,f,h,w,c
|
198 |
+
sample = sample.transpose((0, 2, 3, 4, 1))
|
199 |
+
|
200 |
+
t_emb = self.time_proj(timesteps)
|
201 |
+
t_emb = self.time_embedding(t_emb)
|
202 |
+
sample = self.conv_in(sample)
|
203 |
+
down_block_res_samples = (sample, )
|
204 |
+
for down_block in self.down_blocks:
|
205 |
+
if isinstance(down_block, CrossAttnDownBlockPseudo3D):
|
206 |
+
sample, res_samples = down_block(
|
207 |
+
hidden_states = sample,
|
208 |
+
temb = t_emb,
|
209 |
+
encoder_hidden_states = encoder_hidden_states
|
210 |
+
)
|
211 |
+
elif isinstance(down_block, DownBlockPseudo3D):
|
212 |
+
sample, res_samples = down_block(
|
213 |
+
hidden_states = sample,
|
214 |
+
temb = t_emb
|
215 |
+
)
|
216 |
+
else:
|
217 |
+
raise NotImplementedError(f'Unimplemented down block type: {down_block.__class__.__name__}')
|
218 |
+
down_block_res_samples += res_samples
|
219 |
+
sample = self.mid_block(
|
220 |
+
hidden_states = sample,
|
221 |
+
temb = t_emb,
|
222 |
+
encoder_hidden_states = encoder_hidden_states
|
223 |
+
)
|
224 |
+
for up_block in self.up_blocks:
|
225 |
+
res_samples = down_block_res_samples[-(self.layers_per_block + 1):]
|
226 |
+
down_block_res_samples = down_block_res_samples[:-(self.layers_per_block + 1)]
|
227 |
+
if isinstance(up_block, CrossAttnUpBlockPseudo3D):
|
228 |
+
sample = up_block(
|
229 |
+
hidden_states = sample,
|
230 |
+
temb = t_emb,
|
231 |
+
encoder_hidden_states = encoder_hidden_states,
|
232 |
+
res_hidden_states_tuple = res_samples
|
233 |
+
)
|
234 |
+
elif isinstance(up_block, UpBlockPseudo3D):
|
235 |
+
sample = up_block(
|
236 |
+
hidden_states = sample,
|
237 |
+
temb = t_emb,
|
238 |
+
res_hidden_states_tuple = res_samples
|
239 |
+
)
|
240 |
+
else:
|
241 |
+
raise NotImplementedError(f'Unimplemented up block type: {up_block.__class__.__name__}')
|
242 |
+
sample = self.conv_norm_out(sample)
|
243 |
+
sample = nn.silu(sample)
|
244 |
+
sample = self.conv_out(sample)
|
245 |
+
|
246 |
+
# b,f,h,w,c -> b,c,f,h,w
|
247 |
+
sample = sample.transpose((0, 4, 1, 2, 3))
|
248 |
+
if not return_dict:
|
249 |
+
return (sample, )
|
250 |
+
return UNetPseudo3DConditionOutput(sample = sample)
|
251 |
+
|
makeavid_sd/makeavid_sd/flax_impl/train.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import jax
|
3 |
+
_ = jax.device_count() # ugly hack to prevent tpu comms to lock/race or smth smh
|
4 |
+
|
5 |
+
from typing import Tuple, Optional
|
6 |
+
import os
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
|
9 |
+
from flax_trainer import FlaxTrainerUNetPseudo3D
|
10 |
+
from dataset import load_dataset
|
11 |
+
|
12 |
+
def train(
|
13 |
+
dataset_path: str,
|
14 |
+
model_path: str,
|
15 |
+
output_dir: str,
|
16 |
+
dataset_cache_dir: Optional[str] = None,
|
17 |
+
from_pt: bool = True,
|
18 |
+
convert2d: bool = False,
|
19 |
+
only_temporal: bool = True,
|
20 |
+
sample_size: Tuple[int, int] = (64, 64),
|
21 |
+
lr: float = 5e-5,
|
22 |
+
batch_size: int = 1,
|
23 |
+
num_frames: int = 24,
|
24 |
+
epochs: int = 10,
|
25 |
+
warmup: float = 0.1,
|
26 |
+
decay: float = 0.0,
|
27 |
+
weight_decay: float = 1e-2,
|
28 |
+
log_every_step: int = 50,
|
29 |
+
save_every_epoch: int = 1,
|
30 |
+
sample_every_epoch: int = 1,
|
31 |
+
seed: int = 0,
|
32 |
+
dtype: str = 'bfloat16',
|
33 |
+
param_dtype: str = 'float32',
|
34 |
+
use_memory_efficient_attention: bool = True,
|
35 |
+
verbose: bool = True,
|
36 |
+
use_wandb: bool = False
|
37 |
+
) -> None:
|
38 |
+
log = lambda x: print(x) if verbose else None
|
39 |
+
log('\n----------------')
|
40 |
+
log('Init trainer')
|
41 |
+
trainer = FlaxTrainerUNetPseudo3D(
|
42 |
+
model_path = model_path,
|
43 |
+
from_pt = from_pt,
|
44 |
+
convert2d = convert2d,
|
45 |
+
sample_size = sample_size,
|
46 |
+
seed = seed,
|
47 |
+
dtype = dtype,
|
48 |
+
param_dtype = param_dtype,
|
49 |
+
use_memory_efficient_attention = use_memory_efficient_attention,
|
50 |
+
verbose = verbose,
|
51 |
+
only_temporal = only_temporal
|
52 |
+
)
|
53 |
+
log('\n----------------')
|
54 |
+
log('Init dataset')
|
55 |
+
dataloader = load_dataset(
|
56 |
+
dataset_path = dataset_path,
|
57 |
+
model_path = model_path,
|
58 |
+
cache_dir = dataset_cache_dir,
|
59 |
+
batch_size = batch_size * trainer.num_devices,
|
60 |
+
num_frames = num_frames,
|
61 |
+
num_workers = min(trainer.num_devices * 2, os.cpu_count() - 1),
|
62 |
+
as_numpy = True,
|
63 |
+
shuffle = True
|
64 |
+
)
|
65 |
+
log('\n----------------')
|
66 |
+
log('Train')
|
67 |
+
if use_wandb:
|
68 |
+
trainer.enable_wandb()
|
69 |
+
trainer.train(
|
70 |
+
dataloader = dataloader,
|
71 |
+
epochs = epochs,
|
72 |
+
num_frames = num_frames,
|
73 |
+
log_every_step = log_every_step,
|
74 |
+
save_every_epoch = save_every_epoch,
|
75 |
+
sample_every_epoch = sample_every_epoch,
|
76 |
+
lr = lr,
|
77 |
+
warmup = warmup,
|
78 |
+
decay = decay,
|
79 |
+
weight_decay = weight_decay,
|
80 |
+
output_dir = output_dir
|
81 |
+
)
|
82 |
+
log('\n----------------')
|
83 |
+
log('Done')
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == '__main__':
|
87 |
+
parser = ArgumentParser()
|
88 |
+
bool_type = lambda x: x.lower() in ['true', '1', 'yes']
|
89 |
+
parser.add_argument('-v', '--verbose', type = bool_type, default = True)
|
90 |
+
parser.add_argument('-d', '--dataset_path', required = True)
|
91 |
+
parser.add_argument('-m', '--model_path', required = True)
|
92 |
+
parser.add_argument('-o', '--output_dir', required = True)
|
93 |
+
parser.add_argument('-b', '--batch_size', type = int, default = 1)
|
94 |
+
parser.add_argument('-f', '--num_frames', type = int, default = 24)
|
95 |
+
parser.add_argument('-e', '--epochs', type = int, default = 2)
|
96 |
+
parser.add_argument('--only_temporal', type = bool_type, default = True)
|
97 |
+
parser.add_argument('--dataset_cache_dir', type = str, default = None)
|
98 |
+
parser.add_argument('--from_pt', type = bool_type, default = True)
|
99 |
+
parser.add_argument('--convert2d', type = bool_type, default = False)
|
100 |
+
parser.add_argument('--lr', type = float, default = 1e-4)
|
101 |
+
parser.add_argument('--warmup', type = float, default = 0.1)
|
102 |
+
parser.add_argument('--decay', type = float, default = 0.0)
|
103 |
+
parser.add_argument('--weight_decay', type = float, default = 1e-2)
|
104 |
+
parser.add_argument('--sample_size', type = int, nargs = 2, default = [64, 64])
|
105 |
+
parser.add_argument('--log_every_step', type = int, default = 250)
|
106 |
+
parser.add_argument('--save_every_epoch', type = int, default = 1)
|
107 |
+
parser.add_argument('--sample_every_epoch', type = int, default = 1)
|
108 |
+
parser.add_argument('--seed', type = int, default = 0)
|
109 |
+
parser.add_argument('--use_memory_efficient_attention', type = bool_type, default = True)
|
110 |
+
parser.add_argument('--dtype', choices = ['float32', 'bfloat16', 'float16'], default = 'bfloat16')
|
111 |
+
parser.add_argument('--param_dtype', choices = ['float32', 'bfloat16', 'float16'], default = 'float32')
|
112 |
+
parser.add_argument('--wandb', type = bool_type, default = False)
|
113 |
+
args = parser.parse_args()
|
114 |
+
args.sample_size = tuple(args.sample_size)
|
115 |
+
if args.verbose:
|
116 |
+
print(args)
|
117 |
+
train(
|
118 |
+
dataset_path = args.dataset_path,
|
119 |
+
model_path = args.model_path,
|
120 |
+
from_pt = args.from_pt,
|
121 |
+
convert2d = args.convert2d,
|
122 |
+
only_temporal = args.only_temporal,
|
123 |
+
output_dir = args.output_dir,
|
124 |
+
dataset_cache_dir = args.dataset_cache_dir,
|
125 |
+
batch_size = args.batch_size,
|
126 |
+
num_frames = args.num_frames,
|
127 |
+
epochs = args.epochs,
|
128 |
+
lr = args.lr,
|
129 |
+
warmup = args.warmup,
|
130 |
+
decay = args.decay,
|
131 |
+
weight_decay = args.weight_decay,
|
132 |
+
sample_size = args.sample_size,
|
133 |
+
seed = args.seed,
|
134 |
+
dtype = args.dtype,
|
135 |
+
param_dtype = args.param_dtype,
|
136 |
+
use_memory_efficient_attention = args.use_memory_efficient_attention,
|
137 |
+
log_every_step = args.log_every_step,
|
138 |
+
save_every_epoch = args.save_every_epoch,
|
139 |
+
sample_every_epoch = args.sample_every_epoch,
|
140 |
+
verbose = args.verbose,
|
141 |
+
use_wandb = args.wandb
|
142 |
+
)
|
143 |
+
|
makeavid_sd/makeavid_sd/flax_impl/train.sh
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
|
3 |
+
#export WANDB_API_KEY="your_api_key"
|
4 |
+
export WANDB_ENTITY="tempofunk"
|
5 |
+
export WANDB_JOB_TYPE="train"
|
6 |
+
export WANDB_PROJECT="makeavid-sd-tpu"
|
7 |
+
|
8 |
+
python train.py \
|
9 |
+
--dataset_path ../storage/dataset/tempofunk-s \
|
10 |
+
--model_path ../storage/trained_models/ep20 \
|
11 |
+
--from_pt False \
|
12 |
+
--convert2d False \
|
13 |
+
--only_temporal True \
|
14 |
+
--output_dir ../storage/output \
|
15 |
+
--batch_size 1 \
|
16 |
+
--num_frames 24 \
|
17 |
+
--epochs 20 \
|
18 |
+
--lr 0.00005 \
|
19 |
+
--warmup 0.1 \
|
20 |
+
--decay 0.0 \
|
21 |
+
--sample_size 64 64 \
|
22 |
+
--log_every_step 50 \
|
23 |
+
--save_every_epoch 1 \
|
24 |
+
--sample_every_epoch 1 \
|
25 |
+
--seed 2 \
|
26 |
+
--use_memory_efficient_attention True \
|
27 |
+
--dtype bfloat16 \
|
28 |
+
--param_dtype float32 \
|
29 |
+
--verbose True \
|
30 |
+
--dataset_cache_dir ../storage/cache/hf/datasets \
|
31 |
+
--wandb True
|
32 |
+
|
33 |
+
# sudo rm /tmp/libtpu_lockfile
|
34 |
+
|
makeavid_sd/makeavid_sd/inference.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Any, Union, Tuple, List, Dict
|
3 |
+
import os
|
4 |
+
import gc
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
import jax
|
8 |
+
import jax.numpy as jnp
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from flax.core.frozen_dict import FrozenDict
|
12 |
+
from flax import jax_utils
|
13 |
+
from flax.training.common_utils import shard
|
14 |
+
from PIL import Image
|
15 |
+
import einops
|
16 |
+
|
17 |
+
from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
18 |
+
from diffusers import (
|
19 |
+
FlaxDDIMScheduler,
|
20 |
+
FlaxDDPMScheduler,
|
21 |
+
FlaxPNDMScheduler,
|
22 |
+
FlaxLMSDiscreteScheduler,
|
23 |
+
FlaxDPMSolverMultistepScheduler,
|
24 |
+
FlaxKarrasVeScheduler,
|
25 |
+
FlaxScoreSdeVeScheduler
|
26 |
+
)
|
27 |
+
|
28 |
+
from transformers import FlaxCLIPTextModel, CLIPTokenizer
|
29 |
+
|
30 |
+
from .flax_impl.flax_unet_pseudo3d_condition import UNetPseudo3DConditionModel
|
31 |
+
|
32 |
+
SchedulerType = Union[
|
33 |
+
FlaxDDIMScheduler,
|
34 |
+
FlaxDDPMScheduler,
|
35 |
+
FlaxPNDMScheduler,
|
36 |
+
FlaxLMSDiscreteScheduler,
|
37 |
+
FlaxDPMSolverMultistepScheduler,
|
38 |
+
FlaxKarrasVeScheduler,
|
39 |
+
FlaxScoreSdeVeScheduler
|
40 |
+
]
|
41 |
+
|
42 |
+
def dtypestr(x: jnp.dtype):
|
43 |
+
if x == jnp.float32: return 'float32'
|
44 |
+
elif x == jnp.float16: return 'float16'
|
45 |
+
elif x == jnp.bfloat16: return 'bfloat16'
|
46 |
+
else: raise
|
47 |
+
def castto(dtype, m, x):
|
48 |
+
if dtype == jnp.float32: return m.to_fp32(x)
|
49 |
+
elif dtype == jnp.float16: return m.to_fp16(x)
|
50 |
+
elif dtype == jnp.bfloat16: return m.to_bf16(x)
|
51 |
+
else: raise
|
52 |
+
|
53 |
+
class InferenceUNetPseudo3D:
|
54 |
+
def __init__(self,
|
55 |
+
model_path: str,
|
56 |
+
scheduler_cls: SchedulerType = FlaxDDIMScheduler,
|
57 |
+
dtype: jnp.dtype = jnp.float16,
|
58 |
+
hf_auth_token: Union[str, None] = None
|
59 |
+
) -> None:
|
60 |
+
self.dtype = dtype
|
61 |
+
self.model_path = model_path
|
62 |
+
self.hf_auth_token = hf_auth_token
|
63 |
+
|
64 |
+
self.params: Dict[str, FrozenDict[str, Any]] = {}
|
65 |
+
unet, unet_params = UNetPseudo3DConditionModel.from_pretrained(
|
66 |
+
self.model_path,
|
67 |
+
subfolder = 'unet',
|
68 |
+
from_pt = False,
|
69 |
+
sample_size = (64, 64),
|
70 |
+
dtype = self.dtype,
|
71 |
+
param_dtype = dtypestr(self.dtype),
|
72 |
+
use_memory_efficient_attention = True,
|
73 |
+
use_auth_token = self.hf_auth_token
|
74 |
+
)
|
75 |
+
self.unet: UNetPseudo3DConditionModel = unet
|
76 |
+
unet_params = castto(self.dtype, self.unet, unet_params)
|
77 |
+
self.params['unet'] = FrozenDict(unet_params)
|
78 |
+
del unet_params
|
79 |
+
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
|
80 |
+
self.model_path,
|
81 |
+
subfolder = 'vae',
|
82 |
+
from_pt = True,
|
83 |
+
dtype = self.dtype,
|
84 |
+
use_auth_token = self.hf_auth_token
|
85 |
+
)
|
86 |
+
self.vae: FlaxAutoencoderKL = vae
|
87 |
+
vae_params = castto(self.dtype, self.vae, vae_params)
|
88 |
+
self.params['vae'] = FrozenDict(vae_params)
|
89 |
+
del vae_params
|
90 |
+
text_encoder = FlaxCLIPTextModel.from_pretrained(
|
91 |
+
self.model_path,
|
92 |
+
subfolder = 'text_encoder',
|
93 |
+
from_pt = True,
|
94 |
+
dtype = self.dtype,
|
95 |
+
use_auth_token = self.hf_auth_token
|
96 |
+
)
|
97 |
+
text_encoder_params = text_encoder.params
|
98 |
+
del text_encoder._params
|
99 |
+
text_encoder_params = castto(self.dtype, text_encoder, text_encoder_params)
|
100 |
+
self.text_encoder: FlaxCLIPTextModel = text_encoder
|
101 |
+
self.params['text_encoder'] = FrozenDict(text_encoder_params)
|
102 |
+
del text_encoder_params
|
103 |
+
imunet, imunet_params = FlaxUNet2DConditionModel.from_pretrained(
|
104 |
+
'runwayml/stable-diffusion-v1-5',
|
105 |
+
subfolder = 'unet',
|
106 |
+
from_pt = True,
|
107 |
+
dtype = self.dtype,
|
108 |
+
use_memory_efficient_attention = True,
|
109 |
+
use_auth_token = self.hf_auth_token
|
110 |
+
)
|
111 |
+
imunet_params = castto(self.dtype, imunet, imunet_params)
|
112 |
+
self.imunet: FlaxUNet2DConditionModel = imunet
|
113 |
+
self.params['imunet'] = FrozenDict(imunet_params)
|
114 |
+
del imunet_params
|
115 |
+
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
|
116 |
+
self.model_path,
|
117 |
+
subfolder = 'tokenizer',
|
118 |
+
use_auth_token = self.hf_auth_token
|
119 |
+
)
|
120 |
+
scheduler, scheduler_state = scheduler_cls.from_pretrained(
|
121 |
+
self.model_path,
|
122 |
+
subfolder = 'scheduler',
|
123 |
+
dtype = jnp.float32,
|
124 |
+
use_auth_token = self.hf_api_key
|
125 |
+
)
|
126 |
+
self.scheduler: scheduler_cls = scheduler
|
127 |
+
self.params['scheduler'] = scheduler_state
|
128 |
+
self.vae_scale_factor: int = int(2 ** (len(self.vae.config.block_out_channels) - 1))
|
129 |
+
self.device_count = jax.device_count()
|
130 |
+
gc.collect()
|
131 |
+
|
132 |
+
def set_scheduler(self, scheduler_cls: SchedulerType) -> None:
|
133 |
+
scheduler, scheduler_state = scheduler_cls.from_pretrained(
|
134 |
+
self.model_path,
|
135 |
+
subfolder = 'scheduler',
|
136 |
+
dtype = jnp.float32,
|
137 |
+
use_auth_token = self.hf_api_key
|
138 |
+
)
|
139 |
+
self.scheduler: scheduler_cls = scheduler
|
140 |
+
self.params['scheduler'] = scheduler_state
|
141 |
+
|
142 |
+
def prepare_inputs(self,
|
143 |
+
prompt: List[str],
|
144 |
+
neg_prompt: List[str],
|
145 |
+
hint_image: List[Image.Image],
|
146 |
+
mask_image: List[Image.Image],
|
147 |
+
width: int,
|
148 |
+
height: int
|
149 |
+
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: # prompt, neg_prompt, hint_image, mask_image
|
150 |
+
tokens = self.tokenizer(
|
151 |
+
prompt,
|
152 |
+
truncation = True,
|
153 |
+
return_overflowing_tokens = False,
|
154 |
+
max_length = 77, #self.text_encoder.config.max_length defaults to 20 if its not in the config smh
|
155 |
+
padding = 'max_length',
|
156 |
+
return_tensors = 'np'
|
157 |
+
).input_ids
|
158 |
+
tokens = jnp.array(tokens, dtype = jnp.int32)
|
159 |
+
neg_tokens = self.tokenizer(
|
160 |
+
neg_prompt,
|
161 |
+
truncation = True,
|
162 |
+
return_overflowing_tokens = False,
|
163 |
+
max_length = 77,
|
164 |
+
padding = 'max_length',
|
165 |
+
return_tensors = 'np'
|
166 |
+
).input_ids
|
167 |
+
neg_tokens = jnp.array(neg_tokens, dtype = jnp.int32)
|
168 |
+
for i,im in enumerate(hint_image):
|
169 |
+
if im.size != (width, height):
|
170 |
+
hint_image[i] = hint_image[i].resize((width, height), resample = Image.Resampling.LANCZOS)
|
171 |
+
for i,im in enumerate(mask_image):
|
172 |
+
if im.size != (width, height):
|
173 |
+
mask_image[i] = mask_image[i].resize((width, height), resample = Image.Resampling.LANCZOS)
|
174 |
+
# b,h,w,c | c == 3
|
175 |
+
hint = jnp.concatenate(
|
176 |
+
[ jnp.expand_dims(np.asarray(x.convert('RGB')), axis = 0) for x in hint_image ],
|
177 |
+
axis = 0
|
178 |
+
).astype(jnp.float32)
|
179 |
+
# scale -1,1
|
180 |
+
hint = (hint / 255) * 2 - 1
|
181 |
+
# b,h,w,c | c == 1
|
182 |
+
mask = jnp.concatenate(
|
183 |
+
[ jnp.expand_dims(np.asarray(x.convert('L')), axis = (0, -1)) for x in mask_image ],
|
184 |
+
axis = 0
|
185 |
+
).astype(jnp.float32)
|
186 |
+
# scale -1,1
|
187 |
+
mask = (mask / 255) * 2 - 1
|
188 |
+
# binarize mask
|
189 |
+
mask = mask.at[mask < 0.5].set(0)
|
190 |
+
mask = mask.at[mask >= 0.5].set(1)
|
191 |
+
# mask
|
192 |
+
hint = hint * (mask < 0.5)
|
193 |
+
# b,h,w,c -> b,c,h,w
|
194 |
+
hint = hint.transpose((0,3,1,2))
|
195 |
+
mask = mask.transpose((0,3,1,2))
|
196 |
+
return tokens, neg_tokens, hint, mask
|
197 |
+
|
198 |
+
def generate(self,
|
199 |
+
prompt: Union[str, List[str]],
|
200 |
+
inference_steps: int,
|
201 |
+
hint_image: Union[Image.Image, List[Image.Image], None] = None,
|
202 |
+
mask_image: Union[Image.Image, List[Image.Image], None] = None,
|
203 |
+
neg_prompt: Union[str, List[str]] = '',
|
204 |
+
cfg: float = 10.0,
|
205 |
+
num_frames: int = 24,
|
206 |
+
width: int = 512,
|
207 |
+
height: int = 512,
|
208 |
+
seed: int = 0
|
209 |
+
) -> List[List[Image.Image]]:
|
210 |
+
assert inference_steps > 0, f'number of inference steps must be > 0 but is {inference_steps}'
|
211 |
+
assert num_frames > 0, f'number of frames must be > 0 but is {num_frames}'
|
212 |
+
assert width % 32 == 0, f'width must be divisible by 32 but is {width}'
|
213 |
+
assert height % 32 == 0, f'height must be divisible by 32 but is {height}'
|
214 |
+
if isinstance(prompt, str):
|
215 |
+
prompt = [ prompt ]
|
216 |
+
batch_size = len(prompt)
|
217 |
+
assert batch_size % self.device_count == 0, f'batch size must be multiple of {self.device_count}'
|
218 |
+
if hint_image is None:
|
219 |
+
hint_image = Image.new('RGB', (width, height), color = (0,0,0))
|
220 |
+
use_imagegen = True
|
221 |
+
else:
|
222 |
+
use_imagegen = False
|
223 |
+
if isinstance(hint_image, Image.Image):
|
224 |
+
hint_image = [ hint_image ] * batch_size
|
225 |
+
assert len(hint_image) == batch_size, f'number of hint images must be equal to batch size {batch_size} but is {len(hint_image)}'
|
226 |
+
if mask_image is None:
|
227 |
+
mask_image = Image.new('L', hint_image[0].size, color = 0)
|
228 |
+
if isinstance(mask_image, Image.Image):
|
229 |
+
mask_image = [ mask_image ] * batch_size
|
230 |
+
assert len(mask_image) == batch_size, f'number of mask images must be equal to batch size {batch_size} but is {len(mask_image)}'
|
231 |
+
if isinstance(neg_prompt, str):
|
232 |
+
neg_prompt = [ neg_prompt ] * batch_size
|
233 |
+
assert len(neg_prompt) == batch_size, f'number of negative prompts must be equal to batch size {batch_size} but is {len(neg_prompt)}'
|
234 |
+
tokens, neg_tokens, hint, mask = self.prepare_inputs(
|
235 |
+
prompt = prompt,
|
236 |
+
neg_prompt = neg_prompt,
|
237 |
+
hint_image = hint_image,
|
238 |
+
mask_image = mask_image,
|
239 |
+
width = width,
|
240 |
+
height = height
|
241 |
+
)
|
242 |
+
# NOTE splitting rngs is not deterministic,
|
243 |
+
# running on different device counts gives different seeds
|
244 |
+
#rng = jax.random.PRNGKey(seed)
|
245 |
+
#rngs = jax.random.split(rng, self.device_count)
|
246 |
+
# manually assign seeded RNGs to devices for reproducability
|
247 |
+
rngs = jnp.array([ jax.random.PRNGKey(seed + i) for i in range(self.device_count) ])
|
248 |
+
params = jax_utils.replicate(self.params)
|
249 |
+
tokens = shard(tokens)
|
250 |
+
neg_tokens = shard(neg_tokens)
|
251 |
+
hint = shard(hint)
|
252 |
+
mask = shard(mask)
|
253 |
+
images = _p_generate(self,
|
254 |
+
tokens,
|
255 |
+
neg_tokens,
|
256 |
+
hint,
|
257 |
+
mask,
|
258 |
+
inference_steps,
|
259 |
+
num_frames,
|
260 |
+
height,
|
261 |
+
width,
|
262 |
+
cfg,
|
263 |
+
rngs,
|
264 |
+
params,
|
265 |
+
use_imagegen
|
266 |
+
)
|
267 |
+
if images.ndim == 5:
|
268 |
+
images = einops.rearrange(images, 'd f c h w -> (d f) h w c')
|
269 |
+
else:
|
270 |
+
images = einops.rearrange(images, 'f c h w -> f h w c')
|
271 |
+
# to cpu
|
272 |
+
images = np.array(images)
|
273 |
+
images = [ Image.fromarray(x) for x in images ]
|
274 |
+
return images
|
275 |
+
|
276 |
+
def _generate(self,
|
277 |
+
tokens: jnp.ndarray,
|
278 |
+
neg_tokens: jnp.ndarray,
|
279 |
+
hint: jnp.ndarray,
|
280 |
+
mask: jnp.ndarray,
|
281 |
+
inference_steps: int,
|
282 |
+
num_frames,
|
283 |
+
height,
|
284 |
+
width,
|
285 |
+
cfg: float,
|
286 |
+
rng: jax.random.KeyArray,
|
287 |
+
params: Union[Dict[str, Any], FrozenDict[str, Any]],
|
288 |
+
use_imagegen: bool
|
289 |
+
) -> List[Image.Image]:
|
290 |
+
batch_size = tokens.shape[0]
|
291 |
+
latent_h = height // self.vae_scale_factor
|
292 |
+
latent_w = width // self.vae_scale_factor
|
293 |
+
latent_shape = (
|
294 |
+
batch_size,
|
295 |
+
self.vae.config.latent_channels,
|
296 |
+
num_frames,
|
297 |
+
latent_h,
|
298 |
+
latent_w
|
299 |
+
)
|
300 |
+
encoded_prompt = self.text_encoder(tokens, params = params['text_encoder'])[0]
|
301 |
+
encoded_neg_prompt = self.text_encoder(neg_tokens, params = params['text_encoder'])[0]
|
302 |
+
|
303 |
+
if use_imagegen:
|
304 |
+
image_latent_shape = (batch_size, self.vae.config.latent_channels, latent_h, latent_w)
|
305 |
+
image_latents = jax.random.normal(
|
306 |
+
rng,
|
307 |
+
shape = image_latent_shape,
|
308 |
+
dtype = jnp.float32
|
309 |
+
) * params['scheduler'].init_noise_sigma
|
310 |
+
image_scheduler_state = self.scheduler.set_timesteps(
|
311 |
+
params['scheduler'],
|
312 |
+
num_inference_steps = inference_steps,
|
313 |
+
shape = image_latents.shape
|
314 |
+
)
|
315 |
+
def image_sample_loop(step, args):
|
316 |
+
image_latents, image_scheduler_state = args
|
317 |
+
t = image_scheduler_state.timesteps[step]
|
318 |
+
tt = jnp.broadcast_to(t, image_latents.shape[0])
|
319 |
+
latents_input = self.scheduler.scale_model_input(image_scheduler_state, image_latents, t)
|
320 |
+
noise_pred = self.imunet.apply(
|
321 |
+
{'params': params['imunet']},
|
322 |
+
latents_input,
|
323 |
+
tt,
|
324 |
+
encoder_hidden_states = encoded_prompt
|
325 |
+
).sample
|
326 |
+
noise_pred_uncond = self.imunet.apply(
|
327 |
+
{'params': params['imunet']},
|
328 |
+
latents_input,
|
329 |
+
tt,
|
330 |
+
encoder_hidden_states = encoded_neg_prompt
|
331 |
+
).sample
|
332 |
+
noise_pred = noise_pred_uncond + cfg * (noise_pred - noise_pred_uncond)
|
333 |
+
image_latents, image_scheduler_state = self.scheduler.step(
|
334 |
+
image_scheduler_state,
|
335 |
+
noise_pred.astype(jnp.float32),
|
336 |
+
t,
|
337 |
+
image_latents
|
338 |
+
).to_tuple()
|
339 |
+
return image_latents, image_scheduler_state
|
340 |
+
image_latents, _ = jax.lax.fori_loop(
|
341 |
+
0, inference_steps,
|
342 |
+
image_sample_loop,
|
343 |
+
(image_latents, image_scheduler_state)
|
344 |
+
)
|
345 |
+
hint = image_latents
|
346 |
+
else:
|
347 |
+
hint = self.vae.apply(
|
348 |
+
{'params': params['vae']},
|
349 |
+
hint,
|
350 |
+
method = self.vae.encode
|
351 |
+
).latent_dist.mean * self.vae.config.scaling_factor
|
352 |
+
# NOTE vae keeps channels last for encode, but rearranges to channels first for decode
|
353 |
+
# b0 h1 w2 c3 -> b0 c3 h1 w2
|
354 |
+
hint = hint.transpose((0, 3, 1, 2))
|
355 |
+
|
356 |
+
hint = jnp.expand_dims(hint, axis = 2).repeat(num_frames, axis = 2)
|
357 |
+
mask = jax.image.resize(mask, (*mask.shape[:-2], *hint.shape[-2:]), method = 'nearest')
|
358 |
+
mask = jnp.expand_dims(mask, axis = 2).repeat(num_frames, axis = 2)
|
359 |
+
# NOTE jax normal distribution is shit with float16 + bfloat16
|
360 |
+
# SEE https://github.com/google/jax/discussions/13798
|
361 |
+
# generate random at float32
|
362 |
+
latents = jax.random.normal(
|
363 |
+
rng,
|
364 |
+
shape = latent_shape,
|
365 |
+
dtype = jnp.float32
|
366 |
+
) * params['scheduler'].init_noise_sigma
|
367 |
+
scheduler_state = self.scheduler.set_timesteps(
|
368 |
+
params['scheduler'],
|
369 |
+
num_inference_steps = inference_steps,
|
370 |
+
shape = latents.shape
|
371 |
+
)
|
372 |
+
|
373 |
+
def sample_loop(step, args):
|
374 |
+
latents, scheduler_state = args
|
375 |
+
t = scheduler_state.timesteps[step]#jnp.array(scheduler_state.timesteps, dtype = jnp.int32)[step]
|
376 |
+
tt = jnp.broadcast_to(t, latents.shape[0])
|
377 |
+
latents_input = self.scheduler.scale_model_input(scheduler_state, latents, t)
|
378 |
+
latents_input = jnp.concatenate([latents_input, mask, hint], axis = 1)
|
379 |
+
noise_pred = self.unet.apply(
|
380 |
+
{ 'params': params['unet'] },
|
381 |
+
latents_input,
|
382 |
+
tt,
|
383 |
+
encoded_prompt
|
384 |
+
).sample
|
385 |
+
noise_pred_uncond = self.unet.apply(
|
386 |
+
{ 'params': params['unet'] },
|
387 |
+
latents_input,
|
388 |
+
tt,
|
389 |
+
encoded_neg_prompt
|
390 |
+
).sample
|
391 |
+
noise_pred = noise_pred_uncond + cfg * (noise_pred - noise_pred_uncond)
|
392 |
+
latents, scheduler_state = self.scheduler.step(
|
393 |
+
scheduler_state,
|
394 |
+
noise_pred.astype(jnp.float32),
|
395 |
+
t,
|
396 |
+
latents
|
397 |
+
).to_tuple()
|
398 |
+
return latents, scheduler_state
|
399 |
+
|
400 |
+
latents, _ = jax.lax.fori_loop(
|
401 |
+
0, inference_steps,
|
402 |
+
sample_loop,
|
403 |
+
(latents, scheduler_state)
|
404 |
+
)
|
405 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
406 |
+
latents = einops.rearrange(latents, 'b c f h w -> (b f) c h w')
|
407 |
+
num_images = len(latents)
|
408 |
+
images_out = jnp.zeros(
|
409 |
+
(
|
410 |
+
num_images,
|
411 |
+
self.vae.config.out_channels,
|
412 |
+
height,
|
413 |
+
width
|
414 |
+
),
|
415 |
+
dtype = self.dtype
|
416 |
+
)
|
417 |
+
def decode_loop(step, images_out):
|
418 |
+
# NOTE vae keeps channels last for encode, but rearranges to channels first for decode
|
419 |
+
im = self.vae.apply(
|
420 |
+
{ 'params': params['vae'] },
|
421 |
+
jnp.expand_dims(latents[step], axis = 0),
|
422 |
+
method = self.vae.decode
|
423 |
+
).sample
|
424 |
+
images_out = images_out.at[step].set(im[0])
|
425 |
+
return images_out
|
426 |
+
images_out = jax.lax.fori_loop(0, num_images, decode_loop, images_out)
|
427 |
+
images_out = ((images_out / 2 + 0.5) * 255).round().clip(0, 255).astype(jnp.uint8)
|
428 |
+
return images_out
|
429 |
+
|
430 |
+
|
431 |
+
@partial(
|
432 |
+
jax.pmap,
|
433 |
+
in_axes = ( # 0 -> split across batch dim, None -> duplicate
|
434 |
+
None, # 0 inference_class
|
435 |
+
0, # 1 tokens
|
436 |
+
0, # 2 neg_tokens
|
437 |
+
0, # 3 hint
|
438 |
+
0, # 4 mask
|
439 |
+
None, # 5 inference_steps
|
440 |
+
None, # 6 num_frames
|
441 |
+
None, # 7 height
|
442 |
+
None, # 8 width
|
443 |
+
None, # 9 cfg
|
444 |
+
0, # 10 rng
|
445 |
+
0, # 11 params
|
446 |
+
None, # 12 use_imagegen
|
447 |
+
),
|
448 |
+
static_broadcasted_argnums = ( # trigger recompilation on change
|
449 |
+
0, # inference_class
|
450 |
+
5, # inference_steps
|
451 |
+
6, # num_frames
|
452 |
+
7, # height
|
453 |
+
8, # width
|
454 |
+
12, # use_imagegen
|
455 |
+
)
|
456 |
+
)
|
457 |
+
def _p_generate(
|
458 |
+
inference_class: InferenceUNetPseudo3D,
|
459 |
+
tokens,
|
460 |
+
neg_tokens,
|
461 |
+
hint,
|
462 |
+
mask,
|
463 |
+
inference_steps,
|
464 |
+
num_frames,
|
465 |
+
height,
|
466 |
+
width,
|
467 |
+
cfg,
|
468 |
+
rng,
|
469 |
+
params,
|
470 |
+
use_imagegen
|
471 |
+
):
|
472 |
+
return inference_class._generate(
|
473 |
+
tokens,
|
474 |
+
neg_tokens,
|
475 |
+
hint,
|
476 |
+
mask,
|
477 |
+
inference_steps,
|
478 |
+
num_frames,
|
479 |
+
height,
|
480 |
+
width,
|
481 |
+
cfg,
|
482 |
+
rng,
|
483 |
+
params,
|
484 |
+
use_imagegen
|
485 |
+
)
|
486 |
+
|
makeavid_sd/makeavid_sd/torch_impl/__init__.py
ADDED
File without changes
|
makeavid_sd/makeavid_sd/torch_impl/torch_attention_pseudo3d.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
from diffusers.models.attention_processor import Attention as CrossAttention
|
10 |
+
#from torch_cross_attention import CrossAttention
|
11 |
+
|
12 |
+
|
13 |
+
class TransformerPseudo3DModelOutput:
|
14 |
+
def __init__(self, sample: torch.FloatTensor) -> None:
|
15 |
+
self.sample = sample
|
16 |
+
|
17 |
+
|
18 |
+
class TransformerPseudo3DModel(nn.Module):
|
19 |
+
def __init__(self,
|
20 |
+
num_attention_heads: int = 16,
|
21 |
+
attention_head_dim: int = 88,
|
22 |
+
in_channels: Optional[int] = None,
|
23 |
+
num_layers: int = 1,
|
24 |
+
dropout: float = 0.0,
|
25 |
+
norm_num_groups: int = 32,
|
26 |
+
cross_attention_dim: Optional[int] = None,
|
27 |
+
attention_bias: bool = False
|
28 |
+
) -> None:
|
29 |
+
super().__init__()
|
30 |
+
self.num_attention_heads = num_attention_heads
|
31 |
+
self.attention_head_dim = attention_head_dim
|
32 |
+
inner_dim = num_attention_heads * attention_head_dim
|
33 |
+
|
34 |
+
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
35 |
+
# Define whether input is continuous or discrete depending on configuration
|
36 |
+
# its continuous
|
37 |
+
|
38 |
+
# 2. Define input layers
|
39 |
+
self.in_channels = in_channels
|
40 |
+
|
41 |
+
self.norm = torch.nn.GroupNorm(
|
42 |
+
num_groups = norm_num_groups,
|
43 |
+
num_channels = in_channels,
|
44 |
+
eps = 1e-6,
|
45 |
+
affine = True
|
46 |
+
)
|
47 |
+
self.proj_in = nn.Conv2d(
|
48 |
+
in_channels,
|
49 |
+
inner_dim,
|
50 |
+
kernel_size = 1,
|
51 |
+
stride = 1,
|
52 |
+
padding = 0
|
53 |
+
)
|
54 |
+
|
55 |
+
# 3. Define transformers blocks
|
56 |
+
self.transformer_blocks = nn.ModuleList(
|
57 |
+
[
|
58 |
+
BasicTransformerBlock(
|
59 |
+
inner_dim,
|
60 |
+
num_attention_heads,
|
61 |
+
attention_head_dim,
|
62 |
+
dropout = dropout,
|
63 |
+
cross_attention_dim = cross_attention_dim,
|
64 |
+
attention_bias = attention_bias,
|
65 |
+
)
|
66 |
+
for _ in range(num_layers)
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
# 4. Define output layers
|
71 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size = 1, stride = 1, padding = 0)
|
72 |
+
|
73 |
+
def forward(self,
|
74 |
+
hidden_states: torch.Tensor,
|
75 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
76 |
+
timestep: torch.long = None
|
77 |
+
) -> TransformerPseudo3DModelOutput:
|
78 |
+
"""
|
79 |
+
Args:
|
80 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
81 |
+
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
82 |
+
hidden_states
|
83 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
|
84 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
85 |
+
self-attention.
|
86 |
+
timestep ( `torch.long`, *optional*):
|
87 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
88 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
89 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
|
93 |
+
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
|
94 |
+
tensor.
|
95 |
+
"""
|
96 |
+
b, c, *_, h, w = hidden_states.shape
|
97 |
+
is_video = hidden_states.ndim == 5
|
98 |
+
f = None
|
99 |
+
if is_video:
|
100 |
+
b, c, f, h, w = hidden_states.shape
|
101 |
+
hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w')
|
102 |
+
#encoder_hidden_states = encoder_hidden_states.repeat_interleave(f, 0)
|
103 |
+
|
104 |
+
# 1. Input
|
105 |
+
batch, channel, height, weight = hidden_states.shape
|
106 |
+
residual = hidden_states
|
107 |
+
hidden_states = self.norm(hidden_states)
|
108 |
+
hidden_states = self.proj_in(hidden_states)
|
109 |
+
inner_dim = hidden_states.shape[1]
|
110 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
111 |
+
|
112 |
+
# 2. Blocks
|
113 |
+
for block in self.transformer_blocks:
|
114 |
+
hidden_states = block(
|
115 |
+
hidden_states,
|
116 |
+
context = encoder_hidden_states,
|
117 |
+
timestep = timestep,
|
118 |
+
frames_length = f,
|
119 |
+
height = height,
|
120 |
+
weight = weight
|
121 |
+
)
|
122 |
+
|
123 |
+
# 3. Output
|
124 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
125 |
+
hidden_states = self.proj_out(hidden_states)
|
126 |
+
output = hidden_states + residual
|
127 |
+
|
128 |
+
if is_video:
|
129 |
+
output = rearrange(output, '(b f) c h w -> b c f h w', b = b)
|
130 |
+
|
131 |
+
return TransformerPseudo3DModelOutput(sample = output)
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
class BasicTransformerBlock(nn.Module):
|
136 |
+
r"""
|
137 |
+
A basic Transformer block.
|
138 |
+
|
139 |
+
Parameters:
|
140 |
+
dim (`int`): The number of channels in the input and output.
|
141 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
142 |
+
attention_head_dim (`int`): The number of channels in each head.
|
143 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
144 |
+
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
|
145 |
+
num_embeds_ada_norm (:
|
146 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
147 |
+
attention_bias (:
|
148 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
149 |
+
"""
|
150 |
+
|
151 |
+
def __init__(self,
|
152 |
+
dim: int,
|
153 |
+
num_attention_heads: int,
|
154 |
+
attention_head_dim: int,
|
155 |
+
dropout: float = 0.0,
|
156 |
+
cross_attention_dim: Optional[int] = None,
|
157 |
+
attention_bias: bool = False,
|
158 |
+
) -> None:
|
159 |
+
super().__init__()
|
160 |
+
self.attn1 = CrossAttention(
|
161 |
+
query_dim = dim,
|
162 |
+
heads = num_attention_heads,
|
163 |
+
dim_head = attention_head_dim,
|
164 |
+
dropout = dropout,
|
165 |
+
bias = attention_bias
|
166 |
+
) # is a self-attention
|
167 |
+
self.ff = FeedForward(dim, dropout = dropout)
|
168 |
+
self.attn2 = CrossAttention(
|
169 |
+
query_dim = dim,
|
170 |
+
cross_attention_dim = cross_attention_dim,
|
171 |
+
heads = num_attention_heads,
|
172 |
+
dim_head = attention_head_dim,
|
173 |
+
dropout = dropout,
|
174 |
+
bias = attention_bias
|
175 |
+
) # is self-attn if context is none
|
176 |
+
self.attn_temporal = CrossAttention(
|
177 |
+
query_dim = dim,
|
178 |
+
heads = num_attention_heads,
|
179 |
+
dim_head = attention_head_dim,
|
180 |
+
dropout = dropout,
|
181 |
+
bias = attention_bias
|
182 |
+
) # is a self-attention
|
183 |
+
|
184 |
+
# layer norms
|
185 |
+
self.norm1 = nn.LayerNorm(dim)
|
186 |
+
self.norm2 = nn.LayerNorm(dim)
|
187 |
+
self.norm_temporal = nn.LayerNorm(dim)
|
188 |
+
self.norm3 = nn.LayerNorm(dim)
|
189 |
+
|
190 |
+
def forward(self,
|
191 |
+
hidden_states: torch.Tensor,
|
192 |
+
context: Optional[torch.Tensor] = None,
|
193 |
+
timestep: torch.int64 = None,
|
194 |
+
frames_length: Optional[int] = None,
|
195 |
+
height: Optional[int] = None,
|
196 |
+
weight: Optional[int] = None
|
197 |
+
) -> torch.Tensor:
|
198 |
+
if context is not None and frames_length is not None:
|
199 |
+
context = context.repeat_interleave(frames_length, 0)
|
200 |
+
# 1. Self-Attention
|
201 |
+
norm_hidden_states = (
|
202 |
+
self.norm1(hidden_states)
|
203 |
+
)
|
204 |
+
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
205 |
+
|
206 |
+
# 2. Cross-Attention
|
207 |
+
norm_hidden_states = (
|
208 |
+
self.norm2(hidden_states)
|
209 |
+
)
|
210 |
+
hidden_states = self.attn2(
|
211 |
+
norm_hidden_states,
|
212 |
+
encoder_hidden_states = context
|
213 |
+
) + hidden_states
|
214 |
+
|
215 |
+
# append temporal attention
|
216 |
+
if frames_length is not None:
|
217 |
+
hidden_states = rearrange(
|
218 |
+
hidden_states,
|
219 |
+
'(b f) (h w) c -> (b h w) f c',
|
220 |
+
f = frames_length,
|
221 |
+
h = height,
|
222 |
+
w = weight
|
223 |
+
)
|
224 |
+
norm_hidden_states = (
|
225 |
+
self.norm_temporal(hidden_states)
|
226 |
+
)
|
227 |
+
hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
|
228 |
+
hidden_states = rearrange(
|
229 |
+
hidden_states,
|
230 |
+
'(b h w) f c -> (b f) (h w) c',
|
231 |
+
f = frames_length,
|
232 |
+
h = height,
|
233 |
+
w = weight
|
234 |
+
)
|
235 |
+
|
236 |
+
# 3. Feed-forward
|
237 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
238 |
+
return hidden_states
|
239 |
+
|
240 |
+
|
241 |
+
class FeedForward(nn.Module):
|
242 |
+
r"""
|
243 |
+
A feed-forward layer.
|
244 |
+
|
245 |
+
Parameters:
|
246 |
+
dim (`int`): The number of channels in the input.
|
247 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
248 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
249 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
250 |
+
"""
|
251 |
+
|
252 |
+
def __init__(self,
|
253 |
+
dim: int,
|
254 |
+
dim_out: Optional[int] = None,
|
255 |
+
mult: int = 4,
|
256 |
+
dropout: float = 0.0
|
257 |
+
) -> None:
|
258 |
+
super().__init__()
|
259 |
+
inner_dim = int(dim * mult)
|
260 |
+
dim_out = dim_out if dim_out is not None else dim
|
261 |
+
|
262 |
+
geglu = GEGLU(dim, inner_dim)
|
263 |
+
|
264 |
+
self.net = nn.ModuleList([])
|
265 |
+
# project in
|
266 |
+
self.net.append(geglu)
|
267 |
+
# project dropout
|
268 |
+
self.net.append(nn.Dropout(dropout))
|
269 |
+
# project out
|
270 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
271 |
+
|
272 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
273 |
+
for module in self.net:
|
274 |
+
hidden_states = module(hidden_states)
|
275 |
+
return hidden_states
|
276 |
+
|
277 |
+
|
278 |
+
# feedforward
|
279 |
+
class GEGLU(nn.Module):
|
280 |
+
r"""
|
281 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
282 |
+
|
283 |
+
Parameters:
|
284 |
+
dim_in (`int`): The number of channels in the input.
|
285 |
+
dim_out (`int`): The number of channels in the output.
|
286 |
+
"""
|
287 |
+
|
288 |
+
def __init__(self, dim_in: int, dim_out: int) -> None:
|
289 |
+
super().__init__()
|
290 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
291 |
+
|
292 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
293 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim = -1)
|
294 |
+
return hidden_states * F.gelu(gate)
|
makeavid_sd/makeavid_sd/torch_impl/torch_cross_attention.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class CrossAttention(nn.Module):
|
6 |
+
r"""
|
7 |
+
A cross attention layer.
|
8 |
+
|
9 |
+
Parameters:
|
10 |
+
query_dim (`int`): The number of channels in the query.
|
11 |
+
cross_attention_dim (`int`, *optional*):
|
12 |
+
The number of channels in the context. If not given, defaults to `query_dim`.
|
13 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
14 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
15 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
16 |
+
bias (`bool`, *optional*, defaults to False):
|
17 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
query_dim: int,
|
22 |
+
cross_attention_dim: Optional[int] = None,
|
23 |
+
heads: int = 8,
|
24 |
+
dim_head: int = 64,
|
25 |
+
dropout: float = 0.0,
|
26 |
+
bias: bool = False
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
inner_dim = dim_head * heads
|
30 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
31 |
+
|
32 |
+
self.scale = dim_head**-0.5
|
33 |
+
self.heads = heads
|
34 |
+
self.n_heads = heads
|
35 |
+
self.d_head = dim_head
|
36 |
+
|
37 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias = bias)
|
38 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias = bias)
|
39 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias = bias)
|
40 |
+
|
41 |
+
self.to_out = nn.ModuleList([])
|
42 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
43 |
+
self.to_out.append(nn.Dropout(dropout))
|
44 |
+
try:
|
45 |
+
# You can install flash attention by cloning their Github repo,
|
46 |
+
# [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
|
47 |
+
# and then running `python setup.py install`
|
48 |
+
from flash_attn.flash_attention import FlashAttention
|
49 |
+
self.flash = FlashAttention()
|
50 |
+
# Set the scale for scaled dot-product attention.
|
51 |
+
self.flash.softmax_scale = self.scale
|
52 |
+
# Set to `None` if it's not installed
|
53 |
+
except ImportError:
|
54 |
+
self.flash = None
|
55 |
+
|
56 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
57 |
+
batch_size, seq_len, dim = tensor.shape
|
58 |
+
head_size = self.heads
|
59 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
60 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
61 |
+
return tensor
|
62 |
+
|
63 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
64 |
+
batch_size, seq_len, dim = tensor.shape
|
65 |
+
head_size = self.heads
|
66 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
67 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
68 |
+
return tensor
|
69 |
+
|
70 |
+
def forward(self,
|
71 |
+
hidden_states: torch.Tensor,
|
72 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
73 |
+
mask: Optional[torch.Tensor] = None
|
74 |
+
) -> torch.Tensor:
|
75 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
76 |
+
is_self = encoder_hidden_states is None
|
77 |
+
# attention, what we cannot get enough of
|
78 |
+
query = self.to_q(hidden_states)
|
79 |
+
has_cond = encoder_hidden_states is not None
|
80 |
+
|
81 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
82 |
+
key = self.to_k(encoder_hidden_states)
|
83 |
+
value = self.to_v(encoder_hidden_states)
|
84 |
+
|
85 |
+
dim = query.shape[-1]
|
86 |
+
|
87 |
+
if self.flash is not None and not has_cond and self.d_head <= 64:
|
88 |
+
hidden_states = self.flash_attention(query, key, value)
|
89 |
+
else:
|
90 |
+
hidden_states = self.normal_attention(query, key, value, is_self)
|
91 |
+
|
92 |
+
# linear proj
|
93 |
+
hidden_states = self.to_out[0](hidden_states)
|
94 |
+
# dropout
|
95 |
+
hidden_states = self.to_out[1](hidden_states)
|
96 |
+
return hidden_states
|
97 |
+
|
98 |
+
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
99 |
+
"""
|
100 |
+
#### Flash Attention
|
101 |
+
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
102 |
+
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
103 |
+
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
104 |
+
"""
|
105 |
+
|
106 |
+
# Get batch size and number of elements along sequence axis (`width * height`)
|
107 |
+
batch_size, seq_len, _ = q.shape
|
108 |
+
|
109 |
+
# Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
|
110 |
+
# shape `[batch_size, seq_len, 3, n_heads * d_head]`
|
111 |
+
qkv = torch.stack((q, k, v), dim = 2)
|
112 |
+
# Split the heads
|
113 |
+
qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
|
114 |
+
|
115 |
+
# Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
|
116 |
+
# fit this size.
|
117 |
+
if self.d_head <= 32:
|
118 |
+
pad = 32 - self.d_head
|
119 |
+
elif self.d_head <= 64:
|
120 |
+
pad = 64 - self.d_head
|
121 |
+
elif self.d_head <= 128:
|
122 |
+
pad = 128 - self.d_head
|
123 |
+
else:
|
124 |
+
raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
|
125 |
+
|
126 |
+
# Pad the heads
|
127 |
+
if pad:
|
128 |
+
qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim = -1)
|
129 |
+
|
130 |
+
# Compute attention
|
131 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
|
132 |
+
# This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
|
133 |
+
out, _ = self.flash(qkv)
|
134 |
+
# Truncate the extra head size
|
135 |
+
out = out[:, :, :, :self.d_head]
|
136 |
+
# Reshape to `[batch_size, seq_len, n_heads * d_head]`
|
137 |
+
out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
|
138 |
+
|
139 |
+
# Map to `[batch_size, height * width, d_model]` with a linear layer
|
140 |
+
return out
|
141 |
+
|
142 |
+
def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, is_self: bool):
|
143 |
+
"""
|
144 |
+
#### Normal Attention
|
145 |
+
|
146 |
+
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
147 |
+
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
148 |
+
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
149 |
+
"""
|
150 |
+
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
|
151 |
+
q = q.view(*q.shape[:2], self.n_heads, -1)
|
152 |
+
k = k.view(*k.shape[:2], self.n_heads, -1)
|
153 |
+
v = v.view(*v.shape[:2], self.n_heads, -1)
|
154 |
+
|
155 |
+
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
|
156 |
+
attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
|
157 |
+
# Compute softmax
|
158 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
|
159 |
+
half = attn.shape[0] // 2
|
160 |
+
attn[half:] = attn[half:].softmax(dim = -1)
|
161 |
+
attn[:half] = attn[:half].softmax(dim = -1)
|
162 |
+
|
163 |
+
# Compute attention output
|
164 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
|
165 |
+
out = torch.einsum('bhij,bjhd->bihd', attn, v)
|
166 |
+
|
167 |
+
# Reshape to `[batch_size, height * width, n_heads * d_head]`
|
168 |
+
out = out.reshape(*out.shape[:2], -1)
|
169 |
+
|
170 |
+
# Map to `[batch_size, height * width, d_model]` with a linear layer
|
171 |
+
return out
|
makeavid_sd/makeavid_sd/torch_impl/torch_embeddings.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
def get_timestep_embedding(
|
6 |
+
timesteps: torch.Tensor,
|
7 |
+
embedding_dim: int,
|
8 |
+
flip_sin_to_cos: bool = False,
|
9 |
+
downscale_freq_shift: float = 1,
|
10 |
+
scale: float = 1,
|
11 |
+
max_period: int = 10000,
|
12 |
+
) -> torch.Tensor:
|
13 |
+
"""
|
14 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
15 |
+
|
16 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
17 |
+
These may be fractional.
|
18 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
19 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
20 |
+
"""
|
21 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
22 |
+
|
23 |
+
half_dim = embedding_dim // 2
|
24 |
+
exponent = -math.log(max_period) * torch.arange(
|
25 |
+
start = 0,
|
26 |
+
end = half_dim,
|
27 |
+
dtype = torch.float32,
|
28 |
+
device = timesteps.device
|
29 |
+
)
|
30 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
31 |
+
|
32 |
+
emb = torch.exp(exponent)
|
33 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
34 |
+
|
35 |
+
# scale embeddings
|
36 |
+
emb = scale * emb
|
37 |
+
|
38 |
+
# concat sine and cosine embeddings
|
39 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim = -1)
|
40 |
+
|
41 |
+
# flip sine and cosine embeddings
|
42 |
+
if flip_sin_to_cos:
|
43 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim = -1)
|
44 |
+
|
45 |
+
# zero pad
|
46 |
+
if embedding_dim % 2 == 1:
|
47 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
48 |
+
return emb
|
49 |
+
|
50 |
+
|
51 |
+
class TimestepEmbedding(nn.Module):
|
52 |
+
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
56 |
+
self.act = None
|
57 |
+
if act_fn == "silu":
|
58 |
+
self.act = nn.SiLU()
|
59 |
+
elif act_fn == "mish":
|
60 |
+
self.act = nn.Mish()
|
61 |
+
|
62 |
+
if out_dim is not None:
|
63 |
+
time_embed_dim_out = out_dim
|
64 |
+
else:
|
65 |
+
time_embed_dim_out = time_embed_dim
|
66 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
67 |
+
|
68 |
+
def forward(self, sample):
|
69 |
+
sample = self.linear_1(sample)
|
70 |
+
|
71 |
+
if self.act is not None:
|
72 |
+
sample = self.act(sample)
|
73 |
+
|
74 |
+
sample = self.linear_2(sample)
|
75 |
+
return sample
|
76 |
+
|
77 |
+
|
78 |
+
class Timesteps(nn.Module):
|
79 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
80 |
+
super().__init__()
|
81 |
+
self.num_channels = num_channels
|
82 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
83 |
+
self.downscale_freq_shift = downscale_freq_shift
|
84 |
+
|
85 |
+
def forward(self, timesteps):
|
86 |
+
t_emb = get_timestep_embedding(
|
87 |
+
timesteps,
|
88 |
+
self.num_channels,
|
89 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
90 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
91 |
+
)
|
92 |
+
return t_emb
|
makeavid_sd/makeavid_sd/torch_impl/torch_resnet_pseudo3d.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
class Pseudo3DConv(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
dim,
|
10 |
+
dim_out,
|
11 |
+
kernel_size,
|
12 |
+
**kwargs
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size, **kwargs)
|
17 |
+
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size, padding=kernel_size // 2)
|
18 |
+
self.temporal_conv = nn.Conv1d(dim_out, dim_out, 3, padding=1)
|
19 |
+
|
20 |
+
nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
|
21 |
+
nn.init.zeros_(self.temporal_conv.bias.data)
|
22 |
+
|
23 |
+
def forward(
|
24 |
+
self,
|
25 |
+
x,
|
26 |
+
convolve_across_time = True
|
27 |
+
):
|
28 |
+
b, c, *_, h, w = x.shape
|
29 |
+
|
30 |
+
is_video = x.ndim == 5
|
31 |
+
convolve_across_time &= is_video
|
32 |
+
|
33 |
+
if is_video:
|
34 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
35 |
+
|
36 |
+
#with torch.no_grad():
|
37 |
+
# x = self.spatial_conv(x)
|
38 |
+
x = self.spatial_conv(x)
|
39 |
+
|
40 |
+
if is_video:
|
41 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b = b)
|
42 |
+
b, c, *_, h, w = x.shape
|
43 |
+
|
44 |
+
if not convolve_across_time:
|
45 |
+
return x
|
46 |
+
|
47 |
+
if is_video:
|
48 |
+
x = rearrange(x, 'b c f h w -> (b h w) c f')
|
49 |
+
x = self.temporal_conv(x)
|
50 |
+
x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)
|
51 |
+
return x
|
52 |
+
|
53 |
+
class Upsample2D(nn.Module):
|
54 |
+
"""
|
55 |
+
An upsampling layer with an optional convolution.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
channels: channels in the inputs and outputs.
|
59 |
+
use_conv: a bool determining if a convolution is applied.
|
60 |
+
use_conv_transpose:
|
61 |
+
out_channels:
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
65 |
+
super().__init__()
|
66 |
+
self.channels = channels
|
67 |
+
self.out_channels = out_channels or channels
|
68 |
+
self.use_conv = use_conv
|
69 |
+
self.use_conv_transpose = use_conv_transpose
|
70 |
+
self.name = name
|
71 |
+
|
72 |
+
conv = None
|
73 |
+
if use_conv_transpose:
|
74 |
+
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
75 |
+
elif use_conv:
|
76 |
+
conv = Pseudo3DConv(self.channels, self.out_channels, 3, padding=1)
|
77 |
+
|
78 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
79 |
+
if name == "conv":
|
80 |
+
self.conv = conv
|
81 |
+
else:
|
82 |
+
self.Conv2d_0 = conv
|
83 |
+
|
84 |
+
def forward(self, hidden_states, output_size=None):
|
85 |
+
assert hidden_states.shape[1] == self.channels
|
86 |
+
|
87 |
+
if self.use_conv_transpose:
|
88 |
+
return self.conv(hidden_states)
|
89 |
+
|
90 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
91 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
92 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
93 |
+
dtype = hidden_states.dtype
|
94 |
+
if dtype == torch.bfloat16:
|
95 |
+
hidden_states = hidden_states.to(torch.float32)
|
96 |
+
|
97 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
98 |
+
if hidden_states.shape[0] >= 64:
|
99 |
+
hidden_states = hidden_states.contiguous()
|
100 |
+
|
101 |
+
b, c, *_, h, w = hidden_states.shape
|
102 |
+
|
103 |
+
is_video = hidden_states.ndim == 5
|
104 |
+
|
105 |
+
if is_video:
|
106 |
+
hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w')
|
107 |
+
|
108 |
+
# if `output_size` is passed we force the interpolation output
|
109 |
+
# size and do not make use of `scale_factor=2`
|
110 |
+
if output_size is None:
|
111 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
112 |
+
else:
|
113 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
114 |
+
|
115 |
+
if is_video:
|
116 |
+
hidden_states = rearrange(hidden_states, '(b f) c h w -> b c f h w', b = b)
|
117 |
+
|
118 |
+
# If the input is bfloat16, we cast back to bfloat16
|
119 |
+
if dtype == torch.bfloat16:
|
120 |
+
hidden_states = hidden_states.to(dtype)
|
121 |
+
|
122 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
123 |
+
if self.use_conv:
|
124 |
+
if self.name == "conv":
|
125 |
+
hidden_states = self.conv(hidden_states)
|
126 |
+
else:
|
127 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
128 |
+
|
129 |
+
return hidden_states
|
130 |
+
|
131 |
+
|
132 |
+
class Downsample2D(nn.Module):
|
133 |
+
"""
|
134 |
+
A downsampling layer with an optional convolution.
|
135 |
+
|
136 |
+
Parameters:
|
137 |
+
channels: channels in the inputs and outputs.
|
138 |
+
use_conv: a bool determining if a convolution is applied.
|
139 |
+
out_channels:
|
140 |
+
padding:
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
144 |
+
super().__init__()
|
145 |
+
self.channels = channels
|
146 |
+
self.out_channels = out_channels or channels
|
147 |
+
self.use_conv = use_conv
|
148 |
+
self.padding = padding
|
149 |
+
stride = 2
|
150 |
+
self.name = name
|
151 |
+
|
152 |
+
if use_conv:
|
153 |
+
conv = Pseudo3DConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
154 |
+
else:
|
155 |
+
assert self.channels == self.out_channels
|
156 |
+
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
157 |
+
|
158 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
159 |
+
if name == "conv":
|
160 |
+
self.Conv2d_0 = conv
|
161 |
+
self.conv = conv
|
162 |
+
elif name == "Conv2d_0":
|
163 |
+
self.conv = conv
|
164 |
+
else:
|
165 |
+
self.conv = conv
|
166 |
+
|
167 |
+
def forward(self, hidden_states):
|
168 |
+
assert hidden_states.shape[1] == self.channels
|
169 |
+
if self.use_conv and self.padding == 0:
|
170 |
+
pad = (0, 1, 0, 1)
|
171 |
+
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
172 |
+
|
173 |
+
assert hidden_states.shape[1] == self.channels
|
174 |
+
if self.use_conv:
|
175 |
+
hidden_states = self.conv(hidden_states)
|
176 |
+
else:
|
177 |
+
b, c, *_, h, w = hidden_states.shape
|
178 |
+
is_video = hidden_states.ndim == 5
|
179 |
+
if is_video:
|
180 |
+
hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w')
|
181 |
+
hidden_states = self.conv(hidden_states)
|
182 |
+
if is_video:
|
183 |
+
hidden_states = rearrange(hidden_states, '(b f) c h w -> b c f h w', b = b)
|
184 |
+
|
185 |
+
return hidden_states
|
186 |
+
|
187 |
+
|
188 |
+
class ResnetBlockPseudo3D(nn.Module):
|
189 |
+
def __init__(
|
190 |
+
self,
|
191 |
+
*,
|
192 |
+
in_channels,
|
193 |
+
out_channels=None,
|
194 |
+
conv_shortcut=False,
|
195 |
+
dropout=0.0,
|
196 |
+
temb_channels=512,
|
197 |
+
groups=32,
|
198 |
+
groups_out=None,
|
199 |
+
pre_norm=True,
|
200 |
+
eps=1e-6,
|
201 |
+
time_embedding_norm="default",
|
202 |
+
kernel=None,
|
203 |
+
output_scale_factor=1.0,
|
204 |
+
use_in_shortcut=None,
|
205 |
+
up=False,
|
206 |
+
down=False,
|
207 |
+
):
|
208 |
+
super().__init__()
|
209 |
+
self.pre_norm = pre_norm
|
210 |
+
self.pre_norm = True
|
211 |
+
self.in_channels = in_channels
|
212 |
+
out_channels = in_channels if out_channels is None else out_channels
|
213 |
+
self.out_channels = out_channels
|
214 |
+
self.use_conv_shortcut = conv_shortcut
|
215 |
+
self.time_embedding_norm = time_embedding_norm
|
216 |
+
self.up = up
|
217 |
+
self.down = down
|
218 |
+
self.output_scale_factor = output_scale_factor
|
219 |
+
print('OUTPUT_SCALE_FACTOR:', output_scale_factor)
|
220 |
+
|
221 |
+
if groups_out is None:
|
222 |
+
groups_out = groups
|
223 |
+
|
224 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
225 |
+
|
226 |
+
self.conv1 = Pseudo3DConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
227 |
+
|
228 |
+
if temb_channels is not None:
|
229 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
230 |
+
else:
|
231 |
+
self.time_emb_proj = None
|
232 |
+
|
233 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
234 |
+
self.dropout = torch.nn.Dropout(dropout)
|
235 |
+
self.conv2 = Pseudo3DConv(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
236 |
+
|
237 |
+
self.nonlinearity = nn.SiLU()
|
238 |
+
|
239 |
+
self.upsample = self.downsample = None
|
240 |
+
if self.up:
|
241 |
+
self.upsample = Upsample2D(in_channels, use_conv=False)
|
242 |
+
elif self.down:
|
243 |
+
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
|
244 |
+
|
245 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
246 |
+
|
247 |
+
self.conv_shortcut = None
|
248 |
+
if self.use_in_shortcut:
|
249 |
+
self.conv_shortcut = Pseudo3DConv(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
250 |
+
|
251 |
+
def forward(self, input_tensor, temb):
|
252 |
+
hidden_states = input_tensor
|
253 |
+
|
254 |
+
hidden_states = self.norm1(hidden_states)
|
255 |
+
hidden_states = self.nonlinearity(hidden_states)
|
256 |
+
|
257 |
+
if self.upsample is not None:
|
258 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
259 |
+
if hidden_states.shape[0] >= 64:
|
260 |
+
input_tensor = input_tensor.contiguous()
|
261 |
+
hidden_states = hidden_states.contiguous()
|
262 |
+
input_tensor = self.upsample(input_tensor)
|
263 |
+
hidden_states = self.upsample(hidden_states)
|
264 |
+
elif self.downsample is not None:
|
265 |
+
input_tensor = self.downsample(input_tensor)
|
266 |
+
hidden_states = self.downsample(hidden_states)
|
267 |
+
|
268 |
+
hidden_states = self.conv1(hidden_states)
|
269 |
+
|
270 |
+
if temb is not None:
|
271 |
+
b, c, *_, h, w = hidden_states.shape
|
272 |
+
is_video = hidden_states.ndim == 5
|
273 |
+
if is_video:
|
274 |
+
b, c, f, h, w = hidden_states.shape
|
275 |
+
hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w')
|
276 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
277 |
+
hidden_states = hidden_states + temb.repeat_interleave(f, 0)
|
278 |
+
hidden_states = rearrange(hidden_states, '(b f) c h w -> b c f h w', b=b)
|
279 |
+
else:
|
280 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
281 |
+
hidden_states = hidden_states + temb
|
282 |
+
|
283 |
+
hidden_states = self.norm2(hidden_states)
|
284 |
+
hidden_states = self.nonlinearity(hidden_states)
|
285 |
+
|
286 |
+
hidden_states = self.dropout(hidden_states)
|
287 |
+
hidden_states = self.conv2(hidden_states)
|
288 |
+
|
289 |
+
if self.conv_shortcut is not None:
|
290 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
291 |
+
|
292 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
293 |
+
|
294 |
+
return output_tensor
|
295 |
+
|
makeavid_sd/makeavid_sd/torch_impl/torch_unet_pseudo3d_blocks.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, Optional
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from torch_attention_pseudo3d import TransformerPseudo3DModel
|
6 |
+
from torch_resnet_pseudo3d import Downsample2D, ResnetBlockPseudo3D, Upsample2D
|
7 |
+
|
8 |
+
|
9 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
10 |
+
def __init__(self,
|
11 |
+
in_channels: int,
|
12 |
+
temb_channels: int,
|
13 |
+
dropout: float = 0.0,
|
14 |
+
num_layers: int = 1,
|
15 |
+
resnet_eps: float = 1e-6,
|
16 |
+
resnet_time_scale_shift: str = "default",
|
17 |
+
resnet_act_fn: str = "swish",
|
18 |
+
resnet_groups: Optional[int] = 32,
|
19 |
+
resnet_pre_norm: bool = True,
|
20 |
+
attn_num_head_channels: int = 1,
|
21 |
+
attention_type: str = "default",
|
22 |
+
output_scale_factor: float =1.0,
|
23 |
+
cross_attention_dim: int = 1280,
|
24 |
+
**kwargs
|
25 |
+
) -> None:
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.attention_type = attention_type
|
29 |
+
self.attn_num_head_channels = attn_num_head_channels
|
30 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
31 |
+
|
32 |
+
# there is always at least one resnet
|
33 |
+
resnets = [
|
34 |
+
ResnetBlockPseudo3D(
|
35 |
+
in_channels = in_channels,
|
36 |
+
out_channels = in_channels,
|
37 |
+
temb_channels = temb_channels,
|
38 |
+
eps = resnet_eps,
|
39 |
+
groups = resnet_groups,
|
40 |
+
dropout = dropout,
|
41 |
+
time_embedding_norm = resnet_time_scale_shift,
|
42 |
+
#non_linearity = resnet_act_fn,
|
43 |
+
output_scale_factor = output_scale_factor,
|
44 |
+
pre_norm = resnet_pre_norm
|
45 |
+
)
|
46 |
+
]
|
47 |
+
attentions = []
|
48 |
+
|
49 |
+
for _ in range(num_layers):
|
50 |
+
attentions.append(
|
51 |
+
TransformerPseudo3DModel(
|
52 |
+
in_channels = in_channels,
|
53 |
+
num_attention_heads = attn_num_head_channels,
|
54 |
+
attention_head_dim = in_channels // attn_num_head_channels,
|
55 |
+
num_layers = 1,
|
56 |
+
cross_attention_dim = cross_attention_dim,
|
57 |
+
norm_num_groups = resnet_groups
|
58 |
+
)
|
59 |
+
)
|
60 |
+
resnets.append(
|
61 |
+
ResnetBlockPseudo3D(
|
62 |
+
in_channels = in_channels,
|
63 |
+
out_channels = in_channels,
|
64 |
+
temb_channels = temb_channels,
|
65 |
+
eps = resnet_eps,
|
66 |
+
groups = resnet_groups,
|
67 |
+
dropout = dropout,
|
68 |
+
time_embedding_norm = resnet_time_scale_shift,
|
69 |
+
#non_linearity = resnet_act_fn,
|
70 |
+
output_scale_factor = output_scale_factor,
|
71 |
+
pre_norm = resnet_pre_norm
|
72 |
+
)
|
73 |
+
)
|
74 |
+
|
75 |
+
self.attentions = nn.ModuleList(attentions)
|
76 |
+
self.resnets = nn.ModuleList(resnets)
|
77 |
+
|
78 |
+
def forward(self, hidden_states, temb = None, encoder_hidden_states = None):
|
79 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
80 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
81 |
+
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
82 |
+
hidden_states = resnet(hidden_states, temb)
|
83 |
+
|
84 |
+
return hidden_states
|
85 |
+
|
86 |
+
|
87 |
+
class CrossAttnDownBlock2D(nn.Module):
|
88 |
+
def __init__(self,
|
89 |
+
in_channels: int,
|
90 |
+
out_channels: int,
|
91 |
+
temb_channels: int,
|
92 |
+
dropout: float = 0.0,
|
93 |
+
num_layers: int = 1,
|
94 |
+
resnet_eps: float = 1e-6,
|
95 |
+
resnet_time_scale_shift: str = "default",
|
96 |
+
resnet_act_fn: str = "swish",
|
97 |
+
resnet_groups: int = 32,
|
98 |
+
resnet_pre_norm: bool = True,
|
99 |
+
attn_num_head_channels: int = 1,
|
100 |
+
cross_attention_dim: int = 1280,
|
101 |
+
attention_type: str = "default",
|
102 |
+
output_scale_factor: float = 1.0,
|
103 |
+
downsample_padding: int = 1,
|
104 |
+
add_downsample: bool = True
|
105 |
+
):
|
106 |
+
super().__init__()
|
107 |
+
resnets = []
|
108 |
+
attentions = []
|
109 |
+
|
110 |
+
self.attention_type = attention_type
|
111 |
+
self.attn_num_head_channels = attn_num_head_channels
|
112 |
+
|
113 |
+
for i in range(num_layers):
|
114 |
+
in_channels = in_channels if i == 0 else out_channels
|
115 |
+
resnets.append(
|
116 |
+
ResnetBlockPseudo3D(
|
117 |
+
in_channels = in_channels,
|
118 |
+
out_channels = out_channels,
|
119 |
+
temb_channels = temb_channels,
|
120 |
+
eps = resnet_eps,
|
121 |
+
groups = resnet_groups,
|
122 |
+
dropout = dropout,
|
123 |
+
time_embedding_norm = resnet_time_scale_shift,
|
124 |
+
#non_linearity = resnet_act_fn,
|
125 |
+
output_scale_factor = output_scale_factor,
|
126 |
+
pre_norm = resnet_pre_norm
|
127 |
+
)
|
128 |
+
)
|
129 |
+
attentions.append(
|
130 |
+
TransformerPseudo3DModel(
|
131 |
+
in_channels = out_channels,
|
132 |
+
num_attention_heads = attn_num_head_channels,
|
133 |
+
attention_head_dim = out_channels // attn_num_head_channels,
|
134 |
+
num_layers = 1,
|
135 |
+
cross_attention_dim = cross_attention_dim,
|
136 |
+
norm_num_groups = resnet_groups
|
137 |
+
)
|
138 |
+
)
|
139 |
+
self.attentions = nn.ModuleList(attentions)
|
140 |
+
self.resnets = nn.ModuleList(resnets)
|
141 |
+
|
142 |
+
if add_downsample:
|
143 |
+
self.downsamplers = nn.ModuleList(
|
144 |
+
[
|
145 |
+
Downsample2D(
|
146 |
+
out_channels,
|
147 |
+
use_conv = True,
|
148 |
+
out_channels = out_channels,
|
149 |
+
padding = downsample_padding,
|
150 |
+
name = "op"
|
151 |
+
)
|
152 |
+
]
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
self.downsamplers = None
|
156 |
+
|
157 |
+
def forward(self, hidden_states, temb = None, encoder_hidden_states = None):
|
158 |
+
output_states = ()
|
159 |
+
|
160 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
161 |
+
hidden_states = resnet(hidden_states, temb)
|
162 |
+
hidden_states = attn(hidden_states, encoder_hidden_states = encoder_hidden_states).sample
|
163 |
+
|
164 |
+
output_states += (hidden_states,)
|
165 |
+
|
166 |
+
if self.downsamplers is not None:
|
167 |
+
for downsampler in self.downsamplers:
|
168 |
+
hidden_states = downsampler(hidden_states)
|
169 |
+
|
170 |
+
output_states += (hidden_states,)
|
171 |
+
|
172 |
+
return hidden_states, output_states
|
173 |
+
|
174 |
+
|
175 |
+
class DownBlock2D(nn.Module):
|
176 |
+
def __init__(self,
|
177 |
+
in_channels: int,
|
178 |
+
out_channels: int,
|
179 |
+
temb_channels: int,
|
180 |
+
dropout: float = 0.0,
|
181 |
+
num_layers: int = 1,
|
182 |
+
resnet_eps: float = 1e-6,
|
183 |
+
resnet_time_scale_shift: str = "default",
|
184 |
+
resnet_act_fn: str = "swish",
|
185 |
+
resnet_groups: int = 32,
|
186 |
+
resnet_pre_norm: bool = True,
|
187 |
+
output_scale_factor: float = 1.0,
|
188 |
+
add_downsample: bool = True,
|
189 |
+
downsample_padding: int = 1
|
190 |
+
) -> None:
|
191 |
+
super().__init__()
|
192 |
+
resnets = []
|
193 |
+
|
194 |
+
for i in range(num_layers):
|
195 |
+
in_channels = in_channels if i == 0 else out_channels
|
196 |
+
resnets.append(
|
197 |
+
ResnetBlockPseudo3D(
|
198 |
+
in_channels = in_channels,
|
199 |
+
out_channels = out_channels,
|
200 |
+
temb_channels = temb_channels,
|
201 |
+
eps = resnet_eps,
|
202 |
+
groups = resnet_groups,
|
203 |
+
dropout = dropout,
|
204 |
+
time_embedding_norm = resnet_time_scale_shift,
|
205 |
+
#non_linearity = resnet_act_fn,
|
206 |
+
output_scale_factor = output_scale_factor,
|
207 |
+
pre_norm = resnet_pre_norm
|
208 |
+
)
|
209 |
+
)
|
210 |
+
|
211 |
+
self.resnets = nn.ModuleList(resnets)
|
212 |
+
|
213 |
+
if add_downsample:
|
214 |
+
self.downsamplers = nn.ModuleList(
|
215 |
+
[
|
216 |
+
Downsample2D(
|
217 |
+
out_channels,
|
218 |
+
use_conv = True,
|
219 |
+
out_channels = out_channels,
|
220 |
+
padding = downsample_padding,
|
221 |
+
name = "op"
|
222 |
+
)
|
223 |
+
]
|
224 |
+
)
|
225 |
+
else:
|
226 |
+
self.downsamplers = None
|
227 |
+
|
228 |
+
|
229 |
+
def forward(self, hidden_states, temb = None):
|
230 |
+
output_states = ()
|
231 |
+
|
232 |
+
for resnet in self.resnets:
|
233 |
+
hidden_states = resnet(hidden_states, temb)
|
234 |
+
|
235 |
+
output_states += (hidden_states,)
|
236 |
+
|
237 |
+
if self.downsamplers is not None:
|
238 |
+
for downsampler in self.downsamplers:
|
239 |
+
hidden_states = downsampler(hidden_states)
|
240 |
+
|
241 |
+
output_states += (hidden_states,)
|
242 |
+
|
243 |
+
return hidden_states, output_states
|
244 |
+
|
245 |
+
|
246 |
+
class CrossAttnUpBlock2D(nn.Module):
|
247 |
+
def __init__(self,
|
248 |
+
in_channels: int,
|
249 |
+
out_channels: int,
|
250 |
+
prev_output_channel: int,
|
251 |
+
temb_channels: int,
|
252 |
+
dropout: float = 0.0,
|
253 |
+
num_layers: int = 1,
|
254 |
+
resnet_eps: float = 1e-6,
|
255 |
+
resnet_time_scale_shift: str = "default",
|
256 |
+
resnet_act_fn: str = "swish",
|
257 |
+
resnet_groups: int = 32,
|
258 |
+
resnet_pre_norm: bool = True,
|
259 |
+
attn_num_head_channels: int = 1,
|
260 |
+
cross_attention_dim: int = 1280,
|
261 |
+
attention_type: str = "default",
|
262 |
+
output_scale_factor: float = 1.0,
|
263 |
+
add_upsample: bool = True
|
264 |
+
) -> None:
|
265 |
+
super().__init__()
|
266 |
+
resnets = []
|
267 |
+
attentions = []
|
268 |
+
|
269 |
+
self.attention_type = attention_type
|
270 |
+
self.attn_num_head_channels = attn_num_head_channels
|
271 |
+
|
272 |
+
for i in range(num_layers):
|
273 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
274 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
275 |
+
|
276 |
+
resnets.append(
|
277 |
+
ResnetBlockPseudo3D(
|
278 |
+
in_channels = resnet_in_channels + res_skip_channels,
|
279 |
+
out_channels = out_channels,
|
280 |
+
temb_channels = temb_channels,
|
281 |
+
eps = resnet_eps,
|
282 |
+
groups = resnet_groups,
|
283 |
+
dropout = dropout,
|
284 |
+
time_embedding_norm = resnet_time_scale_shift,
|
285 |
+
#non_linearity = resnet_act_fn,
|
286 |
+
output_scale_factor = output_scale_factor,
|
287 |
+
pre_norm = resnet_pre_norm
|
288 |
+
)
|
289 |
+
)
|
290 |
+
attentions.append(
|
291 |
+
TransformerPseudo3DModel(
|
292 |
+
in_channels = out_channels,
|
293 |
+
num_attention_heads = attn_num_head_channels,
|
294 |
+
attention_head_dim = out_channels // attn_num_head_channels,
|
295 |
+
num_layers = 1,
|
296 |
+
cross_attention_dim = cross_attention_dim,
|
297 |
+
norm_num_groups = resnet_groups
|
298 |
+
)
|
299 |
+
)
|
300 |
+
self.attentions = nn.ModuleList(attentions)
|
301 |
+
self.resnets = nn.ModuleList(resnets)
|
302 |
+
|
303 |
+
if add_upsample:
|
304 |
+
self.upsamplers = nn.ModuleList([
|
305 |
+
Upsample2D(
|
306 |
+
out_channels,
|
307 |
+
use_conv = True,
|
308 |
+
out_channels = out_channels
|
309 |
+
)
|
310 |
+
])
|
311 |
+
else:
|
312 |
+
self.upsamplers = None
|
313 |
+
|
314 |
+
def forward(self,
|
315 |
+
hidden_states,
|
316 |
+
res_hidden_states_tuple,
|
317 |
+
temb = None,
|
318 |
+
encoder_hidden_states = None,
|
319 |
+
upsample_size = None
|
320 |
+
):
|
321 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
322 |
+
# pop res hidden states
|
323 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
324 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
325 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
326 |
+
hidden_states = resnet(hidden_states, temb)
|
327 |
+
hidden_states = attn(hidden_states, encoder_hidden_states = encoder_hidden_states).sample
|
328 |
+
|
329 |
+
if self.upsamplers is not None:
|
330 |
+
for upsampler in self.upsamplers:
|
331 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
332 |
+
|
333 |
+
return hidden_states
|
334 |
+
|
335 |
+
|
336 |
+
class UpBlock2D(nn.Module):
|
337 |
+
def __init__(self,
|
338 |
+
in_channels: int,
|
339 |
+
prev_output_channel: int,
|
340 |
+
out_channels: int,
|
341 |
+
temb_channels: int,
|
342 |
+
dropout: float = 0.0,
|
343 |
+
num_layers: int = 1,
|
344 |
+
resnet_eps: float = 1e-6,
|
345 |
+
resnet_time_scale_shift: str = "default",
|
346 |
+
resnet_act_fn: str = "swish",
|
347 |
+
resnet_groups: int = 32,
|
348 |
+
resnet_pre_norm: bool = True,
|
349 |
+
output_scale_factor: float = 1.0,
|
350 |
+
add_upsample: bool = True
|
351 |
+
) -> None:
|
352 |
+
super().__init__()
|
353 |
+
resnets = []
|
354 |
+
|
355 |
+
for i in range(num_layers):
|
356 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
357 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
358 |
+
|
359 |
+
resnets.append(
|
360 |
+
ResnetBlockPseudo3D(
|
361 |
+
in_channels = resnet_in_channels + res_skip_channels,
|
362 |
+
out_channels = out_channels,
|
363 |
+
temb_channels = temb_channels,
|
364 |
+
eps = resnet_eps,
|
365 |
+
groups = resnet_groups,
|
366 |
+
dropout = dropout,
|
367 |
+
time_embedding_norm = resnet_time_scale_shift,
|
368 |
+
#non_linearity = resnet_act_fn,
|
369 |
+
output_scale_factor = output_scale_factor,
|
370 |
+
pre_norm = resnet_pre_norm
|
371 |
+
)
|
372 |
+
)
|
373 |
+
|
374 |
+
self.resnets = nn.ModuleList(resnets)
|
375 |
+
|
376 |
+
if add_upsample:
|
377 |
+
self.upsamplers = nn.ModuleList([
|
378 |
+
Upsample2D(
|
379 |
+
out_channels,
|
380 |
+
use_conv = True,
|
381 |
+
out_channels = out_channels
|
382 |
+
)
|
383 |
+
])
|
384 |
+
else:
|
385 |
+
self.upsamplers = None
|
386 |
+
|
387 |
+
|
388 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb = None, upsample_size = None):
|
389 |
+
for resnet in self.resnets:
|
390 |
+
# pop res hidden states
|
391 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
392 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
393 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
394 |
+
hidden_states = resnet(hidden_states, temb)
|
395 |
+
|
396 |
+
if self.upsamplers is not None:
|
397 |
+
for upsampler in self.upsamplers:
|
398 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
399 |
+
|
400 |
+
return hidden_states
|
401 |
+
|
402 |
+
|
403 |
+
def get_down_block(
|
404 |
+
down_block_type: str,
|
405 |
+
num_layers: int,
|
406 |
+
in_channels: int,
|
407 |
+
out_channels: int,
|
408 |
+
temb_channels: int,
|
409 |
+
add_downsample: bool,
|
410 |
+
resnet_eps: float,
|
411 |
+
resnet_act_fn: str,
|
412 |
+
attn_num_head_channels: int,
|
413 |
+
resnet_groups: Optional[int] = None,
|
414 |
+
cross_attention_dim: Optional[int] = None,
|
415 |
+
downsample_padding: Optional[int] = None,
|
416 |
+
) -> Union[DownBlock2D, CrossAttnDownBlock2D]:
|
417 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
418 |
+
if down_block_type == "DownBlock2D":
|
419 |
+
return DownBlock2D(
|
420 |
+
num_layers = num_layers,
|
421 |
+
in_channels = in_channels,
|
422 |
+
out_channels = out_channels,
|
423 |
+
temb_channels = temb_channels,
|
424 |
+
add_downsample = add_downsample,
|
425 |
+
resnet_eps = resnet_eps,
|
426 |
+
resnet_act_fn = resnet_act_fn,
|
427 |
+
resnet_groups = resnet_groups,
|
428 |
+
downsample_padding = downsample_padding
|
429 |
+
)
|
430 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
431 |
+
if cross_attention_dim is None:
|
432 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
|
433 |
+
return CrossAttnDownBlock2D(
|
434 |
+
num_layers = num_layers,
|
435 |
+
in_channels = in_channels,
|
436 |
+
out_channels = out_channels,
|
437 |
+
temb_channels = temb_channels,
|
438 |
+
add_downsample = add_downsample,
|
439 |
+
resnet_eps = resnet_eps,
|
440 |
+
resnet_act_fn = resnet_act_fn,
|
441 |
+
resnet_groups = resnet_groups,
|
442 |
+
downsample_padding = downsample_padding,
|
443 |
+
cross_attention_dim = cross_attention_dim,
|
444 |
+
attn_num_head_channels = attn_num_head_channels
|
445 |
+
)
|
446 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
447 |
+
|
448 |
+
|
449 |
+
def get_up_block(
|
450 |
+
up_block_type: str,
|
451 |
+
num_layers,
|
452 |
+
in_channels,
|
453 |
+
out_channels,
|
454 |
+
prev_output_channel,
|
455 |
+
temb_channels,
|
456 |
+
add_upsample,
|
457 |
+
resnet_eps,
|
458 |
+
resnet_act_fn,
|
459 |
+
attn_num_head_channels,
|
460 |
+
resnet_groups = None,
|
461 |
+
cross_attention_dim = None,
|
462 |
+
) -> Union[UpBlock2D, CrossAttnUpBlock2D]:
|
463 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
464 |
+
if up_block_type == "UpBlock2D":
|
465 |
+
return UpBlock2D(
|
466 |
+
num_layers = num_layers,
|
467 |
+
in_channels = in_channels,
|
468 |
+
out_channels = out_channels,
|
469 |
+
prev_output_channel = prev_output_channel,
|
470 |
+
temb_channels = temb_channels,
|
471 |
+
add_upsample = add_upsample,
|
472 |
+
resnet_eps = resnet_eps,
|
473 |
+
resnet_act_fn = resnet_act_fn,
|
474 |
+
resnet_groups = resnet_groups
|
475 |
+
)
|
476 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
477 |
+
if cross_attention_dim is None:
|
478 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
479 |
+
return CrossAttnUpBlock2D(
|
480 |
+
num_layers = num_layers,
|
481 |
+
in_channels = in_channels,
|
482 |
+
out_channels = out_channels,
|
483 |
+
prev_output_channel = prev_output_channel,
|
484 |
+
temb_channels = temb_channels,
|
485 |
+
add_upsample = add_upsample,
|
486 |
+
resnet_eps = resnet_eps,
|
487 |
+
resnet_act_fn = resnet_act_fn,
|
488 |
+
resnet_groups = resnet_groups,
|
489 |
+
cross_attention_dim = cross_attention_dim,
|
490 |
+
attn_num_head_channels = attn_num_head_channels
|
491 |
+
)
|
492 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
493 |
+
|
makeavid_sd/makeavid_sd/torch_impl/torch_unet_pseudo3d_condition.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from torch_embeddings import TimestepEmbedding, Timesteps
|
8 |
+
from torch_unet_pseudo3d_blocks import (
|
9 |
+
UNetMidBlock2DCrossAttn,
|
10 |
+
get_down_block,
|
11 |
+
get_up_block,
|
12 |
+
)
|
13 |
+
|
14 |
+
from torch_resnet_pseudo3d import Pseudo3DConv
|
15 |
+
|
16 |
+
class UNetPseudo3DConditionOutput:
|
17 |
+
sample: torch.FloatTensor
|
18 |
+
def __init__(self, sample: torch.FloatTensor) -> None:
|
19 |
+
self.sample = sample
|
20 |
+
|
21 |
+
|
22 |
+
class UNetPseudo3DConditionModel(nn.Module):
|
23 |
+
def __init__(self,
|
24 |
+
sample_size: Optional[int] = None,
|
25 |
+
in_channels: int = 9,
|
26 |
+
out_channels: int = 4,
|
27 |
+
flip_sin_to_cos: bool = True,
|
28 |
+
freq_shift: int = 0,
|
29 |
+
down_block_types: Tuple[str] = (
|
30 |
+
"CrossAttnDownBlock2D",
|
31 |
+
"CrossAttnDownBlock2D",
|
32 |
+
"CrossAttnDownBlock2D",
|
33 |
+
"DownBlock2D",
|
34 |
+
),
|
35 |
+
up_block_types: Tuple[str] = (
|
36 |
+
"UpBlock2D",
|
37 |
+
"CrossAttnUpBlock2D",
|
38 |
+
"CrossAttnUpBlock2D",
|
39 |
+
"CrossAttnUpBlock2D"
|
40 |
+
),
|
41 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
42 |
+
layers_per_block: int = 2,
|
43 |
+
downsample_padding: int = 1,
|
44 |
+
mid_block_scale_factor: float = 1,
|
45 |
+
act_fn: str = "silu",
|
46 |
+
norm_num_groups: int = 32,
|
47 |
+
norm_eps: float = 1e-5,
|
48 |
+
cross_attention_dim: int = 768,
|
49 |
+
attention_head_dim: int = 8,
|
50 |
+
**kwargs
|
51 |
+
) -> None:
|
52 |
+
super().__init__()
|
53 |
+
self.dtype = torch.float32
|
54 |
+
self.sample_size = sample_size
|
55 |
+
time_embed_dim = block_out_channels[0] * 4
|
56 |
+
|
57 |
+
# input
|
58 |
+
self.conv_in = Pseudo3DConv(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
59 |
+
|
60 |
+
# time
|
61 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
62 |
+
timestep_input_dim = block_out_channels[0]
|
63 |
+
|
64 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
65 |
+
|
66 |
+
self.down_blocks = nn.ModuleList([])
|
67 |
+
self.mid_block = None
|
68 |
+
self.up_blocks = nn.ModuleList([])
|
69 |
+
|
70 |
+
# down
|
71 |
+
output_channel = block_out_channels[0]
|
72 |
+
for i, down_block_type in enumerate(down_block_types):
|
73 |
+
input_channel = output_channel
|
74 |
+
output_channel = block_out_channels[i]
|
75 |
+
is_final_block = i == len(block_out_channels) - 1
|
76 |
+
|
77 |
+
down_block = get_down_block(
|
78 |
+
down_block_type,
|
79 |
+
num_layers = layers_per_block,
|
80 |
+
in_channels = input_channel,
|
81 |
+
out_channels = output_channel,
|
82 |
+
temb_channels = time_embed_dim,
|
83 |
+
add_downsample = not is_final_block,
|
84 |
+
resnet_eps = norm_eps,
|
85 |
+
resnet_act_fn = act_fn,
|
86 |
+
resnet_groups = norm_num_groups,
|
87 |
+
cross_attention_dim = cross_attention_dim,
|
88 |
+
attn_num_head_channels = attention_head_dim,
|
89 |
+
downsample_padding = downsample_padding
|
90 |
+
)
|
91 |
+
self.down_blocks.append(down_block)
|
92 |
+
|
93 |
+
# mid
|
94 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
95 |
+
in_channels = block_out_channels[-1],
|
96 |
+
temb_channels = time_embed_dim,
|
97 |
+
resnet_eps = norm_eps,
|
98 |
+
resnet_act_fn = act_fn,
|
99 |
+
output_scale_factor = mid_block_scale_factor,
|
100 |
+
resnet_time_scale_shift = "default",
|
101 |
+
cross_attention_dim = cross_attention_dim,
|
102 |
+
attn_num_head_channels = attention_head_dim,
|
103 |
+
resnet_groups = norm_num_groups
|
104 |
+
)
|
105 |
+
|
106 |
+
# count how many layers upsample the images
|
107 |
+
self.num_upsamplers = 0
|
108 |
+
|
109 |
+
# up
|
110 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
111 |
+
output_channel = reversed_block_out_channels[0]
|
112 |
+
for i, up_block_type in enumerate(up_block_types):
|
113 |
+
is_final_block = i == len(block_out_channels) - 1
|
114 |
+
|
115 |
+
prev_output_channel = output_channel
|
116 |
+
output_channel = reversed_block_out_channels[i]
|
117 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
118 |
+
|
119 |
+
# add upsample block for all BUT final layer
|
120 |
+
if not is_final_block:
|
121 |
+
add_upsample = True
|
122 |
+
self.num_upsamplers += 1
|
123 |
+
else:
|
124 |
+
add_upsample = False
|
125 |
+
|
126 |
+
up_block = get_up_block(
|
127 |
+
up_block_type,
|
128 |
+
num_layers = layers_per_block + 1,
|
129 |
+
in_channels = input_channel,
|
130 |
+
out_channels = output_channel,
|
131 |
+
prev_output_channel = prev_output_channel,
|
132 |
+
temb_channels = time_embed_dim,
|
133 |
+
add_upsample = add_upsample,
|
134 |
+
resnet_eps = norm_eps,
|
135 |
+
resnet_act_fn = act_fn,
|
136 |
+
resnet_groups = norm_num_groups,
|
137 |
+
cross_attention_dim = cross_attention_dim,
|
138 |
+
attn_num_head_channels = attention_head_dim
|
139 |
+
)
|
140 |
+
self.up_blocks.append(up_block)
|
141 |
+
prev_output_channel = output_channel
|
142 |
+
|
143 |
+
# out
|
144 |
+
self.conv_norm_out = nn.GroupNorm(
|
145 |
+
num_channels = block_out_channels[0],
|
146 |
+
num_groups = norm_num_groups,
|
147 |
+
eps = norm_eps
|
148 |
+
)
|
149 |
+
self.conv_act = nn.SiLU()
|
150 |
+
self.conv_out = Pseudo3DConv(block_out_channels[0], out_channels, 3, padding = 1)
|
151 |
+
|
152 |
+
|
153 |
+
def forward(
|
154 |
+
self,
|
155 |
+
sample: torch.FloatTensor,
|
156 |
+
timesteps: Union[torch.Tensor, float, int],
|
157 |
+
encoder_hidden_states: torch.Tensor
|
158 |
+
) -> Union[UNetPseudo3DConditionOutput, Tuple]:
|
159 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
160 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
161 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
162 |
+
# on the fly if necessary.
|
163 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
164 |
+
|
165 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
166 |
+
forward_upsample_size = False
|
167 |
+
upsample_size = None
|
168 |
+
|
169 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
170 |
+
forward_upsample_size = True
|
171 |
+
|
172 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
173 |
+
timesteps = timesteps.expand(sample.shape[0])
|
174 |
+
|
175 |
+
t_emb = self.time_proj(timesteps)
|
176 |
+
|
177 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
178 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
179 |
+
# there might be better ways to encapsulate this.
|
180 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
181 |
+
emb = self.time_embedding(t_emb)
|
182 |
+
|
183 |
+
# 2. pre-process
|
184 |
+
sample = self.conv_in(sample)
|
185 |
+
|
186 |
+
# 3. down
|
187 |
+
down_block_res_samples = (sample,)
|
188 |
+
for downsample_block in self.down_blocks:
|
189 |
+
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
190 |
+
sample, res_samples = downsample_block(
|
191 |
+
hidden_states = sample,
|
192 |
+
temb = emb,
|
193 |
+
encoder_hidden_states = encoder_hidden_states,
|
194 |
+
)
|
195 |
+
else:
|
196 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
197 |
+
|
198 |
+
down_block_res_samples += res_samples
|
199 |
+
|
200 |
+
# 4. mid
|
201 |
+
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
202 |
+
|
203 |
+
# 5. up
|
204 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
205 |
+
is_final_block = i == len(self.up_blocks) - 1
|
206 |
+
|
207 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
208 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
209 |
+
|
210 |
+
# if we have not reached the final block and need to forward the
|
211 |
+
# upsample size, we do it here
|
212 |
+
if not is_final_block and forward_upsample_size:
|
213 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
214 |
+
|
215 |
+
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
|
216 |
+
sample = upsample_block(
|
217 |
+
hidden_states = sample,
|
218 |
+
temb = emb,
|
219 |
+
res_hidden_states_tuple = res_samples,
|
220 |
+
encoder_hidden_states = encoder_hidden_states,
|
221 |
+
upsample_size = upsample_size,
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
sample = upsample_block(
|
225 |
+
hidden_states = sample,
|
226 |
+
temb = emb,
|
227 |
+
res_hidden_states_tuple = res_samples,
|
228 |
+
upsample_size = upsample_size
|
229 |
+
)
|
230 |
+
# 6. post-process
|
231 |
+
sample = self.conv_norm_out(sample)
|
232 |
+
sample = self.conv_act(sample)
|
233 |
+
sample = self.conv_out(sample)
|
234 |
+
|
235 |
+
return UNetPseudo3DConditionOutput(sample = sample)
|
makeavid_sd/requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torch_xla
|
makeavid_sd/setup.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
setup(
|
3 |
+
name = 'makeavid_sd',
|
4 |
+
version = '0.1.0',
|
5 |
+
description = 'makeavid sd',
|
6 |
+
author = 'Lopho',
|
7 |
+
author_email = '[email protected]',
|
8 |
+
platforms = ['any'],
|
9 |
+
license = 'GNU Affero General Public License v3',
|
10 |
+
url = 'http://github.com/lopho/makeavid-sd-tpu'
|
11 |
+
)
|
makeavid_sd/trainer_xla.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['PJRT_DEVICE'] = 'TPU'
|
3 |
+
|
4 |
+
from tqdm.auto import tqdm
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from torch_xla.core import xla_model
|
8 |
+
from diffusers import UNetPseudo3DConditionModel
|
9 |
+
from dataset import load_dataset
|
10 |
+
|
11 |
+
|
12 |
+
class TempoTrainerXLA:
|
13 |
+
def __init__(self,
|
14 |
+
pretrained: str = 'lxj616/make-a-stable-diffusion-video-timelapse',
|
15 |
+
lr: float = 1e-4,
|
16 |
+
dtype: torch.dtype = torch.float32,
|
17 |
+
) -> None:
|
18 |
+
self.dtype = dtype
|
19 |
+
self.device: torch.device = xla_model.xla_device(0)
|
20 |
+
unet: UNetPseudo3DConditionModel = UNetPseudo3DConditionModel.from_pretrained(
|
21 |
+
pretrained,
|
22 |
+
subfolder = 'unet'
|
23 |
+
).to(dtype = dtype, memory_format = torch.contiguous_format)
|
24 |
+
unfreeze_all: bool = False
|
25 |
+
unet = unet.train()
|
26 |
+
if not unfreeze_all:
|
27 |
+
unet.requires_grad_(False)
|
28 |
+
for name, param in unet.named_parameters():
|
29 |
+
if 'temporal_conv' in name:
|
30 |
+
param.requires_grad_(True)
|
31 |
+
for block in [*unet.down_blocks, unet.mid_block, *unet.up_blocks]:
|
32 |
+
if hasattr(block, 'attentions') and block.attentions is not None:
|
33 |
+
for attn_block in block.attentions:
|
34 |
+
for transformer_block in attn_block.transformer_blocks:
|
35 |
+
transformer_block.requires_grad_(False)
|
36 |
+
transformer_block.attn_temporal.requires_grad_(True)
|
37 |
+
transformer_block.norm_temporal.requires_grad_(True)
|
38 |
+
else:
|
39 |
+
unet.requires_grad_(True)
|
40 |
+
self.model: UNetPseudo3DConditionModel = unet.to(device = self.device)
|
41 |
+
#self.model = torch.compile(self.model, backend = 'aot_torchxla_trace_once')
|
42 |
+
self.params = lambda: filter(lambda p: p.requires_grad, self.model.parameters())
|
43 |
+
self.optim: torch.optim.Optimizer = torch.optim.AdamW(self.params(), lr = lr)
|
44 |
+
def lr_warmup(warmup_steps: int = 0):
|
45 |
+
def lambda_lr(step: int) -> float:
|
46 |
+
if step < warmup_steps:
|
47 |
+
return step / warmup_steps
|
48 |
+
else:
|
49 |
+
return 1.0
|
50 |
+
return lambda_lr
|
51 |
+
self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optim, lr_lambda = lr_warmup(warmup_steps = 60), last_epoch = -1)
|
52 |
+
|
53 |
+
@torch.no_grad()
|
54 |
+
def train(self, dataloader: DataLoader, epochs: int = 1, log_every: int = 1, save_every: int = 1000) -> None:
|
55 |
+
# 'latent_model_input'
|
56 |
+
# 'encoder_hidden_states'
|
57 |
+
# 'timesteps'
|
58 |
+
# 'noise'
|
59 |
+
global_step: int = 0
|
60 |
+
for epoch in range(epochs):
|
61 |
+
pbar = tqdm(dataloader, dynamic_ncols = True, smoothing = 0.01)
|
62 |
+
for b in pbar:
|
63 |
+
latent_model_input: torch.Tensor = b['latent_model_input'].to(device = self.device)
|
64 |
+
encoder_hidden_states: torch.Tensor = b['encoder_hidden_states'].to(device = self.device)
|
65 |
+
timesteps: torch.Tensor = b['timesteps'].to(device = self.device)
|
66 |
+
noise: torch.Tensor = b['noise'].to(device = self.device)
|
67 |
+
with torch.enable_grad():
|
68 |
+
self.optim.zero_grad(set_to_none = True)
|
69 |
+
y = self.model(latent_model_input, timesteps, encoder_hidden_states).sample
|
70 |
+
loss = torch.nn.functional.mse_loss(noise, y)
|
71 |
+
loss.backward()
|
72 |
+
self.optim.step()
|
73 |
+
self.scheduler.step()
|
74 |
+
xla_model.mark_step()
|
75 |
+
if global_step % log_every == 0:
|
76 |
+
pbar.set_postfix({ 'loss': loss.detach().item(), 'epoch': epoch })
|
77 |
+
|
78 |
+
def main():
|
79 |
+
pretrained: str = 'lxj616/make-a-stable-diffusion-video-timelapse'
|
80 |
+
dataset_path: str = './storage/dataset/tempofunk'
|
81 |
+
dtype: torch.dtype = torch.bfloat16
|
82 |
+
trainer = TempoTrainerXLA(
|
83 |
+
pretrained = pretrained,
|
84 |
+
lr = 1e-5,
|
85 |
+
dtype = dtype
|
86 |
+
)
|
87 |
+
dataloader: DataLoader = load_dataset(
|
88 |
+
dataset_path = dataset_path,
|
89 |
+
pretrained = pretrained,
|
90 |
+
batch_size = 1,
|
91 |
+
num_frames = 10,
|
92 |
+
num_workers = 1,
|
93 |
+
dtype = dtype
|
94 |
+
)
|
95 |
+
trainer.train(
|
96 |
+
dataloader = dataloader,
|
97 |
+
epochs = 1000,
|
98 |
+
log_every = 1,
|
99 |
+
save_every = 1000
|
100 |
+
)
|
101 |
+
|
102 |
+
if __name__ == '__main__':
|
103 |
+
main()
|
104 |
+
|