artificialguybr commited on
Commit
eadd7b4
·
1 Parent(s): f8cfb21
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +29 -0
  2. LICENSE +661 -0
  3. app/app_pixart_dmd.py +343 -0
  4. app/app_pixart_sigma.py +420 -0
  5. asset/PixArt.svg +96 -0
  6. asset/docs/pixart.md +112 -0
  7. asset/examples.py +36 -0
  8. asset/logo-sigma.png +0 -0
  9. asset/logo.png +0 -0
  10. asset/samples.txt +120 -0
  11. configs/PixArt_xl2_internal.py +79 -0
  12. configs/pixart_alpha_config/PixArt_xl2_img1024_dreambooth.py +30 -0
  13. configs/pixart_alpha_config/PixArt_xl2_img1024_internal.py +29 -0
  14. configs/pixart_alpha_config/PixArt_xl2_img1024_internalms.py +32 -0
  15. configs/pixart_alpha_config/PixArt_xl2_img256_internal.py +27 -0
  16. configs/pixart_alpha_config/PixArt_xl2_img512_internal.py +29 -0
  17. configs/pixart_alpha_config/PixArt_xl2_img512_internalms.py +31 -0
  18. configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms.py +46 -0
  19. configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms_kvcompress.py +51 -0
  20. configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_lcm.py +52 -0
  21. configs/pixart_sigma_config/PixArt_sigma_xl2_img256_internal.py +41 -0
  22. configs/pixart_sigma_config/PixArt_sigma_xl2_img2K_internalms_kvcompress.py +49 -0
  23. configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py +43 -0
  24. diffusion/__init__.py +8 -0
  25. diffusion/data/__init__.py +2 -0
  26. diffusion/data/builder.py +50 -0
  27. diffusion/data/datasets/InternalData.py +312 -0
  28. diffusion/data/datasets/InternalData_ms.py +336 -0
  29. diffusion/data/datasets/__init__.py +3 -0
  30. diffusion/data/datasets/utils.py +134 -0
  31. diffusion/data/transforms.py +30 -0
  32. diffusion/dpm_solver.py +36 -0
  33. diffusion/iddpm.py +53 -0
  34. diffusion/lcm_scheduler.py +459 -0
  35. diffusion/model/__init__.py +1 -0
  36. diffusion/model/builder.py +14 -0
  37. diffusion/model/diffusion_utils.py +88 -0
  38. diffusion/model/dpm_solver.py +1337 -0
  39. diffusion/model/edm_sample.py +171 -0
  40. diffusion/model/gaussian_diffusion.py +1041 -0
  41. diffusion/model/llava/__init__.py +1 -0
  42. diffusion/model/llava/llava_mpt.py +280 -0
  43. diffusion/model/llava/mpt/attention.py +276 -0
  44. diffusion/model/llava/mpt/blocks.py +41 -0
  45. diffusion/model/llava/mpt/configuration_mpt.py +118 -0
  46. diffusion/model/llava/mpt/modeling_mpt.py +308 -0
  47. diffusion/model/llava/mpt/norm.py +56 -0
  48. diffusion/model/llava/mpt/param_init_fns.py +181 -0
  49. diffusion/model/nets/PixArt.py +315 -0
  50. diffusion/model/nets/PixArtMS.py +293 -0
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a sample Dockefile that builds a runtime container and runs the sample Gradio app.
2
+ # Note, you must pass in the pretrained models when you run the container.
3
+
4
+ FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04
5
+
6
+ WORKDIR /workspace
7
+
8
+ RUN apt-get update && \
9
+ apt-get install -y \
10
+ git \
11
+ python3 \
12
+ python-is-python3 \
13
+ python3-pip \
14
+ python3.10-venv \
15
+ libgl1 \
16
+ libgl1-mesa-glx \
17
+ libglib2.0-0 \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ ADD requirements.txt .
21
+
22
+ RUN pip install -r requirements.txt
23
+
24
+ ADD . .
25
+
26
+ RUN chmod a+x docker-entrypoint.sh
27
+
28
+ ENV DEMO_PORT=12345
29
+ ENTRYPOINT [ "/workspace/docker-entrypoint.sh" ]
LICENSE ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
app/app_pixart_dmd.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ from __future__ import annotations
3
+ import argparse
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ current_file_path = Path(__file__).resolve()
9
+ sys.path.insert(0, str(current_file_path.parent.parent))
10
+ import random
11
+ import gradio as gr
12
+ import numpy as np
13
+ import uuid
14
+ from diffusers import ConsistencyDecoderVAE, PixArtAlphaPipeline, Transformer2DModel, DDPMScheduler
15
+ import torch
16
+ from typing import Tuple
17
+ from datetime import datetime
18
+ from scripts.diffusers_patches import pipeline_pixart_alpha_call
19
+
20
+ DESCRIPTION = """![Logo](https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma-project/master/static/images/logo-sigma.png)
21
+ # PixArt-Alpha One Step 512px
22
+ #### [PixArt-Alpha-DMD 512px](https://github.com/PixArt-alpha/PixArt-sigma) is a transformer-based text-to-image diffusion system trained on text embeddings from T5. This demo uses the [PixArt-Alpha-DMD-XL-2-512x512](https://huggingface.co/PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512) checkpoint.
23
+ #### English prompts ONLY; 提示词仅限英文
24
+ ### <span style='color: red;'>We only use 8 V100 GPUs for PixArt-DMD training. There's still plenty of room for improvement.
25
+ """
26
+ if not torch.cuda.is_available():
27
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
28
+
29
+ MAX_SEED = np.iinfo(np.int32).max
30
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
31
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "6000"))
32
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
33
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
34
+ PORT = int(os.getenv("DEMO_PORT", "15432"))
35
+
36
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
+
38
+ style_list = [
39
+ {
40
+ "name": "(No style)",
41
+ "prompt": "{prompt}",
42
+ "negative_prompt": "",
43
+ },
44
+ {
45
+ "name": "Cinematic",
46
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
47
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
48
+ },
49
+ {
50
+ "name": "Photographic",
51
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
52
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
53
+ },
54
+ {
55
+ "name": "Anime",
56
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
57
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
58
+ },
59
+ {
60
+ "name": "Manga",
61
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
62
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
63
+ },
64
+ {
65
+ "name": "Digital Art",
66
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
67
+ "negative_prompt": "photo, photorealistic, realism, ugly",
68
+ },
69
+ {
70
+ "name": "Pixel art",
71
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
72
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
73
+ },
74
+ {
75
+ "name": "Fantasy art",
76
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
77
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
78
+ },
79
+ {
80
+ "name": "Neonpunk",
81
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
82
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
83
+ },
84
+ {
85
+ "name": "3D Model",
86
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
87
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
88
+ },
89
+ ]
90
+
91
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
92
+ STYLE_NAMES = list(styles.keys())
93
+ DEFAULT_STYLE_NAME = "(No style)"
94
+ SCHEDULE_NAME = ["PixArt-DMD"]
95
+ DEFAULT_SCHEDULE_NAME = "PixArt-DMD"
96
+ NUM_IMAGES_PER_PROMPT = 2
97
+
98
+
99
+ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
100
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
101
+ if not negative:
102
+ negative = ""
103
+ return p.replace("{prompt}", positive), n + negative
104
+
105
+
106
+ def get_args():
107
+ parser = argparse.ArgumentParser()
108
+ parser.add_argument('--model_path', default="PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", type=str)
109
+ parser.add_argument(
110
+ '--pipeline_load_from', default="PixArt-alpha/PixArt-XL-2-1024-MS", type=str,
111
+ help="Download for loading text_encoder, "
112
+ "tokenizer and vae from https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS")
113
+ parser.add_argument('--T5_token_max_length', default=120, type=int, help='max length of tokens for T5')
114
+ return parser.parse_args()
115
+
116
+
117
+ args = get_args()
118
+
119
+ if torch.cuda.is_available():
120
+ weight_dtype = torch.float16
121
+ T5_token_max_length = args.T5_token_max_length
122
+ model_path = args.model_path
123
+ if 'Sigma' in args.model_path:
124
+ T5_token_max_length = 300
125
+
126
+ pipe = PixArtAlphaPipeline.from_pretrained(
127
+ args.pipeline_load_from,
128
+ transformer=None,
129
+ torch_dtype=weight_dtype,
130
+ )
131
+ pipe.transformer = Transformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=weight_dtype)
132
+ pipe.scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
133
+
134
+ print("Changing __call__ method of PixArtAlphaPipeline using scripts.diffusers_patches.pipeline_pixart_alpha_call")
135
+ setattr(PixArtAlphaPipeline, '__call__', pipeline_pixart_alpha_call)
136
+
137
+ if os.getenv('CONSISTENCY_DECODER', False):
138
+ print("Using DALL-E 3 Consistency Decoder")
139
+ pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
140
+
141
+ if ENABLE_CPU_OFFLOAD:
142
+ pipe.enable_model_cpu_offload()
143
+ else:
144
+ pipe.to(device)
145
+ print("Loaded on Device!")
146
+
147
+ # speed-up T5
148
+ pipe.text_encoder.to_bettertransformer()
149
+
150
+ if USE_TORCH_COMPILE:
151
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
152
+ print("Model Compiled!")
153
+
154
+
155
+ def save_image(img, seed=''):
156
+ unique_name = f"{str(uuid.uuid4())}_{seed}.png"
157
+ save_path = os.path.join(f'output/online_demo_img/{datetime.now().date()}')
158
+ os.umask(0o000) # file permission: 666; dir permission: 777
159
+ os.makedirs(save_path, exist_ok=True)
160
+ unique_name = os.path.join(save_path, unique_name)
161
+ img.save(unique_name)
162
+ return unique_name
163
+
164
+
165
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
166
+ if randomize_seed:
167
+ seed = random.randint(0, MAX_SEED)
168
+ return seed
169
+
170
+
171
+ @torch.no_grad()
172
+ @torch.inference_mode()
173
+ def generate(
174
+ prompt: str,
175
+ negative_prompt: str = "",
176
+ style: str = DEFAULT_STYLE_NAME,
177
+ use_negative_prompt: bool = False,
178
+ num_imgs: int = 1,
179
+ seed: int = 0,
180
+ width: int = 1024,
181
+ height: int = 1024,
182
+ randomize_seed: bool = False,
183
+ use_resolution_binning: bool = True,
184
+ progress=gr.Progress(track_tqdm=True),
185
+ ):
186
+ seed = int(randomize_seed_fn(seed, randomize_seed))
187
+ generator = torch.Generator().manual_seed(seed)
188
+ print(f"{PORT}: {model_path}")
189
+ print(prompt)
190
+
191
+ if not use_negative_prompt:
192
+ negative_prompt = None # type: ignore
193
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
194
+
195
+ images = pipe(
196
+ prompt=prompt,
197
+ timesteps=[400],
198
+ width=width,
199
+ height=height,
200
+ guidance_scale=1,
201
+ num_inference_steps=1,
202
+ generator=generator,
203
+ num_images_per_prompt=num_imgs,
204
+ use_resolution_binning=use_resolution_binning,
205
+ output_type="pil",
206
+ max_sequence_length=T5_token_max_length,
207
+ ).images
208
+
209
+ image_paths = [save_image(img, seed) for img in images]
210
+ print(image_paths)
211
+ return image_paths, seed
212
+
213
+
214
+ examples = [
215
+ "A small cactus with a happy face in the Sahara desert.",
216
+ "an astronaut sitting in a diner, eating fries, cinematic, analog film",
217
+ "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
218
+ "stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.",
219
+ "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
220
+ "beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background",
221
+ "Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism",
222
+ "anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur",
223
+ "The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8",
224
+ ]
225
+
226
+ with gr.Blocks(css="scripts/style.css") as demo:
227
+ gr.Markdown(DESCRIPTION)
228
+ gr.DuplicateButton(
229
+ value="Duplicate Space for private use",
230
+ elem_id="duplicate-button",
231
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
232
+ )
233
+ with gr.Row(equal_height=False):
234
+ with gr.Group():
235
+ with gr.Row():
236
+ prompt = gr.Text(
237
+ label="Prompt",
238
+ show_label=False,
239
+ max_lines=1,
240
+ placeholder="Enter your prompt",
241
+ container=False,
242
+ )
243
+ run_button = gr.Button("Run", scale=0)
244
+ result = gr.Gallery(label="Result", show_label=False)
245
+ # with gr.Accordion("Advanced options", open=False):
246
+ with gr.Group():
247
+ with gr.Row():
248
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
249
+ with gr.Row(visible=True):
250
+ schedule = gr.Radio(
251
+ show_label=True,
252
+ container=True,
253
+ interactive=True,
254
+ choices=SCHEDULE_NAME,
255
+ value=DEFAULT_SCHEDULE_NAME,
256
+ label="Sampler Schedule",
257
+ visible=True,
258
+ )
259
+ num_imgs = gr.Slider(
260
+ label="Num Images",
261
+ minimum=1,
262
+ maximum=8,
263
+ step=1,
264
+ value=NUM_IMAGES_PER_PROMPT,
265
+ )
266
+ style_selection = gr.Radio(
267
+ show_label=True,
268
+ container=True,
269
+ interactive=True,
270
+ choices=STYLE_NAMES,
271
+ value=DEFAULT_STYLE_NAME,
272
+ label="Image Style",
273
+ )
274
+ negative_prompt = gr.Text(
275
+ label="Negative prompt",
276
+ max_lines=1,
277
+ placeholder="Enter a negative prompt",
278
+ visible=True,
279
+ )
280
+ seed = gr.Slider(
281
+ label="Seed",
282
+ minimum=0,
283
+ maximum=MAX_SEED,
284
+ step=1,
285
+ value=0,
286
+ )
287
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
288
+ with gr.Row(visible=True):
289
+ width = gr.Slider(
290
+ label="Width",
291
+ minimum=256,
292
+ maximum=MAX_IMAGE_SIZE,
293
+ step=32,
294
+ value=512,
295
+ )
296
+ height = gr.Slider(
297
+ label="Height",
298
+ minimum=256,
299
+ maximum=MAX_IMAGE_SIZE,
300
+ step=32,
301
+ value=512,
302
+ )
303
+
304
+ gr.Examples(
305
+ examples=examples,
306
+ inputs=prompt,
307
+ outputs=[result, seed],
308
+ fn=generate,
309
+ cache_examples=CACHE_EXAMPLES,
310
+ )
311
+
312
+ use_negative_prompt.change(
313
+ fn=lambda x: gr.update(visible=x),
314
+ inputs=use_negative_prompt,
315
+ outputs=negative_prompt,
316
+ api_name=False,
317
+ )
318
+
319
+ gr.on(
320
+ triggers=[
321
+ prompt.submit,
322
+ negative_prompt.submit,
323
+ run_button.click,
324
+ ],
325
+ fn=generate,
326
+ inputs=[
327
+ prompt,
328
+ negative_prompt,
329
+ style_selection,
330
+ use_negative_prompt,
331
+ num_imgs,
332
+ seed,
333
+ width,
334
+ height,
335
+ schedule,
336
+ randomize_seed,
337
+ ],
338
+ outputs=[result, seed],
339
+ api_name="run",
340
+ )
341
+
342
+ if __name__ == "__main__":
343
+ demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=PORT, debug=True)
app/app_pixart_sigma.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ from __future__ import annotations
3
+ import argparse
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+ current_file_path = Path(__file__).resolve()
8
+ sys.path.insert(0, str(current_file_path.parent.parent))
9
+ import random
10
+ import gradio as gr
11
+ import numpy as np
12
+ import uuid
13
+ from diffusers import ConsistencyDecoderVAE, DPMSolverMultistepScheduler, Transformer2DModel, AutoencoderKL
14
+ import torch
15
+ from typing import Tuple
16
+ from datetime import datetime
17
+ from diffusion.sa_solver_diffusers import SASolverScheduler
18
+ from peft import PeftModel
19
+ from scripts.diffusers_patches import pixart_sigma_init_patched_inputs, PixArtSigmaPipeline
20
+
21
+
22
+ DESCRIPTION = """![Logo](https://raw.githubusercontent.com/PixArt-alpha/PixArt-sigma-project/master/static/images/logo-sigma.png)
23
+ # PixArt-Sigma 1024px
24
+ #### [PixArt-Sigma 1024px](https://github.com/PixArt-alpha/PixArt-sigma) is a transformer-based text-to-image diffusion system trained on text embeddings from T5. This demo uses the [PixArt-alpha/PixArt-XL-2-1024-MS](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) checkpoint.
25
+ #### English prompts ONLY; 提示词仅限英文
26
+ ### <span style='color: red;'>You may change the DPM-Solver inference steps from 14 to 20, or DPM-Solver Guidance scale from 4.5 to 3.5 if you didn't get satisfied results.
27
+ """
28
+ if not torch.cuda.is_available():
29
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
30
+
31
+ MAX_SEED = np.iinfo(np.int32).max
32
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
33
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "6000"))
34
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
35
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
36
+ PORT = int(os.getenv("DEMO_PORT", "15432"))
37
+
38
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
39
+
40
+
41
+ style_list = [
42
+ {
43
+ "name": "(No style)",
44
+ "prompt": "{prompt}",
45
+ "negative_prompt": "",
46
+ },
47
+ {
48
+ "name": "Cinematic",
49
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
50
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
51
+ },
52
+ {
53
+ "name": "Photographic",
54
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
55
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
56
+ },
57
+ {
58
+ "name": "Anime",
59
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
60
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
61
+ },
62
+ {
63
+ "name": "Manga",
64
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
65
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
66
+ },
67
+ {
68
+ "name": "Digital Art",
69
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
70
+ "negative_prompt": "photo, photorealistic, realism, ugly",
71
+ },
72
+ {
73
+ "name": "Pixel art",
74
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
75
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
76
+ },
77
+ {
78
+ "name": "Fantasy art",
79
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
80
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
81
+ },
82
+ {
83
+ "name": "Neonpunk",
84
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
85
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
86
+ },
87
+ {
88
+ "name": "3D Model",
89
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
90
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
91
+ },
92
+ ]
93
+
94
+
95
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
96
+ STYLE_NAMES = list(styles.keys())
97
+ DEFAULT_STYLE_NAME = "(No style)"
98
+ SCHEDULE_NAME = ["DPM-Solver", "SA-Solver"]
99
+ DEFAULT_SCHEDULE_NAME = "DPM-Solver"
100
+ NUM_IMAGES_PER_PROMPT = 1
101
+
102
+ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
103
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
104
+ if not negative:
105
+ negative = ""
106
+ return p.replace("{prompt}", positive), n + negative
107
+
108
+
109
+ def get_args():
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument('--is_lora', action='store_true', help='enable lora ckpt loading')
112
+ parser.add_argument('--repo_id', default="PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", type=str)
113
+ parser.add_argument('--lora_repo_id', default=None, type=str)
114
+ parser.add_argument('--model_path', default=None, type=str)
115
+ parser.add_argument(
116
+ '--pipeline_load_from', default="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", type=str,
117
+ help="Download for loading text_encoder, tokenizer and vae "
118
+ "from https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS")
119
+ parser.add_argument('--T5_token_max_length', default=120, type=int, help='max length of tokens for T5')
120
+ return parser.parse_args()
121
+
122
+
123
+ args = get_args()
124
+
125
+ if torch.cuda.is_available():
126
+ weight_dtype = torch.float16
127
+ T5_token_max_length = args.T5_token_max_length
128
+ model_path = args.model_path
129
+ if 'Sigma' in args.model_path:
130
+ T5_token_max_length = 300
131
+
132
+ # tmp patches for diffusers PixArtSigmaPipeline Implementation
133
+ print(
134
+ "Changing _init_patched_inputs method of diffusers.models.Transformer2DModel "
135
+ "using scripts.diffusers_patches.pixart_sigma_init_patched_inputs")
136
+ setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs)
137
+
138
+ if not args.is_lora:
139
+ transformer = Transformer2DModel.from_pretrained(
140
+ model_path,
141
+ subfolder='transformer',
142
+ torch_dtype=weight_dtype,
143
+ )
144
+ pipe = PixArtSigmaPipeline.from_pretrained(
145
+ args.pipeline_load_from,
146
+ transformer=transformer,
147
+ torch_dtype=weight_dtype,
148
+ use_safetensors=True,
149
+ )
150
+ else:
151
+ assert args.lora_repo_id is not None
152
+ transformer = Transformer2DModel.from_pretrained(args.repo_id, subfolder="transformer", torch_dtype=torch.float16)
153
+ transformer = PeftModel.from_pretrained(transformer, args.lora_repo_id)
154
+ pipe = PixArtSigmaPipeline.from_pretrained(
155
+ args.repo_id,
156
+ transformer=transformer,
157
+ torch_dtype=torch.float16,
158
+ use_safetensors=True,
159
+ )
160
+ del transformer
161
+
162
+
163
+ if os.getenv('CONSISTENCY_DECODER', False):
164
+ print("Using DALL-E 3 Consistency Decoder")
165
+ pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
166
+
167
+ if ENABLE_CPU_OFFLOAD:
168
+ pipe.enable_model_cpu_offload()
169
+ else:
170
+ pipe.to(device)
171
+ print("Loaded on Device!")
172
+
173
+ # speed-up T5
174
+ pipe.text_encoder.to_bettertransformer()
175
+
176
+ if USE_TORCH_COMPILE:
177
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
178
+ print("Model Compiled!")
179
+
180
+
181
+ def save_image(img, seed=''):
182
+ unique_name = f"{str(uuid.uuid4())}_{seed}.png"
183
+ save_path = os.path.join(f'output/online_demo_img/{datetime.now().date()}')
184
+ os.umask(0o000) # file permission: 666; dir permission: 777
185
+ os.makedirs(save_path, exist_ok=True)
186
+ unique_name = os.path.join(save_path, unique_name)
187
+ img.save(unique_name)
188
+ return unique_name
189
+
190
+
191
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
192
+ if randomize_seed:
193
+ seed = random.randint(0, MAX_SEED)
194
+ return seed
195
+
196
+
197
+ @torch.no_grad()
198
+ @torch.inference_mode()
199
+ def generate(
200
+ prompt: str,
201
+ negative_prompt: str = "",
202
+ style: str = DEFAULT_STYLE_NAME,
203
+ use_negative_prompt: bool = False,
204
+ num_imgs: int = 1,
205
+ seed: int = 0,
206
+ width: int = 1024,
207
+ height: int = 1024,
208
+ schedule: str = 'DPM-Solver',
209
+ dpms_guidance_scale: float = 4.5,
210
+ sas_guidance_scale: float = 3,
211
+ dpms_inference_steps: int = 20,
212
+ sas_inference_steps: int = 25,
213
+ randomize_seed: bool = False,
214
+ use_resolution_binning: bool = True,
215
+ progress=gr.Progress(track_tqdm=True),
216
+ ):
217
+ seed = int(randomize_seed_fn(seed, randomize_seed))
218
+ generator = torch.Generator().manual_seed(seed)
219
+ print(f"{PORT}: {model_path}")
220
+ print(prompt)
221
+
222
+ if schedule == 'DPM-Solver':
223
+ if not isinstance(pipe.scheduler, DPMSolverMultistepScheduler):
224
+ pipe.scheduler = DPMSolverMultistepScheduler()
225
+ num_inference_steps = dpms_inference_steps
226
+ guidance_scale = dpms_guidance_scale
227
+ elif schedule == "SA-Solver":
228
+ if not isinstance(pipe.scheduler, SASolverScheduler):
229
+ pipe.scheduler = SASolverScheduler.from_config(pipe.scheduler.config, algorithm_type='data_prediction', tau_func=lambda t: 1 if 200 <= t <= 800 else 0, predictor_order=2, corrector_order=2)
230
+ num_inference_steps = sas_inference_steps
231
+ guidance_scale = sas_guidance_scale
232
+ else:
233
+ raise ValueError(f"Unknown schedule: {schedule}")
234
+
235
+ if not use_negative_prompt:
236
+ negative_prompt = None # type: ignore
237
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
238
+
239
+ images = pipe(
240
+ prompt=prompt,
241
+ width=width,
242
+ height=height,
243
+ guidance_scale=guidance_scale,
244
+ num_inference_steps=num_inference_steps,
245
+ generator=generator,
246
+ num_images_per_prompt=num_imgs,
247
+ use_resolution_binning=use_resolution_binning,
248
+ output_type="pil",
249
+ max_sequence_length=args.T5_token_max_length,
250
+ ).images
251
+
252
+ image_paths = [save_image(img, seed) for img in images]
253
+ print(image_paths)
254
+ return image_paths, seed
255
+
256
+
257
+ examples = [
258
+ "A small cactus with a happy face in the Sahara desert.",
259
+ "an astronaut sitting in a diner, eating fries, cinematic, analog film",
260
+ "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
261
+ "stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.",
262
+ "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
263
+ "beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background",
264
+ "Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism",
265
+ "anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur",
266
+ "The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8",
267
+ ]
268
+
269
+ with gr.Blocks(css="scripts/style.css") as demo:
270
+ gr.Markdown(DESCRIPTION)
271
+ gr.DuplicateButton(
272
+ value="Duplicate Space for private use",
273
+ elem_id="duplicate-button",
274
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
275
+ )
276
+ with gr.Row(equal_height=False):
277
+ with gr.Group():
278
+ with gr.Row():
279
+ prompt = gr.Text(
280
+ label="Prompt",
281
+ show_label=False,
282
+ max_lines=1,
283
+ placeholder="Enter your prompt",
284
+ container=False,
285
+ )
286
+ run_button = gr.Button("Run", scale=0)
287
+ result = gr.Gallery(label="Result", show_label=False)
288
+ # with gr.Accordion("Advanced options", open=False):
289
+ with gr.Group():
290
+ with gr.Row():
291
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
292
+ with gr.Row(visible=True):
293
+ schedule = gr.Radio(
294
+ show_label=True,
295
+ container=True,
296
+ interactive=True,
297
+ choices=SCHEDULE_NAME,
298
+ value=DEFAULT_SCHEDULE_NAME,
299
+ label="Sampler Schedule",
300
+ visible=True,
301
+ )
302
+ num_imgs = gr.Slider(
303
+ label="Num Images",
304
+ minimum=1,
305
+ maximum=8,
306
+ step=1,
307
+ value=1,
308
+ )
309
+ style_selection = gr.Radio(
310
+ show_label=True,
311
+ container=True,
312
+ interactive=True,
313
+ choices=STYLE_NAMES,
314
+ value=DEFAULT_STYLE_NAME,
315
+ label="Image Style",
316
+ )
317
+ negative_prompt = gr.Text(
318
+ label="Negative prompt",
319
+ max_lines=1,
320
+ placeholder="Enter a negative prompt",
321
+ visible=True,
322
+ )
323
+ seed = gr.Slider(
324
+ label="Seed",
325
+ minimum=0,
326
+ maximum=MAX_SEED,
327
+ step=1,
328
+ value=0,
329
+ )
330
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
331
+ with gr.Row(visible=True):
332
+ width = gr.Slider(
333
+ label="Width",
334
+ minimum=256,
335
+ maximum=MAX_IMAGE_SIZE,
336
+ step=32,
337
+ value=1024,
338
+ )
339
+ height = gr.Slider(
340
+ label="Height",
341
+ minimum=256,
342
+ maximum=MAX_IMAGE_SIZE,
343
+ step=32,
344
+ value=1024,
345
+ )
346
+ with gr.Row():
347
+ dpms_guidance_scale = gr.Slider(
348
+ label="DPM-Solver Guidance scale",
349
+ minimum=1,
350
+ maximum=10,
351
+ step=0.1,
352
+ value=4.5,
353
+ )
354
+ dpms_inference_steps = gr.Slider(
355
+ label="DPM-Solver inference steps",
356
+ minimum=5,
357
+ maximum=40,
358
+ step=1,
359
+ value=14,
360
+ )
361
+ with gr.Row():
362
+ sas_guidance_scale = gr.Slider(
363
+ label="SA-Solver Guidance scale",
364
+ minimum=1,
365
+ maximum=10,
366
+ step=0.1,
367
+ value=3,
368
+ )
369
+ sas_inference_steps = gr.Slider(
370
+ label="SA-Solver inference steps",
371
+ minimum=10,
372
+ maximum=40,
373
+ step=1,
374
+ value=25,
375
+ )
376
+
377
+ gr.Examples(
378
+ examples=examples,
379
+ inputs=prompt,
380
+ outputs=[result, seed],
381
+ fn=generate,
382
+ cache_examples=CACHE_EXAMPLES,
383
+ )
384
+
385
+ use_negative_prompt.change(
386
+ fn=lambda x: gr.update(visible=x),
387
+ inputs=use_negative_prompt,
388
+ outputs=negative_prompt,
389
+ api_name=False,
390
+ )
391
+
392
+ gr.on(
393
+ triggers=[
394
+ prompt.submit,
395
+ negative_prompt.submit,
396
+ run_button.click,
397
+ ],
398
+ fn=generate,
399
+ inputs=[
400
+ prompt,
401
+ negative_prompt,
402
+ style_selection,
403
+ use_negative_prompt,
404
+ num_imgs,
405
+ seed,
406
+ width,
407
+ height,
408
+ schedule,
409
+ dpms_guidance_scale,
410
+ sas_guidance_scale,
411
+ dpms_inference_steps,
412
+ sas_inference_steps,
413
+ randomize_seed,
414
+ ],
415
+ outputs=[result, seed],
416
+ api_name="run",
417
+ )
418
+
419
+ if __name__ == "__main__":
420
+ demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=PORT, debug=True)
asset/PixArt.svg ADDED
asset/docs/pixart.md ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!--Copyright 2023 The HuggingFace Team. All rights reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4
+ the License. You may obtain a copy of the License at
5
+
6
+ http://www.apache.org/licenses/LICENSE-2.0
7
+
8
+ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9
+ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10
+ specific language governing permissions and limitations under the License.
11
+ -->
12
+
13
+ [//]: # (&#40;reference from [hugging Face]&#40;https://github.com/huggingface/diffusers/blob/docs/8bit-inference-pixart/docs/source/en/api/pipelines/pixart.md&#41;&#41;)
14
+
15
+ ## Running the `PixArtAlphaPipeline` in under 8GB GPU VRAM
16
+
17
+ It is possible to run the [`PixArtAlphaPipeline`] under 8GB GPU VRAM by loading the text encoder in 8-bit numerical precision. Let's walk through a full-fledged example.
18
+
19
+ First, install the `bitsandbytes` library:
20
+
21
+ ```bash
22
+ pip install -U bitsandbytes
23
+ ```
24
+
25
+ Then load the text encoder in 8-bit:
26
+
27
+ ```python
28
+ from transformers import T5EncoderModel
29
+ from diffusers import PixArtAlphaPipeline
30
+
31
+ text_encoder = T5EncoderModel.from_pretrained(
32
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
33
+ subfolder="text_encoder",
34
+ load_in_8bit=True,
35
+ device_map="auto",
36
+
37
+ )
38
+ pipe = PixArtAlphaPipeline.from_pretrained(
39
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
40
+ text_encoder=text_encoder,
41
+ transformer=None,
42
+ device_map="auto"
43
+ )
44
+ ```
45
+
46
+ Now, use the `pipe` to encode a prompt:
47
+
48
+ ```python
49
+ with torch.no_grad():
50
+ prompt = "cute cat"
51
+ prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)
52
+
53
+ del text_encoder
54
+ del pipe
55
+ flush()
56
+ ```
57
+
58
+ `flush()` is just a utility function to clear the GPU VRAM and is implemented like so:
59
+
60
+ ```python
61
+ import gc
62
+
63
+ def flush():
64
+ gc.collect()
65
+ torch.cuda.empty_cache()
66
+ ```
67
+
68
+ Then compute the latents providing the prompt embeddings as inputs:
69
+
70
+ ```python
71
+ pipe = PixArtAlphaPipeline.from_pretrained(
72
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
73
+ text_encoder=None,
74
+ torch_dtype=torch.float16,
75
+ ).to("cuda")
76
+
77
+ latents = pipe(
78
+ negative_prompt=None,
79
+ prompt_embeds=prompt_embeds,
80
+ negative_prompt_embeds=negative_embeds,
81
+ prompt_attention_mask=prompt_attention_mask,
82
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
83
+ num_images_per_prompt=1,
84
+ output_type="latent",
85
+ ).images
86
+
87
+ del pipe.transformer
88
+ flush()
89
+ ```
90
+
91
+ Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.
92
+
93
+ Once the latents are computed, pass it off the VAE to decode into a real image:
94
+
95
+ ```python
96
+ with torch.no_grad():
97
+ image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
98
+ image = pipe.image_processor.postprocess(image, output_type="pil")
99
+ image.save("cat.png")
100
+ ```
101
+
102
+ All of this, put together, should allow you to run [`PixArtAlphaPipeline`] under 8GB GPU VRAM.
103
+
104
+ ![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/8bits_cat.png)
105
+
106
+ Find the script [here](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e) that can be run end-to-end to report the memory being used.
107
+
108
+ <Tip warning={true}>
109
+
110
+ Text embeddings computed in 8-bit can have an impact on the quality of the generated images because of the information loss in the representation space induced by the reduced precision. It's recommended to compare the outputs with and without 8-bit.
111
+
112
+ </Tip>
asset/examples.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ examples = [
3
+ [
4
+ "A small cactus with a happy face in the Sahara desert.",
5
+ "dpm-solver", 20, 4.5,
6
+ ],
7
+ [
8
+ "An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history"
9
+ "of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits "
10
+ "mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret "
11
+ "and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile "
12
+ "as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and "
13
+ "the Parisian streets and city in the background, depth of field, cinematic 35mm film.",
14
+ "dpm-solver", 20, 4.5,
15
+ ],
16
+ [
17
+ "An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. "
18
+ "Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. "
19
+ "The quote 'Find the universe within you' is etched in bold letters across the horizon."
20
+ "blue and pink, brilliantly illuminated in the background.",
21
+ "dpm-solver", 20, 4.5,
22
+ ],
23
+ [
24
+ "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.",
25
+ "dpm-solver", 20, 4.5,
26
+ ],
27
+ [
28
+ "A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.",
29
+ "dpm-solver", 20, 4.5,
30
+ ],
31
+ [
32
+ "a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, "
33
+ "national geographic photo, 8k resolution, crayon art, interactive artwork",
34
+ "dpm-solver", 20, 4.5,
35
+ ]
36
+ ]
asset/logo-sigma.png ADDED
asset/logo.png ADDED
asset/samples.txt ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A small cactus with a happy face in the Sahara desert.
2
+ Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.
3
+ beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background
4
+ stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.
5
+ nature vs human nature, surreal, UHD, 8k, hyper details, rich colors, photograph.
6
+ Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism
7
+ anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur
8
+ The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8
9
+ Bright scene, aerial view, ancient city, fantasy, gorgeous light, mirror reflection, high detail, wide angle lens.
10
+ 8k uhd A man looks up at the starry sky, lonely and ethereal, Minimalism, Chaotic composition Op Art
11
+ A middle-aged woman of Asian descent, her dark hair streaked with silver, appears fractured and splintered, intricately embedded within a sea of broken porcelain. The porcelain glistens with splatter paint patterns in a harmonious blend of glossy and matte blues, greens, oranges, and reds, capturing her dance in a surreal juxtaposition of movement and stillness. Her skin tone, a light hue like the porcelain, adds an almost mystical quality to her form.
12
+ A 4k dslr image of a lemur wearing a red magician hat and a blue coat performing magic tricks with cards in a garden.
13
+ A alpaca made of colorful building blocks, cyberpunk
14
+ A baby painter trying to draw very simple picture, white background
15
+ A boy and a girl fall in love
16
+ A dog that has been meditating all the time
17
+ A man is sitting in a chair with his chin resting on his hand. The chair, along with the man's feet, are submerged in the sea. Strikingly, the man's back is on fire.
18
+ A painter study hard to learn how to draw with many concepts in the air, white background
19
+ A painter with low quality, white background, pixel art
20
+ A person standing on the desert, desert waves, gossip illustration, half red, half blue, abstract image of sand, clear style, trendy illustration, outdoor, top view, clear style, precision art, ultra high definition image
21
+ A silhouette of a grand piano overlooking a dusky cityscape viewed from a top-floor penthouse, rendered in the bold and vivid sytle of a vintage travel poster.
22
+ A sureal parallel world where mankind avoid extinction by preserving nature, epic trees, water streams, various flowers, intricate details, rich colors, rich vegetation, cinematic, symmetrical, beautiful lighting, V-Ray render, sun rays, magical lights, photography
23
+ A woman is shopping for fresh produce at the farmer's market.
24
+ A worker that looks like a mixture of cow and horse is working hard to type code
25
+ A young man dressed in ancient Chinese clothing, Asian people, White robe, Handsome, Hand gestures forming a spell, Martial arts and fairy-like vibe, Carrying a legendary-level giant sword on the back, Game character, Surrounded by runes, Cyberpunk style, neon lights, best quality, masterpiece, cg, hdr, high-definition, extremely detailed, photorealistic, epic, character design, detailed face, superhero, hero, detailed UHD, real-time, vfx, 3D rendering, 8k
26
+ An alien octopus floats through a protal reading a newspaper
27
+ An epressive oil painting of a basketbal player dunking, depicted as an explosion of a nebula
28
+ art collection style and fashion shoot, in the style of made of glass, dark blue and light pink, paul rand, solarpunk, camille vivier, beth didonato hair, barbiecore, hyper-realistic
29
+ artistic
30
+ beautiful secen
31
+ Crocodile in a sweater
32
+ Design a letter A, 3D stereoscopic Ice material Interior light blue Conceptual product design Futuristic Blind box toy Handcrafted Exquisite 3D effect Full body display Ultra-high precision Ultra-detailed Perfect lighting OC Renderer Blender 8k Ultra-sharp Ultra-noise reduction
33
+ Floating,colossal,futuristic statue in the sky, awe-inspiring and serenein the style of Stuart Lippincott:2with detailed composition and subtle geometric elements.This sanctuary-ike atmosphere features crisp clarity and soft amber tones.In contrasttiny human figures surround the statueThe pieceincorporates flowing draperiesreminiscent of Shwedoff and Philip McKay's stylesemphasizing thejuxtaposition between the powerful presence of the statue and thevulnerability of the minuscule human figuresshwedoff
34
+ knolling of a drawing tools for painter
35
+ Leonardo da Vinci's Last Supper content, Van Goph's Starry Night Style
36
+ Luffy from ONEPIECE, handsome face, fantasy
37
+ photography shot through an outdoor window of a coffee shop with neon sign lighting, window glares and reflections, depth of field, {little girl with red hair sitting at a table, portrait, kodak portra 800,105 mm f1.8
38
+ poster of a mechanical cat, techical Schematics viewed from front and side view on light white blueprint paper, illustartion drafting style, illustation, typography, conceptual art, dark fantasy steampunk, cinematic, dark fantasy
39
+ The girl in the car is filled with goldfish and flowers, goldfish can fly, Kawaguchi Renko's art, natural posture, holiday dadcore, youthful energy and pressure, body stretching, goldfish simulation movies in the sky, super details, and dreamy high photography. Colorful. Covered by water and goldfish, indoor scene, close-up shot in XT4 movie
40
+ The image features a woman wearing a red shirt with an icon. She appears to be posing for the camera, and her outfit includes a pair of jeans. The woman seems to be in a good mood, as she is smiling. The background of the image is blurry, focusing more on the woman and her attire.
41
+ The towel was on top of the hard counter.
42
+ A vast landscape made entirely of various meats spreads out before the viewer. tender, succulent hills of roast beef, chicken drumstick trees, bacon rivers, and ham boulders create a surreal, yet appetizing scene. the sky is adorned with pepperoni sun and salami clouds.
43
+ I want to supplement vitamin c, please help me paint related food.
44
+ A vibrant yellow banana-shaped couch sits in a cozy living room, its curve cradling a pile of colorful cushions. on the wooden floor, a patterned rug adds a touch of eclectic charm, and a potted plant sits in the corner, reaching towards the sunlight filtering through the window.
45
+ A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.
46
+ A blue jay standing on a large basket of rainbow macarons.
47
+ A bucket bag made of blue suede. The bag is decorated with intricate golden paisley patterns. The handle of the bag is made of rubies and pearls.
48
+ An alien octopus floats through a portal reading a newspaper.
49
+ bird's eye view of a city.
50
+ beautiful scene
51
+ A 2D animation of a folk music band composed of anthropomorphic autumn leaves, each playing traditional bluegrass instruments, amidst a rustic forest setting dappled with the soft light of a harvest moon.
52
+ In front of a deep black backdrop, a figure of middle years, her Tongan skin rich and glowing, is captured mid-twirl, her curly hair flowing like a storm behind her. Her attire resembles a whirlwind of marble and porcelain fragments. Illuminated by the gleam of scattered porcelain shards, creating a dreamlike atmosphere, the dancer manages to appear fragmented, yet maintains a harmonious and fluid form.
53
+ Digital illustration of a beach scene crafted from yarn. The sandy beach is depicted with beige yarn, waves are made of blue and white yarn crashing onto the shore. A yarn sun sets on the horizon, casting a warm glow. Yarn palm trees sway gently, and little yarn seashells dot the shoreline.
54
+ Illustration of a chic chair with a design reminiscent of a pumpkin’s form, with deep orange cushioning, in a stylish loft setting.
55
+ A detailed oil painting of an old sea captain, steering his ship through a storm. Saltwater is splashing against his weathered face, determination in his eyes. Twirling malevolent clouds are seen above and stern waves threaten to submerge the ship while seagulls dive and twirl through the chaotic landscape. Thunder and lights embark in the distance, illuminating the scene with an eerie green glow.
56
+ An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. The quote 'Find the universe within you' is etched in bold letters across the horizon.
57
+ A modern architectural building with large glass windows, situated on a cliff overlooking a serene ocean at sunset
58
+ photo of an ancient shipwreck nestled on the ocean floor. Marine plants have claimed the wooden structure, and fish swim in and out of its hollow spaces. Sunken treasures and old cannons are scattered around, providing a glimpse into the past
59
+ A 3D render of a coffee mug placed on a window sill during a stormy day. The storm outside the window is reflected in the coffee, with miniature lightning bolts and turbulent waves seen inside the mug. The room is dimly lit, adding to the dramatic atmosphere.A minimap diorama of a cafe adorned with indoor plants. Wooden beams crisscross above, and a cold brew station stands out with tiny bottles and glasses.
60
+ An antique botanical illustration drawn with fine lines and a touch of watercolour whimsy, depicting a strange lily crossed with a Venus flytrap, its petals poised as if ready to snap shut on any unsuspecting insects.An illustration inspired by old-world botanical sketches blends a cactus with lilac blooms into a Möbius strip, using detailed lines and subtle watercolor touches to capture nature's diverse beauty and mathematical intrigue.
61
+ An ink sketch style illustration of a small hedgehog holding a piece of watermelon with its tiny paws, taking little bites with its eyes closed in delight.Photo of a lychee-inspired spherical chair, with a bumpy white exterior and plush interior, set against a tropical wallpaper.
62
+ 3d digital art of an adorable ghost, glowing within, holding a heart shaped pumpkin, Halloween, super cute, spooky haunted house background
63
+ professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.
64
+ an astronaut sitting in a diner, eating fries, cinematic, analog film
65
+ Chinese architecture, ancient style,mountain, bird, lotus, pond, big tree, 4K Unity, octane rendering.
66
+ Ethereal fantasy concept art of thunder god with hammer. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy.
67
+ A Japanese girl walking along a path, surrounding by blooming oriental cherry, pink petal slowly falling down to the ground
68
+ A Ukiyoe style painting, an astronaut riding a unicorn, In the background there is an ancient Japanese architecture
69
+ Steampunk makeup, in the style of vray tracing, colorful impasto, uhd image, indonesian art, fine feather details with bright red and yellow and green and pink and orange colours, intricate patterns and details, dark cyan and amber makeup. Rich colourful plumes. Victorian style.
70
+ A cute teddy bear in front of a plain white wall, warm and brown fur, soft and fluffy
71
+ The beautiful scenery of Seattle, painting by Al Capp.
72
+ Photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang.
73
+ An astronaut riding a horse on the moon, oil painting by Van Gogh.
74
+ A deep forest clearing with a mirrored pond reflecting a galaxy-filled night sky
75
+ Realistic oil painting of a stunning model merged in multicolor splash made of finely torn paper, eye contact, walking with class in a street.
76
+ a chinese model is sitting on a train, magazine cover, clothes made of plastic, photorealistic,futuristic style, gray and green light, movie lighting, 32K HD
77
+ a handsome 24 years old boy in the middle with sky color background wearing eye glasses, it's super detailed with anime style, it's a portrait with delicated eyes and nice looking face
78
+ a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, national geographic photo, 8k resolution, crayon art, interactive artwork
79
+ 3D rendering miniature scene design, Many tall buildings, A winding urban road runs through the middle,a lot of cars on the road, transparent material pipeline transports Materials, ,there are many people around, in thestyle of light orange and yellow, graphic design- inspired illustrations, classic still-life, beeple, josan gon-zalez, manga-influenced, miniature dioramas, in thestyle of playful and whimsical designs, graphic de-sign-inspired illustrations, minimalism, hyperrealismlomo lca, e-commerce C4D style, e-commerce posterUl, UX, octane render, blender
80
+ Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works
81
+ A cute orange kitten sliding down an aqua slide. happy excited. 16mm lens in front. we see his excitement and scared in the eye. vibrant colors. water splashing on the lens
82
+ Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.
83
+ A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures.
84
+ An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film.
85
+ A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.
86
+ A New Zealand female business owner stands and is happy that his business is growing by having good VoIP and broadband supplied by Voyager Internet. This business owner is dressed semi casual and is standing with a funky office space in the background. The image is light and bright and is well lit. This image needs to be shot like a professional photo shoot using a Canon R6 with high quality 25mm lens. This image has a shallow depth of field
87
+ The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8
88
+ Editorial photoshoot of a old woman, high fashion 2000s fashion
89
+ Mural Painted of Prince in Purple Rain on side of 5 story brick building next to zen garden vacant lot in the urban center district, rgb
90
+ Cozy Scandinavian living room, there is a cat sleeping on the couch, depth of field
91
+ Street style centered straight shot photo shot on Afga Vista 400, lense 50mm, of a two women,skin to skin touch face, emotion, hughing, natural blond hair, natural features, ultra detailed, skin texture, Rembrandt light, soft shadows
92
+ Frog, in forest, colorful, no watermark, no signature, in forest, 8k
93
+ selfie of a woman and her lion cub on the plains
94
+ A fisherman fixing his net sitting on a beautiful tropical beach at sunset with bending palm trees fishing gear and a small boat on shore
95
+ Coast, decorative painting, horizon, modern, fashionable, full of abstract feeling, full of imagination, the picture reveals the sense of time passing, there is a feeling of the end of the world
96
+ A close up of a branch of a tree and a golden bug on the top a leaf, shutterstock contest winner,ecological art, depth of field, shallow depth of field, macro photography
97
+ Outdoor style fashion photo, full – body shot of a man with short brown hair, happy and smiling, he is standing on his hipster bicycle wearing a light blue long sleeved blouse with closed buttons and dark blue jeans trousers, in the background the exterior of an Aldi store, fully lit background, natural afternoon lighting
98
+ beautiful woman sniper, wearing soviet army uniform, one eye on sniper lens, in snow ground
99
+ A very attractive and natural woman, sitting on a yoka mat, breathing, eye closed, no make up, intense satisfaction, she looks like she is intensely relaxed, yoga class, sunrise, 35mm
100
+ a close up of a helmet on a person, digital art, inspired by Han Gan, cloisonnism, female, victorian armor, ultramarine, best of behance, anton fadeev 8 k, fined detail, sci-fi character, elegant armor, fantasy art behance
101
+ a melting apple
102
+ yellow FIAT 500 Cinquecento 1957 driving through liechtenstein castle with a lot of banknotes scattered behind ,filled with wads of cash , car color yellow, license plate R-33
103
+ tented resort in the desert, rocky and sandy terrain, 5 star hotel, beautiful landscape, landscape photography, depth of view, Fujifilm GFX 100 –uplight
104
+ Full body shot, a French woman, Photography, French Streets background, backlighting, rim light, Fujifilm.
105
+ Modern luxury contemporary luxury home interiors house, in the style of mimicking ruined materials, ray tracing, haunting houses, and stone, capture the essence of nature, gray and bronze, dynamic outdoor shots.
106
+ Over the shoulder game perspective, game screen of Diablo 4, Inside the gorgeous palace is the wet ground, The necromancer knelt before the king, and a horde of skeletons he summoned stood at his side, cinematic light.
107
+ Color photo of a corgi made of transparent glass, standing on the riverside in Yosemite National Park.
108
+ Happy dreamy owl monster sitting on a tree branch, colorful glittering particles, forest background, detailed feathers.
109
+ Game-Art - An island with different geographical properties and multiple small cities floating in space
110
+ Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee.
111
+ A car made out of vegetables.
112
+ A serene lakeside during autumn with trees displaying a palette of fiery colors.
113
+ A realistic landscape shot of the Northern Lights dancing over a snowy mountain range in Iceland.
114
+ A deep forest clearing with a mirrored pond reflecting a galaxy-filled night sky.
115
+ Drone view of waves crashing against the rugged cliffs along Big Sur’s Garay Point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore.
116
+ A curvy timber house near a sea, designed by Zaha Hadid, represent the image of a cold, modern architecture, at night, white lighting, highly detailed.
117
+ Eiffel Tower was Made up of more than 2 million translucent straws to look like a cloud, with the bell tower at the top of the building, Michel installed huge foam-making machines in the forest to blow huge amounts of unpredictable wet clouds in the building's classic architecture.
118
+ Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works.
119
+ A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.
120
+ A close-up photo of a person. The subject is a woman. She wore a blue coat with a gray dress underneath. She has blue eyes and blond hair, and wears a pair of earrings. Behind are blurred city buildings and streets.
configs/PixArt_xl2_internal.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root = '/data/data'
2
+ data = dict(type='InternalData', root='images', image_list_json=['data_info.json'], transform='default_train', load_vae_feat=True, load_t5_feat=True)
3
+ image_size = 256 # the generated image resolution
4
+ train_batch_size = 32
5
+ eval_batch_size = 16
6
+ use_fsdp=False # if use FSDP mode
7
+ valid_num=0 # take as valid aspect-ratio when sample number >= valid_num
8
+ fp32_attention = True
9
+ # model setting
10
+ model = 'PixArt_XL_2'
11
+ aspect_ratio_type = None # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
12
+ multi_scale = False # if use multiscale dataset model training
13
+ pe_interpolation = 1.0 # positional embedding interpolation
14
+ # qk norm
15
+ qk_norm = False
16
+ # kv token compression
17
+ kv_compress = False
18
+ kv_compress_config = {
19
+ 'sampling': None,
20
+ 'scale_factor': 1,
21
+ 'kv_compress_layer': [],
22
+ }
23
+
24
+ # training setting
25
+ num_workers=4
26
+ train_sampling_steps = 1000
27
+ visualize=False
28
+ eval_sampling_steps = 250
29
+ model_max_length = 120
30
+ lora_rank = 4
31
+ num_epochs = 80
32
+ gradient_accumulation_steps = 1
33
+ grad_checkpointing = False
34
+ gradient_clip = 1.0
35
+ gc_step = 1
36
+ auto_lr = dict(rule='sqrt')
37
+
38
+ # we use different weight decay with the official implementation since it results better result
39
+ optimizer = dict(type='AdamW', lr=1e-4, weight_decay=3e-2, eps=1e-10)
40
+ lr_schedule = 'constant'
41
+ lr_schedule_args = dict(num_warmup_steps=500)
42
+
43
+ save_image_epochs = 1
44
+ save_model_epochs = 1
45
+ save_model_steps=1000000
46
+
47
+ sample_posterior = True
48
+ mixed_precision = 'fp16'
49
+ scale_factor = 0.18215 # ldm vae: 0.18215; sdxl vae: 0.13025
50
+ ema_rate = 0.9999
51
+ tensorboard_mox_interval = 50
52
+ log_interval = 50
53
+ cfg_scale = 4
54
+ mask_type='null'
55
+ num_group_tokens=0
56
+ mask_loss_coef=0.
57
+ load_mask_index=False # load prepared mask_type index
58
+ # load model settings
59
+ vae_pretrained = "/cache/pretrained_models/sd-vae-ft-ema"
60
+ load_from = None
61
+ resume_from = dict(checkpoint=None, load_ema=False, resume_optimizer=True, resume_lr_scheduler=True)
62
+ snr_loss=False
63
+ real_prompt_ratio = 1.0
64
+ # classifier free guidance
65
+ class_dropout_prob = 0.1
66
+ # work dir settings
67
+ work_dir = '/cache/exps/'
68
+ s3_work_dir = None
69
+ micro_condition = False
70
+ seed = 43
71
+ skip_step=0
72
+
73
+ # LCM
74
+ loss_type = 'huber'
75
+ huber_c = 0.001
76
+ num_ddim_timesteps=50
77
+ w_max = 15.0
78
+ w_min = 3.0
79
+ ema_decay = 0.95
configs/pixart_alpha_config/PixArt_xl2_img1024_dreambooth.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['../PixArt_xl2_internal.py']
2
+ data_root = 'data/dreambooth/dataset'
3
+
4
+ data = dict(type='DreamBooth', root='dog6', prompt=['a photo of sks dog'], transform='default_train', load_vae_feat=True)
5
+ image_size = 1024
6
+
7
+ # model setting
8
+ model = 'PixArtMS_XL_2' # model for multi-scale training
9
+ fp32_attention = True
10
+ load_from = 'Path/to/PixArt-XL-2-1024-MS.pth'
11
+ vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
12
+ aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
13
+ multi_scale = True # if use multiscale dataset model training
14
+ pe_interpolation = 2.0
15
+
16
+ # training setting
17
+ num_workers=1
18
+ train_batch_size = 1
19
+ num_epochs = 200
20
+ gradient_accumulation_steps = 1
21
+ grad_checkpointing = True
22
+ gradient_clip = 0.01
23
+ optimizer = dict(type='AdamW', lr=5e-6, weight_decay=3e-2, eps=1e-10)
24
+ lr_schedule_args = dict(num_warmup_steps=0)
25
+ auto_lr = None
26
+
27
+ log_interval = 1
28
+ save_model_epochs=10000
29
+ save_model_steps=100
30
+ work_dir = 'output/debug'
configs/pixart_alpha_config/PixArt_xl2_img1024_internal.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['../PixArt_xl2_internal.py']
2
+ data_root = 'data'
3
+ image_list_json = ['data_info.json',]
4
+
5
+ data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
6
+ image_size = 1024
7
+
8
+ # model setting
9
+ model = 'PixArt_XL_2'
10
+ fp32_attention = True
11
+ load_from = None
12
+ vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
13
+ pe_interpolation = 2.0
14
+
15
+ # training setting
16
+ num_workers=10
17
+ train_batch_size = 2 # 32
18
+ num_epochs = 200 # 3
19
+ gradient_accumulation_steps = 1
20
+ grad_checkpointing = True
21
+ gradient_clip = 0.01
22
+ optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
23
+ lr_schedule_args = dict(num_warmup_steps=1000)
24
+
25
+ eval_sampling_steps = 200
26
+ log_interval = 20
27
+ save_model_epochs=1
28
+ save_model_steps=2000
29
+ work_dir = 'output/debug'
configs/pixart_alpha_config/PixArt_xl2_img1024_internalms.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['../PixArt_xl2_internal.py']
2
+ data_root = 'data'
3
+ image_list_json = ['data_info.json',]
4
+
5
+ data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
6
+ image_size = 1024
7
+
8
+ # model setting
9
+ model = 'PixArtMS_XL_2' # model for multi-scale training
10
+ fp32_attention = True
11
+ load_from = None
12
+ vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
13
+ aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
14
+ multi_scale = True # if use multiscale dataset model training
15
+ pe_interpolation = 2.0
16
+
17
+ # training setting
18
+ num_workers=10
19
+ train_batch_size = 12 # max 14 for PixArt-xL/2 when grad_checkpoint
20
+ num_epochs = 10 # 3
21
+ gradient_accumulation_steps = 1
22
+ grad_checkpointing = True
23
+ gradient_clip = 0.01
24
+ optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
25
+ lr_schedule_args = dict(num_warmup_steps=1000)
26
+ save_model_epochs=1
27
+ save_model_steps=2000
28
+
29
+ log_interval = 20
30
+ eval_sampling_steps = 200
31
+ work_dir = 'output/debug'
32
+ micro_condition = True
configs/pixart_alpha_config/PixArt_xl2_img256_internal.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['../PixArt_xl2_internal.py']
2
+ data_root = 'data'
3
+ image_list_json = ['data_info.json',]
4
+
5
+ data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
6
+ image_size = 256
7
+
8
+ # model setting
9
+ model = 'PixArt_XL_2'
10
+ fp32_attention = True
11
+ load_from = None
12
+ vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
13
+ # training setting
14
+ eval_sampling_steps = 200
15
+
16
+ num_workers=10
17
+ train_batch_size = 176 # 32 # max 96 for PixArt-L/4 when grad_checkpoint
18
+ num_epochs = 200 # 3
19
+ gradient_accumulation_steps = 1
20
+ grad_checkpointing = True
21
+ gradient_clip = 0.01
22
+ optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
23
+ lr_schedule_args = dict(num_warmup_steps=1000)
24
+
25
+ log_interval = 20
26
+ save_model_epochs=5
27
+ work_dir = 'output/debug'
configs/pixart_alpha_config/PixArt_xl2_img512_internal.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['../PixArt_xl2_internal.py']
2
+ data_root = 'data'
3
+ image_list_json = ['data_info.json',]
4
+
5
+ data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
6
+ image_size = 512
7
+
8
+ # model setting
9
+ model = 'PixArt_XL_2'
10
+ fp32_attention = True
11
+ load_from = None
12
+ vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
13
+ pe_interpolation = 1.0
14
+
15
+ # training setting
16
+ use_fsdp=False # if use FSDP mode
17
+ num_workers=10
18
+ train_batch_size = 38 # 32
19
+ num_epochs = 200 # 3
20
+ gradient_accumulation_steps = 1
21
+ grad_checkpointing = True
22
+ gradient_clip = 0.01
23
+ optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
24
+ lr_schedule_args = dict(num_warmup_steps=1000)
25
+
26
+ eval_sampling_steps = 200
27
+ log_interval = 20
28
+ save_model_epochs=1
29
+ work_dir = 'output/debug'
configs/pixart_alpha_config/PixArt_xl2_img512_internalms.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['../PixArt_xl2_internal.py']
2
+ data_root = 'data'
3
+ image_list_json = ['data_info.json',]
4
+
5
+ data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
6
+ image_size = 512
7
+
8
+ # model setting
9
+ model = 'PixArtMS_XL_2' # model for multi-scale training
10
+ fp32_attention = True
11
+ load_from = None
12
+ vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
13
+ aspect_ratio_type = 'ASPECT_RATIO_512' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
14
+ multi_scale = True # if use multiscale dataset model training
15
+ pe_interpolation = 1.0
16
+
17
+ # training setting
18
+ num_workers=10
19
+ train_batch_size = 40 # max 40 for PixArt-xL/2 when grad_checkpoint
20
+ num_epochs = 20 # 3
21
+ gradient_accumulation_steps = 1
22
+ grad_checkpointing = True
23
+ gradient_clip = 0.01
24
+ optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
25
+ lr_schedule_args = dict(num_warmup_steps=1000)
26
+ save_model_epochs=1
27
+ save_model_steps=2000
28
+
29
+ log_interval = 20
30
+ eval_sampling_steps = 200
31
+ work_dir = 'output/debug'
configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['../PixArt_xl2_internal.py']
2
+ data_root = 'pixart-sigma-toy-dataset'
3
+ image_list_json = ['data_info.json']
4
+
5
+ data = dict(
6
+ type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
7
+ load_vae_feat=False, load_t5_feat=False
8
+ )
9
+ image_size = 1024
10
+
11
+ # model setting
12
+ model = 'PixArtMS_XL_2'
13
+ mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16']
14
+ fp32_attention = True
15
+ load_from = None
16
+ resume_from = None
17
+ vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
18
+ aspect_ratio_type = 'ASPECT_RATIO_1024'
19
+ multi_scale = True # if use multiscale dataset model training
20
+ pe_interpolation = 2.0
21
+
22
+ # training setting
23
+ num_workers = 10
24
+ train_batch_size = 2 # 3 for w.o feature extraction; 12 for feature extraction
25
+ num_epochs = 2 # 3
26
+ gradient_accumulation_steps = 1
27
+ grad_checkpointing = True
28
+ gradient_clip = 0.01
29
+ optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
30
+ lr_schedule_args = dict(num_warmup_steps=1000)
31
+
32
+ eval_sampling_steps = 500
33
+ visualize = True
34
+ log_interval = 20
35
+ save_model_epochs = 1
36
+ save_model_steps = 1000
37
+ work_dir = 'output/debug'
38
+
39
+ # pixart-sigma
40
+ scale_factor = 0.13025
41
+ real_prompt_ratio = 0.5
42
+ model_max_length = 300
43
+ class_dropout_prob = 0.1
44
+
45
+ qk_norm = False
46
+ skip_step = 0 # skip steps during data loading
configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_internalms_kvcompress.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['../PixArt_xl2_internal.py']
2
+ data_root = 'data'
3
+ image_list_json = ['data_info.json']
4
+
5
+ data = dict(
6
+ type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
7
+ load_vae_feat=False, load_t5_feat=False
8
+ )
9
+ image_size = 1024
10
+
11
+ # model setting
12
+ model = 'PixArtMS_XL_2'
13
+ mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16']
14
+ fp32_attention = True
15
+ load_from = None
16
+ resume_from = None
17
+ vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
18
+ aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
19
+ multi_scale = True # if use multiscale dataset model training
20
+ pe_interpolation = 2.0
21
+
22
+ # training setting
23
+ num_workers = 10
24
+ train_batch_size = 4 # 16
25
+ num_epochs = 2 # 3
26
+ gradient_accumulation_steps = 1
27
+ grad_checkpointing = True
28
+ gradient_clip = 0.01
29
+ optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
30
+ lr_schedule_args = dict(num_warmup_steps=500)
31
+
32
+ eval_sampling_steps = 250
33
+ visualize = True
34
+ log_interval = 10
35
+ save_model_epochs = 1
36
+ save_model_steps = 1000
37
+ work_dir = 'output/debug'
38
+
39
+ # pixart-sigma
40
+ scale_factor = 0.13025
41
+ real_prompt_ratio = 0.5
42
+ model_max_length = 300
43
+ class_dropout_prob = 0.1
44
+ kv_compress = True
45
+ kv_compress_config = {
46
+ 'sampling': 'conv', # ['conv', 'uniform', 'ave']
47
+ 'scale_factor': 2,
48
+ 'kv_compress_layer': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
49
+ }
50
+ qk_norm = False
51
+ skip_step = 0 # skip steps during data loading
configs/pixart_sigma_config/PixArt_sigma_xl2_img1024_lcm.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['../PixArt_xl2_internal.py']
2
+ data_root = 'pixart-sigma-toy-dataset'
3
+ image_list_json = ['data_info.json']
4
+
5
+ data = dict(
6
+ type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
7
+ load_vae_feat=True, load_t5_feat=True,
8
+ )
9
+ image_size = 1024
10
+
11
+ # model setting
12
+ model = 'PixArtMS_XL_2' # model for multi-scale training
13
+ fp32_attention = False
14
+ load_from = None
15
+ resume_from = None
16
+ vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
17
+ aspect_ratio_type = 'ASPECT_RATIO_1024'
18
+ multi_scale = True # if use multiscale dataset model training
19
+ pe_interpolation = 2.0
20
+
21
+ # training setting
22
+ num_workers = 4
23
+ train_batch_size = 12 # max 12 for PixArt-xL/2 when grad_checkpoint
24
+ num_epochs = 10 # 3
25
+ gradient_accumulation_steps = 1
26
+ grad_checkpointing = True
27
+ gradient_clip = 0.01
28
+ optimizer = dict(type='CAMEWrapper', lr=1e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
29
+ lr_schedule_args = dict(num_warmup_steps=100)
30
+ save_model_epochs = 10
31
+ save_model_steps = 1000
32
+ valid_num = 0 # take as valid aspect-ratio when sample number >= valid_num
33
+
34
+ log_interval = 10
35
+ eval_sampling_steps = 5
36
+ visualize = True
37
+ work_dir = 'output/debug'
38
+
39
+ # pixart-sigma
40
+ scale_factor = 0.13025
41
+ real_prompt_ratio = 0.5
42
+ model_max_length = 300
43
+ class_dropout_prob = 0.1
44
+
45
+ # LCM
46
+ loss_type = 'huber'
47
+ huber_c = 0.001
48
+ num_ddim_timesteps = 50
49
+ w_max = 15.0
50
+ w_min = 3.0
51
+ ema_decay = 0.95
52
+ cfg_scale = 4.5
configs/pixart_sigma_config/PixArt_sigma_xl2_img256_internal.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['../PixArt_xl2_internal.py']
2
+ data_root = 'pixart-sigma-toy-dataset'
3
+ image_list_json = ['data_info.json']
4
+
5
+ data = dict(
6
+ type='InternalDataSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
7
+ load_vae_feat=False, load_t5_feat=False,
8
+ )
9
+ image_size = 256
10
+
11
+ # model setting
12
+ model = 'PixArt_XL_2'
13
+ mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16']
14
+ fp32_attention = True
15
+ load_from = "output/pretrained_models/PixArt-Sigma-XL-2-256x256.pth" # https://huggingface.co/PixArt-alpha/PixArt-Sigma
16
+ resume_from = None
17
+ vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
18
+ multi_scale = False # if use multiscale dataset model training
19
+ pe_interpolation = 0.5
20
+
21
+ # training setting
22
+ num_workers = 10
23
+ train_batch_size = 64 # 64 as default
24
+ num_epochs = 200 # 3
25
+ gradient_accumulation_steps = 1
26
+ grad_checkpointing = True
27
+ gradient_clip = 0.01
28
+ optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
29
+ lr_schedule_args = dict(num_warmup_steps=1000)
30
+
31
+ eval_sampling_steps = 500
32
+ log_interval = 20
33
+ save_model_epochs = 5
34
+ save_model_steps = 2500
35
+ work_dir = 'output/debug'
36
+
37
+ # pixart-sigma
38
+ scale_factor = 0.13025
39
+ real_prompt_ratio = 0.5
40
+ model_max_length = 300
41
+ class_dropout_prob = 0.1
configs/pixart_sigma_config/PixArt_sigma_xl2_img2K_internalms_kvcompress.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['../PixArt_xl2_internal.py']
2
+ data_root = 'data'
3
+ image_list_json = ['data_info.json']
4
+
5
+ data = dict(
6
+ type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
7
+ load_vae_feat=False, load_t5_feat=False
8
+ )
9
+ image_size = 2048
10
+
11
+ # model setting
12
+ model = 'PixArtMS_XL_2'
13
+ mixed_precision = 'fp16'
14
+ fp32_attention = True
15
+ load_from = None
16
+ resume_from = None
17
+ vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
18
+ aspect_ratio_type = 'ASPECT_RATIO_2048' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
19
+ multi_scale = True # if use multiscale dataset model training
20
+ pe_interpolation = 4.0
21
+
22
+ # training setting
23
+ num_workers = 10
24
+ train_batch_size = 4 # 48
25
+ num_epochs = 10 # 3
26
+ gradient_accumulation_steps = 1
27
+ grad_checkpointing = True
28
+ gradient_clip = 0.01
29
+ optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
30
+ lr_schedule_args = dict(num_warmup_steps=100)
31
+
32
+ eval_sampling_steps = 100
33
+ visualize = True
34
+ log_interval = 10
35
+ save_model_epochs = 10
36
+ save_model_steps = 100
37
+ work_dir = 'output/debug'
38
+
39
+ # pixart-sigma
40
+ scale_factor = 0.13025
41
+ real_prompt_ratio = 0.5
42
+ model_max_length = 300
43
+ class_dropout_prob = 0.1
44
+ kv_compress = False
45
+ kv_compress_config = {
46
+ 'sampling': 'conv',
47
+ 'scale_factor': 2,
48
+ 'kv_compress_layer': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
49
+ }
configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['../PixArt_xl2_internal.py']
2
+ data_root = 'pixart-sigma-toy-dataset'
3
+ image_list_json = ['data_info.json']
4
+
5
+ data = dict(
6
+ type='InternalDataMSSigma', root='InternData', image_list_json=image_list_json, transform='default_train',
7
+ load_vae_feat=False, load_t5_feat=False,
8
+ )
9
+ image_size = 512
10
+
11
+ # model setting
12
+ model = 'PixArtMS_XL_2'
13
+ mixed_precision = 'fp16' # ['fp16', 'fp32', 'bf16']
14
+ fp32_attention = True
15
+ load_from = "output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth" # https://huggingface.co/PixArt-alpha/PixArt-Sigma
16
+ resume_from = None
17
+ vae_pretrained = "output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers/vae" # sdxl vae
18
+ aspect_ratio_type = 'ASPECT_RATIO_512'
19
+ multi_scale = True # if use multiscale dataset model training
20
+ pe_interpolation = 1.0
21
+
22
+ # training setting
23
+ num_workers = 10
24
+ train_batch_size = 2 # 48 as default
25
+ num_epochs = 10 # 3
26
+ gradient_accumulation_steps = 1
27
+ grad_checkpointing = True
28
+ gradient_clip = 0.01
29
+ optimizer = dict(type='CAMEWrapper', lr=2e-5, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
30
+ lr_schedule_args = dict(num_warmup_steps=1000)
31
+
32
+ eval_sampling_steps = 500
33
+ visualize = True
34
+ log_interval = 20
35
+ save_model_epochs = 5
36
+ save_model_steps = 2500
37
+ work_dir = 'output/debug'
38
+
39
+ # pixart-sigma
40
+ scale_factor = 0.13025
41
+ real_prompt_ratio = 0.5
42
+ model_max_length = 300
43
+ class_dropout_prob = 0.1
diffusion/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from .iddpm import IDDPM
7
+ from .dpm_solver import DPMS
8
+ from .sa_sampler import SASolverSampler
diffusion/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .datasets import *
2
+ from .transforms import get_transform
diffusion/data/builder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ from mmcv import Registry, build_from_cfg
5
+ from torch.utils.data import DataLoader
6
+
7
+ from diffusion.data.transforms import get_transform
8
+ from diffusion.utils.logger import get_root_logger
9
+
10
+ DATASETS = Registry('datasets')
11
+
12
+ DATA_ROOT = '/cache/data'
13
+
14
+
15
+ def set_data_root(data_root):
16
+ global DATA_ROOT
17
+ DATA_ROOT = data_root
18
+
19
+
20
+ def get_data_path(data_dir):
21
+ if os.path.isabs(data_dir):
22
+ return data_dir
23
+ global DATA_ROOT
24
+ return os.path.join(DATA_ROOT, data_dir)
25
+
26
+
27
+ def build_dataset(cfg, resolution=224, **kwargs):
28
+ logger = get_root_logger()
29
+
30
+ dataset_type = cfg.get('type')
31
+ logger.info(f"Constructing dataset {dataset_type}...")
32
+ t = time.time()
33
+ transform = cfg.pop('transform', 'default_train')
34
+ transform = get_transform(transform, resolution)
35
+ dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs))
36
+ logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}")
37
+ return dataset
38
+
39
+
40
+ def build_dataloader(dataset, batch_size=256, num_workers=4, shuffle=True, **kwargs):
41
+ if 'batch_sampler' in kwargs:
42
+ dataloader = DataLoader(dataset, batch_sampler=kwargs['batch_sampler'], num_workers=num_workers, pin_memory=True)
43
+ else:
44
+ dataloader = DataLoader(dataset,
45
+ batch_size=batch_size,
46
+ shuffle=shuffle,
47
+ num_workers=num_workers,
48
+ pin_memory=True,
49
+ **kwargs)
50
+ return dataloader
diffusion/data/datasets/InternalData.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from PIL import Image
4
+ import numpy as np
5
+ import torch
6
+ from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
7
+ from torch.utils.data import Dataset
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ from torchvision import transforms as T
10
+ from diffusion.data.builder import get_data_path, DATASETS
11
+ from diffusion.utils.logger import get_root_logger
12
+
13
+ import json
14
+
15
+ @DATASETS.register_module()
16
+ class InternalData(Dataset):
17
+ def __init__(self,
18
+ root,
19
+ image_list_json='data_info.json',
20
+ transform=None,
21
+ resolution=256,
22
+ sample_subset=None,
23
+ load_vae_feat=False,
24
+ input_size=32,
25
+ patch_size=2,
26
+ mask_ratio=0.0,
27
+ load_mask_index=False,
28
+ max_length=120,
29
+ config=None,
30
+ **kwargs):
31
+ self.root = get_data_path(root)
32
+ self.transform = transform
33
+ self.load_vae_feat = load_vae_feat
34
+ self.ori_imgs_nums = 0
35
+ self.resolution = resolution
36
+ self.N = int(resolution // (input_size // patch_size))
37
+ self.mask_ratio = mask_ratio
38
+ self.load_mask_index = load_mask_index
39
+ self.max_lenth = max_length
40
+ self.meta_data_clean = []
41
+ self.img_samples = []
42
+ self.txt_feat_samples = []
43
+ self.vae_feat_samples = []
44
+ self.mask_index_samples = []
45
+ self.prompt_samples = []
46
+
47
+ image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
48
+ for json_file in image_list_json:
49
+ meta_data = self.load_json(os.path.join(self.root, 'partition', json_file))
50
+ self.ori_imgs_nums += len(meta_data)
51
+ meta_data_clean = [item for item in meta_data if item['ratio'] <= 4]
52
+ self.meta_data_clean.extend(meta_data_clean)
53
+ self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean])
54
+ self.txt_feat_samples.extend([os.path.join(self.root, 'caption_feature_wmask', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean])
55
+ self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_features_{resolution}resolution/noflip', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean])
56
+ self.prompt_samples.extend([item['prompt'] for item in meta_data_clean])
57
+
58
+ # Set loader and extensions
59
+ if load_vae_feat:
60
+ self.transform = None
61
+ self.loader = self.vae_feat_loader
62
+ else:
63
+ self.loader = default_loader
64
+
65
+ if sample_subset is not None:
66
+ self.sample_subset(sample_subset) # sample dataset for local debug
67
+ logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
68
+ logger.info(f"T5 max token length: {self.max_lenth}")
69
+
70
+ def getdata(self, index):
71
+ img_path = self.img_samples[index]
72
+ npz_path = self.txt_feat_samples[index]
73
+ npy_path = self.vae_feat_samples[index]
74
+ prompt = self.prompt_samples[index]
75
+ data_info = {
76
+ 'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32),
77
+ 'aspect_ratio': torch.tensor(1.)
78
+ }
79
+
80
+ img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path)
81
+ txt_info = np.load(npz_path)
82
+ txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096
83
+ attention_mask = torch.ones(1, 1, txt_fea.shape[1]) # 1x1xT
84
+ if 'attention_mask' in txt_info.keys():
85
+ attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
86
+ if txt_fea.shape[1] != self.max_lenth:
87
+ txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1)
88
+ attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1)
89
+
90
+ if self.transform:
91
+ img = self.transform(img)
92
+
93
+ data_info['prompt'] = prompt
94
+ return img, txt_fea, attention_mask, data_info
95
+
96
+ def __getitem__(self, idx):
97
+ for _ in range(20):
98
+ try:
99
+ return self.getdata(idx)
100
+ except Exception as e:
101
+ print(f"Error details: {str(e)}")
102
+ idx = np.random.randint(len(self))
103
+ raise RuntimeError('Too many bad data.')
104
+
105
+ def get_data_info(self, idx):
106
+ data_info = self.meta_data_clean[idx]
107
+ return {'height': data_info['height'], 'width': data_info['width']}
108
+
109
+ @staticmethod
110
+ def vae_feat_loader(path):
111
+ # [mean, std]
112
+ mean, std = torch.from_numpy(np.load(path)).chunk(2)
113
+ sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype)
114
+ return mean + std * sample
115
+
116
+ def load_ori_img(self, img_path):
117
+ # 加载图像并转换为Tensor
118
+ transform = T.Compose([
119
+ T.Resize(256), # Image.BICUBIC
120
+ T.CenterCrop(256),
121
+ T.ToTensor(),
122
+ ])
123
+ return transform(Image.open(img_path))
124
+
125
+ def load_json(self, file_path):
126
+ with open(file_path, 'r') as f:
127
+ meta_data = json.load(f)
128
+
129
+ return meta_data
130
+
131
+ def sample_subset(self, ratio):
132
+ sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio))
133
+ self.img_samples = [self.img_samples[i] for i in sampled_idx]
134
+
135
+ def __len__(self):
136
+ return len(self.img_samples)
137
+
138
+ def __getattr__(self, name):
139
+ if name == "set_epoch":
140
+ return lambda epoch: None
141
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
142
+
143
+ @DATASETS.register_module()
144
+ class InternalDataSigma(Dataset):
145
+ def __init__(self,
146
+ root,
147
+ image_list_json='data_info.json',
148
+ transform=None,
149
+ resolution=256,
150
+ sample_subset=None,
151
+ load_vae_feat=False,
152
+ load_t5_feat=False,
153
+ input_size=32,
154
+ patch_size=2,
155
+ mask_ratio=0.0,
156
+ mask_type='null',
157
+ load_mask_index=False,
158
+ real_prompt_ratio=1.0,
159
+ max_length=300,
160
+ config=None,
161
+ **kwargs):
162
+ self.root = get_data_path(root)
163
+ self.transform = transform
164
+ self.load_vae_feat = load_vae_feat
165
+ self.load_t5_feat = load_t5_feat
166
+ self.ori_imgs_nums = 0
167
+ self.resolution = resolution
168
+ self.N = int(resolution // (input_size // patch_size))
169
+ self.mask_ratio = mask_ratio
170
+ self.load_mask_index = load_mask_index
171
+ self.mask_type = mask_type
172
+ self.real_prompt_ratio = real_prompt_ratio
173
+ self.max_lenth = max_length
174
+ self.meta_data_clean = []
175
+ self.img_samples = []
176
+ self.txt_samples = []
177
+ self.sharegpt4v_txt_samples = []
178
+ self.txt_feat_samples = []
179
+ self.vae_feat_samples = []
180
+ self.mask_index_samples = []
181
+ self.gpt4v_txt_feat_samples = []
182
+ self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32
183
+ logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
184
+ logger.info(f"T5 max token length: {self.max_lenth}")
185
+ logger.info(f"ratio of real user prompt: {self.real_prompt_ratio}")
186
+
187
+ image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
188
+ for json_file in image_list_json:
189
+ meta_data = self.load_json(os.path.join(self.root, json_file))
190
+ logger.info(f"{json_file} data volume: {len(meta_data)}")
191
+ self.ori_imgs_nums += len(meta_data)
192
+ meta_data_clean = [item for item in meta_data if item['ratio'] <= 4.5]
193
+ self.meta_data_clean.extend(meta_data_clean)
194
+ self.img_samples.extend([
195
+ os.path.join(self.root.replace('InternData', 'InternImgs'), item['path']) for item in meta_data_clean
196
+ ])
197
+ self.txt_samples.extend([item['prompt'] for item in meta_data_clean])
198
+ self.sharegpt4v_txt_samples.extend([item['sharegpt4v'] if 'sharegpt4v' in item else '' for item in meta_data_clean])
199
+ self.txt_feat_samples.extend([
200
+ os.path.join(
201
+ self.root,
202
+ 'caption_features_new',
203
+ item['path'].rsplit('/', 1)[-1].replace('.png', '.npz')
204
+ ) for item in meta_data_clean
205
+ ])
206
+ self.gpt4v_txt_feat_samples.extend([
207
+ os.path.join(
208
+ self.root,
209
+ 'sharegpt4v_caption_features_new',
210
+ item['path'].rsplit('/', 1)[-1].replace('.png', '.npz')
211
+ ) for item in meta_data_clean
212
+ ])
213
+ self.vae_feat_samples.extend(
214
+ [
215
+ os.path.join(
216
+ self.root,
217
+ f'img_sdxl_vae_features_{resolution}resolution_new',
218
+ item['path'].rsplit('/', 1)[-1].replace('.png', '.npy')
219
+ ) for item in meta_data_clean
220
+ ])
221
+
222
+ # Set loader and extensions
223
+ if load_vae_feat:
224
+ self.transform = None
225
+ self.loader = self.vae_feat_loader
226
+ else:
227
+ self.loader = default_loader
228
+
229
+ if sample_subset is not None:
230
+ self.sample_subset(sample_subset) # sample dataset for local debug
231
+
232
+ def getdata(self, index):
233
+ img_path = self.img_samples[index]
234
+ real_prompt = random.random() < self.real_prompt_ratio
235
+ npz_path = self.txt_feat_samples[index] if real_prompt else self.gpt4v_txt_feat_samples[index]
236
+ txt = self.txt_samples[index] if real_prompt else self.sharegpt4v_txt_samples[index]
237
+ npy_path = self.vae_feat_samples[index]
238
+ data_info = {'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32),
239
+ 'aspect_ratio': torch.tensor(1.)}
240
+
241
+ if self.load_vae_feat:
242
+ img = self.loader(npy_path)
243
+ else:
244
+ img = self.loader(img_path)
245
+
246
+ attention_mask = torch.ones(1, 1, self.max_lenth) # 1x1xT
247
+ if self.load_t5_feat:
248
+ txt_info = np.load(npz_path)
249
+ txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096
250
+ if 'attention_mask' in txt_info.keys():
251
+ attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
252
+ if txt_fea.shape[1] != self.max_lenth:
253
+ txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1)
254
+ attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1)
255
+ else:
256
+ txt_fea = txt
257
+
258
+ if self.transform:
259
+ img = self.transform(img)
260
+
261
+ data_info["mask_type"] = self.mask_type
262
+ return img, txt_fea, attention_mask.to(torch.int16), data_info
263
+
264
+ def __getitem__(self, idx):
265
+ for _ in range(20):
266
+ try:
267
+ data = self.getdata(idx)
268
+ return data
269
+ except Exception as e:
270
+ print(f"Error details {self.img_samples[idx]}: {str(e)}")
271
+ idx = np.random.randint(len(self))
272
+ raise RuntimeError('Too many bad data.')
273
+
274
+ def get_data_info(self, idx):
275
+ data_info = self.meta_data_clean[idx]
276
+ return {'height': data_info['height'], 'width': data_info['width']}
277
+
278
+ @staticmethod
279
+ def vae_feat_loader(path):
280
+ # [mean, std]
281
+ mean, std = torch.from_numpy(np.load(path)).chunk(2)
282
+ sample = randn_tensor(mean.shape, generator=None, device=mean.device, dtype=mean.dtype)
283
+ return mean + std * sample
284
+
285
+ def load_ori_img(self, img_path):
286
+ # 加载图像并转换为Tensor
287
+ transform = T.Compose([
288
+ T.Resize(256), # Image.BICUBIC
289
+ T.CenterCrop(256),
290
+ T.ToTensor(),
291
+ ])
292
+ img = transform(Image.open(img_path))
293
+ return img
294
+
295
+ def load_json(self, file_path):
296
+ with open(file_path, 'r') as f:
297
+ meta_data = json.load(f)
298
+
299
+ return meta_data
300
+
301
+ def sample_subset(self, ratio):
302
+ sampled_idx = random.sample(list(range(len(self))), int(len(self) * ratio))
303
+ self.img_samples = [self.img_samples[i] for i in sampled_idx]
304
+
305
+ def __len__(self):
306
+ return len(self.img_samples)
307
+
308
+ def __getattr__(self, name):
309
+ if name == "set_epoch":
310
+ return lambda epoch: None
311
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
312
+
diffusion/data/datasets/InternalData_ms.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import random
5
+ from torchvision.datasets.folder import default_loader
6
+ from diffusion.data.datasets.InternalData import InternalData, InternalDataSigma
7
+ from diffusion.data.builder import get_data_path, DATASETS
8
+ from diffusion.utils.logger import get_root_logger
9
+ import torchvision.transforms as T
10
+ from torchvision.transforms.functional import InterpolationMode
11
+ from diffusion.data.datasets.utils import *
12
+
13
+ def get_closest_ratio(height: float, width: float, ratios: dict):
14
+ aspect_ratio = height / width
15
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
16
+ return ratios[closest_ratio], float(closest_ratio)
17
+
18
+
19
+ @DATASETS.register_module()
20
+ class InternalDataMS(InternalData):
21
+ def __init__(self,
22
+ root,
23
+ image_list_json='data_info.json',
24
+ transform=None,
25
+ resolution=256,
26
+ sample_subset=None,
27
+ load_vae_feat=False,
28
+ input_size=32,
29
+ patch_size=2,
30
+ mask_ratio=0.0,
31
+ mask_type='null',
32
+ load_mask_index=False,
33
+ real_prompt_ratio=1.0,
34
+ max_length=120,
35
+ config=None,
36
+ **kwargs):
37
+ self.root = get_data_path(root)
38
+ self.transform = transform
39
+ self.load_vae_feat = load_vae_feat
40
+ self.ori_imgs_nums = 0
41
+ self.resolution = resolution
42
+ self.N = int(resolution // (input_size // patch_size))
43
+ self.mask_ratio = mask_ratio
44
+ self.load_mask_index = load_mask_index
45
+ self.mask_type = mask_type
46
+ self.real_prompt_ratio = real_prompt_ratio
47
+ self.max_lenth = max_length
48
+ self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1])
49
+ self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio
50
+ self.meta_data_clean = []
51
+ self.img_samples = []
52
+ self.txt_feat_samples = []
53
+ self.vae_feat_samples = []
54
+ self.mask_index_samples = []
55
+ self.ratio_index = {}
56
+ self.ratio_nums = {}
57
+ # self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32
58
+ for k, v in self.aspect_ratio.items():
59
+ self.ratio_index[float(k)] = [] # used for self.getitem
60
+ self.ratio_nums[float(k)] = 0 # used for batch-sampler
61
+
62
+ image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
63
+ for json_file in image_list_json:
64
+ meta_data = self.load_json(os.path.join(self.root, json_file))
65
+ self.ori_imgs_nums += len(meta_data)
66
+ meta_data_clean = [item for item in meta_data if item['ratio'] <= 4]
67
+ self.meta_data_clean.extend(meta_data_clean)
68
+ self.img_samples.extend([os.path.join(self.root.replace('InternData', "InternImgs"), item['path']) for item in meta_data_clean])
69
+ self.txt_feat_samples.extend([os.path.join(self.root, 'caption_features', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')) for item in meta_data_clean])
70
+ self.vae_feat_samples.extend([os.path.join(self.root, f'img_vae_fatures_{resolution}_multiscale/ms', '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')) for item in meta_data_clean])
71
+
72
+ # Set loader and extensions
73
+ if load_vae_feat:
74
+ self.transform = None
75
+ self.loader = self.vae_feat_loader
76
+ else:
77
+ self.loader = default_loader
78
+
79
+ if sample_subset is not None:
80
+ self.sample_subset(sample_subset) # sample dataset for local debug
81
+
82
+ # scan the dataset for ratio static
83
+ for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]):
84
+ ori_h, ori_w = info['height'], info['width']
85
+ closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
86
+ self.ratio_nums[closest_ratio] += 1
87
+ if len(self.ratio_index[closest_ratio]) == 0:
88
+ self.ratio_index[closest_ratio].append(i)
89
+ # print(self.ratio_nums)
90
+ logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
91
+ logger.info(f"T5 max token length: {self.max_lenth}")
92
+
93
+ def getdata(self, index):
94
+ img_path = self.img_samples[index]
95
+ npz_path = self.txt_feat_samples[index]
96
+ npy_path = self.vae_feat_samples[index]
97
+ ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width']
98
+
99
+ # Calculate the closest aspect ratio and resize & crop image[w, h]
100
+ closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
101
+ closest_size = list(map(lambda x: int(x), closest_size))
102
+ self.closest_ratio = closest_ratio
103
+
104
+ if self.load_vae_feat:
105
+ try:
106
+ img = self.loader(npy_path)
107
+ if index not in self.ratio_index[closest_ratio]:
108
+ self.ratio_index[closest_ratio].append(index)
109
+ except Exception:
110
+ index = random.choice(self.ratio_index[closest_ratio])
111
+ return self.getdata(index)
112
+ h, w = (img.shape[1], img.shape[2])
113
+ assert h, w == (ori_h//8, ori_w//8)
114
+ else:
115
+ img = self.loader(img_path)
116
+ h, w = (img.size[1], img.size[0])
117
+ assert h, w == (ori_h, ori_w)
118
+
119
+ data_info = {'img_hw': torch.tensor([ori_h, ori_w], dtype=torch.float32)}
120
+ data_info['aspect_ratio'] = closest_ratio
121
+ data_info["mask_type"] = self.mask_type
122
+
123
+ txt_info = np.load(npz_path)
124
+ txt_fea = torch.from_numpy(txt_info['caption_feature'])
125
+ attention_mask = torch.ones(1, 1, txt_fea.shape[1])
126
+ if 'attention_mask' in txt_info.keys():
127
+ attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
128
+
129
+ if not self.load_vae_feat:
130
+ if closest_size[0] / ori_h > closest_size[1] / ori_w:
131
+ resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h)
132
+ else:
133
+ resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1]
134
+ self.transform = T.Compose([
135
+ T.Lambda(lambda img: img.convert('RGB')),
136
+ T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC
137
+ T.CenterCrop(closest_size),
138
+ T.ToTensor(),
139
+ T.Normalize([.5], [.5]),
140
+ ])
141
+
142
+ if self.transform:
143
+ img = self.transform(img)
144
+
145
+ return img, txt_fea, attention_mask, data_info
146
+
147
+ def __getitem__(self, idx):
148
+ for _ in range(20):
149
+ try:
150
+ return self.getdata(idx)
151
+ except Exception as e:
152
+ print(f"Error details: {str(e)}")
153
+ idx = random.choice(self.ratio_index[self.closest_ratio])
154
+ raise RuntimeError('Too many bad data.')
155
+
156
+
157
+ @DATASETS.register_module()
158
+ class InternalDataMSSigma(InternalDataSigma):
159
+ def __init__(self,
160
+ root,
161
+ image_list_json='data_info.json',
162
+ transform=None,
163
+ resolution=256,
164
+ sample_subset=None,
165
+ load_vae_feat=False,
166
+ load_t5_feat=False,
167
+ input_size=32,
168
+ patch_size=2,
169
+ mask_ratio=0.0,
170
+ mask_type='null',
171
+ load_mask_index=False,
172
+ real_prompt_ratio=1.0,
173
+ max_length=300,
174
+ config=None,
175
+ **kwargs):
176
+ self.root = get_data_path(root)
177
+ self.transform = transform
178
+ self.load_vae_feat = load_vae_feat
179
+ self.load_t5_feat = load_t5_feat
180
+ self.ori_imgs_nums = 0
181
+ self.resolution = resolution
182
+ self.N = int(resolution // (input_size // patch_size))
183
+ self.mask_ratio = mask_ratio
184
+ self.load_mask_index = load_mask_index
185
+ self.mask_type = mask_type
186
+ self.real_prompt_ratio = real_prompt_ratio
187
+ self.max_lenth = max_length
188
+ self.base_size = int(kwargs['aspect_ratio_type'].split('_')[-1])
189
+ self.aspect_ratio = eval(kwargs.pop('aspect_ratio_type')) # base aspect ratio
190
+ self.meta_data_clean = []
191
+ self.img_samples = []
192
+ self.txt_samples = []
193
+ self.sharegpt4v_txt_samples = []
194
+ self.txt_feat_samples = []
195
+ self.vae_feat_samples = []
196
+ self.mask_index_samples = []
197
+ self.ratio_index = {}
198
+ self.ratio_nums = {}
199
+ self.gpt4v_txt_feat_samples = []
200
+ self.weight_dtype = torch.float16 if self.real_prompt_ratio > 0 else torch.float32
201
+ self.interpolate_model = InterpolationMode.BICUBIC
202
+ if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]:
203
+ self.interpolate_model = InterpolationMode.LANCZOS
204
+ suffix = ''
205
+ for k, v in self.aspect_ratio.items():
206
+ self.ratio_index[float(k)] = [] # used for self.getitem
207
+ self.ratio_nums[float(k)] = 0 # used for batch-sampler
208
+ logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
209
+ logger.info(f"T5 max token length: {self.max_lenth}")
210
+ logger.info(f"ratio of real user prompt: {self.real_prompt_ratio}")
211
+
212
+ image_list_json = image_list_json if isinstance(image_list_json, list) else [image_list_json]
213
+ for json_file in image_list_json:
214
+ meta_data = self.load_json(os.path.join(self.root, json_file))
215
+ logger.info(f"{json_file} data volume: {len(meta_data)}")
216
+ self.ori_imgs_nums += len(meta_data)
217
+ meta_data_clean = [item for item in meta_data if item['ratio'] <= 4.5]
218
+ self.meta_data_clean.extend(meta_data_clean)
219
+ self.img_samples.extend([
220
+ os.path.join(self.root.replace('InternData'+suffix, 'InternImgs'), item['path']) for item in meta_data_clean
221
+ ])
222
+ self.txt_samples.extend([item['prompt'] for item in meta_data_clean])
223
+ self.sharegpt4v_txt_samples.extend([item['sharegpt4v'] if 'sharegpt4v' in item else '' for item in meta_data_clean])
224
+ self.txt_feat_samples.extend([
225
+ os.path.join(
226
+ self.root,
227
+ 'caption_features_new',
228
+ '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')
229
+ ) for item in meta_data_clean
230
+ ])
231
+ self.gpt4v_txt_feat_samples.extend([
232
+ os.path.join(
233
+ self.root,
234
+ 'sharegpt4v_caption_features_new',
235
+ '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npz')
236
+ ) for item in meta_data_clean
237
+ ])
238
+ self.vae_feat_samples.extend(
239
+ [
240
+ os.path.join(
241
+ self.root + suffix,
242
+ f'img_sdxl_vae_features_{resolution}resolution_ms_new',
243
+ '_'.join(item['path'].rsplit('/', 1)).replace('.png', '.npy')
244
+ ) for item in meta_data_clean
245
+ ])
246
+
247
+ if self.real_prompt_ratio < 1:
248
+ assert len(self.sharegpt4v_txt_samples[0]) != 0
249
+
250
+ # Set loader and extensions
251
+ if load_vae_feat:
252
+ self.transform = None
253
+ self.loader = self.vae_feat_loader
254
+ else:
255
+ self.loader = default_loader
256
+
257
+ if sample_subset is not None:
258
+ self.sample_subset(sample_subset) # sample dataset for local debug
259
+
260
+ # scan the dataset for ratio static
261
+ for i, info in enumerate(self.meta_data_clean[:len(self.meta_data_clean)//3]):
262
+ ori_h, ori_w = info['height'], info['width']
263
+ closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
264
+ self.ratio_nums[closest_ratio] += 1
265
+ if len(self.ratio_index[closest_ratio]) == 0:
266
+ self.ratio_index[closest_ratio].append(i)
267
+
268
+ def getdata(self, index):
269
+ img_path = self.img_samples[index]
270
+ real_prompt = random.random() < self.real_prompt_ratio
271
+ npz_path = self.txt_feat_samples[index] if real_prompt else self.gpt4v_txt_feat_samples[index]
272
+ txt = self.txt_samples[index] if real_prompt else self.sharegpt4v_txt_samples[index]
273
+ npy_path = self.vae_feat_samples[index]
274
+ data_info = {}
275
+ ori_h, ori_w = self.meta_data_clean[index]['height'], self.meta_data_clean[index]['width']
276
+
277
+ # Calculate the closest aspect ratio and resize & crop image[w, h]
278
+ closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
279
+ closest_size = list(map(lambda x: int(x), closest_size))
280
+ self.closest_ratio = closest_ratio
281
+
282
+ if self.load_vae_feat:
283
+ img = self.loader(npy_path)
284
+ if index not in self.ratio_index[closest_ratio]:
285
+ self.ratio_index[closest_ratio].append(index)
286
+ h, w = (img.shape[1], img.shape[2])
287
+ assert h, w == (ori_h//8, ori_w//8)
288
+ else:
289
+ img = self.loader(img_path)
290
+ h, w = (img.size[1], img.size[0])
291
+ assert h, w == (ori_h, ori_w)
292
+
293
+ data_info['img_hw'] = torch.tensor([ori_h, ori_w], dtype=torch.float32)
294
+ data_info['aspect_ratio'] = closest_ratio
295
+ data_info["mask_type"] = self.mask_type
296
+
297
+ attention_mask = torch.ones(1, 1, self.max_lenth)
298
+ if self.load_t5_feat:
299
+ txt_info = np.load(npz_path)
300
+ txt_fea = torch.from_numpy(txt_info['caption_feature'])
301
+ if 'attention_mask' in txt_info.keys():
302
+ attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
303
+ if txt_fea.shape[1] != self.max_lenth:
304
+ txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1).to(self.weight_dtype)
305
+ attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1)
306
+ else:
307
+ txt_fea = txt
308
+
309
+ if not self.load_vae_feat:
310
+ if closest_size[0] / ori_h > closest_size[1] / ori_w:
311
+ resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h)
312
+ else:
313
+ resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1]
314
+ self.transform = T.Compose([
315
+ T.Lambda(lambda img: img.convert('RGB')),
316
+ T.Resize(resize_size, interpolation=self.interpolate_model), # Image.BICUBIC
317
+ T.CenterCrop(closest_size),
318
+ T.ToTensor(),
319
+ T.Normalize([.5], [.5]),
320
+ ])
321
+
322
+ if self.transform:
323
+ img = self.transform(img)
324
+
325
+ return img, txt_fea, attention_mask.to(torch.int16), data_info
326
+
327
+ def __getitem__(self, idx):
328
+ for _ in range(20):
329
+ try:
330
+ data = self.getdata(idx)
331
+ return data
332
+ except Exception as e:
333
+ print(f"Error details: {str(e)}")
334
+ idx = random.choice(self.ratio_index[self.closest_ratio])
335
+ raise RuntimeError('Too many bad data.')
336
+
diffusion/data/datasets/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .InternalData import InternalData, InternalDataSigma
2
+ from .InternalData_ms import InternalDataMS, InternalDataSigma
3
+ from .utils import *
diffusion/data/datasets/utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ASPECT_RATIO_2880 = {
3
+ '0.25': [1408.0, 5760.0], '0.26': [1408.0, 5568.0], '0.27': [1408.0, 5376.0], '0.28': [1408.0, 5184.0],
4
+ '0.32': [1600.0, 4992.0], '0.33': [1600.0, 4800.0], '0.34': [1600.0, 4672.0], '0.4': [1792.0, 4480.0],
5
+ '0.42': [1792.0, 4288.0], '0.47': [1920.0, 4096.0], '0.49': [1920.0, 3904.0], '0.51': [1920.0, 3776.0],
6
+ '0.55': [2112.0, 3840.0], '0.59': [2112.0, 3584.0], '0.68': [2304.0, 3392.0], '0.72': [2304.0, 3200.0],
7
+ '0.78': [2496.0, 3200.0], '0.83': [2496.0, 3008.0], '0.89': [2688.0, 3008.0], '0.93': [2688.0, 2880.0],
8
+ '1.0': [2880.0, 2880.0], '1.07': [2880.0, 2688.0], '1.12': [3008.0, 2688.0], '1.21': [3008.0, 2496.0],
9
+ '1.28': [3200.0, 2496.0], '1.39': [3200.0, 2304.0], '1.47': [3392.0, 2304.0], '1.7': [3584.0, 2112.0],
10
+ '1.82': [3840.0, 2112.0], '2.03': [3904.0, 1920.0], '2.13': [4096.0, 1920.0], '2.39': [4288.0, 1792.0],
11
+ '2.5': [4480.0, 1792.0], '2.92': [4672.0, 1600.0], '3.0': [4800.0, 1600.0], '3.12': [4992.0, 1600.0],
12
+ '3.68': [5184.0, 1408.0], '3.82': [5376.0, 1408.0], '3.95': [5568.0, 1408.0], '4.0': [5760.0, 1408.0]
13
+ }
14
+
15
+ ASPECT_RATIO_2048 = {
16
+ '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], '0.27': [1024.0, 3840.0], '0.28': [1024.0, 3712.0],
17
+ '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
18
+ '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
19
+ '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
20
+ '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
21
+ '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
22
+ '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
23
+ '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
24
+ '2.5': [3200.0, 1280.0], '2.89': [3328.0, 1152.0], '3.0': [3456.0, 1152.0], '3.11': [3584.0, 1152.0],
25
+ '3.62': [3712.0, 1024.0], '3.75': [3840.0, 1024.0], '3.88': [3968.0, 1024.0], '4.0': [4096.0, 1024.0]
26
+ }
27
+
28
+ ASPECT_RATIO_1024 = {
29
+ '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.],
30
+ '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
31
+ '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
32
+ '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
33
+ '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
34
+ '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
35
+ '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
36
+ '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
37
+ '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.],
38
+ '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.],
39
+ }
40
+
41
+ ASPECT_RATIO_512 = {
42
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
43
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
44
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
45
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
46
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
47
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
48
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
49
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
50
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
51
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
52
+ }
53
+
54
+ ASPECT_RATIO_256 = {
55
+ '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0],
56
+ '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
57
+ '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
58
+ '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
59
+ '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
60
+ '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
61
+ '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
62
+ '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
63
+ '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0],
64
+ '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0]
65
+ }
66
+
67
+ ASPECT_RATIO_256_TEST = {
68
+ '0.25': [128.0, 512.0], '0.28': [128.0, 464.0],
69
+ '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
70
+ '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
71
+ '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
72
+ '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
73
+ '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
74
+ '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
75
+ '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
76
+ '2.5': [400.0, 160.0], '3.0': [432.0, 144.0],
77
+ '4.0': [512.0, 128.0]
78
+ }
79
+
80
+ ASPECT_RATIO_512_TEST = {
81
+ '0.25': [256.0, 1024.0], '0.28': [256.0, 928.0],
82
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
83
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
84
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
85
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
86
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
87
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
88
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
89
+ '2.5': [800.0, 320.0], '3.0': [864.0, 288.0],
90
+ '4.0': [1024.0, 256.0]
91
+ }
92
+
93
+ ASPECT_RATIO_1024_TEST = {
94
+ '0.25': [512., 2048.], '0.28': [512., 1856.],
95
+ '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
96
+ '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
97
+ '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
98
+ '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
99
+ '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
100
+ '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
101
+ '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
102
+ '2.5': [1600., 640.], '3.0': [1728., 576.],
103
+ '4.0': [2048., 512.],
104
+ }
105
+
106
+ ASPECT_RATIO_2048_TEST = {
107
+ '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0],
108
+ '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
109
+ '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
110
+ '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
111
+ '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
112
+ '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
113
+ '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
114
+ '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
115
+ '2.5': [3200.0, 1280.0], '3.0': [3456.0, 1152.0],
116
+ '4.0': [4096.0, 1024.0]
117
+ }
118
+
119
+ ASPECT_RATIO_2880_TEST = {
120
+ '0.25': [2048.0, 8192.0], '0.26': [2048.0, 7936.0],
121
+ '0.32': [2304.0, 7168.0], '0.33': [2304.0, 6912.0], '0.35': [2304.0, 6656.0], '0.4': [2560.0, 6400.0],
122
+ '0.42': [2560.0, 6144.0], '0.48': [2816.0, 5888.0], '0.5': [2816.0, 5632.0], '0.52': [2816.0, 5376.0],
123
+ '0.57': [3072.0, 5376.0], '0.6': [3072.0, 5120.0], '0.68': [3328.0, 4864.0], '0.72': [3328.0, 4608.0],
124
+ '0.78': [3584.0, 4608.0], '0.82': [3584.0, 4352.0], '0.88': [3840.0, 4352.0], '0.94': [3840.0, 4096.0],
125
+ '1.0': [4096.0, 4096.0], '1.07': [4096.0, 3840.0], '1.13': [4352.0, 3840.0], '1.21': [4352.0, 3584.0],
126
+ '1.29': [4608.0, 3584.0], '1.38': [4608.0, 3328.0], '1.46': [4864.0, 3328.0], '1.67': [5120.0, 3072.0],
127
+ '1.75': [5376.0, 3072.0], '2.0': [5632.0, 2816.0], '2.09': [5888.0, 2816.0], '2.4': [6144.0, 2560.0],
128
+ '2.5': [6400.0, 2560.0], '3.0': [6912.0, 2304.0],
129
+ '4.0': [8192.0, 2048.0],
130
+ }
131
+
132
+ def get_chunks(lst, n):
133
+ for i in range(0, len(lst), n):
134
+ yield lst[i:i + n]
diffusion/data/transforms.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as T
2
+
3
+ TRANSFORMS = dict()
4
+
5
+
6
+ def register_transform(transform):
7
+ name = transform.__name__
8
+ if name in TRANSFORMS:
9
+ raise RuntimeError(f'Transform {name} has already registered.')
10
+ TRANSFORMS.update({name: transform})
11
+
12
+
13
+ def get_transform(type, resolution):
14
+ transform = TRANSFORMS[type](resolution)
15
+ transform = T.Compose(transform)
16
+ transform.image_size = resolution
17
+ return transform
18
+
19
+
20
+ @register_transform
21
+ def default_train(n_px):
22
+ transform = [
23
+ T.Lambda(lambda img: img.convert('RGB')),
24
+ T.Resize(n_px), # Image.BICUBIC
25
+ T.CenterCrop(n_px),
26
+ # T.RandomHorizontalFlip(),
27
+ T.ToTensor(),
28
+ T.Normalize([.5], [.5]),
29
+ ]
30
+ return transform
diffusion/dpm_solver.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .model import gaussian_diffusion as gd
3
+ from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP
4
+
5
+
6
+ def DPMS(
7
+ model,
8
+ condition,
9
+ uncondition,
10
+ cfg_scale,
11
+ model_type='noise', # or "x_start" or "v" or "score"
12
+ noise_schedule="linear",
13
+ guidance_type='classifier-free',
14
+ model_kwargs={},
15
+ diffusion_steps=1000
16
+ ):
17
+ betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
18
+
19
+ ## 1. Define the noise schedule.
20
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
21
+
22
+ ## 2. Convert your discrete-time `model` to the continuous-time
23
+ ## noise prediction model. Here is an example for a diffusion model
24
+ ## `model` with the noise prediction type ("noise") .
25
+ model_fn = model_wrapper(
26
+ model,
27
+ noise_schedule,
28
+ model_type=model_type,
29
+ model_kwargs=model_kwargs,
30
+ guidance_type=guidance_type,
31
+ condition=condition,
32
+ unconditional_condition=uncondition,
33
+ guidance_scale=cfg_scale,
34
+ )
35
+ ## 3. Define dpm-solver and sample by multistep DPM-Solver.
36
+ return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
diffusion/iddpm.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+ from diffusion.model.respace import SpacedDiffusion, space_timesteps
6
+ from .model import gaussian_diffusion as gd
7
+
8
+
9
+ def IDDPM(
10
+ timestep_respacing,
11
+ noise_schedule="linear",
12
+ use_kl=False,
13
+ sigma_small=False,
14
+ predict_xstart=False,
15
+ learn_sigma=True,
16
+ pred_sigma=True,
17
+ rescale_learned_sigmas=False,
18
+ diffusion_steps=1000,
19
+ snr=False,
20
+ return_startx=False,
21
+ ):
22
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
23
+ if use_kl:
24
+ loss_type = gd.LossType.RESCALED_KL
25
+ elif rescale_learned_sigmas:
26
+ loss_type = gd.LossType.RESCALED_MSE
27
+ else:
28
+ loss_type = gd.LossType.MSE
29
+ if timestep_respacing is None or timestep_respacing == "":
30
+ timestep_respacing = [diffusion_steps]
31
+ return SpacedDiffusion(
32
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
33
+ betas=betas,
34
+ model_mean_type=(
35
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
36
+ ),
37
+ model_var_type=(
38
+ ((
39
+ gd.ModelVarType.FIXED_LARGE
40
+ if not sigma_small
41
+ else gd.ModelVarType.FIXED_SMALL
42
+ )
43
+ if not learn_sigma
44
+ else gd.ModelVarType.LEARNED_RANGE
45
+ )
46
+ if pred_sigma
47
+ else None
48
+ ),
49
+ loss_type=loss_type,
50
+ snr=snr,
51
+ return_startx=return_startx,
52
+ # rescale_timesteps=rescale_timesteps,
53
+ )
diffusion/lcm_scheduler.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from diffusers import ConfigMixin, SchedulerMixin
26
+ from diffusers.configuration_utils import register_to_config
27
+ from diffusers.utils import BaseOutput
28
+
29
+
30
+ @dataclass
31
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
32
+ class LCMSchedulerOutput(BaseOutput):
33
+ """
34
+ Output class for the scheduler's `step` function output.
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
41
+ `pred_original_sample` can be used to preview progress or for guidance.
42
+ """
43
+
44
+ prev_sample: torch.FloatTensor
45
+ denoised: Optional[torch.FloatTensor] = None
46
+
47
+
48
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
49
+ def betas_for_alpha_bar(
50
+ num_diffusion_timesteps,
51
+ max_beta=0.999,
52
+ alpha_transform_type="cosine",
53
+ ):
54
+ """
55
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
56
+ (1-beta) over time from t = [0,1].
57
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
58
+ to that part of the diffusion process.
59
+ Args:
60
+ num_diffusion_timesteps (`int`): the number of betas to produce.
61
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
62
+ prevent singularities.
63
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
64
+ Choose from `cosine` or `exp`
65
+ Returns:
66
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
67
+ """
68
+ if alpha_transform_type == "cosine":
69
+
70
+ def alpha_bar_fn(t):
71
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
72
+
73
+ elif alpha_transform_type == "exp":
74
+
75
+ def alpha_bar_fn(t):
76
+ return math.exp(t * -12.0)
77
+
78
+ else:
79
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
80
+
81
+ betas = []
82
+ for i in range(num_diffusion_timesteps):
83
+ t1 = i / num_diffusion_timesteps
84
+ t2 = (i + 1) / num_diffusion_timesteps
85
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
86
+ return torch.tensor(betas, dtype=torch.float32)
87
+
88
+
89
+ def rescale_zero_terminal_snr(betas):
90
+ """
91
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
92
+ Args:
93
+ betas (`torch.FloatTensor`):
94
+ the betas that the scheduler is being initialized with.
95
+ Returns:
96
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
97
+ """
98
+ # Convert betas to alphas_bar_sqrt
99
+ alphas = 1.0 - betas
100
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
101
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
102
+
103
+ # Store old values.
104
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
105
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
106
+
107
+ # Shift so the last timestep is zero.
108
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
109
+
110
+ # Scale so the first timestep is back to the old value.
111
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
112
+
113
+ # Convert alphas_bar_sqrt to betas
114
+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
115
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
116
+ alphas = torch.cat([alphas_bar[0:1], alphas])
117
+ betas = 1 - alphas
118
+
119
+ return betas
120
+
121
+
122
+ class LCMScheduler(SchedulerMixin, ConfigMixin):
123
+ """
124
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
125
+ non-Markovian guidance.
126
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
127
+ methods the library implements for all schedulers such as loading and saving.
128
+ Args:
129
+ num_train_timesteps (`int`, defaults to 1000):
130
+ The number of diffusion steps to train the model.
131
+ beta_start (`float`, defaults to 0.0001):
132
+ The starting `beta` value of inference.
133
+ beta_end (`float`, defaults to 0.02):
134
+ The final `beta` value.
135
+ beta_schedule (`str`, defaults to `"linear"`):
136
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
137
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
138
+ trained_betas (`np.ndarray`, *optional*):
139
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
140
+ clip_sample (`bool`, defaults to `True`):
141
+ Clip the predicted sample for numerical stability.
142
+ clip_sample_range (`float`, defaults to 1.0):
143
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
144
+ set_alpha_to_one (`bool`, defaults to `True`):
145
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
146
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
147
+ otherwise it uses the alpha value at step 0.
148
+ steps_offset (`int`, defaults to 0):
149
+ An offset added to the inference steps. You can use a combination of `offset=1` and
150
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
151
+ Diffusion.
152
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
153
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
154
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
155
+ Video](https://imagen.research.google/video/paper.pdf) paper).
156
+ thresholding (`bool`, defaults to `False`):
157
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
158
+ as Stable Diffusion.
159
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
160
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
161
+ sample_max_value (`float`, defaults to 1.0):
162
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
163
+ timestep_spacing (`str`, defaults to `"leading"`):
164
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
165
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
166
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
167
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
168
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
169
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
170
+ """
171
+
172
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
173
+ order = 1
174
+
175
+ @register_to_config
176
+ def __init__(
177
+ self,
178
+ num_train_timesteps: int = 1000,
179
+ beta_start: float = 0.0001,
180
+ beta_end: float = 0.02,
181
+ beta_schedule: str = "linear",
182
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
183
+ clip_sample: bool = True,
184
+ set_alpha_to_one: bool = True,
185
+ steps_offset: int = 0,
186
+ prediction_type: str = "epsilon",
187
+ thresholding: bool = False,
188
+ dynamic_thresholding_ratio: float = 0.995,
189
+ clip_sample_range: float = 1.0,
190
+ sample_max_value: float = 1.0,
191
+ timestep_spacing: str = "leading",
192
+ rescale_betas_zero_snr: bool = False,
193
+ ):
194
+ if trained_betas is not None:
195
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
196
+ elif beta_schedule == "linear":
197
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
198
+ elif beta_schedule == "scaled_linear":
199
+ # this schedule is very specific to the latent diffusion model.
200
+ self.betas = (
201
+ torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
202
+ )
203
+ elif beta_schedule == "squaredcos_cap_v2":
204
+ # Glide cosine schedule
205
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
206
+ else:
207
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
208
+
209
+ # Rescale for zero SNR
210
+ if rescale_betas_zero_snr:
211
+ self.betas = rescale_zero_terminal_snr(self.betas)
212
+
213
+ self.alphas = 1.0 - self.betas
214
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
215
+
216
+ # At every step in ddim, we are looking into the previous alphas_cumprod
217
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
218
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
219
+ # whether we use the final alpha of the "non-previous" one.
220
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
221
+
222
+ # standard deviation of the initial noise distribution
223
+ self.init_noise_sigma = 1.0
224
+
225
+ # setable values
226
+ self.num_inference_steps = None
227
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
228
+
229
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
230
+ """
231
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
232
+ current timestep.
233
+ Args:
234
+ sample (`torch.FloatTensor`):
235
+ The input sample.
236
+ timestep (`int`, *optional*):
237
+ The current timestep in the diffusion chain.
238
+ Returns:
239
+ `torch.FloatTensor`:
240
+ A scaled input sample.
241
+ """
242
+ return sample
243
+
244
+ def _get_variance(self, timestep, prev_timestep):
245
+ alpha_prod_t = self.alphas_cumprod[timestep]
246
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
247
+ beta_prod_t = 1 - alpha_prod_t
248
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
249
+
250
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
251
+
252
+ return variance
253
+
254
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
255
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
256
+ """
257
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
258
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
259
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
260
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
261
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
262
+ https://arxiv.org/abs/2205.11487
263
+ """
264
+ dtype = sample.dtype
265
+ batch_size, channels, height, width = sample.shape
266
+
267
+ if dtype not in (torch.float32, torch.float64):
268
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
269
+
270
+ # Flatten sample for doing quantile calculation along each image
271
+ sample = sample.reshape(batch_size, channels * height * width)
272
+
273
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
274
+
275
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
276
+ s = torch.clamp(
277
+ s, min=1, max=self.config.sample_max_value
278
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
279
+
280
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
281
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
282
+
283
+ sample = sample.reshape(batch_size, channels, height, width)
284
+ sample = sample.to(dtype)
285
+
286
+ return sample
287
+
288
+ def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None):
289
+ """
290
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
291
+ Args:
292
+ num_inference_steps (`int`):
293
+ The number of diffusion steps used when generating samples with a pre-trained model.
294
+ """
295
+
296
+ if num_inference_steps > self.config.num_train_timesteps:
297
+ raise ValueError(
298
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
299
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
300
+ f" maximal {self.config.num_train_timesteps} timesteps."
301
+ )
302
+
303
+ self.num_inference_steps = num_inference_steps
304
+
305
+ # LCM Timesteps Setting: # Linear Spacing
306
+ c = self.config.num_train_timesteps // lcm_origin_steps
307
+ lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 # LCM Training Steps Schedule
308
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
309
+ timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
310
+
311
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
312
+
313
+ def get_scalings_for_boundary_condition_discrete(self, t):
314
+ self.sigma_data = 0.5 # Default: 0.5
315
+
316
+ # By dividing 0.1: This is almost a delta function at t=0.
317
+ c_skip = self.sigma_data ** 2 / ((t / 0.1) ** 2 + self.sigma_data ** 2)
318
+ c_out = ((t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data ** 2) ** 0.5)
319
+ return c_skip, c_out
320
+
321
+ def step(
322
+ self,
323
+ model_output: torch.FloatTensor,
324
+ timeindex: int,
325
+ timestep: int,
326
+ sample: torch.FloatTensor,
327
+ eta: float = 0.0,
328
+ use_clipped_model_output: bool = False,
329
+ generator=None,
330
+ variance_noise: Optional[torch.FloatTensor] = None,
331
+ return_dict: bool = True,
332
+ ) -> Union[LCMSchedulerOutput, Tuple]:
333
+ """
334
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
335
+ process from the learned model outputs (most often the predicted noise).
336
+ Args:
337
+ model_output (`torch.FloatTensor`):
338
+ The direct output from learned diffusion model.
339
+ timestep (`float`):
340
+ The current discrete timestep in the diffusion chain.
341
+ sample (`torch.FloatTensor`):
342
+ A current instance of a sample created by the diffusion process.
343
+ eta (`float`):
344
+ The weight of noise for added noise in diffusion step.
345
+ use_clipped_model_output (`bool`, defaults to `False`):
346
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
347
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
348
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
349
+ `use_clipped_model_output` has no effect.
350
+ generator (`torch.Generator`, *optional*):
351
+ A random number generator.
352
+ variance_noise (`torch.FloatTensor`):
353
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
354
+ itself. Useful for methods such as [`CycleDiffusion`].
355
+ return_dict (`bool`, *optional*, defaults to `True`):
356
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
357
+ Returns:
358
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
359
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
360
+ tuple is returned where the first element is the sample tensor.
361
+ """
362
+ if self.num_inference_steps is None:
363
+ raise ValueError(
364
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
365
+ )
366
+
367
+ # 1. get previous step value
368
+ prev_timeindex = timeindex + 1
369
+ if prev_timeindex < len(self.timesteps):
370
+ prev_timestep = self.timesteps[prev_timeindex]
371
+ else:
372
+ prev_timestep = timestep
373
+
374
+ # 2. compute alphas, betas
375
+ alpha_prod_t = self.alphas_cumprod[timestep]
376
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
377
+
378
+ beta_prod_t = 1 - alpha_prod_t
379
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
380
+
381
+ # 3. Get scalings for boundary conditions
382
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
383
+
384
+ # 4. Different Parameterization:
385
+ parameterization = self.config.prediction_type
386
+
387
+ if parameterization == "epsilon": # noise-prediction
388
+ pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
389
+
390
+ elif parameterization == "sample": # x-prediction
391
+ pred_x0 = model_output
392
+
393
+ elif parameterization == "v_prediction": # v-prediction
394
+ pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
395
+
396
+ # 4. Denoise model output using boundary conditions
397
+ denoised = c_out * pred_x0 + c_skip * sample
398
+
399
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
400
+ # Noise is not used for one-step sampling.
401
+ if len(self.timesteps) > 1:
402
+ noise = torch.randn(model_output.shape).to(model_output.device)
403
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
404
+ else:
405
+ prev_sample = denoised
406
+
407
+ if not return_dict:
408
+ return (prev_sample, denoised)
409
+
410
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
411
+
412
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
413
+ def add_noise(
414
+ self,
415
+ original_samples: torch.FloatTensor,
416
+ noise: torch.FloatTensor,
417
+ timesteps: torch.IntTensor,
418
+ ) -> torch.FloatTensor:
419
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
420
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
421
+ timesteps = timesteps.to(original_samples.device)
422
+
423
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
424
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
425
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
426
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
427
+
428
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
429
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
430
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
431
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
432
+
433
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
434
+ return noisy_samples
435
+
436
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
437
+ def get_velocity(
438
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
439
+ ) -> torch.FloatTensor:
440
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
441
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
442
+ timesteps = timesteps.to(sample.device)
443
+
444
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
445
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
446
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
447
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
448
+
449
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
450
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
451
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
452
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
453
+
454
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
455
+ return velocity
456
+
457
+ def __len__(self):
458
+ return self.config.num_train_timesteps
459
+
diffusion/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .nets import *
diffusion/model/builder.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmcv import Registry
2
+
3
+ from diffusion.model.utils import set_grad_checkpoint
4
+
5
+ MODELS = Registry('models')
6
+
7
+
8
+ def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs):
9
+ if isinstance(cfg, str):
10
+ cfg = dict(type=cfg)
11
+ model = MODELS.build(cfg, default_args=kwargs)
12
+ if use_grad_checkpoint:
13
+ set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step)
14
+ return model
diffusion/model/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
diffusion/model/dpm_solver.py ADDED
@@ -0,0 +1,1337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+
4
+
5
+ class NoiseScheduleVP:
6
+ def __init__(
7
+ self,
8
+ schedule='discrete',
9
+ betas=None,
10
+ alphas_cumprod=None,
11
+ continuous_beta_0=0.1,
12
+ continuous_beta_1=20.,
13
+ dtype=torch.float32,
14
+ ):
15
+ """Create a wrapper class for the forward SDE (VP type).
16
+
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+
22
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
23
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
24
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
25
+
26
+ log_alpha_t = self.marginal_log_mean_coeff(t)
27
+ sigma_t = self.marginal_std(t)
28
+ lambda_t = self.marginal_lambda(t)
29
+
30
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
31
+
32
+ t = self.inverse_lambda(lambda_t)
33
+
34
+ ===============================================================
35
+
36
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
37
+
38
+ 1. For discrete-time DPMs:
39
+
40
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
41
+ t_i = (i + 1) / N
42
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
43
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
44
+
45
+ Args:
46
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
47
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
48
+
49
+ Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
50
+
51
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
52
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
53
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
54
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
55
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
56
+ and
57
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
58
+
59
+
60
+ 2. For continuous-time DPMs:
61
+
62
+ We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise
63
+ schedule are the default settings in Yang Song's ScoreSDE:
64
+
65
+ Args:
66
+ beta_min: A `float` number. The smallest beta for the linear schedule.
67
+ beta_max: A `float` number. The largest beta for the linear schedule.
68
+ T: A `float` number. The ending time of the forward process.
69
+
70
+ ===============================================================
71
+
72
+ Args:
73
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
74
+ 'linear' for continuous-time DPMs.
75
+ Returns:
76
+ A wrapper object of the forward SDE (VP type).
77
+
78
+ ===============================================================
79
+
80
+ Example:
81
+
82
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
83
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
84
+
85
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
86
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
87
+
88
+ # For continuous-time DPMs (VPSDE), linear schedule:
89
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
90
+
91
+ """
92
+
93
+ if schedule not in ['discrete', 'linear']:
94
+ raise ValueError(
95
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule))
96
+
97
+ self.schedule = schedule
98
+ if schedule == 'discrete':
99
+ if betas is not None:
100
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
101
+ else:
102
+ assert alphas_cumprod is not None
103
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
104
+ self.T = 1.
105
+ self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype)
106
+ self.total_N = self.log_alpha_array.shape[1]
107
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
108
+ else:
109
+ self.T = 1.
110
+ self.total_N = 1000
111
+ self.beta_0 = continuous_beta_0
112
+ self.beta_1 = continuous_beta_1
113
+
114
+ def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1):
115
+ """
116
+ For some beta schedules such as cosine schedule, the log-SNR has numerical isssues.
117
+ We clip the log-SNR near t=T within -5.1 to ensure the stability.
118
+ Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE.
119
+ """
120
+ log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas))
121
+ lambs = log_alphas - log_sigmas
122
+ idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda)
123
+ if idx > 0:
124
+ log_alphas = log_alphas[:-idx]
125
+ return log_alphas
126
+
127
+ def marginal_log_mean_coeff(self, t):
128
+ """
129
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
130
+ """
131
+ if self.schedule == 'discrete':
132
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
133
+ self.log_alpha_array.to(t.device)).reshape((-1))
134
+ elif self.schedule == 'linear':
135
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
136
+
137
+ def marginal_alpha(self, t):
138
+ """
139
+ Compute alpha_t of a given continuous-time label t in [0, T].
140
+ """
141
+ return torch.exp(self.marginal_log_mean_coeff(t))
142
+
143
+ def marginal_std(self, t):
144
+ """
145
+ Compute sigma_t of a given continuous-time label t in [0, T].
146
+ """
147
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
148
+
149
+ def marginal_lambda(self, t):
150
+ """
151
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
152
+ """
153
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
154
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
155
+ return log_mean_coeff - log_std
156
+
157
+ def inverse_lambda(self, lamb):
158
+ """
159
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
160
+ """
161
+ if self.schedule == 'linear':
162
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
163
+ Delta = self.beta_0 ** 2 + tmp
164
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
165
+ elif self.schedule == 'discrete':
166
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
167
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
168
+ torch.flip(self.t_array.to(lamb.device), [1]))
169
+ return t.reshape((-1,))
170
+
171
+
172
+ def model_wrapper(
173
+ model,
174
+ noise_schedule,
175
+ model_type="noise",
176
+ model_kwargs={},
177
+ guidance_type="uncond",
178
+ condition=None,
179
+ unconditional_condition=None,
180
+ guidance_scale=1.,
181
+ classifier_fn=None,
182
+ classifier_kwargs={},
183
+ ):
184
+ """Create a wrapper function for the noise prediction model.
185
+
186
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
187
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
188
+
189
+ We support four types of the diffusion model by setting `model_type`:
190
+
191
+ 1. "noise": noise prediction model. (Trained by predicting noise).
192
+
193
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
194
+
195
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
196
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
197
+
198
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
199
+ arXiv preprint arXiv:2202.00512 (2022).
200
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
201
+ arXiv preprint arXiv:2210.02303 (2022).
202
+
203
+ 4. "score": marginal score function. (Trained by denoising score matching).
204
+ Note that the score function and the noise prediction model follows a simple relationship:
205
+ ```
206
+ noise(x_t, t) = -sigma_t * score(x_t, t)
207
+ ```
208
+
209
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
210
+ 1. "uncond": unconditional sampling by DPMs.
211
+ The input `model` has the following format:
212
+ ``
213
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
214
+ ``
215
+
216
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
217
+ The input `model` has the following format:
218
+ ``
219
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
220
+ ``
221
+
222
+ The input `classifier_fn` has the following format:
223
+ ``
224
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
225
+ ``
226
+
227
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
228
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
229
+
230
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
231
+ The input `model` has the following format:
232
+ ``
233
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
234
+ ``
235
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
236
+
237
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
238
+ arXiv preprint arXiv:2207.12598 (2022).
239
+
240
+
241
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
242
+ or continuous-time labels (i.e. epsilon to T).
243
+
244
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
245
+ ``
246
+ def model_fn(x, t_continuous) -> noise:
247
+ t_input = get_model_input_time(t_continuous)
248
+ return noise_pred(model, x, t_input, **model_kwargs)
249
+ ``
250
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
251
+
252
+ ===============================================================
253
+
254
+ Args:
255
+ model: A diffusion model with the corresponding format described above.
256
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
257
+ model_type: A `str`. The parameterization type of the diffusion model.
258
+ "noise" or "x_start" or "v" or "score".
259
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
260
+ guidance_type: A `str`. The type of the guidance for sampling.
261
+ "uncond" or "classifier" or "classifier-free".
262
+ condition: A pytorch tensor. The condition for the guided sampling.
263
+ Only used for "classifier" or "classifier-free" guidance type.
264
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
265
+ Only used for "classifier-free" guidance type.
266
+ guidance_scale: A `float`. The scale for the guided sampling.
267
+ classifier_fn: A classifier function. Only used for the classifier guidance.
268
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
269
+ Returns:
270
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
271
+ """
272
+
273
+ def get_model_input_time(t_continuous):
274
+ """
275
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
276
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
277
+ For continuous-time DPMs, we just use `t_continuous`.
278
+ """
279
+ if noise_schedule.schedule == 'discrete':
280
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
281
+ else:
282
+ return t_continuous
283
+
284
+ def noise_pred_fn(x, t_continuous, cond=None):
285
+ t_input = get_model_input_time(t_continuous)
286
+ if cond is None:
287
+ output = model(x, t_input, **model_kwargs)
288
+ else:
289
+ output = model(x, t_input, cond, **model_kwargs)
290
+ if model_type == "noise":
291
+ return output
292
+ elif model_type == "x_start":
293
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
294
+ return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
295
+ elif model_type == "v":
296
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
297
+ return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
298
+ elif model_type == "score":
299
+ sigma_t = noise_schedule.marginal_std(t_continuous)
300
+ return -expand_dims(sigma_t, x.dim()) * output
301
+
302
+ def cond_grad_fn(x, t_input):
303
+ """
304
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
305
+ """
306
+ with torch.enable_grad():
307
+ x_in = x.detach().requires_grad_(True)
308
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
309
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
310
+
311
+ def model_fn(x, t_continuous):
312
+ """
313
+ The noise predicition model function that is used for DPM-Solver.
314
+ """
315
+ if guidance_type == "uncond":
316
+ return noise_pred_fn(x, t_continuous)
317
+ elif guidance_type == "classifier":
318
+ assert classifier_fn is not None
319
+ t_input = get_model_input_time(t_continuous)
320
+ cond_grad = cond_grad_fn(x, t_input)
321
+ sigma_t = noise_schedule.marginal_std(t_continuous)
322
+ noise = noise_pred_fn(x, t_continuous)
323
+ return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
324
+ elif guidance_type == "classifier-free":
325
+ if guidance_scale == 1. or unconditional_condition is None:
326
+ return noise_pred_fn(x, t_continuous, cond=condition)
327
+ else:
328
+ x_in = torch.cat([x] * 2)
329
+ t_in = torch.cat([t_continuous] * 2)
330
+ c_in = torch.cat([unconditional_condition, condition])
331
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
332
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
333
+
334
+ assert model_type in ["noise", "x_start", "v", "score"]
335
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
336
+ return model_fn
337
+
338
+
339
+ class DPM_Solver:
340
+ def __init__(
341
+ self,
342
+ model_fn,
343
+ noise_schedule,
344
+ algorithm_type="dpmsolver++",
345
+ correcting_x0_fn=None,
346
+ correcting_xt_fn=None,
347
+ thresholding_max_val=1.,
348
+ dynamic_thresholding_ratio=0.995,
349
+ ):
350
+ """Construct a DPM-Solver.
351
+
352
+ We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
353
+
354
+ We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
355
+ can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
356
+ dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
357
+ DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
358
+ DPMs (such as stable-diffusion).
359
+
360
+ To support advanced algorithms in image-to-image applications, we also support corrector functions for
361
+ both x0 and xt.
362
+
363
+ Args:
364
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
365
+ ``
366
+ def model_fn(x, t_continuous):
367
+ return noise
368
+ ``
369
+ The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
370
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
371
+ algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
372
+ correcting_x0_fn: A `str` or a function with the following format:
373
+ ```
374
+ def correcting_x0_fn(x0, t):
375
+ x0_new = ...
376
+ return x0_new
377
+ ```
378
+ This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
379
+ ```
380
+ x0_pred = data_pred_model(xt, t)
381
+ if correcting_x0_fn is not None:
382
+ x0_pred = correcting_x0_fn(x0_pred, t)
383
+ xt_1 = update(x0_pred, xt, t)
384
+ ```
385
+ If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
386
+ correcting_xt_fn: A function with the following format:
387
+ ```
388
+ def correcting_xt_fn(xt, t, step):
389
+ x_new = ...
390
+ return x_new
391
+ ```
392
+ This function is to correct the intermediate samples xt at each sampling step. e.g.,
393
+ ```
394
+ xt = ...
395
+ xt = correcting_xt_fn(xt, t, step)
396
+ ```
397
+ thresholding_max_val: A `float`. The max value for thresholding.
398
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
399
+ dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
400
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
401
+
402
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
403
+ Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
404
+ with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
405
+ """
406
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
407
+ self.noise_schedule = noise_schedule
408
+ assert algorithm_type in ["dpmsolver", "dpmsolver++"]
409
+ self.algorithm_type = algorithm_type
410
+ if correcting_x0_fn == "dynamic_thresholding":
411
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
412
+ else:
413
+ self.correcting_x0_fn = correcting_x0_fn
414
+ self.correcting_xt_fn = correcting_xt_fn
415
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
416
+ self.thresholding_max_val = thresholding_max_val
417
+
418
+ def dynamic_thresholding_fn(self, x0, t):
419
+ """
420
+ The dynamic thresholding method.
421
+ """
422
+ dims = x0.dim()
423
+ p = self.dynamic_thresholding_ratio
424
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
425
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
426
+ x0 = torch.clamp(x0, -s, s) / s
427
+ return x0
428
+
429
+ def noise_prediction_fn(self, x, t):
430
+ """
431
+ Return the noise prediction model.
432
+ """
433
+ return self.model(x, t)
434
+
435
+ def data_prediction_fn(self, x, t):
436
+ """
437
+ Return the data prediction model (with corrector).
438
+ """
439
+ noise = self.noise_prediction_fn(x, t)
440
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
441
+ x0 = (x - sigma_t * noise) / alpha_t
442
+ if self.correcting_x0_fn is not None:
443
+ x0 = self.correcting_x0_fn(x0, t)
444
+ return x0
445
+
446
+ def model_fn(self, x, t):
447
+ """
448
+ Convert the model to the noise prediction model or the data prediction model.
449
+ """
450
+ if self.algorithm_type == "dpmsolver++":
451
+ return self.data_prediction_fn(x, t)
452
+ else:
453
+ return self.noise_prediction_fn(x, t)
454
+
455
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
456
+ """Compute the intermediate time steps for sampling.
457
+
458
+ Args:
459
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
460
+ - 'logSNR': uniform logSNR for the time steps.
461
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
462
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
463
+ t_T: A `float`. The starting time of the sampling (default is T).
464
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
465
+ N: A `int`. The total number of the spacing of the time steps.
466
+ device: A torch device.
467
+ Returns:
468
+ A pytorch tensor of the time steps, with the shape (N + 1,).
469
+ """
470
+ if skip_type == 'logSNR':
471
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
472
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
473
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
474
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
475
+ elif skip_type == 'time_uniform':
476
+ return torch.linspace(t_T, t_0, N + 1).to(device)
477
+ elif skip_type == 'time_quadratic':
478
+ t_order = 2
479
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
480
+ return t
481
+ else:
482
+ raise ValueError(
483
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
484
+
485
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
486
+ """
487
+ Get the order of each step for sampling by the singlestep DPM-Solver.
488
+
489
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
490
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
491
+ - If order == 1:
492
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
493
+ - If order == 2:
494
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
495
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
496
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
497
+ - If order == 3:
498
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
499
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
500
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
501
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
502
+
503
+ ============================================
504
+ Args:
505
+ order: A `int`. The max order for the solver (2 or 3).
506
+ steps: A `int`. The total number of function evaluations (NFE).
507
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
508
+ - 'logSNR': uniform logSNR for the time steps.
509
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
510
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
511
+ t_T: A `float`. The starting time of the sampling (default is T).
512
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
513
+ device: A torch device.
514
+ Returns:
515
+ orders: A list of the solver order of each step.
516
+ """
517
+ if order == 3:
518
+ K = steps // 3 + 1
519
+ if steps % 3 == 0:
520
+ orders = [3, ] * (K - 2) + [2, 1]
521
+ elif steps % 3 == 1:
522
+ orders = [3, ] * (K - 1) + [1]
523
+ else:
524
+ orders = [3, ] * (K - 1) + [2]
525
+ elif order == 2:
526
+ if steps % 2 == 0:
527
+ K = steps // 2
528
+ orders = [2, ] * K
529
+ else:
530
+ K = steps // 2 + 1
531
+ orders = [2, ] * (K - 1) + [1]
532
+ elif order == 1:
533
+ K = 1
534
+ orders = [1, ] * steps
535
+ else:
536
+ raise ValueError("'order' must be '1' or '2' or '3'.")
537
+ if skip_type == 'logSNR':
538
+ # To reproduce the results in DPM-Solver paper
539
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
540
+ else:
541
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
542
+ torch.cumsum(torch.tensor([0, ] + orders), 0).to(device)]
543
+ return timesteps_outer, orders
544
+
545
+ def denoise_to_zero_fn(self, x, s):
546
+ """
547
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
548
+ """
549
+ return self.data_prediction_fn(x, s)
550
+
551
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
552
+ """
553
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
554
+
555
+ Args:
556
+ x: A pytorch tensor. The initial value at time `s`.
557
+ s: A pytorch tensor. The starting time, with the shape (1,).
558
+ t: A pytorch tensor. The ending time, with the shape (1,).
559
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
560
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
561
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
562
+ Returns:
563
+ x_t: A pytorch tensor. The approximated solution at time `t`.
564
+ """
565
+ ns = self.noise_schedule
566
+ dims = x.dim()
567
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
568
+ h = lambda_t - lambda_s
569
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
570
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
571
+ alpha_t = torch.exp(log_alpha_t)
572
+
573
+ if self.algorithm_type == "dpmsolver++":
574
+ phi_1 = torch.expm1(-h)
575
+ if model_s is None:
576
+ model_s = self.model_fn(x, s)
577
+ x_t = (
578
+ sigma_t / sigma_s * x
579
+ - alpha_t * phi_1 * model_s
580
+ )
581
+ if return_intermediate:
582
+ return x_t, {'model_s': model_s}
583
+ else:
584
+ return x_t
585
+ else:
586
+ phi_1 = torch.expm1(h)
587
+ if model_s is None:
588
+ model_s = self.model_fn(x, s)
589
+ x_t = (
590
+ torch.exp(log_alpha_t - log_alpha_s) * x
591
+ - (sigma_t * phi_1) * model_s
592
+ )
593
+ if return_intermediate:
594
+ return x_t, {'model_s': model_s}
595
+ else:
596
+ return x_t
597
+
598
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
599
+ solver_type='dpmsolver'):
600
+ """
601
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
602
+
603
+ Args:
604
+ x: A pytorch tensor. The initial value at time `s`.
605
+ s: A pytorch tensor. The starting time, with the shape (1,).
606
+ t: A pytorch tensor. The ending time, with the shape (1,).
607
+ r1: A `float`. The hyperparameter of the second-order solver.
608
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
609
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
610
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
611
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
612
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
613
+ Returns:
614
+ x_t: A pytorch tensor. The approximated solution at time `t`.
615
+ """
616
+ if solver_type not in ['dpmsolver', 'taylor']:
617
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
618
+ if r1 is None:
619
+ r1 = 0.5
620
+ ns = self.noise_schedule
621
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
622
+ h = lambda_t - lambda_s
623
+ lambda_s1 = lambda_s + r1 * h
624
+ s1 = ns.inverse_lambda(lambda_s1)
625
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
626
+ s1), ns.marginal_log_mean_coeff(t)
627
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
628
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
629
+
630
+ if self.algorithm_type == "dpmsolver++":
631
+ phi_11 = torch.expm1(-r1 * h)
632
+ phi_1 = torch.expm1(-h)
633
+
634
+ if model_s is None:
635
+ model_s = self.model_fn(x, s)
636
+ x_s1 = (
637
+ (sigma_s1 / sigma_s) * x
638
+ - (alpha_s1 * phi_11) * model_s
639
+ )
640
+ model_s1 = self.model_fn(x_s1, s1)
641
+ if solver_type == 'dpmsolver':
642
+ x_t = (
643
+ (sigma_t / sigma_s) * x
644
+ - (alpha_t * phi_1) * model_s
645
+ - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
646
+ )
647
+ elif solver_type == 'taylor':
648
+ x_t = (
649
+ (sigma_t / sigma_s) * x
650
+ - (alpha_t * phi_1) * model_s
651
+ + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s)
652
+ )
653
+ else:
654
+ phi_11 = torch.expm1(r1 * h)
655
+ phi_1 = torch.expm1(h)
656
+
657
+ if model_s is None:
658
+ model_s = self.model_fn(x, s)
659
+ x_s1 = (
660
+ torch.exp(log_alpha_s1 - log_alpha_s) * x
661
+ - (sigma_s1 * phi_11) * model_s
662
+ )
663
+ model_s1 = self.model_fn(x_s1, s1)
664
+ if solver_type == 'dpmsolver':
665
+ x_t = (
666
+ torch.exp(log_alpha_t - log_alpha_s) * x
667
+ - (sigma_t * phi_1) * model_s
668
+ - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
669
+ )
670
+ elif solver_type == 'taylor':
671
+ x_t = (
672
+ torch.exp(log_alpha_t - log_alpha_s) * x
673
+ - (sigma_t * phi_1) * model_s
674
+ - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s)
675
+ )
676
+ if return_intermediate:
677
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
678
+ else:
679
+ return x_t
680
+
681
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
682
+ return_intermediate=False, solver_type='dpmsolver'):
683
+ """
684
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
685
+
686
+ Args:
687
+ x: A pytorch tensor. The initial value at time `s`.
688
+ s: A pytorch tensor. The starting time, with the shape (1,).
689
+ t: A pytorch tensor. The ending time, with the shape (1,).
690
+ r1: A `float`. The hyperparameter of the third-order solver.
691
+ r2: A `float`. The hyperparameter of the third-order solver.
692
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
693
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
694
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
695
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
696
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
697
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
698
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
699
+ Returns:
700
+ x_t: A pytorch tensor. The approximated solution at time `t`.
701
+ """
702
+ if solver_type not in ['dpmsolver', 'taylor']:
703
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
704
+ if r1 is None:
705
+ r1 = 1. / 3.
706
+ if r2 is None:
707
+ r2 = 2. / 3.
708
+ ns = self.noise_schedule
709
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
710
+ h = lambda_t - lambda_s
711
+ lambda_s1 = lambda_s + r1 * h
712
+ lambda_s2 = lambda_s + r2 * h
713
+ s1 = ns.inverse_lambda(lambda_s1)
714
+ s2 = ns.inverse_lambda(lambda_s2)
715
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
716
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
717
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
718
+ s2), ns.marginal_std(t)
719
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
720
+
721
+ if self.algorithm_type == "dpmsolver++":
722
+ phi_11 = torch.expm1(-r1 * h)
723
+ phi_12 = torch.expm1(-r2 * h)
724
+ phi_1 = torch.expm1(-h)
725
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
726
+ phi_2 = phi_1 / h + 1.
727
+ phi_3 = phi_2 / h - 0.5
728
+
729
+ if model_s is None:
730
+ model_s = self.model_fn(x, s)
731
+ if model_s1 is None:
732
+ x_s1 = (
733
+ (sigma_s1 / sigma_s) * x
734
+ - (alpha_s1 * phi_11) * model_s
735
+ )
736
+ model_s1 = self.model_fn(x_s1, s1)
737
+ x_s2 = (
738
+ (sigma_s2 / sigma_s) * x
739
+ - (alpha_s2 * phi_12) * model_s
740
+ + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
741
+ )
742
+ model_s2 = self.model_fn(x_s2, s2)
743
+ if solver_type == 'dpmsolver':
744
+ x_t = (
745
+ (sigma_t / sigma_s) * x
746
+ - (alpha_t * phi_1) * model_s
747
+ + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
748
+ )
749
+ elif solver_type == 'taylor':
750
+ D1_0 = (1. / r1) * (model_s1 - model_s)
751
+ D1_1 = (1. / r2) * (model_s2 - model_s)
752
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
753
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
754
+ x_t = (
755
+ (sigma_t / sigma_s) * x
756
+ - (alpha_t * phi_1) * model_s
757
+ + (alpha_t * phi_2) * D1
758
+ - (alpha_t * phi_3) * D2
759
+ )
760
+ else:
761
+ phi_11 = torch.expm1(r1 * h)
762
+ phi_12 = torch.expm1(r2 * h)
763
+ phi_1 = torch.expm1(h)
764
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
765
+ phi_2 = phi_1 / h - 1.
766
+ phi_3 = phi_2 / h - 0.5
767
+
768
+ if model_s is None:
769
+ model_s = self.model_fn(x, s)
770
+ if model_s1 is None:
771
+ x_s1 = (
772
+ (torch.exp(log_alpha_s1 - log_alpha_s)) * x
773
+ - (sigma_s1 * phi_11) * model_s
774
+ )
775
+ model_s1 = self.model_fn(x_s1, s1)
776
+ x_s2 = (
777
+ (torch.exp(log_alpha_s2 - log_alpha_s)) * x
778
+ - (sigma_s2 * phi_12) * model_s
779
+ - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
780
+ )
781
+ model_s2 = self.model_fn(x_s2, s2)
782
+ if solver_type == 'dpmsolver':
783
+ x_t = (
784
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
785
+ - (sigma_t * phi_1) * model_s
786
+ - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
787
+ )
788
+ elif solver_type == 'taylor':
789
+ D1_0 = (1. / r1) * (model_s1 - model_s)
790
+ D1_1 = (1. / r2) * (model_s2 - model_s)
791
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
792
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
793
+ x_t = (
794
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
795
+ - (sigma_t * phi_1) * model_s
796
+ - (sigma_t * phi_2) * D1
797
+ - (sigma_t * phi_3) * D2
798
+ )
799
+
800
+ if return_intermediate:
801
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
802
+ else:
803
+ return x_t
804
+
805
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
806
+ """
807
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
808
+
809
+ Args:
810
+ x: A pytorch tensor. The initial value at time `s`.
811
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
812
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
813
+ t: A pytorch tensor. The ending time, with the shape (1,).
814
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
815
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
816
+ Returns:
817
+ x_t: A pytorch tensor. The approximated solution at time `t`.
818
+ """
819
+ if solver_type not in ['dpmsolver', 'taylor']:
820
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
821
+ ns = self.noise_schedule
822
+ model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
823
+ t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
824
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
825
+ t_prev_0), ns.marginal_lambda(t)
826
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
827
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
828
+ alpha_t = torch.exp(log_alpha_t)
829
+
830
+ h_0 = lambda_prev_0 - lambda_prev_1
831
+ h = lambda_t - lambda_prev_0
832
+ r0 = h_0 / h
833
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
834
+ if self.algorithm_type == "dpmsolver++":
835
+ phi_1 = torch.expm1(-h)
836
+ if solver_type == 'dpmsolver':
837
+ x_t = (
838
+ (sigma_t / sigma_prev_0) * x
839
+ - (alpha_t * phi_1) * model_prev_0
840
+ - 0.5 * (alpha_t * phi_1) * D1_0
841
+ )
842
+ elif solver_type == 'taylor':
843
+ x_t = (
844
+ (sigma_t / sigma_prev_0) * x
845
+ - (alpha_t * phi_1) * model_prev_0
846
+ + (alpha_t * (phi_1 / h + 1.)) * D1_0
847
+ )
848
+ else:
849
+ phi_1 = torch.expm1(h)
850
+ if solver_type == 'dpmsolver':
851
+ x_t = (
852
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
853
+ - (sigma_t * phi_1) * model_prev_0
854
+ - 0.5 * (sigma_t * phi_1) * D1_0
855
+ )
856
+ elif solver_type == 'taylor':
857
+ x_t = (
858
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
859
+ - (sigma_t * phi_1) * model_prev_0
860
+ - (sigma_t * (phi_1 / h - 1.)) * D1_0
861
+ )
862
+ return x_t
863
+
864
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'):
865
+ """
866
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
867
+
868
+ Args:
869
+ x: A pytorch tensor. The initial value at time `s`.
870
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
871
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
872
+ t: A pytorch tensor. The ending time, with the shape (1,).
873
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
874
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
875
+ Returns:
876
+ x_t: A pytorch tensor. The approximated solution at time `t`.
877
+ """
878
+ ns = self.noise_schedule
879
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
880
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
881
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
882
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
883
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
884
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
885
+ alpha_t = torch.exp(log_alpha_t)
886
+
887
+ h_1 = lambda_prev_1 - lambda_prev_2
888
+ h_0 = lambda_prev_0 - lambda_prev_1
889
+ h = lambda_t - lambda_prev_0
890
+ r0, r1 = h_0 / h, h_1 / h
891
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
892
+ D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
893
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
894
+ D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
895
+ if self.algorithm_type == "dpmsolver++":
896
+ phi_1 = torch.expm1(-h)
897
+ phi_2 = phi_1 / h + 1.
898
+ phi_3 = phi_2 / h - 0.5
899
+ x_t = (
900
+ (sigma_t / sigma_prev_0) * x
901
+ - (alpha_t * phi_1) * model_prev_0
902
+ + (alpha_t * phi_2) * D1
903
+ - (alpha_t * phi_3) * D2
904
+ )
905
+ else:
906
+ phi_1 = torch.expm1(h)
907
+ phi_2 = phi_1 / h - 1.
908
+ phi_3 = phi_2 / h - 0.5
909
+ x_t = (
910
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
911
+ - (sigma_t * phi_1) * model_prev_0
912
+ - (sigma_t * phi_2) * D1
913
+ - (sigma_t * phi_3) * D2
914
+ )
915
+ return x_t
916
+
917
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None,
918
+ r2=None):
919
+ """
920
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
921
+
922
+ Args:
923
+ x: A pytorch tensor. The initial value at time `s`.
924
+ s: A pytorch tensor. The starting time, with the shape (1,).
925
+ t: A pytorch tensor. The ending time, with the shape (1,).
926
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
927
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
928
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
929
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
930
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
931
+ r2: A `float`. The hyperparameter of the third-order solver.
932
+ Returns:
933
+ x_t: A pytorch tensor. The approximated solution at time `t`.
934
+ """
935
+ if order == 1:
936
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
937
+ elif order == 2:
938
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
939
+ solver_type=solver_type, r1=r1)
940
+ elif order == 3:
941
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
942
+ solver_type=solver_type, r1=r1, r2=r2)
943
+ else:
944
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
945
+
946
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'):
947
+ """
948
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
949
+
950
+ Args:
951
+ x: A pytorch tensor. The initial value at time `s`.
952
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
953
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
954
+ t: A pytorch tensor. The ending time, with the shape (1,).
955
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
956
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
957
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
958
+ Returns:
959
+ x_t: A pytorch tensor. The approximated solution at time `t`.
960
+ """
961
+ if order == 1:
962
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
963
+ elif order == 2:
964
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
965
+ elif order == 3:
966
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
967
+ else:
968
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
969
+
970
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
971
+ solver_type='dpmsolver'):
972
+ """
973
+ The adaptive step size solver based on singlestep DPM-Solver.
974
+
975
+ Args:
976
+ x: A pytorch tensor. The initial value at time `t_T`.
977
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
978
+ t_T: A `float`. The starting time of the sampling (default is T).
979
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
980
+ h_init: A `float`. The initial step size (for logSNR).
981
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
982
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
983
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
984
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
985
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
986
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
987
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
988
+ Returns:
989
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
990
+
991
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
992
+ """
993
+ ns = self.noise_schedule
994
+ s = t_T * torch.ones((1,)).to(x)
995
+ lambda_s = ns.marginal_lambda(s)
996
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
997
+ h = h_init * torch.ones_like(s).to(x)
998
+ x_prev = x
999
+ nfe = 0
1000
+ if order == 2:
1001
+ r1 = 0.5
1002
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
1003
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
1004
+ solver_type=solver_type,
1005
+ **kwargs)
1006
+ elif order == 3:
1007
+ r1, r2 = 1. / 3., 2. / 3.
1008
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
1009
+ return_intermediate=True,
1010
+ solver_type=solver_type)
1011
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
1012
+ solver_type=solver_type,
1013
+ **kwargs)
1014
+ else:
1015
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
1016
+ while torch.abs((s - t_0)).mean() > t_err:
1017
+ t = ns.inverse_lambda(lambda_s + h)
1018
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
1019
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
1020
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
1021
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
1022
+ E = norm_fn((x_higher - x_lower) / delta).max()
1023
+ if torch.all(E <= 1.):
1024
+ x = x_higher
1025
+ s = t
1026
+ x_prev = x_lower
1027
+ lambda_s = ns.marginal_lambda(s)
1028
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
1029
+ nfe += order
1030
+ print('adaptive solver nfe', nfe)
1031
+ return x
1032
+
1033
+ def add_noise(self, x, t, noise=None):
1034
+ """
1035
+ Compute the noised input xt = alpha_t * x + sigma_t * noise.
1036
+
1037
+ Args:
1038
+ x: A `torch.Tensor` with shape `(batch_size, *shape)`.
1039
+ t: A `torch.Tensor` with shape `(t_size,)`.
1040
+ Returns:
1041
+ xt with shape `(t_size, batch_size, *shape)`.
1042
+ """
1043
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
1044
+ if noise is None:
1045
+ noise = torch.randn((t.shape[0], *x.shape), device=x.device)
1046
+ x = x.reshape((-1, *x.shape))
1047
+ xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
1048
+ if t.shape[0] == 1:
1049
+ return xt.squeeze(0)
1050
+ else:
1051
+ return xt
1052
+
1053
+ def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1054
+ method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1055
+ atol=0.0078, rtol=0.05, return_intermediate=False,
1056
+ ):
1057
+ """
1058
+ Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
1059
+ For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
1060
+ """
1061
+ t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start
1062
+ t_T = self.noise_schedule.T if t_end is None else t_end
1063
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1064
+ return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type,
1065
+ method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero,
1066
+ solver_type=solver_type,
1067
+ atol=atol, rtol=rtol, return_intermediate=return_intermediate)
1068
+
1069
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1070
+ method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1071
+ atol=0.0078, rtol=0.05, return_intermediate=False,
1072
+ ):
1073
+ """
1074
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1075
+
1076
+ =====================================================
1077
+
1078
+ We support the following algorithms for both noise prediction model and data prediction model:
1079
+ - 'singlestep':
1080
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1081
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1082
+ The total number of function evaluations (NFE) == `steps`.
1083
+ Given a fixed NFE == `steps`, the sampling procedure is:
1084
+ - If `order` == 1:
1085
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1086
+ - If `order` == 2:
1087
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1088
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1089
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1090
+ - If `order` == 3:
1091
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1092
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1093
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1094
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1095
+ - 'multistep':
1096
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1097
+ We initialize the first `order` values by lower order multistep solvers.
1098
+ Given a fixed NFE == `steps`, the sampling procedure is:
1099
+ Denote K = steps.
1100
+ - If `order` == 1:
1101
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
1102
+ - If `order` == 2:
1103
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1104
+ - If `order` == 3:
1105
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1106
+ - 'singlestep_fixed':
1107
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1108
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1109
+ - 'adaptive':
1110
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1111
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1112
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1113
+ (NFE) and the sample quality.
1114
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1115
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1116
+
1117
+ =====================================================
1118
+
1119
+ Some advices for choosing the algorithm:
1120
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1121
+ Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
1122
+ e.g., DPM-Solver:
1123
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
1124
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1125
+ skip_type='time_uniform', method='singlestep')
1126
+ e.g., DPM-Solver++:
1127
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1128
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1129
+ skip_type='time_uniform', method='singlestep')
1130
+ - For **guided sampling with large guidance scale** by DPMs:
1131
+ Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
1132
+ e.g.
1133
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1134
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1135
+ skip_type='time_uniform', method='multistep')
1136
+
1137
+ We support three types of `skip_type`:
1138
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1139
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1140
+ - 'time_quadratic': quadratic time for the time steps.
1141
+
1142
+ =====================================================
1143
+ Args:
1144
+ x: A pytorch tensor. The initial value at time `t_start`
1145
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1146
+ steps: A `int`. The total number of function evaluations (NFE).
1147
+ t_start: A `float`. The starting time of the sampling.
1148
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1149
+ t_end: A `float`. The ending time of the sampling.
1150
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1151
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1152
+ For discrete-time DPMs:
1153
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1154
+ For continuous-time DPMs:
1155
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1156
+ order: A `int`. The order of DPM-Solver.
1157
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1158
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1159
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1160
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1161
+
1162
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1163
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1164
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1165
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1166
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1167
+ it for high-resolutional images.
1168
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1169
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1170
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1171
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1172
+ solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
1173
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1174
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1175
+ return_intermediate: A `bool`. Whether to save the xt at each step.
1176
+ When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
1177
+ Returns:
1178
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1179
+
1180
+ """
1181
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1182
+ t_T = self.noise_schedule.T if t_start is None else t_start
1183
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1184
+ if return_intermediate:
1185
+ assert method in ['multistep', 'singlestep',
1186
+ 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
1187
+ if self.correcting_xt_fn is not None:
1188
+ assert method in ['multistep', 'singlestep',
1189
+ 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
1190
+ device = x.device
1191
+ intermediates = []
1192
+ with torch.no_grad():
1193
+ if method == 'adaptive':
1194
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1195
+ solver_type=solver_type)
1196
+ elif method == 'multistep':
1197
+ assert steps >= order
1198
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1199
+ assert timesteps.shape[0] - 1 == steps
1200
+ # Init the initial values.
1201
+ step = 0
1202
+ t = timesteps[step]
1203
+ t_prev_list = [t]
1204
+ model_prev_list = [self.model_fn(x, t)]
1205
+ if self.correcting_xt_fn is not None:
1206
+ x = self.correcting_xt_fn(x, t, step)
1207
+ if return_intermediate:
1208
+ intermediates.append(x)
1209
+ # Init the first `order` values by lower order multistep DPM-Solver.
1210
+ for step in range(1, order):
1211
+ t = timesteps[step]
1212
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step,
1213
+ solver_type=solver_type)
1214
+ if self.correcting_xt_fn is not None:
1215
+ x = self.correcting_xt_fn(x, t, step)
1216
+ if return_intermediate:
1217
+ intermediates.append(x)
1218
+ t_prev_list.append(t)
1219
+ model_prev_list.append(self.model_fn(x, t))
1220
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1221
+ for step in tqdm(range(order, steps + 1)):
1222
+ t = timesteps[step]
1223
+ # We only use lower order for steps < 10
1224
+ # if lower_order_final and steps < 10:
1225
+ if lower_order_final: # recommended by Shuchen Xue
1226
+ step_order = min(order, steps + 1 - step)
1227
+ else:
1228
+ step_order = order
1229
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order,
1230
+ solver_type=solver_type)
1231
+ if self.correcting_xt_fn is not None:
1232
+ x = self.correcting_xt_fn(x, t, step)
1233
+ if return_intermediate:
1234
+ intermediates.append(x)
1235
+ for i in range(order - 1):
1236
+ t_prev_list[i] = t_prev_list[i + 1]
1237
+ model_prev_list[i] = model_prev_list[i + 1]
1238
+ t_prev_list[-1] = t
1239
+ # We do not need to evaluate the final model value.
1240
+ if step < steps:
1241
+ model_prev_list[-1] = self.model_fn(x, t)
1242
+ elif method in ['singlestep', 'singlestep_fixed']:
1243
+ if method == 'singlestep':
1244
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps,
1245
+ order=order,
1246
+ skip_type=skip_type,
1247
+ t_T=t_T, t_0=t_0,
1248
+ device=device)
1249
+ elif method == 'singlestep_fixed':
1250
+ K = steps // order
1251
+ orders = [order, ] * K
1252
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1253
+ for step, order in enumerate(orders):
1254
+ s, t = timesteps_outer[step], timesteps_outer[step + 1]
1255
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order,
1256
+ device=device)
1257
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1258
+ h = lambda_inner[-1] - lambda_inner[0]
1259
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1260
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1261
+ x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
1262
+ if self.correcting_xt_fn is not None:
1263
+ x = self.correcting_xt_fn(x, t, step)
1264
+ if return_intermediate:
1265
+ intermediates.append(x)
1266
+ else:
1267
+ raise ValueError("Got wrong method {}".format(method))
1268
+ if denoise_to_zero:
1269
+ t = torch.ones((1,)).to(device) * t_0
1270
+ x = self.denoise_to_zero_fn(x, t)
1271
+ if self.correcting_xt_fn is not None:
1272
+ x = self.correcting_xt_fn(x, t, step + 1)
1273
+ if return_intermediate:
1274
+ intermediates.append(x)
1275
+ if return_intermediate:
1276
+ return x, intermediates
1277
+ else:
1278
+ return x
1279
+
1280
+
1281
+ #############################################################
1282
+ # other utility functions
1283
+ #############################################################
1284
+
1285
+ def interpolate_fn(x, xp, yp):
1286
+ """
1287
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1288
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1289
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1290
+
1291
+ Args:
1292
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1293
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1294
+ yp: PyTorch tensor with shape [C, K].
1295
+ Returns:
1296
+ The function values f(x), with shape [N, C].
1297
+ """
1298
+ N, K = x.shape[0], xp.shape[1]
1299
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1300
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1301
+ x_idx = torch.argmin(x_indices, dim=2)
1302
+ cand_start_idx = x_idx - 1
1303
+ start_idx = torch.where(
1304
+ torch.eq(x_idx, 0),
1305
+ torch.tensor(1, device=x.device),
1306
+ torch.where(
1307
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1308
+ ),
1309
+ )
1310
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1311
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1312
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1313
+ start_idx2 = torch.where(
1314
+ torch.eq(x_idx, 0),
1315
+ torch.tensor(0, device=x.device),
1316
+ torch.where(
1317
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1318
+ ),
1319
+ )
1320
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1321
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1322
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1323
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1324
+ return cand
1325
+
1326
+
1327
+ def expand_dims(v, dims):
1328
+ """
1329
+ Expand the tensor `v` to the dim `dims`.
1330
+
1331
+ Args:
1332
+ `v`: a PyTorch tensor with shape [N].
1333
+ `dim`: a `int`.
1334
+ Returns:
1335
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1336
+ """
1337
+ return v[(...,) + (None,) * (dims - 1)]
diffusion/model/edm_sample.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+ from diffusion.model.utils import *
6
+
7
+
8
+ # ----------------------------------------------------------------------------
9
+ # Proposed EDM sampler (Algorithm 2).
10
+
11
+ def edm_sampler(
12
+ net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like,
13
+ num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
14
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs
15
+ ):
16
+ # Adjust noise levels based on what's supported by the network.
17
+ sigma_min = max(sigma_min, net.sigma_min)
18
+ sigma_max = min(sigma_max, net.sigma_max)
19
+
20
+ # Time step discretization.
21
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
22
+ t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
23
+ sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
24
+ t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
25
+
26
+ # Main sampling loop.
27
+ x_next = latents.to(torch.float64) * t_steps[0]
28
+ for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1
29
+ x_cur = x_next
30
+
31
+ # Increase noise temporarily.
32
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
33
+ t_hat = net.round_sigma(t_cur + gamma * t_cur)
34
+ x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
35
+
36
+ # Euler step.
37
+ denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
38
+ d_cur = (x_hat - denoised) / t_hat
39
+ x_next = x_hat + (t_next - t_hat) * d_cur
40
+
41
+ # Apply 2nd order correction.
42
+ if i < num_steps - 1:
43
+ denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
44
+ d_prime = (x_next - denoised) / t_next
45
+ x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
46
+
47
+ return x_next
48
+
49
+
50
+ # ----------------------------------------------------------------------------
51
+ # Generalized ablation sampler, representing the superset of all sampling
52
+ # methods discussed in the paper.
53
+
54
+ def ablation_sampler(
55
+ net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
56
+ num_steps=18, sigma_min=None, sigma_max=None, rho=7,
57
+ solver='heun', discretization='edm', schedule='linear', scaling='none',
58
+ epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
59
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
60
+ ):
61
+ assert solver in ['euler', 'heun']
62
+ assert discretization in ['vp', 've', 'iddpm', 'edm']
63
+ assert schedule in ['vp', 've', 'linear']
64
+ assert scaling in ['vp', 'none']
65
+
66
+ # Helper functions for VP & VE noise level schedules.
67
+ vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
68
+ vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
69
+ vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (
70
+ sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
71
+ ve_sigma = lambda t: t.sqrt()
72
+ ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
73
+ ve_sigma_inv = lambda sigma: sigma ** 2
74
+
75
+ # Select default noise level range based on the specified time step discretization.
76
+ if sigma_min is None:
77
+ vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
78
+ sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
79
+ if sigma_max is None:
80
+ vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
81
+ sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
82
+
83
+ # Adjust noise levels based on what's supported by the network.
84
+ sigma_min = max(sigma_min, net.sigma_min)
85
+ sigma_max = min(sigma_max, net.sigma_max)
86
+
87
+ # Compute corresponding betas for VP.
88
+ vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
89
+ vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
90
+
91
+ # Define time steps in terms of noise level.
92
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
93
+ if discretization == 'vp':
94
+ orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
95
+ sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
96
+ elif discretization == 've':
97
+ orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
98
+ sigma_steps = ve_sigma(orig_t_steps)
99
+ elif discretization == 'iddpm':
100
+ u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
101
+ alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
102
+ for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
103
+ u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
104
+ u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
105
+ sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
106
+ else:
107
+ assert discretization == 'edm'
108
+ sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
109
+ sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
110
+
111
+ # Define noise level schedule.
112
+ if schedule == 'vp':
113
+ sigma = vp_sigma(vp_beta_d, vp_beta_min)
114
+ sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
115
+ sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
116
+ elif schedule == 've':
117
+ sigma = ve_sigma
118
+ sigma_deriv = ve_sigma_deriv
119
+ sigma_inv = ve_sigma_inv
120
+ else:
121
+ assert schedule == 'linear'
122
+ sigma = lambda t: t
123
+ sigma_deriv = lambda t: 1
124
+ sigma_inv = lambda sigma: sigma
125
+
126
+ # Define scaling schedule.
127
+ if scaling == 'vp':
128
+ s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
129
+ s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
130
+ else:
131
+ assert scaling == 'none'
132
+ s = lambda t: 1
133
+ s_deriv = lambda t: 0
134
+
135
+ # Compute final time steps based on the corresponding noise levels.
136
+ t_steps = sigma_inv(net.round_sigma(sigma_steps))
137
+ t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
138
+
139
+ # Main sampling loop.
140
+ t_next = t_steps[0]
141
+ x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
142
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
143
+ x_cur = x_next
144
+
145
+ # Increase noise temporarily.
146
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
147
+ t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
148
+ x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(
149
+ t_hat) * S_noise * randn_like(x_cur)
150
+
151
+ # Euler step.
152
+ h = t_next - t_hat
153
+ denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to(
154
+ torch.float64)
155
+ d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(
156
+ t_hat) / sigma(t_hat) * denoised
157
+ x_prime = x_hat + alpha * h * d_cur
158
+ t_prime = t_hat + alpha * h
159
+
160
+ # Apply 2nd order correction.
161
+ if solver == 'euler' or i == num_steps - 1:
162
+ x_next = x_hat + h * d_cur
163
+ else:
164
+ assert solver == 'heun'
165
+ denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to(
166
+ torch.float64)
167
+ d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(
168
+ t_prime) * s(t_prime) / sigma(t_prime) * denoised
169
+ x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
170
+
171
+ return x_next
diffusion/model/gaussian_diffusion.py ADDED
@@ -0,0 +1,1041 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import enum
8
+ import math
9
+
10
+ import numpy as np
11
+ import torch as th
12
+ import torch.nn.functional as F
13
+
14
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
15
+
16
+
17
+ def mean_flat(tensor):
18
+ """
19
+ Take the mean over all non-batch dimensions.
20
+ """
21
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
22
+
23
+
24
+ class ModelMeanType(enum.Enum):
25
+ """
26
+ Which type of output the model predicts.
27
+ """
28
+
29
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
30
+ START_X = enum.auto() # the model predicts x_0
31
+ EPSILON = enum.auto() # the model predicts epsilon
32
+
33
+
34
+ class ModelVarType(enum.Enum):
35
+ """
36
+ What is used as the model's output variance.
37
+ The LEARNED_RANGE option has been added to allow the model to predict
38
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
39
+ """
40
+
41
+ LEARNED = enum.auto()
42
+ FIXED_SMALL = enum.auto()
43
+ FIXED_LARGE = enum.auto()
44
+ LEARNED_RANGE = enum.auto()
45
+
46
+
47
+ class LossType(enum.Enum):
48
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
49
+ RESCALED_MSE = (
50
+ enum.auto()
51
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
52
+ KL = enum.auto() # use the variational lower-bound
53
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
54
+
55
+ def is_vb(self):
56
+ return self == LossType.KL or self == LossType.RESCALED_KL
57
+
58
+
59
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
60
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
61
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
62
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
63
+ return betas
64
+
65
+
66
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
67
+ """
68
+ This is the deprecated API for creating beta schedules.
69
+ See get_named_beta_schedule() for the new library of schedules.
70
+ """
71
+ if beta_schedule == "quad":
72
+ betas = (
73
+ np.linspace(
74
+ beta_start ** 0.5,
75
+ beta_end ** 0.5,
76
+ num_diffusion_timesteps,
77
+ dtype=np.float64,
78
+ )
79
+ ** 2
80
+ )
81
+ elif beta_schedule == "linear":
82
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
83
+ elif beta_schedule == "warmup10":
84
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
85
+ elif beta_schedule == "warmup50":
86
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
87
+ elif beta_schedule == "const":
88
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
89
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
90
+ betas = 1.0 / np.linspace(
91
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
92
+ )
93
+ else:
94
+ raise NotImplementedError(beta_schedule)
95
+ assert betas.shape == (num_diffusion_timesteps,)
96
+ return betas
97
+
98
+
99
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
100
+ """
101
+ Get a pre-defined beta schedule for the given name.
102
+ The beta schedule library consists of beta schedules which remain similar
103
+ in the limit of num_diffusion_timesteps.
104
+ Beta schedules may be added, but should not be removed or changed once
105
+ they are committed to maintain backwards compatibility.
106
+ """
107
+ if schedule_name == "linear":
108
+ # Linear schedule from Ho et al, extended to work for any number of
109
+ # diffusion steps.
110
+ scale = 1000 / num_diffusion_timesteps
111
+ return get_beta_schedule(
112
+ "linear",
113
+ beta_start=scale * 0.0001,
114
+ beta_end=scale * 0.02,
115
+ num_diffusion_timesteps=num_diffusion_timesteps,
116
+ )
117
+ elif schedule_name == "squaredcos_cap_v2":
118
+ return betas_for_alpha_bar(
119
+ num_diffusion_timesteps,
120
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
121
+ )
122
+ else:
123
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
124
+
125
+
126
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
127
+ """
128
+ Create a beta schedule that discretizes the given alpha_t_bar function,
129
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
130
+ :param num_diffusion_timesteps: the number of betas to produce.
131
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
132
+ produces the cumulative product of (1-beta) up to that
133
+ part of the diffusion process.
134
+ :param max_beta: the maximum beta to use; use values lower than 1 to
135
+ prevent singularities.
136
+ """
137
+ betas = []
138
+ for i in range(num_diffusion_timesteps):
139
+ t1 = i / num_diffusion_timesteps
140
+ t2 = (i + 1) / num_diffusion_timesteps
141
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
142
+ return np.array(betas)
143
+
144
+
145
+ class GaussianDiffusion:
146
+ """
147
+ Utilities for training and sampling diffusion models.
148
+ Original ported from this codebase:
149
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
150
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
151
+ starting at T and going to 1.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ *,
157
+ betas,
158
+ model_mean_type,
159
+ model_var_type,
160
+ loss_type,
161
+ snr=False,
162
+ return_startx=False,
163
+ ):
164
+
165
+ self.model_mean_type = model_mean_type
166
+ self.model_var_type = model_var_type
167
+ self.loss_type = loss_type
168
+ self.snr = snr
169
+ self.return_startx = return_startx
170
+
171
+ # Use float64 for accuracy.
172
+ betas = np.array(betas, dtype=np.float64)
173
+ self.betas = betas
174
+ assert len(betas.shape) == 1, "betas must be 1-D"
175
+ assert (betas > 0).all() and (betas <= 1).all()
176
+
177
+ self.num_timesteps = int(betas.shape[0])
178
+
179
+ alphas = 1.0 - betas
180
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
181
+
182
+ if False:
183
+ target_resolution = 128 # 1024:128; 512:64; 256:32;
184
+ reference_resolution = 64 # Reference resolution (e.g., 64x64)
185
+ scaling_factor = (target_resolution / reference_resolution) ** 2
186
+ print('scaling_factor', scaling_factor)
187
+
188
+ # Adjust alphas and betas according to the scaling factor
189
+ alpha_cumprod_snr_shift = self.alphas_cumprod / (scaling_factor * (1 - self.alphas_cumprod) + self.alphas_cumprod)
190
+ alpha_cuspord_rmove1 = np.concatenate([np.ones([1]), alpha_cumprod_snr_shift[:999]])
191
+ alpha_snr_shift = alpha_cumprod_snr_shift / alpha_cuspord_rmove1
192
+
193
+ betas_snr_shift = 1 - alpha_snr_shift
194
+
195
+ # Update the class attributes with adjusted values
196
+ snr_ref = (self.alphas_cumprod / (1 - self.alphas_cumprod))
197
+ snr_cur = (alpha_cumprod_snr_shift / (1 - alpha_cumprod_snr_shift))
198
+
199
+ self.betas = betas_snr_shift
200
+ self.alphas_cumprod = np.cumprod(alpha_snr_shift, axis=0)
201
+
202
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
203
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
204
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
205
+
206
+ # calculations for diffusion q(x_t | x_{t-1}) and others
207
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
208
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
209
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
210
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
211
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
212
+
213
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
214
+ self.posterior_variance = (
215
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
216
+ )
217
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
218
+ self.posterior_log_variance_clipped = np.log(
219
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
220
+ ) if len(self.posterior_variance) > 1 else np.array([])
221
+
222
+ self.posterior_mean_coef1 = (
223
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
224
+ )
225
+ self.posterior_mean_coef2 = (
226
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
227
+ )
228
+
229
+ def q_mean_variance(self, x_start, t):
230
+ """
231
+ Get the distribution q(x_t | x_0).
232
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
233
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
234
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
235
+ """
236
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
237
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
238
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
239
+ return mean, variance, log_variance
240
+
241
+ def q_sample(self, x_start, t, noise=None):
242
+ """
243
+ Diffuse the data for a given number of diffusion steps.
244
+ In other words, sample from q(x_t | x_0).
245
+ :param x_start: the initial data batch.
246
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
247
+ :param noise: if specified, the split-out normal noise.
248
+ :return: A noisy version of x_start.
249
+ """
250
+ if noise is None:
251
+ noise = th.randn_like(x_start)
252
+ assert noise.shape == x_start.shape
253
+ return (
254
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
255
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
256
+ )
257
+
258
+ def q_posterior_mean_variance(self, x_start, x_t, t):
259
+ """
260
+ Compute the mean and variance of the diffusion posterior:
261
+ q(x_{t-1} | x_t, x_0)
262
+ """
263
+ assert x_start.shape == x_t.shape
264
+ posterior_mean = (
265
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
266
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
267
+ )
268
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
269
+ posterior_log_variance_clipped = _extract_into_tensor(
270
+ self.posterior_log_variance_clipped, t, x_t.shape
271
+ )
272
+ assert (
273
+ posterior_mean.shape[0]
274
+ == posterior_variance.shape[0]
275
+ == posterior_log_variance_clipped.shape[0]
276
+ == x_start.shape[0]
277
+ )
278
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
279
+
280
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
281
+ """
282
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
283
+ the initial x, x_0.
284
+ :param model: the model, which takes a signal and a batch of timesteps
285
+ as input.
286
+ :param x: the [N x C x ...] tensor at time t.
287
+ :param t: a 1-D Tensor of timesteps.
288
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
289
+ :param denoised_fn: if not None, a function which applies to the
290
+ x_start prediction before it is used to sample. Applies before
291
+ clip_denoised.
292
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
293
+ pass to the model. This can be used for conditioning.
294
+ :return: a dict with the following keys:
295
+ - 'mean': the model mean output.
296
+ - 'variance': the model variance output.
297
+ - 'log_variance': the log of 'variance'.
298
+ - 'pred_xstart': the prediction for x_0.
299
+ """
300
+ if model_kwargs is None:
301
+ model_kwargs = {}
302
+
303
+ B, C = x.shape[:2]
304
+ assert t.shape == (B,)
305
+ model_output = model(x, t, **model_kwargs)
306
+ if isinstance(model_output, tuple):
307
+ model_output, extra = model_output
308
+ else:
309
+ extra = None
310
+
311
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
312
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
313
+ model_output, model_var_values = th.split(model_output, C, dim=1)
314
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
315
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
316
+ # The model_var_values is [-1, 1] for [min_var, max_var].
317
+ frac = (model_var_values + 1) / 2
318
+ model_log_variance = frac * max_log + (1 - frac) * min_log
319
+ model_variance = th.exp(model_log_variance)
320
+ elif self.model_var_type in [ModelVarType.FIXED_LARGE, ModelVarType.FIXED_SMALL]:
321
+ model_variance, model_log_variance = {
322
+ # for fixedlarge, we set the initial (log-)variance like so
323
+ # to get a better decoder log likelihood.
324
+ ModelVarType.FIXED_LARGE: (
325
+ np.append(self.posterior_variance[1], self.betas[1:]),
326
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
327
+ ),
328
+ ModelVarType.FIXED_SMALL: (
329
+ self.posterior_variance,
330
+ self.posterior_log_variance_clipped,
331
+ ),
332
+ }[self.model_var_type]
333
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
334
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
335
+ else:
336
+ model_variance = th.zeros_like(model_output)
337
+ model_log_variance = th.zeros_like(model_output)
338
+
339
+ def process_xstart(x):
340
+ if denoised_fn is not None:
341
+ x = denoised_fn(x)
342
+ if clip_denoised:
343
+ return x.clamp(-1, 1)
344
+ return x
345
+
346
+ if self.model_mean_type == ModelMeanType.START_X:
347
+ pred_xstart = process_xstart(model_output)
348
+ else:
349
+ pred_xstart = process_xstart(
350
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
351
+ )
352
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
353
+
354
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
355
+ return {
356
+ "mean": model_mean,
357
+ "variance": model_variance,
358
+ "log_variance": model_log_variance,
359
+ "pred_xstart": pred_xstart,
360
+ "extra": extra,
361
+ }
362
+
363
+ def _predict_xstart_from_eps(self, x_t, t, eps):
364
+ assert x_t.shape == eps.shape
365
+ return (
366
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
367
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
368
+ )
369
+
370
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
371
+ return (
372
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
373
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
374
+
375
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
376
+ """
377
+ Compute the mean for the previous step, given a function cond_fn that
378
+ computes the gradient of a conditional log probability with respect to
379
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
380
+ condition on y.
381
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
382
+ """
383
+ gradient = cond_fn(x, t, **model_kwargs)
384
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
385
+ return new_mean
386
+
387
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
388
+ """
389
+ Compute what the p_mean_variance output would have been, should the
390
+ model's score function be conditioned by cond_fn.
391
+ See condition_mean() for details on cond_fn.
392
+ Unlike condition_mean(), this instead uses the conditioning strategy
393
+ from Song et al (2020).
394
+ """
395
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
396
+
397
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
398
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
399
+
400
+ out = p_mean_var.copy()
401
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
402
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
403
+ return out
404
+
405
+ def p_sample(
406
+ self,
407
+ model,
408
+ x,
409
+ t,
410
+ clip_denoised=True,
411
+ denoised_fn=None,
412
+ cond_fn=None,
413
+ model_kwargs=None,
414
+ ):
415
+ """
416
+ Sample x_{t-1} from the model at the given timestep.
417
+ :param model: the model to sample from.
418
+ :param x: the current tensor at x_{t-1}.
419
+ :param t: the value of t, starting at 0 for the first diffusion step.
420
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
421
+ :param denoised_fn: if not None, a function which applies to the
422
+ x_start prediction before it is used to sample.
423
+ :param cond_fn: if not None, this is a gradient function that acts
424
+ similarly to the model.
425
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
426
+ pass to the model. This can be used for conditioning.
427
+ :return: a dict containing the following keys:
428
+ - 'sample': a random sample from the model.
429
+ - 'pred_xstart': a prediction of x_0.
430
+ """
431
+ out = self.p_mean_variance(
432
+ model,
433
+ x,
434
+ t,
435
+ clip_denoised=clip_denoised,
436
+ denoised_fn=denoised_fn,
437
+ model_kwargs=model_kwargs,
438
+ )
439
+ noise = th.randn_like(x)
440
+ nonzero_mask = (
441
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
442
+ ) # no noise when t == 0
443
+ if cond_fn is not None:
444
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
445
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
446
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
447
+
448
+ def p_sample_loop(
449
+ self,
450
+ model,
451
+ shape,
452
+ noise=None,
453
+ clip_denoised=True,
454
+ denoised_fn=None,
455
+ cond_fn=None,
456
+ model_kwargs=None,
457
+ device=None,
458
+ progress=False,
459
+ ):
460
+ """
461
+ Generate samples from the model.
462
+ :param model: the model module.
463
+ :param shape: the shape of the samples, (N, C, H, W).
464
+ :param noise: if specified, the noise from the encoder to sample.
465
+ Should be of the same shape as `shape`.
466
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
467
+ :param denoised_fn: if not None, a function which applies to the
468
+ x_start prediction before it is used to sample.
469
+ :param cond_fn: if not None, this is a gradient function that acts
470
+ similarly to the model.
471
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
472
+ pass to the model. This can be used for conditioning.
473
+ :param device: if specified, the device to create the samples on.
474
+ If not specified, use a model parameter's device.
475
+ :param progress: if True, show a tqdm progress bar.
476
+ :return: a non-differentiable batch of samples.
477
+ """
478
+ final = None
479
+ for sample in self.p_sample_loop_progressive(
480
+ model,
481
+ shape,
482
+ noise=noise,
483
+ clip_denoised=clip_denoised,
484
+ denoised_fn=denoised_fn,
485
+ cond_fn=cond_fn,
486
+ model_kwargs=model_kwargs,
487
+ device=device,
488
+ progress=progress,
489
+ ):
490
+ final = sample
491
+ return final["sample"]
492
+
493
+ def p_sample_loop_progressive(
494
+ self,
495
+ model,
496
+ shape,
497
+ noise=None,
498
+ clip_denoised=True,
499
+ denoised_fn=None,
500
+ cond_fn=None,
501
+ model_kwargs=None,
502
+ device=None,
503
+ progress=False,
504
+ ):
505
+ """
506
+ Generate samples from the model and yield intermediate samples from
507
+ each timestep of diffusion.
508
+ Arguments are the same as p_sample_loop().
509
+ Returns a generator over dicts, where each dict is the return value of
510
+ p_sample().
511
+ """
512
+ if device is None:
513
+ device = next(model.parameters()).device
514
+ assert isinstance(shape, (tuple, list))
515
+ if noise is not None:
516
+ img = noise
517
+ else:
518
+ img = th.randn(*shape, device=device)
519
+ indices = list(range(self.num_timesteps))[::-1]
520
+
521
+ if progress:
522
+ # Lazy import so that we don't depend on tqdm.
523
+ from tqdm.auto import tqdm
524
+
525
+ indices = tqdm(indices)
526
+
527
+ for i in indices:
528
+ t = th.tensor([i] * shape[0], device=device)
529
+ with th.no_grad():
530
+ out = self.p_sample(
531
+ model,
532
+ img,
533
+ t,
534
+ clip_denoised=clip_denoised,
535
+ denoised_fn=denoised_fn,
536
+ cond_fn=cond_fn,
537
+ model_kwargs=model_kwargs,
538
+ )
539
+ yield out
540
+ img = out["sample"]
541
+
542
+ def ddim_sample(
543
+ self,
544
+ model,
545
+ x,
546
+ t,
547
+ clip_denoised=True,
548
+ denoised_fn=None,
549
+ cond_fn=None,
550
+ model_kwargs=None,
551
+ eta=0.0,
552
+ ):
553
+ """
554
+ Sample x_{t-1} from the model using DDIM.
555
+ Same usage as p_sample().
556
+ """
557
+ out = self.p_mean_variance(
558
+ model,
559
+ x,
560
+ t,
561
+ clip_denoised=clip_denoised,
562
+ denoised_fn=denoised_fn,
563
+ model_kwargs=model_kwargs,
564
+ )
565
+ if cond_fn is not None:
566
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
567
+
568
+ # Usually our model outputs epsilon, but we re-derive it
569
+ # in case we used x_start or x_prev prediction.
570
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
571
+
572
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
573
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
574
+ sigma = (
575
+ eta
576
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
577
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
578
+ )
579
+ # Equation 12.
580
+ noise = th.randn_like(x)
581
+ mean_pred = (
582
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
583
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
584
+ )
585
+ nonzero_mask = (
586
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
587
+ ) # no noise when t == 0
588
+ sample = mean_pred + nonzero_mask * sigma * noise
589
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
590
+
591
+ def ddim_reverse_sample(
592
+ self,
593
+ model,
594
+ x,
595
+ t,
596
+ clip_denoised=True,
597
+ denoised_fn=None,
598
+ cond_fn=None,
599
+ model_kwargs=None,
600
+ eta=0.0,
601
+ ):
602
+ """
603
+ Sample x_{t+1} from the model using DDIM reverse ODE.
604
+ """
605
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
606
+ out = self.p_mean_variance(
607
+ model,
608
+ x,
609
+ t,
610
+ clip_denoised=clip_denoised,
611
+ denoised_fn=denoised_fn,
612
+ model_kwargs=model_kwargs,
613
+ )
614
+ if cond_fn is not None:
615
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
616
+ # Usually our model outputs epsilon, but we re-derive it
617
+ # in case we used x_start or x_prev prediction.
618
+ eps = (
619
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
620
+ - out["pred_xstart"]
621
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
622
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
623
+
624
+ # Equation 12. reversed
625
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
626
+
627
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
628
+
629
+ def ddim_sample_loop(
630
+ self,
631
+ model,
632
+ shape,
633
+ noise=None,
634
+ clip_denoised=True,
635
+ denoised_fn=None,
636
+ cond_fn=None,
637
+ model_kwargs=None,
638
+ device=None,
639
+ progress=False,
640
+ eta=0.0,
641
+ ):
642
+ """
643
+ Generate samples from the model using DDIM.
644
+ Same usage as p_sample_loop().
645
+ """
646
+ final = None
647
+ for sample in self.ddim_sample_loop_progressive(
648
+ model,
649
+ shape,
650
+ noise=noise,
651
+ clip_denoised=clip_denoised,
652
+ denoised_fn=denoised_fn,
653
+ cond_fn=cond_fn,
654
+ model_kwargs=model_kwargs,
655
+ device=device,
656
+ progress=progress,
657
+ eta=eta,
658
+ ):
659
+ final = sample
660
+ return final["sample"]
661
+
662
+ def ddim_sample_loop_progressive(
663
+ self,
664
+ model,
665
+ shape,
666
+ noise=None,
667
+ clip_denoised=True,
668
+ denoised_fn=None,
669
+ cond_fn=None,
670
+ model_kwargs=None,
671
+ device=None,
672
+ progress=False,
673
+ eta=0.0,
674
+ ):
675
+ """
676
+ Use DDIM to sample from the model and yield intermediate samples from
677
+ each timestep of DDIM.
678
+ Same usage as p_sample_loop_progressive().
679
+ """
680
+ if device is None:
681
+ device = next(model.parameters()).device
682
+ assert isinstance(shape, (tuple, list))
683
+ if noise is not None:
684
+ img = noise
685
+ else:
686
+ img = th.randn(*shape, device=device)
687
+ indices = list(range(self.num_timesteps))[::-1]
688
+
689
+ if progress:
690
+ # Lazy import so that we don't depend on tqdm.
691
+ from tqdm.auto import tqdm
692
+
693
+ indices = tqdm(indices)
694
+
695
+ for i in indices:
696
+ t = th.tensor([i] * shape[0], device=device)
697
+ with th.no_grad():
698
+ out = self.ddim_sample(
699
+ model,
700
+ img,
701
+ t,
702
+ clip_denoised=clip_denoised,
703
+ denoised_fn=denoised_fn,
704
+ cond_fn=cond_fn,
705
+ model_kwargs=model_kwargs,
706
+ eta=eta,
707
+ )
708
+ yield out
709
+ img = out["sample"]
710
+
711
+ def _vb_terms_bpd(
712
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
713
+ ):
714
+ """
715
+ Get a term for the variational lower-bound.
716
+ The resulting units are bits (rather than nats, as one might expect).
717
+ This allows for comparison to other papers.
718
+ :return: a dict with the following keys:
719
+ - 'output': a shape [N] tensor of NLLs or KLs.
720
+ - 'pred_xstart': the x_0 predictions.
721
+ """
722
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
723
+ x_start=x_start, x_t=x_t, t=t
724
+ )
725
+ out = self.p_mean_variance(
726
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
727
+ )
728
+ kl = normal_kl(
729
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
730
+ )
731
+ kl = mean_flat(kl) / np.log(2.0)
732
+
733
+ decoder_nll = -discretized_gaussian_log_likelihood(
734
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
735
+ )
736
+ assert decoder_nll.shape == x_start.shape
737
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
738
+
739
+ # At the first timestep return the decoder NLL,
740
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
741
+ output = th.where((t == 0), decoder_nll, kl)
742
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
743
+
744
+ def training_losses(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
745
+ """
746
+ Compute training losses for a single timestep.
747
+ :param model: the model to evaluate loss on.
748
+ :param x_start: the [N x C x ...] tensor of inputs.
749
+ :param t: a batch of timestep indices.
750
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
751
+ pass to the model. This can be used for conditioning.
752
+ :param noise: if specified, the specific Gaussian noise to try to remove.
753
+ :return: a dict with the key "loss" containing a tensor of shape [N].
754
+ Some mean or variance settings may also have other keys.
755
+ """
756
+ t = timestep
757
+ if model_kwargs is None:
758
+ model_kwargs = {}
759
+ if skip_noise:
760
+ x_t = x_start
761
+ else:
762
+ if noise is None:
763
+ noise = th.randn_like(x_start)
764
+ x_t = self.q_sample(x_start, t, noise=noise)
765
+
766
+ terms = {}
767
+
768
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
769
+ terms["loss"] = self._vb_terms_bpd(
770
+ model=model,
771
+ x_start=x_start,
772
+ x_t=x_t,
773
+ t=t,
774
+ clip_denoised=False,
775
+ model_kwargs=model_kwargs,
776
+ )["output"]
777
+ if self.loss_type == LossType.RESCALED_KL:
778
+ terms["loss"] *= self.num_timesteps
779
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
780
+ model_output = model(x_t, t, **model_kwargs)
781
+ if isinstance(model_output, dict) and model_output.get('x', None) is not None:
782
+ output = model_output['x']
783
+ else:
784
+ output = model_output
785
+
786
+ if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
787
+ B, C = x_t.shape[:2]
788
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
789
+ output = th.split(output, C, dim=1)[0]
790
+ return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t
791
+
792
+ if self.model_var_type in [
793
+ ModelVarType.LEARNED,
794
+ ModelVarType.LEARNED_RANGE,
795
+ ]:
796
+ B, C = x_t.shape[:2]
797
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
798
+ output, model_var_values = th.split(output, C, dim=1)
799
+ # Learn the variance using the variational bound, but don't let it affect our mean prediction.
800
+ frozen_out = th.cat([output.detach(), model_var_values], dim=1)
801
+ terms["vb"] = self._vb_terms_bpd(
802
+ model=lambda *args, r=frozen_out, **kwargs: r,
803
+ x_start=x_start,
804
+ x_t=x_t,
805
+ t=t,
806
+ clip_denoised=False,
807
+ )["output"]
808
+ if self.loss_type == LossType.RESCALED_MSE:
809
+ # Divide by 1000 for equivalence with initial implementation.
810
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
811
+ terms["vb"] *= self.num_timesteps / 1000.0
812
+
813
+ target = {
814
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
815
+ x_start=x_start, x_t=x_t, t=t
816
+ )[0],
817
+ ModelMeanType.START_X: x_start,
818
+ ModelMeanType.EPSILON: noise,
819
+ }[self.model_mean_type]
820
+ assert output.shape == target.shape == x_start.shape
821
+ if self.snr:
822
+ if self.model_mean_type == ModelMeanType.START_X:
823
+ pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
824
+ pred_startx = output
825
+ elif self.model_mean_type == ModelMeanType.EPSILON:
826
+ pred_noise = output
827
+ pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
828
+ # terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
829
+ # terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
830
+
831
+ t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
832
+ # best
833
+ target = th.where(t > 249, noise, x_start)
834
+ output = th.where(t > 249, pred_noise, pred_startx)
835
+ loss = (target - output) ** 2
836
+ if model_kwargs.get('mask_ratio', False) and model_kwargs['mask_ratio'] > 0:
837
+ assert 'mask' in model_output
838
+ loss = F.avg_pool2d(loss.mean(dim=1), model.model.module.patch_size).flatten(1)
839
+ mask = model_output['mask']
840
+ unmask = 1 - mask
841
+ terms['mse'] = mean_flat(loss * unmask) * unmask.shape[1]/unmask.sum(1)
842
+ if model_kwargs['mask_loss_coef'] > 0:
843
+ terms['mae'] = model_kwargs['mask_loss_coef'] * mean_flat(loss * mask) * mask.shape[1]/mask.sum(1)
844
+ else:
845
+ terms["mse"] = mean_flat(loss)
846
+ if "vb" in terms:
847
+ terms["loss"] = terms["mse"] + terms["vb"]
848
+ else:
849
+ terms["loss"] = terms["mse"]
850
+ if "mae" in terms:
851
+ terms["loss"] = terms["loss"] + terms["mae"]
852
+ else:
853
+ raise NotImplementedError(self.loss_type)
854
+
855
+ return terms
856
+
857
+ def training_losses_diffusers(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
858
+ """
859
+ Compute training losses for a single timestep.
860
+ :param model: the model to evaluate loss on.
861
+ :param x_start: the [N x C x ...] tensor of inputs.
862
+ :param t: a batch of timestep indices.
863
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
864
+ pass to the model. This can be used for conditioning.
865
+ :param noise: if specified, the specific Gaussian noise to try to remove.
866
+ :return: a dict with the key "loss" containing a tensor of shape [N].
867
+ Some mean or variance settings may also have other keys.
868
+ """
869
+ t = timestep
870
+ if model_kwargs is None:
871
+ model_kwargs = {}
872
+ if skip_noise:
873
+ x_t = x_start
874
+ else:
875
+ if noise is None:
876
+ noise = th.randn_like(x_start)
877
+ x_t = self.q_sample(x_start, t, noise=noise)
878
+
879
+ terms = {}
880
+
881
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
882
+ terms["loss"] = self._vb_terms_bpd(
883
+ model=model,
884
+ x_start=x_start,
885
+ x_t=x_t,
886
+ t=t,
887
+ clip_denoised=False,
888
+ model_kwargs=model_kwargs,
889
+ )["output"]
890
+ if self.loss_type == LossType.RESCALED_KL:
891
+ terms["loss"] *= self.num_timesteps
892
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
893
+ output = model(x_t, timestep=t, **model_kwargs, return_dict=False)[0]
894
+
895
+ if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
896
+ B, C = x_t.shape[:2]
897
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
898
+ output = th.split(output, C, dim=1)[0]
899
+ return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t
900
+
901
+ if self.model_var_type in [
902
+ ModelVarType.LEARNED,
903
+ ModelVarType.LEARNED_RANGE,
904
+ ]:
905
+ B, C = x_t.shape[:2]
906
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
907
+ output, model_var_values = th.split(output, C, dim=1)
908
+ # Learn the variance using the variational bound, but don't let it affect our mean prediction.
909
+ frozen_out = th.cat([output.detach(), model_var_values], dim=1)
910
+ terms["vb"] = self._vb_terms_bpd(
911
+ model=lambda *args, r=frozen_out, **kwargs: r,
912
+ x_start=x_start,
913
+ x_t=x_t,
914
+ t=t,
915
+ clip_denoised=False,
916
+ )["output"]
917
+ if self.loss_type == LossType.RESCALED_MSE:
918
+ # Divide by 1000 for equivalence with initial implementation.
919
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
920
+ terms["vb"] *= self.num_timesteps / 1000.0
921
+
922
+ target = {
923
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
924
+ x_start=x_start, x_t=x_t, t=t
925
+ )[0],
926
+ ModelMeanType.START_X: x_start,
927
+ ModelMeanType.EPSILON: noise,
928
+ }[self.model_mean_type]
929
+ assert output.shape == target.shape == x_start.shape
930
+ if self.snr:
931
+ if self.model_mean_type == ModelMeanType.START_X:
932
+ pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
933
+ pred_startx = output
934
+ elif self.model_mean_type == ModelMeanType.EPSILON:
935
+ pred_noise = output
936
+ pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
937
+ # terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
938
+ # terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
939
+
940
+ t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
941
+ # best
942
+ target = th.where(t > 249, noise, x_start)
943
+ output = th.where(t > 249, pred_noise, pred_startx)
944
+ loss = (target - output) ** 2
945
+ terms["mse"] = mean_flat(loss)
946
+ if "vb" in terms:
947
+ terms["loss"] = terms["mse"] + terms["vb"]
948
+ else:
949
+ terms["loss"] = terms["mse"]
950
+ if "mae" in terms:
951
+ terms["loss"] = terms["loss"] + terms["mae"]
952
+ else:
953
+ raise NotImplementedError(self.loss_type)
954
+
955
+ return terms
956
+
957
+ def _prior_bpd(self, x_start):
958
+ """
959
+ Get the prior KL term for the variational lower-bound, measured in
960
+ bits-per-dim.
961
+ This term can't be optimized, as it only depends on the encoder.
962
+ :param x_start: the [N x C x ...] tensor of inputs.
963
+ :return: a batch of [N] KL values (in bits), one per batch element.
964
+ """
965
+ batch_size = x_start.shape[0]
966
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
967
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
968
+ kl_prior = normal_kl(
969
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
970
+ )
971
+ return mean_flat(kl_prior) / np.log(2.0)
972
+
973
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
974
+ """
975
+ Compute the entire variational lower-bound, measured in bits-per-dim,
976
+ as well as other related quantities.
977
+ :param model: the model to evaluate loss on.
978
+ :param x_start: the [N x C x ...] tensor of inputs.
979
+ :param clip_denoised: if True, clip denoised samples.
980
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
981
+ pass to the model. This can be used for conditioning.
982
+ :return: a dict containing the following keys:
983
+ - total_bpd: the total variational lower-bound, per batch element.
984
+ - prior_bpd: the prior term in the lower-bound.
985
+ - vb: an [N x T] tensor of terms in the lower-bound.
986
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
987
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
988
+ """
989
+ device = x_start.device
990
+ batch_size = x_start.shape[0]
991
+
992
+ vb = []
993
+ xstart_mse = []
994
+ mse = []
995
+ for t in list(range(self.num_timesteps))[::-1]:
996
+ t_batch = th.tensor([t] * batch_size, device=device)
997
+ noise = th.randn_like(x_start)
998
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
999
+ # Calculate VLB term at the current timestep
1000
+ with th.no_grad():
1001
+ out = self._vb_terms_bpd(
1002
+ model,
1003
+ x_start=x_start,
1004
+ x_t=x_t,
1005
+ t=t_batch,
1006
+ clip_denoised=clip_denoised,
1007
+ model_kwargs=model_kwargs,
1008
+ )
1009
+ vb.append(out["output"])
1010
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
1011
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
1012
+ mse.append(mean_flat((eps - noise) ** 2))
1013
+
1014
+ vb = th.stack(vb, dim=1)
1015
+ xstart_mse = th.stack(xstart_mse, dim=1)
1016
+ mse = th.stack(mse, dim=1)
1017
+
1018
+ prior_bpd = self._prior_bpd(x_start)
1019
+ total_bpd = vb.sum(dim=1) + prior_bpd
1020
+ return {
1021
+ "total_bpd": total_bpd,
1022
+ "prior_bpd": prior_bpd,
1023
+ "vb": vb,
1024
+ "xstart_mse": xstart_mse,
1025
+ "mse": mse,
1026
+ }
1027
+
1028
+
1029
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1030
+ """
1031
+ Extract values from a 1-D numpy array for a batch of indices.
1032
+ :param arr: the 1-D numpy array.
1033
+ :param timesteps: a tensor of indices into the array to extract.
1034
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1035
+ dimension equal to the length of timesteps.
1036
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1037
+ """
1038
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1039
+ while len(res.shape) < len(broadcast_shape):
1040
+ res = res[..., None]
1041
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diffusion/model/llava/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from diffusion.model.llava.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
diffusion/model/llava/llava_mpt.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+ import warnings
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from torch.nn import CrossEntropyLoss
23
+
24
+ import math
25
+
26
+ from transformers import AutoConfig, AutoModelForCausalLM, CLIPVisionModel, CLIPImageProcessor
27
+
28
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
29
+
30
+ from diffusion.model.llava.mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
31
+
32
+
33
+ DEFAULT_IMAGE_TOKEN = "<image>"
34
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
35
+ DEFAULT_IM_START_TOKEN = "<im_start>"
36
+ DEFAULT_IM_END_TOKEN = "<im_end>"
37
+
38
+
39
+ class LlavaMPTConfig(MPTConfig):
40
+ model_type = "llava_mpt"
41
+
42
+
43
+ class LlavaMPTModel(MPTModel):
44
+ config_class = LlavaMPTConfig
45
+
46
+ def __init__(self, config: MPTConfig, mm_vision_tower=None, mm_hidden_size=None):
47
+ super(LlavaMPTModel, self).__init__(config)
48
+
49
+ if hasattr(config, "mm_vision_tower"):
50
+ # HACK: for FSDP
51
+ self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
52
+ # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
53
+
54
+ if hasattr(config, "use_mm_proj"):
55
+ self.mm_projector = nn.Linear(config.mm_hidden_size, config.d_model)
56
+
57
+ def initialize_vision_modules(self, vision_tower, mm_vision_select_layer,
58
+ pretrain_mm_mlp_adapter=None, tune_mm_mlp_adapter=False):
59
+ self.config.mm_vision_tower = vision_tower
60
+
61
+ image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
62
+
63
+ if not hasattr(self, 'vision_tower'):
64
+ vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
65
+ else:
66
+ vision_tower = self.vision_tower[0]
67
+ vision_tower.requires_grad_(False)
68
+ vision_tower = vision_tower.to(torch.float16)
69
+ self.vision_tower = [vision_tower]
70
+
71
+ vision_config = vision_tower.config
72
+ num_patches = (vision_config.image_size // vision_config.patch_size) ** 2
73
+
74
+ self.config.use_mm_proj = True
75
+ self.config.mm_hidden_size = vision_config.hidden_size
76
+ self.config.mm_vision_select_layer = mm_vision_select_layer
77
+
78
+ if not hasattr(self, 'mm_projector'):
79
+ self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.d_model)
80
+
81
+ if pretrain_mm_mlp_adapter is not None:
82
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
83
+ self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items() if 'mm_projector' in k})
84
+
85
+ return dict(
86
+ image_processor=image_processor,
87
+ image_token_len=num_patches,
88
+ vision_config=vision_config
89
+ )
90
+
91
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
92
+
93
+ # HACK: replace back original embeddings for LLaVA pretraining
94
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
95
+ # if orig_embeds_params is not None:
96
+ # orig_embeds_params = orig_embeds_params[0]
97
+ # with torch.no_grad():
98
+ # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data
99
+
100
+ inputs_embeds = self.wte(input_ids)
101
+
102
+ vision_tower = getattr(self, 'vision_tower', None)
103
+ if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
104
+ # TODO: this is a modified multimodal LLM -- Haotian Liu
105
+ vision_tower = vision_tower[0] # HACK: for FSDP
106
+ with torch.no_grad():
107
+ if type(images) is list:
108
+ # variable length images
109
+ image_features = []
110
+ for image in images:
111
+ image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True)
112
+ select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
113
+ select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
114
+ image_feature = select_hidden_state[:, 1:]
115
+ image_features.append(image_feature)
116
+ else:
117
+ image_forward_outs = vision_tower(images, output_hidden_states=True)
118
+ select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
119
+ select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
120
+ image_features = select_hidden_state[:, 1:]
121
+ if type(images) is list:
122
+ image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features]
123
+ else:
124
+ image_features = self.mm_projector(image_features)
125
+ dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
126
+ dummy_image_features = self.mm_projector(dummy_image_features)
127
+
128
+ new_input_embeds = []
129
+ cur_image_idx = 0
130
+ for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
131
+ if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
132
+ # multimodal LLM, but the current sample is not multimodal
133
+ cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
134
+ new_input_embeds.append(cur_input_embeds)
135
+ continue
136
+ if vision_tower.config.use_im_start_end:
137
+ cur_image_features = image_features[cur_image_idx]
138
+ num_patches = cur_image_features.shape[0]
139
+ if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
140
+ raise ValueError("The number of image start tokens and image end tokens should be the same.")
141
+ image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0]
142
+ for image_start_token_pos in image_start_tokens:
143
+ cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device)
144
+ num_patches = cur_image_features.shape[0]
145
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
146
+ raise ValueError("The image end token should follow the image start token.")
147
+ if orig_embeds_params is not None:
148
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
149
+ else:
150
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
151
+ cur_image_idx += 1
152
+ new_input_embeds.append(cur_new_input_embeds)
153
+ else:
154
+ cur_image_features = image_features[cur_image_idx]
155
+ num_patches = cur_image_features.shape[0]
156
+ if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
157
+ raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
158
+ masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0]
159
+ mask_index_start = masked_indices[0]
160
+ if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
161
+ raise ValueError("The image patch tokens should be consecutive.")
162
+ if orig_embeds_params is not None:
163
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
164
+ else:
165
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
166
+ new_input_embeds.append(cur_new_input_embeds)
167
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
168
+
169
+ return super(LlavaMPTModel, self).forward(input_ids=None, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, tok_emb=inputs_embeds)
170
+
171
+
172
+ class LlavaMPTForCausalLM(MPTForCausalLM):
173
+ config_class = LlavaMPTConfig
174
+ supports_gradient_checkpointing = True
175
+
176
+ def __init__(self, config):
177
+ super(MPTForCausalLM, self).__init__(config)
178
+
179
+ if not config.tie_word_embeddings:
180
+ raise ValueError('MPTForCausalLM only supports tied word embeddings')
181
+ self.transformer = LlavaMPTModel(config)
182
+ self.logit_scale = None
183
+ if config.logit_scale is not None:
184
+ logit_scale = config.logit_scale
185
+ if isinstance(logit_scale, str):
186
+ if logit_scale == 'inv_sqrt_d_model':
187
+ logit_scale = 1 / math.sqrt(config.d_model)
188
+ else:
189
+ raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
190
+ self.logit_scale = logit_scale
191
+
192
+ def get_model(self):
193
+ return self.transformer
194
+
195
+ def _set_gradient_checkpointing(self, module, value=False):
196
+ if isinstance(module, LlavaMPTModel):
197
+ module.gradient_checkpointing = value
198
+
199
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
200
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
201
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
202
+ outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, images=images)
203
+ logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
204
+ if self.logit_scale is not None:
205
+ if self.logit_scale == 0:
206
+ warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
207
+ logits *= self.logit_scale
208
+ loss = None
209
+ if labels is not None:
210
+ labels = torch.roll(labels, shifts=-1)
211
+ labels[:, -1] = -100
212
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
213
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
214
+
215
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
216
+ if inputs_embeds is not None:
217
+ raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
218
+ attention_mask = kwargs['attention_mask'].bool()
219
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
220
+ raise NotImplementedError('MPT does not support generation with right padding.')
221
+ if self.transformer.attn_uses_sequence_id and self.training:
222
+ sequence_id = torch.zeros_like(input_ids[:1])
223
+ else:
224
+ sequence_id = None
225
+ if past_key_values is not None:
226
+ input_ids = input_ids[:, -1].unsqueeze(-1)
227
+ if self.transformer.prefix_lm:
228
+ prefix_mask = torch.ones_like(attention_mask)
229
+ if kwargs.get('use_cache') == False:
230
+ raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
231
+ else:
232
+ prefix_mask = None
233
+ return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)}
234
+
235
+ def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
236
+ tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
237
+ vision_config = self.get_model().vision_tower[0].config
238
+ vision_config.use_im_start_end = mm_use_im_start_end
239
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
240
+ self.resize_token_embeddings(len(tokenizer))
241
+
242
+ if mm_use_im_start_end:
243
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
244
+ self.resize_token_embeddings(len(tokenizer))
245
+ vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
246
+
247
+ if num_new_tokens > 0:
248
+ input_embeddings = self.get_input_embeddings().weight.data
249
+ output_embeddings = self.get_output_embeddings().weight.data
250
+
251
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
252
+ dim=0, keepdim=True)
253
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
254
+ dim=0, keepdim=True)
255
+
256
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
257
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
258
+
259
+ if tune_mm_mlp_adapter:
260
+ self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
261
+ for p in self.get_input_embeddings().parameters():
262
+ p.requires_grad = True
263
+ for p in self.get_output_embeddings().parameters():
264
+ p.requires_grad = False
265
+
266
+ if pretrain_mm_mlp_adapter:
267
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
268
+ embed_tokens_weight = mm_projector_weights['transformer.wte.weight']
269
+ assert num_new_tokens == 2
270
+ if input_embeddings.shape == embed_tokens_weight.shape:
271
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
272
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
273
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
274
+ else:
275
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
276
+
277
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
278
+
279
+ AutoConfig.register("llava_mpt", LlavaMPTConfig)
280
+ AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)
diffusion/model/llava/mpt/attention.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Attention layers."""
2
+ import math
3
+ import warnings
4
+ from typing import Optional
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import rearrange
8
+ from torch import nn
9
+ from .norm import LPLayerNorm
10
+
11
+ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
12
+ if original_is_causal and num_query_tokens != num_key_tokens:
13
+ if num_query_tokens != 1:
14
+ raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')
15
+ else:
16
+ return False
17
+ return original_is_causal
18
+
19
+ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
20
+ q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
21
+ k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
22
+ v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
23
+ min_val = torch.finfo(q.dtype).min
24
+ (b, _, s_q, d) = q.shape
25
+ s_k = k.size(-1)
26
+ if softmax_scale is None:
27
+ softmax_scale = 1 / math.sqrt(d)
28
+ attn_weight = q.matmul(k) * softmax_scale
29
+ if attn_bias is not None:
30
+ if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
31
+ raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
32
+ attn_weight = attn_weight + attn_bias
33
+ if key_padding_mask is not None:
34
+ if attn_bias is not None:
35
+ warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
36
+ attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
37
+ if is_causal:
38
+ s = max(s_q, s_k)
39
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
40
+ causal_mask = causal_mask.tril()
41
+ causal_mask = causal_mask.to(torch.bool)
42
+ causal_mask = ~causal_mask
43
+ causal_mask = causal_mask[-s_q:, -s_k:]
44
+ attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
45
+ attn_weight = torch.softmax(attn_weight, dim=-1)
46
+ if dropout_p:
47
+ attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
48
+ out = attn_weight.matmul(v)
49
+ out = rearrange(out, 'b h s d -> b s (h d)')
50
+ if needs_weights:
51
+ return (out, attn_weight)
52
+ return (out, None)
53
+
54
+ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
55
+ for tensor in tensors:
56
+ if tensor.dtype not in valid_dtypes:
57
+ raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
58
+ if not tensor.is_cuda:
59
+ raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
60
+
61
+ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
62
+ try:
63
+ from flash_attn import bert_padding, flash_attn_interface
64
+ except:
65
+ raise RuntimeError('Please install flash-attn==1.0.3.post0')
66
+ check_valid_inputs(query, key, value)
67
+ if attn_bias is not None:
68
+ raise NotImplementedError(f'attn_bias not implemented for flash attn.')
69
+ (batch_size, seqlen) = query.shape[:2]
70
+ if key_padding_mask is None:
71
+ key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
72
+ query_padding_mask = key_padding_mask[:, -query.size(1):]
73
+ (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
74
+ query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
75
+ (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
76
+ key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
77
+ (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
78
+ value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
79
+ if multiquery:
80
+ key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
81
+ value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
82
+ dropout_p = dropout_p if training else 0.0
83
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
84
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
85
+ output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
86
+ return (output, None)
87
+
88
+ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
89
+ try:
90
+ from flash_attn import flash_attn_triton
91
+ except:
92
+ raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
93
+ check_valid_inputs(query, key, value)
94
+ if dropout_p:
95
+ raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
96
+ if needs_weights:
97
+ raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
98
+ if key_padding_mask is not None:
99
+ warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
100
+ (b_size, s_k) = key_padding_mask.shape[:2]
101
+ if attn_bias is None:
102
+ attn_bias = query.new_zeros(b_size, 1, 1, s_k)
103
+ attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
104
+ query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
105
+ key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
106
+ value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
107
+ if multiquery:
108
+ key = key.expand(*key.shape[:2], n_heads, key.size(-1))
109
+ value = value.expand(*value.shape[:2], n_heads, value.size(-1))
110
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
111
+ attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
112
+ output = attn_output.view(*attn_output.shape[:2], -1)
113
+ return (output, None)
114
+
115
+ class MultiheadAttention(nn.Module):
116
+ """Multi-head self attention.
117
+
118
+ Using torch or triton attention implemetation enables user to also use
119
+ additive bias.
120
+ """
121
+
122
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
123
+ super().__init__()
124
+ self.attn_impl = attn_impl
125
+ self.clip_qkv = clip_qkv
126
+ self.qk_ln = qk_ln
127
+ self.d_model = d_model
128
+ self.n_heads = n_heads
129
+ self.softmax_scale = softmax_scale
130
+ if self.softmax_scale is None:
131
+ self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
132
+ self.attn_dropout_p = attn_pdrop
133
+ self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
134
+ fuse_splits = (d_model, 2 * d_model)
135
+ self.Wqkv._fused = (0, fuse_splits)
136
+ if self.qk_ln:
137
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
138
+ self.q_ln = layernorm_class(self.d_model, device=device)
139
+ self.k_ln = layernorm_class(self.d_model, device=device)
140
+ if self.attn_impl == 'flash':
141
+ self.attn_fn = flash_attn_fn
142
+ elif self.attn_impl == 'triton':
143
+ self.attn_fn = triton_flash_attn_fn
144
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
145
+ elif self.attn_impl == 'torch':
146
+ self.attn_fn = scaled_multihead_dot_product_attention
147
+ if torch.cuda.is_available():
148
+ warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
149
+ else:
150
+ raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
151
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
152
+ self.out_proj._is_residual = True
153
+
154
+ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
155
+ qkv = self.Wqkv(x)
156
+ if self.clip_qkv:
157
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
158
+ (query, key, value) = qkv.chunk(3, dim=2)
159
+ key_padding_mask = attention_mask
160
+ if self.qk_ln:
161
+ dtype = query.dtype
162
+ query = self.q_ln(query).to(dtype)
163
+ key = self.k_ln(key).to(dtype)
164
+ if past_key_value is not None:
165
+ if len(past_key_value) != 0:
166
+ key = torch.cat([past_key_value[0], key], dim=1)
167
+ value = torch.cat([past_key_value[1], value], dim=1)
168
+ past_key_value = (key, value)
169
+ if attn_bias is not None:
170
+ attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
171
+ (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
172
+ return (self.out_proj(context), attn_weights, past_key_value)
173
+
174
+ class MultiQueryAttention(nn.Module):
175
+ """Multi-Query self attention.
176
+
177
+ Using torch or triton attention implemetation enables user to also use
178
+ additive bias.
179
+ """
180
+
181
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
182
+ super().__init__()
183
+ self.attn_impl = attn_impl
184
+ self.clip_qkv = clip_qkv
185
+ self.qk_ln = qk_ln
186
+ self.d_model = d_model
187
+ self.n_heads = n_heads
188
+ self.head_dim = d_model // n_heads
189
+ self.softmax_scale = softmax_scale
190
+ if self.softmax_scale is None:
191
+ self.softmax_scale = 1 / math.sqrt(self.head_dim)
192
+ self.attn_dropout_p = attn_pdrop
193
+ self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
194
+ fuse_splits = (d_model, d_model + self.head_dim)
195
+ self.Wqkv._fused = (0, fuse_splits)
196
+ if self.qk_ln:
197
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
198
+ self.q_ln = layernorm_class(d_model, device=device)
199
+ self.k_ln = layernorm_class(self.head_dim, device=device)
200
+ if self.attn_impl == 'flash':
201
+ self.attn_fn = flash_attn_fn
202
+ elif self.attn_impl == 'triton':
203
+ self.attn_fn = triton_flash_attn_fn
204
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
205
+ elif self.attn_impl == 'torch':
206
+ self.attn_fn = scaled_multihead_dot_product_attention
207
+ if torch.cuda.is_available():
208
+ warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
209
+ else:
210
+ raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
211
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
212
+ self.out_proj._is_residual = True
213
+
214
+ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
215
+ qkv = self.Wqkv(x)
216
+ if self.clip_qkv:
217
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
218
+ (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
219
+ key_padding_mask = attention_mask
220
+ if self.qk_ln:
221
+ dtype = query.dtype
222
+ query = self.q_ln(query).to(dtype)
223
+ key = self.k_ln(key).to(dtype)
224
+ if past_key_value is not None:
225
+ if len(past_key_value) != 0:
226
+ key = torch.cat([past_key_value[0], key], dim=1)
227
+ value = torch.cat([past_key_value[1], value], dim=1)
228
+ past_key_value = (key, value)
229
+ if attn_bias is not None:
230
+ attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
231
+ (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
232
+ return (self.out_proj(context), attn_weights, past_key_value)
233
+
234
+ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
235
+ if attn_impl == 'flash':
236
+ return None
237
+ elif attn_impl in ['torch', 'triton']:
238
+ if alibi:
239
+ if (prefix_lm or not causal) or use_sequence_id:
240
+ return (1, n_heads, seq_len, seq_len)
241
+ return (1, n_heads, 1, seq_len)
242
+ elif prefix_lm or use_sequence_id:
243
+ return (1, 1, seq_len, seq_len)
244
+ return None
245
+ else:
246
+ raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
247
+
248
+ def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
249
+ if attn_impl == 'flash':
250
+ return None
251
+ elif attn_impl in ['torch', 'triton']:
252
+ if alibi:
253
+ (device, dtype) = (attn_bias.device, attn_bias.dtype)
254
+ attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
255
+ return attn_bias
256
+ else:
257
+ raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
258
+
259
+ def gen_slopes(n_heads, alibi_bias_max=8, device=None):
260
+ _n_heads = 2 ** math.ceil(math.log2(n_heads))
261
+ m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
262
+ m = m.mul(alibi_bias_max / _n_heads)
263
+ slopes = 1.0 / torch.pow(2, m)
264
+ if _n_heads != n_heads:
265
+ slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
266
+ return slopes.view(1, n_heads, 1, 1)
267
+
268
+ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
269
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
270
+ if full:
271
+ alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
272
+ alibi_bias = alibi_bias.abs().mul(-1)
273
+ slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
274
+ alibi_bias = alibi_bias * slopes
275
+ return alibi_bias.to(dtype=dtype)
276
+ ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
diffusion/model/llava/mpt/blocks.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPT Blocks used for the GPT Model."""
2
+ from typing import Dict, Optional, Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ from .attention import ATTN_CLASS_REGISTRY
6
+ from .norm import NORM_CLASS_REGISTRY
7
+
8
+ class MPTMLP(nn.Module):
9
+
10
+ def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
11
+ super().__init__()
12
+ self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13
+ self.act = nn.GELU(approximate='none')
14
+ self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
15
+ self.down_proj._is_residual = True
16
+
17
+ def forward(self, x):
18
+ return self.down_proj(self.act(self.up_proj(x)))
19
+
20
+ class MPTBlock(nn.Module):
21
+
22
+ def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
23
+ del kwargs
24
+ super().__init__()
25
+ norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
+ attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
+ self.norm_1 = norm_class(d_model, device=device)
28
+ self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
29
+ self.norm_2 = norm_class(d_model, device=device)
30
+ self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31
+ self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32
+ self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33
+
34
+ def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
+ a = self.norm_1(x)
36
+ (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
+ x = x + self.resid_attn_dropout(b)
38
+ m = self.norm_2(x)
39
+ n = self.ffn(m)
40
+ x = x + self.resid_ffn_dropout(n)
41
+ return (x, past_key_value)
diffusion/model/llava/mpt/configuration_mpt.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A HuggingFace-style model configuration."""
2
+ from typing import Dict, Optional, Union
3
+ from transformers import PretrainedConfig
4
+ attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
+ init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
6
+
7
+ class MPTConfig(PretrainedConfig):
8
+ model_type = 'mpt'
9
+
10
+ def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs):
11
+ """The MPT configuration class.
12
+
13
+ Args:
14
+ d_model (int): The size of the embedding dimension of the model.
15
+ n_heads (int): The number of attention heads.
16
+ n_layers (int): The number of layers in the model.
17
+ expansion_ratio (int): The ratio of the up/down scale in the MLP.
18
+ max_seq_len (int): The maximum sequence length of the model.
19
+ vocab_size (int): The size of the vocabulary.
20
+ resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
21
+ emb_pdrop (float): The dropout probability for the embedding layer.
22
+ learned_pos_emb (bool): Whether to use learned positional embeddings
23
+ attn_config (Dict): A dictionary used to configure the model's attention module:
24
+ attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
25
+ attn_pdrop (float): The dropout probability for the attention layers.
26
+ attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
27
+ qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
28
+ clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
29
+ this value.
30
+ softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
31
+ use the default scale of ``1/sqrt(d_keys)``.
32
+ prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
33
+ extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
34
+ can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
35
+ attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
36
+ When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
37
+ which sub-sequence each token belongs to.
38
+ Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
39
+ alibi (bool): Whether to use the alibi bias instead of position embeddings.
40
+ alibi_bias_max (int): The maximum value of the alibi bias.
41
+ init_device (str): The device to use for parameter initialization.
42
+ logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
43
+ no_bias (bool): Whether to use bias in all layers.
44
+ verbose (int): The verbosity level. 0 is silent.
45
+ embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
46
+ norm_type (str): choose type of norm to use
47
+ multiquery_attention (bool): Whether to use multiquery attention implementation.
48
+ use_cache (bool): Whether or not the model should return the last key/values attentions
49
+ init_config (Dict): A dictionary used to configure the model initialization:
50
+ init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
51
+ 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
52
+ 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
53
+ init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
54
+ emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
55
+ emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
56
+ used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
57
+ init_std (float): The standard deviation of the normal distribution used to initialize the model,
58
+ if using the baseline_ parameter initialization scheme.
59
+ init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
60
+ fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
61
+ init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
62
+ ---
63
+ See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
64
+ """
65
+ self.d_model = d_model
66
+ self.n_heads = n_heads
67
+ self.n_layers = n_layers
68
+ self.expansion_ratio = expansion_ratio
69
+ self.max_seq_len = max_seq_len
70
+ self.vocab_size = vocab_size
71
+ self.resid_pdrop = resid_pdrop
72
+ self.emb_pdrop = emb_pdrop
73
+ self.learned_pos_emb = learned_pos_emb
74
+ self.attn_config = attn_config
75
+ self.init_device = init_device
76
+ self.logit_scale = logit_scale
77
+ self.no_bias = no_bias
78
+ self.verbose = verbose
79
+ self.embedding_fraction = embedding_fraction
80
+ self.norm_type = norm_type
81
+ self.use_cache = use_cache
82
+ self.init_config = init_config
83
+ if 'name' in kwargs:
84
+ del kwargs['name']
85
+ if 'loss_fn' in kwargs:
86
+ del kwargs['loss_fn']
87
+ super().__init__(**kwargs)
88
+ self._validate_config()
89
+
90
+ def _set_config_defaults(self, config, config_defaults):
91
+ for (k, v) in config_defaults.items():
92
+ if k not in config:
93
+ config[k] = v
94
+ return config
95
+
96
+ def _validate_config(self):
97
+ self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
98
+ self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
99
+ if self.d_model % self.n_heads != 0:
100
+ raise ValueError('d_model must be divisible by n_heads')
101
+ if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
102
+ raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
103
+ if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
104
+ raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
105
+ if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
106
+ raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
107
+ if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
108
+ raise NotImplementedError('alibi only implemented with torch and triton attention.')
109
+ if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
110
+ raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
111
+ if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
112
+ raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
113
+ if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':
114
+ raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
115
+ if self.init_config.get('name', None) is None:
116
+ raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
117
+ if not self.learned_pos_emb and (not self.attn_config['alibi']):
118
+ raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.')
diffusion/model/llava/mpt/modeling_mpt.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A simple, flexible implementation of a GPT model.
2
+
3
+ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
+ """
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
+ from .attention import attn_bias_shape, build_attn_bias
14
+ from .blocks import MPTBlock
15
+ from .norm import NORM_CLASS_REGISTRY
16
+ from .configuration_mpt import MPTConfig
17
+ from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
18
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
19
+
20
+ from transformers.utils import logging
21
+ logger = logging.get_logger(__name__)
22
+
23
+ class MPTPreTrainedModel(PreTrainedModel):
24
+ config_class = MPTConfig
25
+ base_model_prefix = 'model'
26
+
27
+ class MPTModel(MPTPreTrainedModel):
28
+
29
+ def __init__(self, config: MPTConfig):
30
+ config._validate_config()
31
+ super().__init__(config)
32
+ self.attn_impl = config.attn_config['attn_impl']
33
+ self.prefix_lm = config.attn_config['prefix_lm']
34
+ self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
35
+ self.alibi = config.attn_config['alibi']
36
+ self.alibi_bias_max = config.attn_config['alibi_bias_max']
37
+ if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
38
+ norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
39
+ raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
40
+ norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
41
+ self.embedding_fraction = config.embedding_fraction
42
+ self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
43
+ if not self.alibi:
44
+ self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
45
+ self.emb_drop = nn.Dropout(config.emb_pdrop)
46
+ self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
47
+ self.norm_f = norm_class(config.d_model, device=config.init_device)
48
+ if config.init_device != 'meta':
49
+ self.apply(self.param_init_fn)
50
+ self.is_causal = not self.prefix_lm
51
+ self._attn_bias_initialized = False
52
+ self.attn_bias = None
53
+ self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
54
+ if config.no_bias:
55
+ for module in self.modules():
56
+ if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
57
+ if config.verbose:
58
+ warnings.warn(f'Removing bias ({module.bias}) from {module}.')
59
+ module.register_parameter('bias', None)
60
+ if config.verbose and config.verbose > 2:
61
+ print(self)
62
+ if 'verbose' not in self.config.init_config:
63
+ self.config.init_config['verbose'] = self.config.verbose
64
+ if self.config.init_config['verbose'] > 1:
65
+ init_fn_name = self.config.init_config['name']
66
+ warnings.warn(f'Using {init_fn_name} initialization.')
67
+ self.gradient_checkpointing = False
68
+
69
+ def get_input_embeddings(self):
70
+ return self.wte
71
+
72
+ def set_input_embeddings(self, value):
73
+ self.wte = value
74
+
75
+ @torch.no_grad()
76
+ def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None):
77
+ if not self._attn_bias_initialized:
78
+ if self.attn_bias_shape:
79
+ self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
80
+ self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max)
81
+ self._attn_bias_initialized = True
82
+ if self.attn_impl == 'flash':
83
+ return (self.attn_bias, attention_mask)
84
+ if self.attn_bias is not None:
85
+ self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
86
+ attn_bias = self.attn_bias
87
+ if self.prefix_lm:
88
+ assert isinstance(attn_bias, torch.Tensor)
89
+ assert isinstance(prefix_mask, torch.Tensor)
90
+ attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
91
+ if self.attn_uses_sequence_id and sequence_id is not None:
92
+ assert isinstance(attn_bias, torch.Tensor)
93
+ attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
94
+ if attention_mask is not None:
95
+ s_k = attention_mask.shape[-1]
96
+ if attn_bias is None:
97
+ attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
98
+ else:
99
+ attn_bias = attn_bias[:, :, :, -s_k:]
100
+ if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
101
+ raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
102
+ min_val = torch.finfo(attn_bias.dtype).min
103
+ attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
104
+ return (attn_bias, None)
105
+
106
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
107
+ (s_k, s_q) = attn_bias.shape[-2:]
108
+ if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
109
+ raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
110
+ seq_len = prefix_mask.shape[-1]
111
+ if seq_len > self.config.max_seq_len:
112
+ raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
113
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
114
+ causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
115
+ prefix = prefix_mask.view(-1, 1, 1, seq_len)
116
+ cannot_attend = ~torch.logical_or(causal, prefix.bool())
117
+ min_val = torch.finfo(attn_bias.dtype).min
118
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
119
+ return attn_bias
120
+
121
+ def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
122
+ seq_len = sequence_id.shape[-1]
123
+ if seq_len > self.config.max_seq_len:
124
+ raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
125
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
126
+ cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
127
+ min_val = torch.finfo(attn_bias.dtype).min
128
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
129
+ return attn_bias
130
+
131
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, tok_emb: Optional[torch.FloatTensor]=None):
132
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
133
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
134
+
135
+ if self.gradient_checkpointing and self.training:
136
+ if use_cache:
137
+ logger.warning_once(
138
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
139
+ )
140
+ use_cache = False
141
+ if attention_mask is not None:
142
+ attention_mask = attention_mask.bool()
143
+ if prefix_mask is not None:
144
+ prefix_mask = prefix_mask.bool()
145
+ if not return_dict:
146
+ raise NotImplementedError('return_dict False is not implemented yet for MPT')
147
+ if output_attentions:
148
+ raise NotImplementedError('output_attentions is not implemented yet for MPT')
149
+ if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
150
+ raise NotImplementedError('MPT does not support training with left padding.')
151
+ if self.prefix_lm and prefix_mask is None:
152
+ raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
153
+ if self.training:
154
+ if self.attn_uses_sequence_id and sequence_id is None:
155
+ raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
156
+ elif self.attn_uses_sequence_id is False and sequence_id is not None:
157
+ warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
158
+ if input_ids is not None:
159
+ S = input_ids.size(1)
160
+ assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
161
+ tok_emb = self.wte(input_ids)
162
+ else:
163
+ assert tok_emb is not None
164
+ S = tok_emb.size(1)
165
+ if self.alibi:
166
+ x = tok_emb
167
+ else:
168
+ past_position = 0
169
+ if past_key_values is not None:
170
+ if len(past_key_values) != self.config.n_layers:
171
+ raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
172
+ past_position = past_key_values[0][0].size(1)
173
+ if S + past_position > self.config.max_seq_len:
174
+ raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
175
+ pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
176
+ if attention_mask is not None:
177
+ pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
178
+ pos_emb = self.wpe(pos)
179
+ x = tok_emb + pos_emb
180
+ if self.embedding_fraction == 1:
181
+ x = self.emb_drop(x)
182
+ else:
183
+ x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
184
+ assert isinstance(self.emb_drop, nn.Module)
185
+ x = self.emb_drop(x_shrunk)
186
+ (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
187
+ if use_cache and past_key_values is None:
188
+ past_key_values = [() for _ in range(self.config.n_layers)]
189
+ all_hidden_states = () if output_hidden_states else None
190
+ for (b_idx, block) in enumerate(self.blocks):
191
+ if output_hidden_states:
192
+ assert all_hidden_states is not None
193
+ all_hidden_states = all_hidden_states + (x,)
194
+ past_key_value = past_key_values[b_idx] if past_key_values is not None else None
195
+ if self.gradient_checkpointing and self.training:
196
+ (x, past_key_value) = torch.utils.checkpoint.checkpoint(
197
+ block,
198
+ x, past_key_value, attn_bias, attention_mask, self.is_causal
199
+ )
200
+ else:
201
+ (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
202
+ if past_key_values is not None:
203
+ past_key_values[b_idx] = past_key_value
204
+ x = self.norm_f(x)
205
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
206
+
207
+ def param_init_fn(self, module):
208
+ init_fn_name = self.config.init_config['name']
209
+ MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
210
+
211
+ def fsdp_wrap_fn(self, module):
212
+ return isinstance(module, MPTBlock)
213
+
214
+ def activation_checkpointing_fn(self, module):
215
+ return isinstance(module, MPTBlock)
216
+
217
+ class MPTForCausalLM(MPTPreTrainedModel):
218
+
219
+ def __init__(self, config: MPTConfig):
220
+ super().__init__(config)
221
+ if not config.tie_word_embeddings:
222
+ raise ValueError('MPTForCausalLM only supports tied word embeddings')
223
+ self.transformer = MPTModel(config)
224
+ self.logit_scale = None
225
+ if config.logit_scale is not None:
226
+ logit_scale = config.logit_scale
227
+ if isinstance(logit_scale, str):
228
+ if logit_scale == 'inv_sqrt_d_model':
229
+ logit_scale = 1 / math.sqrt(config.d_model)
230
+ else:
231
+ raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
232
+ self.logit_scale = logit_scale
233
+
234
+ def get_input_embeddings(self):
235
+ return self.transformer.wte
236
+
237
+ def set_input_embeddings(self, value):
238
+ self.transformer.wte = value
239
+
240
+ def get_output_embeddings(self):
241
+ return self.transformer.wte
242
+
243
+ def set_output_embeddings(self, new_embeddings):
244
+ self.transformer.wte = new_embeddings
245
+
246
+ def set_decoder(self, decoder):
247
+ self.transformer = decoder
248
+
249
+ def get_decoder(self):
250
+ return self.transformer
251
+
252
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
253
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
254
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
255
+ outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
256
+ logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
257
+ if self.logit_scale is not None:
258
+ if self.logit_scale == 0:
259
+ warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
260
+ logits *= self.logit_scale
261
+ loss = None
262
+ if labels is not None:
263
+ labels = torch.roll(labels, shifts=-1)
264
+ labels[:, -1] = -100
265
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
266
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
267
+
268
+ def param_init_fn(self, module):
269
+ init_fn_name = self.config.init_config['name']
270
+ MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
271
+
272
+ def fsdp_wrap_fn(self, module):
273
+ return isinstance(module, MPTBlock)
274
+
275
+ def activation_checkpointing_fn(self, module):
276
+ return isinstance(module, MPTBlock)
277
+
278
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
279
+ if inputs_embeds is not None:
280
+ raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
281
+ attention_mask = kwargs['attention_mask'].bool()
282
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
283
+ raise NotImplementedError('MPT does not support generation with right padding.')
284
+ if self.transformer.attn_uses_sequence_id and self.training:
285
+ sequence_id = torch.zeros_like(input_ids[:1])
286
+ else:
287
+ sequence_id = None
288
+ if past_key_values is not None:
289
+ input_ids = input_ids[:, -1].unsqueeze(-1)
290
+ if self.transformer.prefix_lm:
291
+ prefix_mask = torch.ones_like(attention_mask)
292
+ if kwargs.get('use_cache') == False:
293
+ raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
294
+ else:
295
+ prefix_mask = None
296
+ return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)}
297
+
298
+ @staticmethod
299
+ def _reorder_cache(past_key_values, beam_idx):
300
+ """Used by HuggingFace generate when using beam search with kv-caching.
301
+
302
+ See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
303
+ for an example in transformers.
304
+ """
305
+ reordered_past = []
306
+ for layer_past in past_key_values:
307
+ reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
308
+ return reordered_past
diffusion/model/llava/mpt/norm.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def _cast_if_autocast_enabled(tensor):
4
+ if torch.is_autocast_enabled():
5
+ if tensor.device.type == 'cuda':
6
+ dtype = torch.get_autocast_gpu_dtype()
7
+ elif tensor.device.type == 'cpu':
8
+ dtype = torch.get_autocast_cpu_dtype()
9
+ else:
10
+ raise NotImplementedError()
11
+ return tensor.to(dtype=dtype)
12
+ return tensor
13
+
14
+ class LPLayerNorm(torch.nn.LayerNorm):
15
+
16
+ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
17
+ super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
18
+
19
+ def forward(self, x):
20
+ module_device = x.device
21
+ downcast_x = _cast_if_autocast_enabled(x)
22
+ downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
23
+ downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
24
+ with torch.autocast(enabled=False, device_type=module_device.type):
25
+ return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26
+
27
+ def rms_norm(x, weight=None, eps=1e-05):
28
+ output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29
+ if weight is not None:
30
+ return output * weight
31
+ return output
32
+
33
+ class RMSNorm(torch.nn.Module):
34
+
35
+ def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
36
+ super().__init__()
37
+ self.eps = eps
38
+ if weight:
39
+ self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
40
+ else:
41
+ self.register_parameter('weight', None)
42
+
43
+ def forward(self, x):
44
+ return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
45
+
46
+ class LPRMSNorm(RMSNorm):
47
+
48
+ def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
49
+ super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
50
+
51
+ def forward(self, x):
52
+ downcast_x = _cast_if_autocast_enabled(x)
53
+ downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
54
+ with torch.autocast(enabled=False, device_type=x.device.type):
55
+ return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
56
+ NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
diffusion/model/llava/mpt/param_init_fns.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from collections.abc import Sequence
4
+ from functools import partial
5
+ from typing import Optional, Tuple, Union
6
+ import torch
7
+ from torch import nn
8
+ from .norm import NORM_CLASS_REGISTRY
9
+
10
+ def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
11
+ del kwargs
12
+ if verbose > 1:
13
+ warnings.warn(f"Initializing network using module's reset_parameters attribute")
14
+ if hasattr(module, 'reset_parameters'):
15
+ module.reset_parameters()
16
+
17
+ def fused_init_helper_(module: nn.Module, init_fn_):
18
+ _fused = getattr(module, '_fused', None)
19
+ if _fused is None:
20
+ raise RuntimeError(f'Internal logic error')
21
+ (dim, splits) = _fused
22
+ splits = (0, *splits, module.weight.size(dim))
23
+ for (s, e) in zip(splits[:-1], splits[1:]):
24
+ slice_indices = [slice(None)] * module.weight.ndim
25
+ slice_indices[dim] = slice(s, e)
26
+ init_fn_(module.weight[slice_indices])
27
+
28
+ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
29
+ del kwargs
30
+ if verbose > 1:
31
+ warnings.warn(f'If model has bias parameters they are initialized to 0.')
32
+ init_div_is_residual = init_div_is_residual
33
+ if init_div_is_residual is False:
34
+ div_is_residual = 1.0
35
+ elif init_div_is_residual is True:
36
+ div_is_residual = math.sqrt(2 * n_layers)
37
+ elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
38
+ div_is_residual = init_div_is_residual
39
+ elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
40
+ div_is_residual = float(init_div_is_residual)
41
+ else:
42
+ div_is_residual = 1.0
43
+ raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
44
+ if init_div_is_residual is not False:
45
+ if verbose > 1:
46
+ warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
47
+ if isinstance(module, nn.Linear):
48
+ if hasattr(module, '_fused'):
49
+ fused_init_helper_(module, init_fn_)
50
+ else:
51
+ init_fn_(module.weight)
52
+ if module.bias is not None:
53
+ torch.nn.init.zeros_(module.bias)
54
+ if init_div_is_residual is not False and getattr(module, '_is_residual', False):
55
+ with torch.no_grad():
56
+ module.weight.div_(div_is_residual)
57
+ elif isinstance(module, nn.Embedding):
58
+ if emb_init_std is not None:
59
+ std = emb_init_std
60
+ if std == 0:
61
+ warnings.warn(f'Embedding layer initialized to 0.')
62
+ emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
63
+ if verbose > 1:
64
+ warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
65
+ elif emb_init_uniform_lim is not None:
66
+ lim = emb_init_uniform_lim
67
+ if isinstance(lim, Sequence):
68
+ if len(lim) > 2:
69
+ raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
70
+ if lim[0] == lim[1]:
71
+ warnings.warn(f'Embedding layer initialized to {lim[0]}.')
72
+ else:
73
+ if lim == 0:
74
+ warnings.warn(f'Embedding layer initialized to 0.')
75
+ lim = [-lim, lim]
76
+ (a, b) = lim
77
+ emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
78
+ if verbose > 1:
79
+ warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
80
+ else:
81
+ emb_init_fn_ = init_fn_
82
+ emb_init_fn_(module.weight)
83
+ elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
84
+ if verbose > 1:
85
+ warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
86
+ if hasattr(module, 'weight') and module.weight is not None:
87
+ torch.nn.init.ones_(module.weight)
88
+ if hasattr(module, 'bias') and module.bias is not None:
89
+ torch.nn.init.zeros_(module.bias)
90
+ elif isinstance(module, nn.MultiheadAttention):
91
+ if module._qkv_same_embed_dim:
92
+ assert module.in_proj_weight is not None
93
+ assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
94
+ assert d_model is not None
95
+ _d = d_model
96
+ splits = (0, _d, 2 * _d, 3 * _d)
97
+ for (s, e) in zip(splits[:-1], splits[1:]):
98
+ init_fn_(module.in_proj_weight[s:e])
99
+ else:
100
+ assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
101
+ assert module.in_proj_weight is None
102
+ init_fn_(module.q_proj_weight)
103
+ init_fn_(module.k_proj_weight)
104
+ init_fn_(module.v_proj_weight)
105
+ if module.in_proj_bias is not None:
106
+ torch.nn.init.zeros_(module.in_proj_bias)
107
+ if module.bias_k is not None:
108
+ torch.nn.init.zeros_(module.bias_k)
109
+ if module.bias_v is not None:
110
+ torch.nn.init.zeros_(module.bias_v)
111
+ init_fn_(module.out_proj.weight)
112
+ if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):
113
+ with torch.no_grad():
114
+ module.out_proj.weight.div_(div_is_residual)
115
+ if module.out_proj.bias is not None:
116
+ torch.nn.init.zeros_(module.out_proj.bias)
117
+ else:
118
+ for _ in module.parameters(recurse=False):
119
+ raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
120
+
121
+ def _normal_init_(std, mean=0.0):
122
+ return partial(torch.nn.init.normal_, mean=mean, std=std)
123
+
124
+ def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
125
+ del kwargs
126
+ init_fn_ = _normal_init_(std=std)
127
+ if verbose > 1:
128
+ warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
129
+ generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
130
+
131
+ def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
132
+ del kwargs
133
+ if init_std is None:
134
+ raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
135
+ _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
136
+
137
+ def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
138
+ del kwargs
139
+ std = math.sqrt(2 / (5 * d_model))
140
+ _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
141
+
142
+ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
143
+ """From section 2.3.1 of GPT-NeoX-20B:
144
+
145
+ An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
146
+ see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
147
+ and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
148
+ """
149
+ del kwargs
150
+ residual_div = n_layers / math.sqrt(10)
151
+ if verbose > 1:
152
+ warnings.warn(f'setting init_div_is_residual to {residual_div}')
153
+ small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
154
+
155
+ def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
156
+ del kwargs
157
+ if verbose > 1:
158
+ warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
159
+ kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
160
+ generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
161
+
162
+ def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
163
+ del kwargs
164
+ if verbose > 1:
165
+ warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
166
+ kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
167
+ generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
168
+
169
+ def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
170
+ del kwargs
171
+ xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
172
+ if verbose > 1:
173
+ warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
174
+ generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
175
+
176
+ def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
177
+ xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178
+ if verbose > 1:
179
+ warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
180
+ generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
181
+ MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
diffusion/model/nets/PixArt.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import os
15
+ import numpy as np
16
+ from timm.models.layers import DropPath
17
+ from timm.models.vision_transformer import PatchEmbed, Mlp
18
+
19
+ from diffusion.model.builder import MODELS
20
+ from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
21
+ from diffusion.model.nets.PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer
22
+ from diffusion.utils.logger import get_root_logger
23
+
24
+
25
+ class PixArtBlock(nn.Module):
26
+ """
27
+ A PixArt block with adaptive layer norm (adaLN-single) conditioning.
28
+ """
29
+
30
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None,
31
+ sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
32
+ super().__init__()
33
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
34
+ self.attn = AttentionKVCompress(
35
+ hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
36
+ qk_norm=qk_norm, **block_kwargs
37
+ )
38
+ self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
39
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
40
+ # to be compatible with lower version pytorch
41
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
42
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
43
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
44
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
45
+ self.sampling = sampling
46
+ self.sr_ratio = sr_ratio
47
+
48
+ def forward(self, x, y, t, mask=None, **kwargs):
49
+ B, N, C = x.shape
50
+
51
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
52
+ x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
53
+ x = x + self.cross_attn(x, y, mask)
54
+ x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
55
+
56
+ return x
57
+
58
+
59
+ #############################################################################
60
+ # Core PixArt Model #
61
+ #################################################################################
62
+ @MODELS.register_module()
63
+ class PixArt(nn.Module):
64
+ """
65
+ Diffusion model with a Transformer backbone.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ input_size=32,
71
+ patch_size=2,
72
+ in_channels=4,
73
+ hidden_size=1152,
74
+ depth=28,
75
+ num_heads=16,
76
+ mlp_ratio=4.0,
77
+ class_dropout_prob=0.1,
78
+ pred_sigma=True,
79
+ drop_path: float = 0.,
80
+ caption_channels=4096,
81
+ pe_interpolation=1.0,
82
+ config=None,
83
+ model_max_length=120,
84
+ qk_norm=False,
85
+ kv_compress_config=None,
86
+ **kwargs,
87
+ ):
88
+ super().__init__()
89
+ self.pred_sigma = pred_sigma
90
+ self.in_channels = in_channels
91
+ self.out_channels = in_channels * 2 if pred_sigma else in_channels
92
+ self.patch_size = patch_size
93
+ self.num_heads = num_heads
94
+ self.pe_interpolation = pe_interpolation
95
+ self.depth = depth
96
+
97
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
98
+ self.t_embedder = TimestepEmbedder(hidden_size)
99
+ num_patches = self.x_embedder.num_patches
100
+ self.base_size = input_size // self.patch_size
101
+ # Will use fixed sin-cos embedding:
102
+ self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
103
+
104
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
105
+ self.t_block = nn.Sequential(
106
+ nn.SiLU(),
107
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
108
+ )
109
+ self.y_embedder = CaptionEmbedder(
110
+ in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
111
+ act_layer=approx_gelu, token_num=model_max_length)
112
+ drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
113
+ self.kv_compress_config = kv_compress_config
114
+ if kv_compress_config is None:
115
+ self.kv_compress_config = {
116
+ 'sampling': None,
117
+ 'scale_factor': 1,
118
+ 'kv_compress_layer': [],
119
+ }
120
+ self.blocks = nn.ModuleList([
121
+ PixArtBlock(
122
+ hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
123
+ input_size=(input_size // patch_size, input_size // patch_size),
124
+ sampling=self.kv_compress_config['sampling'],
125
+ sr_ratio=int(
126
+ self.kv_compress_config['scale_factor']
127
+ ) if i in self.kv_compress_config['kv_compress_layer'] else 1,
128
+ qk_norm=qk_norm,
129
+ )
130
+ for i in range(depth)
131
+ ])
132
+ self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
133
+
134
+ self.initialize_weights()
135
+
136
+ if config:
137
+ logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
138
+ logger.warning(f"position embed interpolation: {self.pe_interpolation}, base size: {self.base_size}")
139
+ logger.warning(f"kv compress config: {self.kv_compress_config}")
140
+ else:
141
+ print(f'Warning: position embed interpolation: {self.pe_interpolation}, base size: {self.base_size}')
142
+ print(f"kv compress config: {self.kv_compress_config}")
143
+
144
+
145
+ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
146
+ """
147
+ Forward pass of PixArt.
148
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
149
+ t: (N,) tensor of diffusion timesteps
150
+ y: (N, 1, 120, C) tensor of class labels
151
+ """
152
+ x = x.to(self.dtype)
153
+ timestep = timestep.to(self.dtype)
154
+ y = y.to(self.dtype)
155
+ pos_embed = self.pos_embed.to(self.dtype)
156
+ self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
157
+ x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
158
+ t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
159
+ t0 = self.t_block(t)
160
+ y = self.y_embedder(y, self.training) # (N, 1, L, D)
161
+ if mask is not None:
162
+ if mask.shape[0] != y.shape[0]:
163
+ mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
164
+ mask = mask.squeeze(1).squeeze(1)
165
+ y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
166
+ y_lens = mask.sum(dim=1).tolist()
167
+ else:
168
+ y_lens = [y.shape[2]] * y.shape[0]
169
+ y = y.squeeze(1).view(1, -1, x.shape[-1])
170
+ for block in self.blocks:
171
+ x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
172
+ x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
173
+ x = self.unpatchify(x) # (N, out_channels, H, W)
174
+ return x
175
+
176
+ def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
177
+ """
178
+ dpm solver donnot need variance prediction
179
+ """
180
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
181
+ model_out = self.forward(x, timestep, y, mask)
182
+ return model_out.chunk(2, dim=1)[0]
183
+
184
+ def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
185
+ """
186
+ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
187
+ """
188
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
189
+ half = x[: len(x) // 2]
190
+ combined = torch.cat([half, half], dim=0)
191
+ model_out = self.forward(combined, timestep, y, mask, kwargs)
192
+ model_out = model_out['x'] if isinstance(model_out, dict) else model_out
193
+ eps, rest = model_out[:, :3], model_out[:, 3:]
194
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
195
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
196
+ eps = torch.cat([half_eps, half_eps], dim=0)
197
+ return torch.cat([eps, rest], dim=1)
198
+
199
+ def unpatchify(self, x):
200
+ """
201
+ x: (N, T, patch_size**2 * C)
202
+ imgs: (N, H, W, C)
203
+ """
204
+ c = self.out_channels
205
+ p = self.x_embedder.patch_size[0]
206
+ h = w = int(x.shape[1] ** 0.5)
207
+ assert h * w == x.shape[1]
208
+
209
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
210
+ x = torch.einsum('nhwpqc->nchpwq', x)
211
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
212
+ return imgs
213
+
214
+ def initialize_weights(self):
215
+ # Initialize transformer layers:
216
+ def _basic_init(module):
217
+ if isinstance(module, nn.Linear):
218
+ torch.nn.init.xavier_uniform_(module.weight)
219
+ if module.bias is not None:
220
+ nn.init.constant_(module.bias, 0)
221
+
222
+ self.apply(_basic_init)
223
+
224
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
225
+ pos_embed = get_2d_sincos_pos_embed(
226
+ self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5),
227
+ pe_interpolation=self.pe_interpolation, base_size=self.base_size
228
+ )
229
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
230
+
231
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
232
+ w = self.x_embedder.proj.weight.data
233
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
234
+
235
+ # Initialize timestep embedding MLP:
236
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
237
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
238
+ nn.init.normal_(self.t_block[1].weight, std=0.02)
239
+
240
+ # Initialize caption embedding MLP:
241
+ nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
242
+ nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
243
+
244
+ # Zero-out adaLN modulation layers in PixArt blocks:
245
+ for block in self.blocks:
246
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
247
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
248
+
249
+ # Zero-out output layers:
250
+ nn.init.constant_(self.final_layer.linear.weight, 0)
251
+ nn.init.constant_(self.final_layer.linear.bias, 0)
252
+
253
+ @property
254
+ def dtype(self):
255
+ return next(self.parameters()).dtype
256
+
257
+
258
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0, base_size=16):
259
+ """
260
+ grid_size: int of the grid height and width
261
+ return:
262
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
263
+ """
264
+ if isinstance(grid_size, int):
265
+ grid_size = to_2tuple(grid_size)
266
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / pe_interpolation
267
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / pe_interpolation
268
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
269
+ grid = np.stack(grid, axis=0)
270
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
271
+
272
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
273
+ if cls_token and extra_tokens > 0:
274
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
275
+ return pos_embed
276
+
277
+
278
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
279
+ assert embed_dim % 2 == 0
280
+
281
+ # use half of dimensions to encode grid_h
282
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
283
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
284
+
285
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
286
+ return emb
287
+
288
+
289
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
290
+ """
291
+ embed_dim: output dimension for each position
292
+ pos: a list of positions to be encoded: size (M,)
293
+ out: (M, D)
294
+ """
295
+ assert embed_dim % 2 == 0
296
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
297
+ omega /= embed_dim / 2.
298
+ omega = 1. / 10000 ** omega # (D/2,)
299
+
300
+ pos = pos.reshape(-1) # (M,)
301
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
302
+
303
+ emb_sin = np.sin(out) # (M, D/2)
304
+ emb_cos = np.cos(out) # (M, D/2)
305
+
306
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
307
+ return emb
308
+
309
+
310
+ #################################################################################
311
+ # PixArt Configs #
312
+ #################################################################################
313
+ @MODELS.register_module()
314
+ def PixArt_XL_2(**kwargs):
315
+ return PixArt(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
diffusion/model/nets/PixArtMS.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+ import torch
12
+ import torch.nn as nn
13
+ from timm.models.layers import DropPath
14
+ from timm.models.vision_transformer import Mlp
15
+
16
+ from diffusion.model.builder import MODELS
17
+ from diffusion.model.utils import auto_grad_checkpoint, to_2tuple
18
+ from diffusion.model.nets.PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, SizeEmbedder
19
+ from diffusion.model.nets.PixArt import PixArt, get_2d_sincos_pos_embed
20
+
21
+
22
+ class PatchEmbed(nn.Module):
23
+ """ 2D Image to Patch Embedding
24
+ """
25
+ def __init__(
26
+ self,
27
+ patch_size=16,
28
+ in_chans=3,
29
+ embed_dim=768,
30
+ norm_layer=None,
31
+ flatten=True,
32
+ bias=True,
33
+ ):
34
+ super().__init__()
35
+ patch_size = to_2tuple(patch_size)
36
+ self.patch_size = patch_size
37
+ self.flatten = flatten
38
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
39
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
40
+
41
+ def forward(self, x):
42
+ x = self.proj(x)
43
+ if self.flatten:
44
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
45
+ x = self.norm(x)
46
+ return x
47
+
48
+
49
+ class PixArtMSBlock(nn.Module):
50
+ """
51
+ A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
52
+ """
53
+
54
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
55
+ sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
56
+ super().__init__()
57
+ self.hidden_size = hidden_size
58
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
59
+ self.attn = AttentionKVCompress(
60
+ hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
61
+ qk_norm=qk_norm, **block_kwargs
62
+ )
63
+ self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
64
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
65
+ # to be compatible with lower version pytorch
66
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
67
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
68
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
69
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
70
+
71
+ def forward(self, x, y, t, mask=None, HW=None, **kwargs):
72
+ B, N, C = x.shape
73
+
74
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
75
+ x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
76
+ x = x + self.cross_attn(x, y, mask)
77
+ x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
78
+
79
+ return x
80
+
81
+
82
+ #############################################################################
83
+ # Core PixArt Model #
84
+ #################################################################################
85
+ @MODELS.register_module()
86
+ class PixArtMS(PixArt):
87
+ """
88
+ Diffusion model with a Transformer backbone.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ input_size=32,
94
+ patch_size=2,
95
+ in_channels=4,
96
+ hidden_size=1152,
97
+ depth=28,
98
+ num_heads=16,
99
+ mlp_ratio=4.0,
100
+ class_dropout_prob=0.1,
101
+ learn_sigma=True,
102
+ pred_sigma=True,
103
+ drop_path: float = 0.,
104
+ caption_channels=4096,
105
+ pe_interpolation=1.,
106
+ config=None,
107
+ model_max_length=120,
108
+ micro_condition=False,
109
+ qk_norm=False,
110
+ kv_compress_config=None,
111
+ **kwargs,
112
+ ):
113
+ super().__init__(
114
+ input_size=input_size,
115
+ patch_size=patch_size,
116
+ in_channels=in_channels,
117
+ hidden_size=hidden_size,
118
+ depth=depth,
119
+ num_heads=num_heads,
120
+ mlp_ratio=mlp_ratio,
121
+ class_dropout_prob=class_dropout_prob,
122
+ learn_sigma=learn_sigma,
123
+ pred_sigma=pred_sigma,
124
+ drop_path=drop_path,
125
+ pe_interpolation=pe_interpolation,
126
+ config=config,
127
+ model_max_length=model_max_length,
128
+ qk_norm=qk_norm,
129
+ kv_compress_config=kv_compress_config,
130
+ **kwargs,
131
+ )
132
+ self.h = self.w = 0
133
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
134
+ self.t_block = nn.Sequential(
135
+ nn.SiLU(),
136
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
137
+ )
138
+ self.x_embedder = PatchEmbed(patch_size, in_channels, hidden_size, bias=True)
139
+ self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=model_max_length)
140
+ self.micro_conditioning = micro_condition
141
+ if self.micro_conditioning:
142
+ self.csize_embedder = SizeEmbedder(hidden_size//3) # c_size embed
143
+ self.ar_embedder = SizeEmbedder(hidden_size//3) # aspect ratio embed
144
+ drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
145
+ if kv_compress_config is None:
146
+ kv_compress_config = {
147
+ 'sampling': None,
148
+ 'scale_factor': 1,
149
+ 'kv_compress_layer': [],
150
+ }
151
+ self.blocks = nn.ModuleList([
152
+ PixArtMSBlock(
153
+ hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
154
+ input_size=(input_size // patch_size, input_size // patch_size),
155
+ sampling=kv_compress_config['sampling'],
156
+ sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
157
+ qk_norm=qk_norm,
158
+ )
159
+ for i in range(depth)
160
+ ])
161
+ self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
162
+
163
+ self.initialize()
164
+
165
+ def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
166
+ """
167
+ Forward pass of PixArt.
168
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
169
+ t: (N,) tensor of diffusion timesteps
170
+ y: (N, 1, 120, C) tensor of class labels
171
+ """
172
+ bs = x.shape[0]
173
+ x = x.to(self.dtype)
174
+ timestep = timestep.to(self.dtype)
175
+ y = y.to(self.dtype)
176
+ self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
177
+ pos_embed = torch.from_numpy(
178
+ get_2d_sincos_pos_embed(
179
+ self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=self.pe_interpolation,
180
+ base_size=self.base_size
181
+ )
182
+ ).unsqueeze(0).to(x.device).to(self.dtype)
183
+
184
+ x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
185
+ t = self.t_embedder(timestep) # (N, D)
186
+
187
+ if self.micro_conditioning:
188
+ c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype)
189
+ csize = self.csize_embedder(c_size, bs) # (N, D)
190
+ ar = self.ar_embedder(ar, bs) # (N, D)
191
+ t = t + torch.cat([csize, ar], dim=1)
192
+
193
+ t0 = self.t_block(t)
194
+ y = self.y_embedder(y, self.training) # (N, D)
195
+
196
+ if mask is not None:
197
+ if mask.shape[0] != y.shape[0]:
198
+ mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
199
+ mask = mask.squeeze(1).squeeze(1)
200
+ y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
201
+ y_lens = mask.sum(dim=1).tolist()
202
+ else:
203
+ y_lens = [y.shape[2]] * y.shape[0]
204
+ y = y.squeeze(1).view(1, -1, x.shape[-1])
205
+ for block in self.blocks:
206
+ x = auto_grad_checkpoint(block, x, y, t0, y_lens, (self.h, self.w), **kwargs) # (N, T, D) #support grad checkpoint
207
+
208
+ x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
209
+ x = self.unpatchify(x) # (N, out_channels, H, W)
210
+
211
+ return x
212
+
213
+ def forward_with_dpmsolver(self, x, timestep, y, data_info, **kwargs):
214
+ """
215
+ dpm solver donnot need variance prediction
216
+ """
217
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
218
+ model_out = self.forward(x, timestep, y, data_info=data_info, **kwargs)
219
+ return model_out.chunk(2, dim=1)[0]
220
+
221
+ def forward_with_cfg(self, x, timestep, y, cfg_scale, data_info, mask=None, **kwargs):
222
+ """
223
+ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
224
+ """
225
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
226
+ half = x[: len(x) // 2]
227
+ combined = torch.cat([half, half], dim=0)
228
+ model_out = self.forward(combined, timestep, y, mask, data_info=data_info, **kwargs)
229
+ model_out = model_out['x'] if isinstance(model_out, dict) else model_out
230
+ eps, rest = model_out[:, :3], model_out[:, 3:]
231
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
232
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
233
+ eps = torch.cat([half_eps, half_eps], dim=0)
234
+ return torch.cat([eps, rest], dim=1)
235
+
236
+ def unpatchify(self, x):
237
+ """
238
+ x: (N, T, patch_size**2 * C)
239
+ imgs: (N, H, W, C)
240
+ """
241
+ c = self.out_channels
242
+ p = self.x_embedder.patch_size[0]
243
+ assert self.h * self.w == x.shape[1]
244
+
245
+ x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
246
+ x = torch.einsum('nhwpqc->nchpwq', x)
247
+ imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
248
+ return imgs
249
+
250
+ def initialize(self):
251
+ # Initialize transformer layers:
252
+ def _basic_init(module):
253
+ if isinstance(module, nn.Linear):
254
+ torch.nn.init.xavier_uniform_(module.weight)
255
+ if module.bias is not None:
256
+ nn.init.constant_(module.bias, 0)
257
+
258
+ self.apply(_basic_init)
259
+
260
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
261
+ w = self.x_embedder.proj.weight.data
262
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
263
+
264
+ # Initialize timestep embedding MLP:
265
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
266
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
267
+ nn.init.normal_(self.t_block[1].weight, std=0.02)
268
+ if self.micro_conditioning:
269
+ nn.init.normal_(self.csize_embedder.mlp[0].weight, std=0.02)
270
+ nn.init.normal_(self.csize_embedder.mlp[2].weight, std=0.02)
271
+ nn.init.normal_(self.ar_embedder.mlp[0].weight, std=0.02)
272
+ nn.init.normal_(self.ar_embedder.mlp[2].weight, std=0.02)
273
+
274
+ # Initialize caption embedding MLP:
275
+ nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
276
+ nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
277
+
278
+ # Zero-out adaLN modulation layers in PixArt blocks:
279
+ for block in self.blocks:
280
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
281
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
282
+
283
+ # Zero-out output layers:
284
+ nn.init.constant_(self.final_layer.linear.weight, 0)
285
+ nn.init.constant_(self.final_layer.linear.bias, 0)
286
+
287
+
288
+ #################################################################################
289
+ # PixArt Configs #
290
+ #################################################################################
291
+ @MODELS.register_module()
292
+ def PixArtMS_XL_2(**kwargs):
293
+ return PixArtMS(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)